#include <RAT/PositionTimeLikelihood.hh>
#include <RAT/DS/FitResult.hh>
#include <RAT/DS/FitVertex.hh>
#include <RAT/DS/Entry.hh>
#include <RAT/Optimiser.hh>
#include <RAT/PMTSelector.hh>
#include <RAT/PDF.hh>
#include <RAT/Log.hh>

using namespace RAT;
using namespace RAT::Methods;
using namespace RAT::DS;

#include <string>
#include <cmath>

using std::vector;
using std::string;

void
PositionTimeLikelihood::DefaultSeed()
{
  FitVertex vertex;
  // Initialise the SeedResult at the centre and with arbitrary
  // 3000 mm errors in all directions (note units are mm and ns)
  const TVector3 seedPosition( 0.0, 0.0, 0.0 );
  const ROOT::Math::XYZVectorF seedPositionError( 3000.0, 3000.0, 3000.0 );
  // 230 ns is roughly the peak from ET1D and GV1D PDFs
  const double seedTime = 230.0;
  const double seedTimeError = 100.0;
  vertex.SetPosition( seedPosition, false, true );
  vertex.SetPositionErrors( seedPositionError, false, true );
  vertex.SetTime( seedTime, false, true );
  vertex.SetTimeErrors( seedTimeError, false, true );
  fSeedResult.SetVertex( 0, vertex );
}

void PositionTimeLikelihood::SetI( const std::string& param, const int value ) {
  if ( param == std::string("nhit_cut") ) {
    fNHitCut = value;
  } else
    throw Method::MethodFailError( "Unknown positionTimeLikelihood parameter " + param );
}

FitResult
PositionTimeLikelihood::GetBestFit()
{
  fFitResult.Reset();
  if( fPMTData.empty() )
    return fFitResult;
  CopySeedToResult();

  SelectPMTData( fFitResult.GetVertex( 0 ) );
  fFitResult.SetFOM( "SelectedNHit", static_cast<double>( fSelectedPMTData.size() ) );
  if( fSelectedPMTData.empty() )
    throw MethodFailError( "PositionTimeLikelihood: No hits to fit" );

  // As agreed on discussions
  if (fSelectedPMTData.size() < fNHitCut) {
    std::ostringstream msg;
    msg << "PositionTimeLikelihood: Number of hits below threshold of " << fNHitCut;
    throw MethodFailError(msg.str());
  }

  fOptimiser->SetComponent( this );
  // Call the designated optimiser to fit and save the fom
  const double logL = fOptimiser->Maximise();
  fFitResult.SetFOM( "LogL", logL );
  // Now save the optimised values
  SetParams( fOptimiser->GetParams() );
  SetPositiveErrors( fOptimiser->GetPositiveErrors() );
  SetNegativeErrors( fOptimiser->GetNegativeErrors() );
  return fFitResult;
}

// Following function calculates the log-likelihood value given the current
// fit parameter values in params.  Called by the optimiser.
double PositionTimeLikelihood::operator()( const std::vector<double>& params )
{
  SetParams( params );
  double logLike = 0.0;
  for( vector<FitterPMT>::const_iterator iPMT = fSelectedPMTData.begin(); iPMT != fSelectedPMTData.end(); ++iPMT )
    {
      // LogLike sum of pdf logs
      logLike += log( fPDF->GetProbability( *iPMT, fFitResult.GetVertex( 0 ) ) );
    }
  return logLike;
}

vector<double>
PositionTimeLikelihood::GetParams() const
{
  vector<double> params;
  FitVertex vertex = fFitResult.GetVertex(0);
  params.push_back( vertex.GetPosition().x() );
  params.push_back( vertex.GetPosition().y() );
  params.push_back( vertex.GetPosition().z() );
  params.push_back( vertex.GetTime() );
  return params;
}

vector<double>
PositionTimeLikelihood::GetPositiveErrors() const
{
  vector<double> errors;
  FitVertex vertex = fFitResult.GetVertex(0);
  errors.push_back( vertex.GetPositivePositionError().x() );
  errors.push_back( vertex.GetPositivePositionError().y() );
  errors.push_back( vertex.GetPositivePositionError().z() );
  errors.push_back( vertex.GetPositiveTimeError() );
  return errors;
}

vector<double>
PositionTimeLikelihood::GetNegativeErrors() const
{
  vector<double> errors;
  FitVertex vertex = fFitResult.GetVertex(0);
  errors.push_back( vertex.GetNegativePositionError().x() );
  errors.push_back( vertex.GetNegativePositionError().y() );
  errors.push_back( vertex.GetNegativePositionError().z() );
  errors.push_back( vertex.GetNegativeTimeError() );
  return errors;
}

void
PositionTimeLikelihood::SetParams( const std::vector<double>& params )
{
  FitVertex vertex = fFitResult.GetVertex(0);
  vertex.SetPosition( TVector3( params[0], params[1], params[2] ), fOptimiser->GetValid() );
  vertex.SetTime( params[3], fOptimiser->GetValid() );
  fFitResult.SetVertex( 0, vertex );
}

void
PositionTimeLikelihood::SetPositiveErrors( const std::vector<double>& errors )
{
  FitVertex vertex = fFitResult.GetVertex(0);
  vertex.SetPositivePositionError( ROOT::Math::XYZVectorF(errors[0], errors[1], errors[2]), fOptimiser->GetValid() );
  vertex.SetPositiveTimeError( errors[3], fOptimiser->GetValid() );
  fFitResult.SetVertex( 0, vertex );
}

void
PositionTimeLikelihood::SetNegativeErrors( const std::vector<double>& errors )
{
  FitVertex vertex = fFitResult.GetVertex(0);
  vertex.SetNegativePositionError( ROOT::Math::XYZVectorF(errors[0], errors[1], errors[2]), fOptimiser->GetValid() );
  vertex.SetNegativeTimeError( errors[3], fOptimiser->GetValid() );
  fFitResult.SetVertex( 0, vertex );
}