#include <RAT/EnergyRThetaFunctional.hh>
#include <RAT/DS/FitVertex.hh>
#include <RAT/DS/FitResult.hh>
#include <RAT/DU/Utility.hh>
#include <RAT/DB.hh>
#include <RAT/Log.hh>
#include <RAT/ChannelEfficiency.hh>

#include <CLHEP/Units/PhysicalConstants.h>

#include <TVector3.h>

#include <vector>

using namespace RAT;
using namespace RAT::Methods;
using namespace std;

void EnergyRThetaFunctional::BeginOfRun(DS::Run&) {

  DB* db = DB::Get();
  if(fIndex.empty())
    {
      try
        {
          fIndex = db->GetLink("GEO", "inner_av")->GetS("material");
        }
      catch(DBNotFoundError& e1)
        {
          try
            {
              fIndex = db->GetLink("GEO", "inner_av")->GetS("material_top");
            }
          catch(DBNotFoundError& e2)
            {
              // Other fitters warn, prefer to die if the material hasn't been coordinated yet
              Log::Die("EnergyRThetaFunctional: No material available for inner_av material or material_top");
            }
        }
    }

  // Now we have a material index set, load the correct database table
  DBLinkPtr dbLink = db->GetLink("FIT_ENERGY_RTHETA_FUNCTIONAL", "labppo_scintillator"); // default
  DBLinkGroup grp = db->GetLinkGroup("FIT_ENERGY_RTHETA_FUNCTIONAL");

  // Check to see if our material is available, if not use the default labppo_scintillator
  for(DBLinkGroup::iterator it = grp.begin(); it != grp.end(); ++it)
    {
      if(fIndex == it->first)
        {
          dbLink = db->GetLink("FIT_ENERGY_RTHETA_FUNCTIONAL", fIndex);
          break;
        }
    }

  fRPiecewise = dbLink->GetD("r_piecewise");
  fRCutoff = dbLink->GetD("r_cutoff");
  fEnergies = dbLink->GetDArray("h_energies");
  fHAtEnergy = dbLink->GetDArray("h_at_energy");
  fThetas = dbLink->GetDArray("thetas");
  fF1Pol0 = dbLink->GetDArray("f1_pol0");
  fF1Pol1 = dbLink->GetDArray("f1_pol1");
  fF1Pol2 = dbLink->GetDArray("f1_pol2");
  fF1Pol3 = dbLink->GetDArray("f1_pol3");
  fF2Constant = dbLink->GetDArray("f2_x_const");
  fF2Pol0 = dbLink->GetDArray("f2_pol0");
  fF2Pol1 = dbLink->GetDArray("f2_pol1");
  fF2Pol2 = dbLink->GetDArray("f2_pol2");
  fF2Pol3 = dbLink->GetDArray("f2_pol3");
  fCoorChanEff = dbLink->GetD("chan_eff");
  fCoorActiveChannels = dbLink->GetD("active_channels");

  ApplyDetectorStateCorrection();
}


void EnergyRThetaFunctional::DefaultSeed()
{
  DS::FitVertex seedVertex;
  // Use central position
  // Set validity to valse, fixed (i.e. not set by a fitter) to true
  seedVertex.SetPosition(TVector3(0, 0, 0), false, true);
  fSeedResult.SetVertex(0, seedVertex);
}


