#include "CascadePdf.hh"


/*!
 * binary search, f should be increasing function of x
 */
template<typename F>
double find_value( F& f , double y , double xmin=-1000, double xmax = 2000, double tol = 0.2 )
{
  double x= ( xmax+xmin )/2;
  if ( xmax-xmin < tol ) return x;
  if ( f(x) > y ) return find_value( f, y, xmin, x, tol ); // go left
  else            return find_value( f, y, x, xmax, tol ); // go right
}

/*! 
 *  Stores the Pmt and corresponding hit times
 */
struct PmtSum
{
  Pmt pmt;
  vector<double> hit_times;
};

/*! 
 *  CascadeLikelihood class
 */
struct CascadeLikelihood
{
  CascadePdf pdf;

  double L_hit, L_unhit, L_time; /*! Likelihoods for different components */

  vector<PmtSum> hit_pmts, unhit_pmts, all_pmts;
  vector<Hit> hits;

  bool use_first_hit_pdf; /*! Use first hit probability for hit times */
  double v; /*! Verbosity */

  void set_nsamples( uint samples ){
    pdf.set_nsamples( samples );
  }

  void set_background( double rate )
  {
    pdf.set_background( rate );
  }

  CascadeLikelihood( string direct_pdf  = "/pbs/throng/km3net/src/Jpp/master/data/J14p.dat",
		     string scatter_pdf = "/pbs/throng/km3net/src/Jpp/master/data/J13p.dat",
		     bool first_hit_prob = true,
                     double sigma_blur   = 0, // sigma of the blurring of the pdf [ns]
                     uint elong_steps    = 10, // number of shower elongation sampling points
                     double R_bg = 0, // background rate [Hz]
                     double t_start = -100, // start of hit time residual window 
                     double t_end = 900,
                     double energy_fraction_sample = 0.0,
                     double verbosity = 0 ) : pdf( direct_pdf,
						   scatter_pdf,
						   sigma_blur, // blur
						   elong_steps, // elong steps
						   R_bg, //background rate
						   t_start,
						   t_end, // end of hit time residual window
						   energy_fraction_sample ), 
					      use_first_hit_pdf( first_hit_prob ),
					      v( verbosity ) 
  {}  

  /*! 
   *  Set event. Create arrays of PmtSum containing the pmts and hit times.
   *  Should be called before evaluation.
   */

  void set_event( vector<Pmt>& pmts, vector<Hit>& hits_ )
  {
    map<int, PmtSum > M;

    for ( auto& p : pmts ) 
    {
      p.flag = false;
      
      PmtSum p_temp;
      p_temp.pmt = p;

      M[p.id] = p_temp;
    }

    for ( auto& h : hits_ ) 
    {
      M[h.pmt_id].pmt.flag = true;
      M[h.pmt_id].hit_times.push_back( h.t );
    }


    hit_pmts.clear();
    unhit_pmts.clear();
    all_pmts.clear();

    hits = hits_;

    foreach_map( id, pmt, M ) {
      all_pmts.push_back(pmt);
      (pmt.pmt.flag? hit_pmts : unhit_pmts) . push_back( pmt );
    }

    if ( v > 0 ) 
    {
      cout << "CascadeLikelihood::set_event" << endl;
      cout << "hits: " << hits_.size() << endl;
      cout << "hit_pmts: " << hit_pmts.size() << endl;
      cout << "unhit_pmts: " << unhit_pmts.size() << endl;
      cout << "all_pmts: " << all_pmts.size() << endl;
    }
  
  }
  
  /*! 
   *  Evaluate the likelihood for a track/shower hypothesis
   */
  double eval(Trk& trk)
  {

    L_unhit = L_hit = L_time = 0;

    for( auto& pmt : unhit_pmts ) L_unhit -= pdf.Ntot( trk, pmt.pmt );
    for( auto& pmt : hit_pmts   ) L_hit   += log( 1 - exp ( - pdf.Ntot (trk, pmt.pmt )));
    for( auto& hit : hits       ) L_time  += log( 1e-12 + pdf.dP1_dt( trk, hit ) );

    return L_unhit + L_hit + L_time;
  }

