/*******************************************************************************
 * \file Calculate total hit probability for paths with a given number of scatterings

  The calculation is based on an input file containing (what is assumed to be) a 
  representative sample of JPhotonPaths.
  Also, the JScatteringModel used to generate the paths must be supplied.

  Approximates the total probability density integral by importance sampling paths from 
  vertex positions distributions that are based on the input path ensemble. 
*******************************************************************************/
// C++ standard library
#include<iostream>
#include<sstream>
#include<iomanip>
#include<vector>
#include<cmath>
#include<cstdlib>

// ROOT
#include "TRandom3.h"
#include "TFile.h"
#include "TPolyLine3D.h"
#include "TPolyMarker3D.h"
#include "TAxis3D.h"
#include "TView.h"
#include "TView3D.h"
#include "TCanvas.h"
#include "TPad.h"

// JPP
#include "Jeep/JParser.hh"
#include "JMarkov/JPhotonPath.hh"
#include "JMarkov/JPhotonPathReader.hh"
#include "JMarkov/JPhotonPathWriter.hh"
#include "JMarkov/JScatteringModel.hh"

// namespaces
using namespace std ;

// this little line removes the make_field macro defined in JParser.hh
// so that we can call the
// JPARSER::make_field(T,string) function ourselves
#undef make_field
using namespace JPP;


