#define MPG_VERSION 2

/*******************************************************************************
 * \file Markov Chain Monte Carlo to generate photon paths

  The result is saved to a file.
*******************************************************************************/
#include<iostream>
#include<sstream>
#include<iomanip>
#include<vector>
#include<cmath>
#include<cstdlib>
#include<string>

#include "TRandom3.h"
#include "TH1F.h"
#include "TH2F.h"
#include "TFile.h"
#include "TTree.h"
#include "TBranch.h"
#include "TGraph.h"

// JPP includes
#include "Jeep/JParser.hh"
#include "JMarkov/JMarkovPathGenerator.hh"
#include "JMarkov/JPhotonPathWriter.hh"
#include "JPhysics/KM3NeT.hh"
//#include "JMarkov/JScatteringModel.hh"

// this little line removes the make_field macro defined in JParser.hh
// so that we can call the
// JPARSER::make_field(T,string) function ourselves
#undef make_field
using namespace JPP;

using namespace JMARKOV ;
using namespace std ;

//ClassImp(ScatteringModel)


int main( int argc, char** argv ) {
  cout << "JMarkovPathGenerator version " << MPG_VERSION << endl 
       << "Written by Martijn Jongen" << endl 
       << endl ; 
  cout << "Type '" << argv[0] << " -h!' to display the command-line options." << endl ;
  cout << endl ;

  string outfile = "out.paths"  ; // output file
  string smoutfile = ""        ; // output file to save scattering model to
  double lHG = 60.241          ; // scattering length for Henyey-Greenstein model
  double g   = 0.924           ; // parameter g in Heyney-Greenstein model
  double lR  = 294.118         ; // scattering length for Rayleigh scattering
  double a   = 0.853           ; // a parameter for Rayleigh scatteing
  double lA  = 50              ; // absorption length
  double d = 37                ; // distance to target
  int npaths = 100             ; // number of paths to simulate
  int nscat = 2                ; // number of scatterings
  int interval = 1000          ; // number of MCMC steps between saved paths
  int burnIn = 50000           ; // number of MCMC steps for initialization
  double stepsize = 10         ; // size of the MCMC steps
  bool sourceNB                ;
  double target_zenith         ;

  try {
    JParser<string> zap ; // this argument parser can handle strings    
    zap["o"] = make_field(outfile,"output file name") ;
    zap["O"] = make_field(smoutfile,"OPTIONAL: output file name for scattering model") ;
    zap["-lH"] = make_field(lHG,"scattering length in m for Henyey-Greenstein scattering") ;
    zap["g"] = make_field(g,"parameter g for Henyey-Greenstein function") ;
    zap["R"] = make_field(lR,"scattering length in m for Rayleigh scattering") ;
    zap["a"] = make_field(a,"parameter a for Rayleigh scattering") ;
    zap["A"] = make_field(lA,"absorption length in m") ;
    zap["d"] = make_field(d,"distance between source and target in m") ;
    zap["n"] = make_field(npaths,"number of paths to generate") ;
    zap["N"] = make_field(nscat,"number of scatterings") ;
    zap["i"] = make_field(interval,"number of MCMC steps to take between saving paths") ;
    zap["b"] = make_field(burnIn,"number of burn-in steps") ;
    zap["s"] = make_field(stepsize,"step size for the MCMC steps") ;
    zap["sourceNB"] = make_field(sourceNB,"Use a nanobeacon profile as source (currently a uniform distribution in a 45 degree cone around the positive z-direction") ;
    zap["target_zenith"] = make_field(target_zenith,"[degrees] OPTIONAL: set to  use a realistic PMT acceptance") = -1 ;

    if (zap.read(argc, argv) != 0) {
      return 1 ;
    }
  }
  catch(const exception &error) {
    // ignore exceptions
  }

  // inform user of the settings
  cout << "output file name   = '" << outfile << "'." << endl ;
  cout << "absorption length  = " << lA << " m" << endl ;
  cout << "distance to target = " << d  << " m" << endl ;
  cout << "npaths             = " << npaths << endl ;
  cout << "nscat              = " << nscat << endl ;
  cout << "interval           = " << interval << " steps" << endl ;
  cout << "burn-in            = " << burnIn << " steps" << endl ;
  cout << "step size          = " << stepsize << " m" << endl ;
  cout << endl ;

  // Henyey-Greenstein scattering model
  JScatteringModel smHG ;
  cout << "Henyey-Greenstein scattering length = " << lHG << " m" << endl ;
  cout << "g = " << g << endl ;
  smHG.setScatteringLength(lHG) ;
  smHG.setScatteringProfileHG(g) ;
  cout << endl ;

  // Rayleigh scattering model
  JScatteringModel smR ;
  cout << "Rayleigh scattering length = " << lR << " m" << endl ;
  cout << "a = " << a << endl ;
  smR.setScatteringLength(lR) ;
  smR.setScatteringProfileRayleigh(a) ;
  cout << endl ;

  // combined scattering model
  cout << "Combining HG and Rayleigh scattering into a single effective model." << endl ;
  JScatteringModel sm ;
  setEffectiveScatteringModel( smHG, smR, sm ) ;
  cout << "Effective scattering length = " << sm.getScatteringLength() << " m" << endl ;
  sm.setAbsorptionLength(lA) ;

  cout << "Integral over scattering profile = " << sm.hscat->Integral("width")*2*M_PI << " (should be 1)" << endl ;
  cout << endl ;

  // set source distribution
  if( sourceNB ) {
    cout << "Setting source distribution to a preliminary approximation of a nanobeacon profile. Light is emitted uniformly within a cone around the positive z-axis." << endl ;
    // we oversample each bin
    const int n = 100 ;
    for( Int_t xbin=1; xbin<=sm.hsource->GetNbinsX() ; ++xbin ) {
      double val = 0 ;
      double xmin = sm.hsource->GetXaxis()->GetBinLowEdge(bin) ;
      double xmax = sm.hsource->GetXaxis()->GetBinUpEdge(bin) ;
      for( int i=0 ; i<n ; ++i ) {
	double ct = xmin + (xmax-xmin)*(i+0.5)/n ;
	if( ct >= sqrt(0.5) ) val += 1 ;
      }
      val /= n ;
      for( Int_t ybin=1 ; ybin<=sm.hsource->GetNbinsY() ; ++ybin ) {
	Int_t bin = sm.hsource->GetBin(xbin,ybin) ;
	sm.hsource->SetBinContent(bin,val) ;
      }
    }
    sm.hsource->Scale( sqrt(2)/(2*M_PI*(sqrt(2)-1)) ) ;
    cout << endl ;
  }

  cout << "Integral over source profile = " << sm.hsource->Integral("width") << " (should be 1)" << endl ;
  cout << endl ;

  // set PMT efficiency
  if( target_zenith >= 0 ) {
    cout << "Setting target to a KM3NeT PMT." << endl 
	 << "Its orientation is rotated " << target_zenith << " degrees w.r.t. the negative z-axis" << endl
	 << "(the rotation is in the yz-plane)" << endl ;
    target_zenith *= M_PI/180 ; // convert to radians
    JGEOMETRY3D::JVersor3D pmtdir( sin(target_zenith), 0, cos(target_zenith) ) ;
    for( Int_t xbin=1; xbin<=sm.htarget->GetXaxis()->GetNbins() ; ++xbin ) {
      double ct = sm.htarget->GetXaxis()->GetBinCenter(xbin) ;
      double theta = acos(ct) ;
      for( Int_t ybin=1; ybin<=sm.htarget->GetYaxis()->GetNbins() ; ++ybin ) {
	Int_t bin = sm.htarget->GetBin(xbin,ybin) ;
	double phi = sm.htarget->GetYaxis()->GetBinCenter(ybin) ;
	
	JGEOMETRY3D::JVersor3D testdir( cos(phi)*sin(theta), 
					sin(phi)*sin(theta), 
					cos(theta) ) ;
	double effct = -testdir.getDot(pmtdir) ;
	double val = KM3NET::getAngularAcceptance(effct) ;
	// safety catch: I never want my probability to be exactly 0
	// so I just set it to a really low number!
	if( val == 0 ) val = 1e-10 ;
	sm.htarget->SetBinContent(bin,val) ;
      }
    }
    double maxval = sm.htarget->GetBinContent( sm.htarget->GetMaximumBin() ) ;
    sm.htarget->Scale(1.0/maxval) ;
    cout << endl ;
  }

  // create output root file containing the scattering model
  if( smoutfile != "" ) {
    cout << "Saving scattering model to '" << smoutfile << "'." << endl ;
    TFile* fout = new TFile(smoutfile.c_str(),"recreate") ;
    sm.Write() ;
    // we also save the source, scattering and target histograms to a separate
    // folder, so that we can view them immediately from a TBrowser
    fout->mkdir("ScatteringModel_ingredients")->cd() ;
    sm.hsource->Write() ;
    sm.hscat->Write() ;
    sm.htarget->Write() ;
    fout->cd() ;
    fout->Close() ;
    cout << endl ;
  }

  // generate an ensemble of paths
  cout << "Generating ensemble" << endl ;
  JMarkovPathGenerator mpg ;
  vector<JPhotonPath> ensemble = mpg.generateEnsemble( npaths, nscat, sm, burnIn, interval, d, stepsize ) ;
  cout << "Done generating ensemble." << endl ;
  double acceptance = mpg.getFractionAccepted() ;
  cout << 100*acceptance << "% of steps were accepted" << endl
       << "(as a rule of thumb, ~23% is optimal for high-dimensional spaces)" << endl 
       << endl ;

  // write the generated photon paths to a file 
  cout << "Writing generated ensemble to '" << outfile << "'." << endl ;
  JMARKOV::JPhotonPathWriter writer ;
  writer.open(outfile.c_str()) ;
  for( vector<JPhotonPath>::iterator it=ensemble.begin() ; it!=ensemble.end() ; ++it ) {
    writer.put( *it ) ;
  }
  writer.close() ;
  cout << endl ;

  cout << "Done!" << endl ;

  return 0 ;
}