  /*! 
   *  Evaluate the likelihood for a double track/shower hypothesis
   */
  double eval_double(Trk& trk1, Trk& trk2, bool amplitude = false )
  {
    L_unhit = L_hit = L_time = 0;

    if ( amplitude ) {

      for( auto& pmt : unhit_pmts ) L_unhit -= pdf.Ntot( trk1, pmt.pmt ) + pdf.Ntot( trk2, pmt.pmt );

      for( auto& pmt : hit_pmts   ) L_hit   += log( 1+1e-12 - exp ( - pdf.Ntot (trk1, pmt.pmt ) - pdf.Ntot (trk2, pmt.pmt ) ) );

    }

    for( auto& hit : hits       )
    { 
      if ( use_first_hit_pdf )	L_time  += log( 1e-12 + pdf.dP1_dt_double( Paramst(trk1, hit), Paramst(trk2, hit) ) );
      else L_time  += log( 1+1e-12 - exp( - pdf.dN_dt( trk1, hit ) - pdf.dN_dt( trk2, hit ) ) );
    }

    return L_unhit + L_hit + L_time;
  }



  struct PmtLikelihood : public Paramst 
  {
    PmtLikelihood( Paramst& p ) : Paramst(p) {}

    double L_hit, L_unhit, L_time, L;
  };




  void gen( vector<Hit>& r, const Trk& trk, const Pmt& pmt, bool first_hit_only = false, bool force_hit = false, bool approx = false )
  {
    static Hit hit;
    
    uint hit_count_cap = 100;

    Params p( trk, pmt );
    const double Ntot = pdf.Ntot(p);          if (Ntot==0) return;

    uint n = 1;
    if( !force_hit ){
      n = gRandom->Poisson( Ntot );           if (n==0)    return;
    }
    
    pmt.dress( hit ); // set pos, dir and ids
    
    double amp = 1;
    if( n > hit_count_cap ){ amp = n/hit_count_cap; n = hit_count_cap; }

    if ( first_hit_only ) 
    {
      auto f     = [&] ( double t ) -> double { return pdf.P1( Paramst( p,t )); };
      hit.t      = p.d / v_light + find_value( f, gRandom->Rndm() );
      hit.a = amp*n;
      r.push_back(hit);
    } else {
      auto f = [&] (double t) -> double { return pdf.N( Paramst( p,t ))/Ntot; };
      for (int i = 0 ; i<n; i++) {

	hit.t = p.d / v_light + find_value( f, gRandom->Rndm() );
	hit.id = r.size()+1;
	hit.a = amp;
	r.push_back(hit);
      }
    } 
  }
  

  vector<Hit> generate( const Trk& trk,  
                        const vector<Pmt>& pmts, 
                        bool first_hit_only = true )
  {
    vector<Hit> r;
    for( auto& pmt : pmts ) gen( r, trk, pmt, first_hit_only );
    return r;
  }


  vector<Hit> generate( const Trk& trk, 
                        const map<int, Pmt*> pmts , 
                        bool first_hit_only = true ) 
   {
    vector<Hit> r;
    for(auto& p : pmts ) gen ( r, trk, *p.second , first_hit_only );
    
    return r;
   } 

  vector<Hit> generate_single_pmt( const Trk& trk, 
				   const Pmt pmt , 
				   const uint n_sim = 10000,
				   bool first_hit_only = true,
				   bool approx = false) 
   {
    vector<Hit> r;
    while( r.size() != n_sim ){
      gen ( r, trk, pmt, first_hit_only, true, approx);
      if( r.size()%(n_sim/10) == 0){
	cout << "r.size(): " << r.size() << endl;
      }
    }
    
    return r;
   } 
};