int main( int argc, char** argv ) {
  cout << "JMarkovPathIntegrator" << endl 
       << "Written by Martijn Jongen" << endl 
       << endl ; 

  gRandom->SetSeed(0) ;

  string ifname_paths ;
  string ifname_model ;
  string ofname ;
  int nsamples ;
  int nscat    ;
  int nbins    ;

  try {
    JParser<string> zap ; // this argument parser can handle strings    
    zap["f"] = make_field(ifname_paths,"input file name (binary file containing JPhotonPaths)") ;
    zap["m"] = make_field(ifname_model,"input file name (root file containing JScatteringModel)") ;
    zap["o"] = make_field(ofname,"OPTIONAL file name for root output") = "" ;
    zap["n"] = make_field(nscat,"number of scatterings") ;
    zap["N"] = make_field(nsamples,"number of samples") = 10000 ;
    zap["nbins"] = make_field(nbins,"nbins of vertex distribution histograms") = 1000 ;

    if (zap.read(argc, argv) != 0) {
      return 1 ;
    }
  }
  catch(const exception &error) {
    cout << error.what() << endl ;
    cout << "Type '" << argv[0] << " -h!' to display the command-line options." << endl ;
    exit(1) ;
  }

  cout << "SETTINGS: " << endl ;
  cout << "  Input file (paths) = '" << ifname_paths << "'" << endl ;
  cout << "  Input file (model) = '" << ifname_model << "'" << endl ;
  cout << "  Output file        = '" << ofname << "'" << endl ;
  cout << "  nsamples           = " << nsamples << endl ;
  cout << "  nscat              = " << nscat << endl ;
  cout << "  nbins              = " << nbins << endl ;
  cout << endl ;

  int nvert = nscat + 2 ; // the number of vertices in each path

  // read in the scatteringmodel
  cout << "Loading scattering model" << endl ;
  JMARKOV::JScatteringModel* sm ;
  // open the root file
  TFile* f = new TFile(ifname_model.c_str(),"read") ;
  if( f->IsZombie() ) {
    cerr << "Could not open file '" << ifname_model << "'." << endl ;
    exit(1) ;
  }
  sm = (JMARKOV::JScatteringModel*) f->Get("JMARKOV::JScatteringModel") ;
  if( sm == NULL ) {
    cerr << "Could not read JScatteringModel from file!" << endl ;
    exit(1) ;
  }
  cout << "JScatteringModel loaded" << endl
       << "  absorption length = " << sm->getAbsorptionLength() << " m" << endl 
       << "  scattering length = " << sm->getScatteringLength() << " m" << endl
       << endl ;


  // read the paths from the file
  JMARKOV::JPhotonPathReader reader ;
  reader.open(ifname_paths.c_str()) ;
  if( !reader.is_open() ) {
    cerr << "FATAL ERROR: unable to open input file '" << ifname_paths << "'." << endl ;
    exit(1) ;
  }
  int nread = 0 ;
  JMARKOV::JPhotonPath* p = NULL ;
  cout << "Reading file" << endl ;
  vector<JMARKOV::JPhotonPath> paths ;
  while( reader.hasNext() ) {
    p = reader.next() ;
    if( p->size() == (unsigned int)nvert )
      paths.push_back(*p) ;
    else 
      cout << p->size() << endl ;
    ++nread ;
  }
  reader.close() ;

  int npaths = paths.size() ;
  if( npaths == 0 ) {
    cerr << "FATAL ERROR: No JPhotonPaths with nscat=" << nscat << " in '" << ifname_paths << "'." << endl ;
    exit(1) ;
  }
  cout << "Done reading file. Selected " << npaths << " / " << nread << " JPhotonPaths." << endl ;
  cout << endl ;

  // find the minimal and maximal values of each coordinate of each vertex
  JMARKOV::JPhotonPath minvals(nscat) ;
  JMARKOV::JPhotonPath maxvals(nscat) ;
  double INF = 1.0/0.0 ;
  double MININF = -1.0/0.0 ;
  // initialize to +- infinity
  for( int i=0 ; i<nvert ; ++i ) {
    minvals[i] = JGEOMETRY3D::JPosition3D(    INF,    INF,    INF ) ;
    maxvals[i] = JGEOMETRY3D::JPosition3D( MININF, MININF, MININF ) ;
  }

  // loop over the paths once, to get the minimal and maximal values of x, y and z for each vertex
  for( vector<JMARKOV::JPhotonPath>::iterator p=paths.begin() ; p!=paths.end() ; ++p ) {
    // loop over the vertices
    for( int i=0; i<nvert; ++i ) {
      minvals[i] = JGEOMETRY3D::JPosition3D( min(minvals[i].getX(),p->at(i).getX()),
				             min(minvals[i].getY(),p->at(i).getY()), 
					     min(minvals[i].getZ(),p->at(i).getZ())  ) ;  
      maxvals[i] = JGEOMETRY3D::JPosition3D( max(maxvals[i].getX(),p->at(i).getX()),
				             max(maxvals[i].getY(),p->at(i).getY()), 
					     max(maxvals[i].getZ(),p->at(i).getZ())  ) ;  
    }
  }
  cout << endl ;

  cout << "Allocating histograms." << endl ;
  // allocate histograms for the vertex positions
  vector<TH1F*> hX(nvert) ; // x positions of vertices
  vector<TH1F*> hY(nvert) ; // y positions of vertices
  vector<TH1F*> hZ(nvert) ; // z positions of vertices
  for( int n=0; n<nvert; ++n ) {
    char hname[200] ;
    char htitle[200] ;

    // we multiply the min/max values by this factor to ensure that no 
    // entries end up in the overflow bin
    double fac = 1.001 ; 

    sprintf(hname, "hX_nscat_%i_vertex_%i", nscat, n ) ;
    sprintf(htitle, "X of vertex %i (nscat = %i)", n, nscat ) ;
    hX[n] = new TH1F(hname,htitle,nbins,fac*minvals[n].getX(),fac*maxvals[n].getX()) ;

    sprintf(hname, "hY_nscat_%i_vertex_%i", nscat, n ) ;
    sprintf(htitle, "Y of vertex %i (nscat = %i)", n, nscat ) ;
    hY[n] = new TH1F(hname,htitle,nbins,fac*minvals[n].getY(),fac*maxvals[n].getY()) ;

    sprintf(hname, "hZ_nscat_%i_vertex_%i", nscat, n ) ;
    sprintf(htitle, "Z of vertex %i (nscat = %i)", n, nscat ) ;
    hZ[n] = new TH1F(hname,htitle,nbins,fac*minvals[n].getZ(),fac*maxvals[n].getZ()) ;
  }
  cout << endl ;

  // fill the histograms
  cout << "Filling histograms" << endl ;
  for( vector<JMARKOV::JPhotonPath>::iterator p=paths.begin() ; p!=paths.end() ; ++p ) {
    // loop over the vertices
    for( int n=0; n<nvert; ++n ) {
      hX[n]->Fill( p->at(n).getX() ) ;
      hY[n]->Fill( p->at(n).getY() ) ;
      hZ[n]->Fill( p->at(n).getZ() ) ;
    }
  }
  cout << endl ;

  // normalize the histograms to 1
  cout << "Normalizing histograms." << endl ;
  for( int n=0 ; n<nvert ; ++n ) {
    hX[n]->Scale( 1.0 / hX[n]->Integral() ) ;
    hY[n]->Scale( 1.0 / hY[n]->Integral() ) ;
    hZ[n]->Scale( 1.0 / hZ[n]->Integral() ) ;
  }
  cout << endl ;

  // #tmp temporarily save the paths with extremely high contributions
  // to a file
  JMARKOV::JPhotonPathWriter writer ;
  writer.open("high_contribution_paths.paths") ;

  // MC-integrate the path probability, using the vertex distributions as
  // our sampling function
  cout << "Computing scattering probability" << endl ;
  double Pscat = 0 ;
  double Pscat_err ;
  vector<double> integralConts(nsamples) ;
  vector<double> rhos(nsamples) ;
  vector<double> ws(nsamples) ;
  TH1F* hRhos ;
  TH1F* hWs ;
  TH1F* hIntegralConts ;
  TH1F* hIntegralConts_zoom ;

  // define test path
  JMARKOV::JPhotonPath testpath(nscat) ;
  // starting vertex
  testpath[0] = JGEOMETRY3D::JPosition3D( hX[0]->GetMean(1),
					  hY[0]->GetMean(1),
					  hZ[0]->GetMean(1) ) ;
  // ending vertex
  testpath[nvert-1] = JGEOMETRY3D::JPosition3D( hX[nvert-1]->GetMean(1),
						hY[nvert-1]->GetMean(1),
						hZ[nvert-1]->GetMean(1) ) ;

  // integrate over the free vertices
  for( int i=0 ; i<nsamples; ++i ) {
    double w = 1 ; // contribution to the integral
    // loop over all vertices except the first and last one
    for( int n=1 ; n<nvert-1 ; ++n ) {
      double x = hX[n]->GetRandom() ;
      double y = hY[n]->GetRandom() ;
      double z = hZ[n]->GetRandom() ;
      Int_t xbin = hX[n]->GetXaxis()->FindBin(x) ;
      Int_t ybin = hY[n]->GetXaxis()->FindBin(y) ;
      Int_t zbin = hZ[n]->GetXaxis()->FindBin(z) ;
      double wx = hX[n]->GetBinContent(xbin) / hX[n]->GetBinWidth(xbin) ;
      double wy = hY[n]->GetBinContent(ybin) / hY[n]->GetBinWidth(ybin) ;
      double wz = hZ[n]->GetBinContent(zbin) / hZ[n]->GetBinWidth(zbin) ;
      w *= wx * wy * wz ;
      testpath[n] = JGEOMETRY3D::JPosition3D(x,y,z) ;
    }
    double rho = sm->getRho(testpath)  ; 
    rhos[i] = rho ;
    ws[i] = w ;
    integralConts[i] = rho/w ;
    Pscat += rho / w ;

    if( rho/w > 0.001 ) {
      /*cout << "Found high rho/w = " << rho/w << endl ;
      cout << "rho = " << rho << ", w = " << w << endl ;
      cout << "Rerunning calculation of rho in verbose mode" << endl ;
      sm->getRho(testpath,true) ;
      cout << "Path: " << endl ;
      for( JMARKOV::JPhotonPath::iterator it=testpath.begin() ; it!=testpath.end() ; ++it ) {
	cout << "( " << it->getX() << ", " << it->getY() << ", " << it->getZ() << " ) " ;
      }
      cout << endl ;
      cout << endl ;*/
      writer.put( testpath ) ;
    }
  }
  Pscat /= nsamples ;
  writer.close() ;

  // statistical characteristics of the integral contributions
  double sigma = 0  ; // standard deviation
  {
    double max ;
    char hname[200] ;
    char htitle[300] ;

    // histogram of path probability densities (whole range)
    max = *(max_element( rhos.begin(), rhos.end() )) ;
    sprintf( hname, "hRhos_nscat%i", nscat ) ;
    sprintf( htitle, "Path probability density for nscat = %i", nscat ) ;
    hRhos = new TH1F(hname,htitle,100,0,1.01*max) ;

    // histogram of path probability densities (whole range)
    max = *(max_element( ws.begin(), ws.end() )) ;
    sprintf( hname, "hWs_nscat%i", nscat ) ;
    sprintf( htitle, "Path weights for nscat = %i", nscat ) ;
    hWs = new TH1F(hname,htitle,100,0,1.01*max) ;

    // histogram of integral contributions (whole range)
    max = *(max_element( integralConts.begin(), integralConts.end() )) ;
    sprintf( hname, "hIntegralConts_nscat%i", nscat ) ;
    sprintf( htitle, "Contributions to the integral for nscat = %i", nscat ) ;
    hIntegralConts = new TH1F(hname,htitle,100,0,1.01*max) ;

    // histogram of integral contributions (up to 10*mean)
    max = 10*Pscat ;
    sprintf( hname, "hIntegralConts_nscat%i_zoom", nscat ) ;
    sprintf( htitle, "Contributions to the integral for nscat = %i (zoom)", nscat ) ;
    hIntegralConts_zoom = new TH1F(hname,htitle,100,0,1.01*max) ;

    for( int i=0; i<nsamples; ++i ) {
      hRhos->Fill( rhos[i] ) ;
      hWs->Fill( ws[i] ) ;
      hIntegralConts->Fill( integralConts[i] ) ;
      hIntegralConts_zoom->Fill( integralConts[i] ) ;
    }

    // get sample variance
    double var = 0 ;
    for( int i=0; i<nsamples; ++i ) {
      var += (integralConts[i]-Pscat) * (integralConts[i]-Pscat) ;
    }
    var /= nsamples ;
    // this is now the estimated width of the distribution
    sigma = sqrt(var) ;
    cout << "sigma = " << sigma << " (standard deviation of integral contributions)" << endl ;
  }
  Pscat_err = sigma / sqrt(nsamples) ;
  cout << "Pscat(" << nscat << " scatterings) = " << Pscat << endl ;
  cout << "Error estimate 1: sigma/sqrt(N) = " 
       << sigma << " / sqrt(" << nsamples << ") = " 
       << Pscat_err << endl ;
  // write some nice grep-able output
  cout << "MPI_PSCAT " << Pscat << endl 
       << "MPI_ERROR " << Pscat_err << endl
       << endl ;

  // optionally write the histograms to a root file
  if( ofname != "" ) {
    TFile* fout = new TFile(ofname.c_str(),"recreate") ;
    char dname[200] ;
    sprintf( dname, "nscat%i", nscat ) ;
    gDirectory->mkdir(dname)->cd() ;
    
    hRhos->Write() ;
    hWs->Write() ;
    hIntegralConts->Write() ;
    hIntegralConts_zoom->Write() ;

    for( int n=0; n<nvert; ++n ) {
      hX[n]->Write() ;
      hY[n]->Write() ;
      hZ[n]->Write() ;
    }
    fout->cd() ;
    fout->Close() ;
    delete fout ;
    cout << "Output written to '" << ofname << "'." << endl ;
    cout << endl ;
  }

  cout << "Done!" << endl ;
  return 0 ;
}