#include <RAT/PartialFitter.hh>
#include <RAT/MethodFactory.hh>
#include <RAT/PDFFactory.hh>
#include <RAT/OptimiserFactory.hh>
#include <RAT/ClassifierFactory.hh>
#include <RAT/PMTSelectorFactory.hh>
#include <RAT/Method.hh>
#include <RAT/Classifier.hh>
#include <RAT/OptimisedMethod.hh>
#include <RAT/PDFMethod.hh>
#include <RAT/SeededMethod.hh>
#include <RAT/SelectorMethod.hh>
#include <RAT/Optimiser.hh>
#include <RAT/SeededClassifier.hh>
#include <RAT/FitterPMT.hh>
#include <RAT/MetaInformation.hh>
#include <RAT/Log.hh>
#include <RAT/DB.hh>
#include <RAT/ListHelp.hh>
#include <RAT/DS/FitResult.hh>
#include <RAT/DS/Entry.hh>
using namespace RAT;
using namespace RAT::Methods;
using namespace RAT::PDFs;
using namespace RAT::Optimisers;
using namespace RAT::PMTSelectors;
using namespace RAT::DS;
using namespace RAT::Classifiers;

#include <sstream>
using namespace std;

#include <TStopwatch.h>
using namespace ROOT;

PartialFitter::PartialFitter()
  : Processor("partialFitter")
{
  fQuadSeed = MethodFactory::Get()->GetMethod( "quad" );
  fPositionTime = MethodFactory::Get()->GetMethod( "positionTimeLikelihood" );
  fPowell = OptimiserFactory::Get()->GetOptimiser( "powell" );
  fNullSelector = PMTSelectorFactory::Get()->GetPMTSelector( "null" );
  fET1D = PDFFactory::Get()->GetPDF( "partialET1D" );
  fEnergySeed = MethodFactory::Get()->GetMethod( "energyRThetaFunctional" );
  fPartialEnergy = MethodFactory::Get()->GetMethod( "partialEnergy" );
  // Now combine the components where appropriate
  dynamic_cast< OptimisedMethod* >( fPositionTime )->SetOptimiser( fPowell );
  dynamic_cast< SelectorMethod* >( fPositionTime )->AddPMTSelector( fNullSelector );
  dynamic_cast< PDFMethod* >( fPositionTime )->SetPDF( fET1D );

  fCutOff = 3.0; // Maximum nhit for which quad cannot run
}

PartialFitter::~PartialFitter()
{
  delete fQuadSeed;
  delete fPositionTime;
  delete fPowell;
  delete fNullSelector;
  delete fET1D;
  delete fEnergySeed;
  delete fPartialEnergy;
}

void
PartialFitter::BeginOfRun( DS::Run& run )
{
  /// First call the begin of run functions
  fQuadSeed->BeginOfRun( run );
  fPositionTime->BeginOfRun( run );
  fPowell->BeginOfRun( run );
  fNullSelector->BeginOfRun( run );
  fET1D->BeginOfRun( run );
  fEnergySeed->BeginOfRun( run );
  fPartialEnergy->BeginOfRun( run );
}

Processor::Result
PartialFitter::DSEvent( DS::Run& run,
                        DS::Entry& ds )
{
  for( size_t iEV = 0; iEV < ds.GetEVCount(); iEV++ )
    Event( run, ds.GetEV( iEV ) );
  return OK;
}

