/* This file is part of MAUS: http://micewww.pp.rl.ac.uk:8080/projects/maus
 *
 * MAUS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * MAUS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with MAUS.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

#include "src/map/MapCppTrackerTOFCombinedFit/MapCppTrackerTOFCombinedFit.hh"

#include <sstream>

#include "src/common_cpp/API/PyWrapMapBase.hh"
#include "src/common_cpp/Recon/Kalman/MAUSTrackWrapper.hh"
#include "src/common_cpp/Recon/Global/MaterialModelAxialLookup.hh"
#include "TMatrixD.h"

#include "Geant4/G4Material.hh"

namespace MAUS {

PyMODINIT_FUNC init_MapCppTrackerTOFCombinedFit(void) {
  PyWrapMapBase<MapCppTrackerTOFCombinedFit>::PyWrapMapBaseModInit
                                        ("MapCppTrackerTOFCombinedFit", "", "", "", "");
}

MapCppTrackerTOFCombinedFit::MapCppTrackerTOFCombinedFit() : \
  MapBase<Data>("MapCppTrackerTOFCombinedFit"),	\
  _radius_cut(5.0), _tof_resolution(.1), \
  _has_tof(false), \
  _helical_track_fitter(NULL) {
}

MapCppTrackerTOFCombinedFit::~MapCppTrackerTOFCombinedFit() {
}

void MapCppTrackerTOFCombinedFit::_birth(const std::string& argJsonConfigDocument) {

  // Pull out the global settings
  if ( !Globals::HasInstance() )
      GlobalsManager::InitialiseGlobals(argJsonConfigDocument);

  Json::Value* json = Globals::GetConfigurationCards();
  _radius_cut          	= (*json)["SciFiHelicalRadiusLimit"].asDouble();
  _tof_resolution	= .1; // [ns], could to smarter? Currently conservative (TODO)

  // Check that the TOFs are in the geometry, extract position of TOF1
  std::vector<const MiceModule*> tof_modules = Globals::GetReconstructionMiceModules()->
                                           findModulesByPropertyString("SensitiveDetector", "TOF");

  _has_tof = tof_modules.size() > 0;
  if ( !_has_tof )
      Squeak::mout(Squeak::warning) << "Tracker-TOF Combined Fit Mapper is initialised but "
                                    << "there is no TOF in the reconstruction geometry";

  bool found0(false), found1(false);
  for (const MiceModule* parent : tof_modules) {
    if ( parent->propertyInt("Station") == 0 ) {
      _tof0_z = parent->globalPosition().z();
      found0 = true;
    } else if ( parent->propertyInt("Station") == 1 ) {
      _tof1_z = parent->globalPosition().z();
      found1 = true;
    }
  }
  if ( !found0 || !found1 )
      throw Exceptions::Exception(Exceptions::nonRecoverable,
	"MapCppTrackerTOFCombinedFit::_birth()", "Could not locate the TOF1 Detetector");

  // Set the initial TOF1 position as the front of the detector
  double lower, upper;
  MaterialModelAxialLookup::GetBounds(_tof0_z, lower, upper);
  _tof0_z = lower;
  MaterialModelAxialLookup::GetBounds(_tof1_z, lower, upper);
  _tof1_z = lower;

  // Set up final track fit (Kalman filter)
  HelicalPropagator* helical_prop = new HelicalPropagator(Globals::GetSciFiGeometryHelper());
  helical_prop->SetCorrectPz(true);
  helical_prop->SetIncludeMCS(true);
  helical_prop->SetSubtractELoss(true);
  _helical_track_fitter = new Kalman::TrackFit(helical_prop);

  // Each measurement plane has a unique alignment and rotation,
  // they all need their own measurement object.
  SciFiTrackerMap& geo_map = Globals::GetSciFiGeometryHelper()->GeometryMap();
  for (std::pair<int, SciFiTrackerGeometry> tracker_geo : geo_map) {
    int tracker_const = (tracker_geo.first == 0 ? -1 : 1);
    for (std::pair<int, SciFiPlaneGeometry> plane_geo : tracker_geo.second.Planes) {

      int id = plane_geo.first * tracker_const;
      _helical_track_fitter->AddMeasurement(id,
                                             new MAUS::SciFiHelicalMeasurements(plane_geo.second));

      if ( id == -1 )
          _tku_z = plane_geo.second.GlobalPosition.Z();
      if ( id == 1 )
          _tkd_z = plane_geo.second.GlobalPosition.Z();
    }
  }
}

void MapCppTrackerTOFCombinedFit::_death() {

  // Delete the Kalman fitter if it has been initialized
  if ( _helical_track_fitter ) {
    delete _helical_track_fitter;
    _helical_track_fitter = NULL;
  }
}

void MapCppTrackerTOFCombinedFit::_process(Data* data) const {

  // If there is no TOF module, skip
  if ( !_has_tof )
      return;

  // Get the spill, check that it contains physics
  Spill& spill = *(data->GetSpill());
  if ( spill.GetDaqEventType() != "physics_event" )
      return;

  // Loop over the reconstructed events
  if ( spill.GetReconEvents() ) {
    for (unsigned int k = 0; k < spill.GetReconEvents()->size(); k++) {

      // Check that there is a TOF event and a SciFi event
      TOFEvent *tof_event = spill.GetReconEvents()->at(k)->GetTOFEvent();
      SciFiEvent *scifi_event = spill.GetReconEvents()->at(k)->GetSciFiEvent();
      if ( !tof_event || !scifi_event )
	continue;

      // Check that there are SciFi tracks
      SciFiTrackPArray tracks = scifi_event->scifitracks();
      bool has_tku(false), has_tkd(false);
      for (SciFiTrack* track : tracks)
	track->tracker() ? has_tkd = true
	  : has_tku = true;
      if ( !has_tku && !has_tkd )
	continue;

      // Calculate the momentum of the muon from its time-of-flight TOF0->1
      double tof_p, tof_dp;
      if ( !calculate_tof_momentum(tof_event, tof_p, tof_dp) )
	  continue;

      // Propagate the momentum to the necessary tracker(s)
      double start(_tof1_z), end;
      double tku_p(tof_p), tku_dp(tof_dp), tkd_p(tof_p), tkd_dp(tof_dp);
      if ( has_tku ) {
	end = _tku_z;
        if ( !propagate_momentum(tku_p, tku_dp, start, end) )
            continue;
      }
      if ( has_tkd ) {
	if ( has_tku ) {
	  start = _tku_z;
	  tkd_p = tku_p;
	  tkd_dp = tku_dp;
	}
        end = _tkd_z;
        if ( !propagate_momentum(tkd_p, tkd_dp, start, end) ) // TODO ? (marginal)
            continue;
      }

      // Loop over the SciFi tracks
      SciFiTrackPArray new_tracks;
      for (SciFiTrack* track : tracks) {
        try {
	  // Refit the tracks that need refitting (Poor fits and straight tracks)
          SciFiSeed* seed = track->scifi_seed();
          double tk_p = track->tracker() ? tkd_p : tku_p;
          double tk_dp = track->tracker() ? tkd_dp : tku_dp;
	  if ( seed->getAlgorithm() ) { // The current fit is a helical track
            if ( needs_refitting(seed) || track->is_dud() ) {
              update_seed(seed, tk_p);
              SciFiTrack* new_track = track_fit_helix(seed);
              new_track->SetWasRefit(1);
              new_tracks.push_back(new_track);
              delete track;
            } else {
              new_tracks.push_back(track);
	    }
	  } else { // The current fit is a staight track
            update_straight_seed(seed, tk_p);
            SciFiTrack* new_track = track_fit_helix(seed);
            new_track->SetWasRefit(2);
            new_tracks.push_back(new_track);
            delete track;
	  }

	  // Integrate the TOF information into the fit
	  update_track(new_tracks.back(), tk_p, tk_dp);
        }

	catch (Exceptions::Exception& e) {
	  std::cerr << "TOF information merge failed: " << e.what();
        }
      }

      scifi_event->set_scifitracks(new_tracks);
    }
  } else {
    std::cout << "No recon events found\n";
  }
}

bool MapCppTrackerTOFCombinedFit::calculate_tof_momentum(TOFEvent* tof_event,
							double& tof_p, double& tof_dp) const {

  // If there is not exactly 1 SP in each station, cannot reseed
  TOFEventSpacePoint* spacepoints = tof_event->GetTOFEventSpacePointPtr();
  if ( spacepoints->GetTOF0SpacePointArraySize() != 1 ||
       spacepoints->GetTOF1SpacePointArraySize() != 1 )
      return false;

  // Get the time-of-flight
  TOFSpacePoint tof0 = spacepoints->GetTOF0SpacePointArrayElement(0);
  TOFSpacePoint tof1 = spacepoints->GetTOF1SpacePointArrayElement(0);
  double delta_t = tof1.GetTime() - tof0.GetTime();			// [ns]

  // Reconstruct the velocity, skip if too close to c or unphysical
  double beta = (_tof1_z - _tof0_z) / (delta_t * CLHEP::c_light); 	// c_light in [mm/ns]
  if ( beta > .98 || beta <= 0 )
      return false;

  // Reconstruct the momentum, assuming muon species
  double gamma = 1.0 / sqrt(1.0 - beta*beta);
  tof_p = beta*gamma*Recon::Constants::MuonMass; 			// [MeV/c]

  // Evaluate the uncertainty on the momentum
  tof_dp = tof_p*gamma*gamma*_tof_resolution/delta_t;			// [MeV/c]

  return true;
}

bool MapCppTrackerTOFCombinedFit::propagate_momentum(double& p, double& dp,
							double start, double end) const {

  // Get the total energy of the muon
  double muon_mass = Recon::Constants::MuonMass;
  double energy = std::sqrt(muon_mass*muon_mass + p*p);

  // Convert the uncertainty on the input momentum into an uncertainty on the energy
  double de = p*dp/energy;
  double estrag2 = 0.;

  // Perform the energy loss step-by-step
  double lower, upper;
  double start_position = start;
  double position = start;

  MaterialModelAxialLookup lookup(0, 0, position);
  double dEdx, dEstrag2dx;
  do {
    lookup.SetMaterial(0, 0, position);
    MaterialModelAxialLookup::GetBounds(position, lower, upper);

    if ( lower < start ) {
      start_position = start;
    } else {
      start_position = lower;
    }

    if ( upper > end ) {
      position = end;
    } else {
      position = upper;
    }

    dEdx = lookup.dEdx(energy, muon_mass, +1);
    dEstrag2dx = lookup.estrag2(energy, muon_mass, +1);

    energy += dEdx*(position-start_position);
    estrag2 += dEstrag2dx*(position-start_position);

    position += MaterialModelAxialLookup::GetZTolerance();
  } while ( position < end && energy > muon_mass );

  // If the energy is smaller than the muon mass, the muon has stopped, cannot proceed
  if ( energy <= muon_mass )
      return false;

  // Evaluate the momentum from the energy
  p = std::sqrt(energy*energy - muon_mass*muon_mass);

  // Convert the total uncertainty on the energy into an uncertainty on the momentum
  de = sqrt(de*de + estrag2);
  dp = energy*de/p;

  return true;
}

bool MapCppTrackerTOFCombinedFit::needs_refitting(MAUS::SciFiSeed* seed) const {

  // Not sure this is the right way, rather make a cut on the uncertainty on pz?
  // A poor fit could produce a radius way larger than 150 for instance
  return static_cast<SciFiHelicalPRTrack*>(seed->getPRTrackTobject())->get_R() < _radius_cut;
}

void MapCppTrackerTOFCombinedFit::update_seed(MAUS::SciFiSeed* seed, double tof_p) const {

  // Get the existing seed, set pz to that measured using the TOF system
  TMatrixD vector = seed->getStateVector();

  if (vector(4, 0) >= 0.0) {
    vector(4, 0) =  1.0 / tof_p;
  } else {
    vector(4, 0) = -1.0 / tof_p;
  }

  seed->setStateVector(vector);
}

void MapCppTrackerTOFCombinedFit::update_straight_seed(MAUS::SciFiSeed* seed, double tof_p) const {

  // Use the pz measured in the TOF system to produce a new seed entirely
  TMatrixD vector = seed->getStateVector();

  TMatrixD new_vector(5, 1);
  new_vector(0, 0) = vector(0, 0);
  new_vector(1, 0) = 0.0;
  new_vector(2, 0) = vector(2, 0);
  new_vector(3, 0) = 0.0;
  if (seed->getTracker() == 0) {
    new_vector(4, 0) =  1.0 / tof_p;
  } else {
    new_vector(4, 0) = -1.0 / tof_p;
  }
  seed->setStateVector(new_vector);

  TMatrixD covariance = seed->getCovariance();
  TMatrixD new_covariance(5, 5);
  for (unsigned int i = 0; i < 5; ++i) {
    for (unsigned int j = 0; j < 5; ++j) {
      new_covariance(i, j) = 0.0;
    }
  }
  new_covariance(0, 0) = 100.0;
  new_covariance(1, 1) = 100.0;
  new_covariance(2, 2) = 100.0;
  new_covariance(3, 3) = 100.0;
  new_covariance(4, 4) = 2.5e-7;

  seed->setCovariance(new_covariance);
}

SciFiTrack* MapCppTrackerTOFCombinedFit::track_fit_helix(SciFiSeed* seed) const {

  // Use the new seed to build a Kalman track
  SciFiHelicalPRTrack* helical = static_cast<SciFiHelicalPRTrack*>(seed->getPRTrackTobject());

  Kalman::Track data_track = BuildTrack(helical, Globals::GetSciFiGeometryHelper(), 5);
  Kalman::State kalman_seed(seed->getStateVector(), seed->getCovariance());

  _helical_track_fitter->SetTrack(data_track);
  _helical_track_fitter->SetSeed(kalman_seed);

  _helical_track_fitter->Filter(false);
  _helical_track_fitter->Smooth(false);

  SciFiTrack* track = ConvertToSciFiTrack(_helical_track_fitter,
                                                       Globals::GetSciFiGeometryHelper(), helical);
  track->set_scifi_seed_tobject(seed);

  ThreeVector seed_pos = track->GetSeedPosition();
  ThreeVector seed_mom = track->GetSeedMomentum();

  return track;
}

bool MapCppTrackerTOFCombinedFit::update_track(MAUS::SciFiTrack* track,
							double tof_p, double tof_dp) const {

  // Get the momentum at the reference plane, find the weighted average
  MAUS::SciFiTrackPointPArray tps = track->scifitrackpoints();
  double tk_p, tk_dp, tk_pz;
  bool found = false;
  for (const MAUS::SciFiTrackPoint *tpoint : tps )
    if ( !tpoint->plane() && tpoint->station() == 1 ) {
      tk_p = tpoint->mom().mag();
      tk_pz = tpoint->mom().z();
      tk_dp = tpoint->mom_error().mag();

      found = true;
    }

  if ( !found )
      return false;

  double sum_of_weights = 1./(tof_dp*tof_dp)+1./(tk_dp*tk_dp);
  double weighted_p = (tof_p/(tof_dp*tof_dp)+tk_p/(tk_dp*tk_dp))/sum_of_weights;

  // Get the momentum correction
  double factor = weighted_p/tk_p;
  double corr = (factor-1.)*tk_pz;

  // Loop over the Kalman track planes, update the momentum vectors
  MAUS::ThreeVector new_mom;
  for (MAUS::SciFiTrackPoint *tpoint : tps) {
    new_mom = tpoint->mom();
    new_mom.SetZ(tpoint->mom().z() + corr);
    tpoint->set_mom(new_mom);
  }

  return true;
}
} // ~namespace MAUS