DS::FitResult EnergyRThetaFunctional::GetBestFit()
{
  // Reset result, check PMT hits exist, copy seed to result
  // methods present for all (seeded) fitters
  fFitResult.Reset();
  if( fPMTData.empty() )
    return fFitResult;
  CopySeedToResult();

  // For each event loop find the energy (H) parameter by segmenting to detector
  // using the Segmentor class, using the event position as the origin for the
  // segmentor.  Then, find the H(r, theta) / H(0) scaling using the functional
  // forms of H(r)/H(0) at the two closest theta values.  Linearly interpolate
  // the two H scalings (at two different theta values) to the event's theta value.
  // Use this scaling to get an effective H at the centre of the detector and then
  // use the H vs energy linear scaling to estimate the energy.

  DS::FitVertex fitVertex = fSeedResult.GetVertex(0);
  TVector3 eventPosition = fitVertex.GetPosition();
  double rEvent = eventPosition.Mag();
  double thetaEvent = eventPosition.Theta();

  // If the radius is above cutoff then skip the event
  if(rEvent > fRCutoff)
    throw MethodFailError("EnergyRTheta: Cannot fit event above cutoff radius");

  // Always use the default number of segments
  DU::Segmentor segmentor = DU::Utility::Get()->GetSegmentor();

  vector<unsigned int> rawPMTSegmentIDs = segmentor.GetSegmentIDs();
  vector<unsigned int> rawPMTSegmentPopulations = segmentor.GetSegmentPopulations();
  vector<unsigned int> hitPMTSegmentPopulations (rawPMTSegmentPopulations.size(), 0);

  const DU::PMTCalStatus& PMTCalStatus = DU::Utility::Get()->GetPMTCalStatus();
  for(int i = 0; i < fPMTData.size(); i++)
  {
    int pmtID = fPMTData[i].GetID();
    double fractionGoodCells = PMTCalStatus.GetChannelStatus( static_cast<int>(pmtID));
    if (fractionGoodCells)
      hitPMTSegmentPopulations[rawPMTSegmentIDs[pmtID]]++;
  }

  double hParameter = 0.0;
  for(size_t i = 0; i < rawPMTSegmentPopulations.size(); i++)
  {
    double numberOfRawPMTs = static_cast<double>(rawPMTSegmentPopulations[i]);
    double numberOfHitPMTs = static_cast<double>(hitPMTSegmentPopulations[i]);
    if(numberOfRawPMTs == 0)
      {
        // Cannot add anything for the H Parameter for this sector
        // We're using the segmentor at a central position though so this
        // is highly unlikely unless who crates are off.
        // May wish to normalise H by number of used segments in future.
        continue;
      }
    else if(numberOfHitPMTs == numberOfRawPMTs)
    {
      warn << "EnergyRThetaFunctional::GetBestFit: (warning) encountered saturated segment with the same number of hit PMTs as actual PMTs (" << numberOfRawPMTs << ") ... correcting and continuing \n";
      numberOfHitPMTs = numberOfRawPMTs - 1;
    }
    hParameter -= (numberOfRawPMTs * log(1.0 - (numberOfHitPMTs / numberOfRawPMTs)));
  }

  // Have an H parameter, first scale by the expected ratio at this R, theta
  double hScaling1 = 0.0;
  double hScaling2 = 0.0;
  // Get the two closest theta values
  int bin1 = static_cast<int>(thetaEvent / (CLHEP::pi / fThetas.size()));
  int bin2 = bin1+1;

  if(rEvent <= fRPiecewise)
    {
      // Use the f1 parameters
      hScaling1 = fF1Pol0[bin1] + fF1Pol1[bin1] * rEvent +
        fF1Pol2[bin1] * rEvent * rEvent +
        fF1Pol3[bin1] * rEvent * rEvent * rEvent;
      hScaling2 = fF1Pol0[bin2] + fF1Pol1[bin2] * rEvent +
        fF1Pol2[bin2] * rEvent * rEvent +
        fF1Pol3[bin2] * rEvent * rEvent * rEvent;
    }
  else
    {
      // Use the f2 parameters (which is a polynomial of (x-c) rather than x)
      double rEventAdj1 = rEvent - fF2Constant[bin1];
      double rEventAdj2 = rEvent - fF2Constant[bin2]; // these should be the same!
      hScaling1 = fF2Pol0[bin1] + fF2Pol1[bin1] * rEventAdj1 +
        fF2Pol2[bin1] * rEventAdj1 * rEventAdj1 +
        fF2Pol3[bin1] * rEventAdj1 * rEventAdj1 * rEventAdj1;
      hScaling2 = fF2Pol0[bin2] + fF2Pol1[bin2] * rEventAdj2 +
        fF2Pol2[bin2] * rEventAdj2 * rEventAdj2 +
        fF2Pol3[bin2] * rEventAdj2 * rEventAdj2 * rEventAdj2;
    }

  // Interpolate the two scalings, adjust the hParameter accordingly
  double hScaling = hScaling1 + (hScaling2 - hScaling1) * (thetaEvent - fThetas[bin1]) / (fThetas[bin2]-fThetas[bin1]);
  hParameter /= hScaling;

  // And find the energy from H vs E lookup
  bool valid = false; // Unless H value is within range
  double energy = DS::INVALID;
  for(size_t iEnergy = 1; iEnergy < fEnergies.size(); iEnergy++ )
    {
      if( hParameter >= fHAtEnergy[iEnergy-1] && hParameter < fHAtEnergy[iEnergy] )
        {
          valid = true;
          double hPerEnergy = (fHAtEnergy[iEnergy] - fHAtEnergy[iEnergy-1]) / (fEnergies[iEnergy] - fEnergies[iEnergy-1]);
          energy = (hParameter - fHAtEnergy[iEnergy-1]) / hPerEnergy + fEnergies[iEnergy-1];
          break;
        }
    }
  // If invalid can still estimate the energy from extrapolating the last lookup points
  if(!valid)
    {
      size_t iEnergy = fEnergies.size()-1;
      double hPerEnergy = (fHAtEnergy[iEnergy] - fHAtEnergy[iEnergy-1]) / (fEnergies[iEnergy] - fEnergies[iEnergy-1]);
      energy = (hParameter - fHAtEnergy[iEnergy-1]) / hPerEnergy + fEnergies[iEnergy-1];
   }

  // Set the result and return it
  fitVertex.SetEnergy(energy, true);
  fitVertex.SetEnergyErrors(1.0); // FIXME: set something along lines of sqrt(Nhits)?
  fFitResult.SetVertex(0, fitVertex);
  return fFitResult;

}