Processor::Result
PartialFitter::Event( DS::Run& run,
                      DS::EV& ev )
{

  /// PartialFitter Logic:
  ///  This is a simple overview of the logic that is expressed in the next 80 lines
  ///
  ///  1. Fit the position & time as the seedResult from  quad.
  ///      if nhits < fNhitCutoff then abort the fit and return no FitResult. \n
  ///  2. Fit the position & time as the partialResult using the positionTimeLikelihood, powell, null and
  ///      seedResult as seed.
  ///     If this fails then abort the fit and return no FitResult
  ///  3. Generate energy seed using energyRThetaFunctional and partialResult as seed (or seedResult if the
  ///      positionTimeLikelihood result is invalid)
  ///  4. Fit the energy using partialEnergy as the method, energySeed and partialResult (or seedResult)
  ///      as seeds.
  ///     If this fails then abort the fit and return no FitResult
  TStopwatch timer;
  timer.Start( true );

  const size_t currentPass = MetaInformation::Get()->GetCurrentPass();
  ev.AddFitterPass( currentPass ); // Ensure the EV knows fitters were run for this pass
  ev.AddClassifierPass( currentPass ); // Ensure the EV knows classifiers were run for this pass

  vector<FitterPMT> pmtData;
  for( size_t iPMTCal =0; iPMTCal < ev.GetCalPMTs().GetCount(); iPMTCal++ )
    pmtData.push_back( FitterPMT( ev.GetCalPMTs().GetPMT( iPMTCal ) ) );

  /// Initialize the seed
  fQuadSeed->SetEventData( pmtData, &ev, &run );
  // Run the seed method
  FitResult seedResult;
  try
    {
      if( ev.GetCalPMTs().GetCount() > fCutOff )
        seedResult = fQuadSeed->GetBestFit();
      else
        {
          warn << "PartialFitter::Event: insufficient points for quad to run, exiting" << newline;
          return Processor::FAIL;
        }
    }
  catch( Method::MethodFailError& error ) { warn << "PartialFitter::Event: Seed failed " << error.what() << ", continuing." << newline; }

  // Now initialise the position time method
  fPositionTime->SetEventData( pmtData, &ev, &run );
  dynamic_cast< SeededMethod* >( fPositionTime )->DefaultSeed();
  dynamic_cast< SeededMethod* >( fPositionTime )->SetSeed( seedResult );
  /// Run the position time method
  FitResult partialResult;
  try
    {
      FitResult positionResult = fPositionTime->GetBestFit();
      partialResult.SetVertex(0, positionResult.GetVertex(0));
      SetFOMs(partialResult, positionResult, "Position");
    }
  catch( Method::MethodFailError& error )
    {
      warn << "PartialFitter::Event: Main method failed " << error.what() << ", exiting" << newline;
      return Processor::FAIL;
    }

  // Now initialise the energy seed
  FitResult energySeed;
  fEnergySeed->SetEventData( pmtData, &ev, &run );
  dynamic_cast< SeededMethod* >( fEnergySeed )->DefaultSeed();
  if( partialResult.GetValid() )
    dynamic_cast< SeededMethod* >( fEnergySeed )->SetSeed( partialResult );
  else if( seedResult.GetValid() )
    dynamic_cast< SeededMethod* >( fEnergySeed )->SetSeed( seedResult );
  else
    warn << "PartialFitter::Event: No seed for the energy lookup method." << newline;
  // Run the energy functional method
  try
    {
      energySeed = fEnergySeed->GetBestFit();
    }
  catch( Method::MethodFailError& error )
    {
      warn << "PartialFitter::Event: Energy seed method failed " << error.what() << ", exiting" << newline;
      return Processor::FAIL;
    }

  // Now initialise the energy method
  fPartialEnergy->SetEventData( pmtData, &ev, &run );
  // Seed position
  dynamic_cast< SeededMethod* >( fPartialEnergy )->DefaultSeed();
  if( partialResult.GetValid() )
    dynamic_cast< SeededMethod* >( fPartialEnergy )->SetSeed( partialResult );
  else if( seedResult.GetValid() )
    dynamic_cast< SeededMethod* >( fPartialEnergy )->SetSeed( seedResult );
  else
    warn << "PartialFitter::Event: No seed for the energy lookup method." << newline;
  // Seed energy
  dynamic_cast< SeededMethod* >( fPartialEnergy )->SetSeed( energySeed );
  // Run the energy functional method
  try
    {
      FitVertex vertex = partialResult.GetVertex(0);
      FitResult energyResult = fPartialEnergy->GetBestFit();
      FitVertex energyVertex = energyResult.GetVertex(0);
      vertex.SetEnergy( energyVertex.GetEnergy(), energyVertex.ValidEnergy() );
      vertex.SetPositiveEnergyError( energyVertex.GetPositiveEnergyError(), energyVertex.ValidEnergy() );
      vertex.SetNegativeEnergyError( energyVertex.GetNegativeEnergyError(), energyVertex.ValidEnergy() );
      partialResult.SetVertex( 0, vertex );
      SetFOMs(partialResult, energyResult, "Energy");
    }
  catch( Method::MethodFailError& error )
    {
      warn << "PartialFitter::Event: Energy seed method failed " << error.what() << ", exiting" << newline;
      return Processor::FAIL;
    }

  timer.Stop();
  partialResult.SetExecutionTime( timer.RealTime() );

  return SetResult( partialResult, ev, currentPass);
}


Processor::Result
PartialFitter::SetResult( DS::FitResult& partialResult,
                          DS::EV& ev,
                          size_t currentPass)
{

  ev.SetFitResult( currentPass, "partialFitter", partialResult );
  // Now set the 0 vertex as the default fit vertex, if the fit is valid
  DS::FitVertex defaultVertex = partialResult.GetVertex(0);
  // Pos, energy and time are calculated by the partial fitter so the DS::INVALID flag must be set
  defaultVertex.SetDirection(TVector3(DS::INVALID, DS::INVALID, DS::INVALID),false);
  defaultVertex.SetDirectionErrors(TVector3(DS::INVALID, DS::INVALID, DS::INVALID),false);
  ev.SetDefaultFitVertex( "partialFitter", defaultVertex );

  return Processor::OK;
}

void
PartialFitter::EndOfRun( DS::Run& run )
{
  // First call the begin of run functions
  fQuadSeed->EndOfRun( run );
  fPositionTime->EndOfRun( run );
  fPowell->EndOfRun( run );
  fNullSelector->EndOfRun( run );
  fET1D->EndOfRun( run );
  fEnergySeed->EndOfRun( run );
  fPartialEnergy->EndOfRun( run );
}


void
PartialFitter::SetFOMs( DS::FitResult& fitResult,
                      const DS::FitResult& partialResult,
                      const string& prefix )
{
  vector<string> fomNames = partialResult.GetFOMNames();
  for(vector<string>::const_iterator iter = fomNames.begin(); iter != fomNames.end(); ++iter)
    fitResult.SetFOM( (prefix+*iter), partialResult.GetFOM(*iter) );
}