void EnergyRThetaFunctional::ApplyDetectorStateCorrection(){

  const DU::PMTCalStatus& PMTCalStatus = DU::Utility::Get()->GetPMTCalStatus();
  const DU::PMTInfo& pmtInfo = DU::Utility::Get()->GetPMTInfo();
  const DU::ChanHWStatus& channelHardwareStatus = DU::Utility::Get()->GetChanHWStatus();

  //Get the fraction of active channels to correct for detector state
  int numChannels = DB::Get()->GetLink( "PMT_DQXX" )->GetIArray( "dqch" ).size();
  unsigned int totalActiveChannels = 0;
  for( size_t lcn = 0; lcn < numChannels; lcn++ ){
    if (channelHardwareStatus.IsEnabled()){
      if (!channelHardwareStatus.IsDAQEnabled(lcn))
        continue;
      if( pmtInfo.GetType( lcn ) != DU::PMTInfo::NORMAL &&
          pmtInfo.GetType( lcn ) != DU::PMTInfo::HQE)
        continue;
      double fractionGoodCells = PMTCalStatus.GetChannelStatus( static_cast<int>(lcn) );
      if( fractionGoodCells )
        totalActiveChannels++;
    }
  }
  double chanNumCorr = static_cast<double>(totalActiveChannels)/fCoorActiveChannels;

  //Get the channel efficiency correction
  double chanEff = ChannelEfficiency::Get()->GetAverageEfficiency();
  double chanEffCorr = chanEff/fCoorChanEff;

  //Apply both corrections
  for(int i = 0; i < fHAtEnergy.size(); i++){
    fHAtEnergy[i] *= chanNumCorr*chanEffCorr;
  }
}