#include <iostream>
#include <iomanip>
#include <vector>
#include <map>

#include "TROOT.h"
#include "TFile.h"
#include "TH1D.h"
#include "TH2D.h"

#include "km3net-dataformat/definitions/module_status.hh"

#include "JDB/JSupport.hh"
#include "JDB/JAHRS.hh"
#include "JDB/JAHRSCalibration_t.hh"
#include "JDB/JAHRSToolkit.hh"

#include "JSupport/JMultipleFileScanner.hh"
#include "JSupport/JFileRecorder.hh"
#include "JSupport/JMeta.hh"

#include "JROOT/JManager.hh"
#include "JFit/JSimplex.hh"
#include "JLang/JComparator.hh"
#include "JSystem/JStat.hh"

#include "JDetector/JDetector.hh"
#include "JDetector/JDetectorToolkit.hh"
#include "JDetector/JModuleRouter.hh"
#include "JDetector/JStringRouter.hh"
#include "JDetector/JCompass.hh"

#include "JCompass/JHit.hh"
#include "JCompass/JModel.hh"
#include "JCompass/JEvt.hh"
#include "JCompass/JEvtToolkit.hh"
#include "JCompass/JSupport.hh"

#include "Jeep/JProperties.hh"
#include "Jeep/JContainer.hh"
#include "Jeep/JPrint.hh"
#include "Jeep/JParser.hh"
#include "Jeep/JMessage.hh"


/**
 * \file
 *
 * Program to calibrate in situ AHRS.
 * \author mdejong
 */
int main(int argc, char **argv)
{
  using namespace std;
  using namespace JPP;

  JMultipleFileScanner_t inputFile;
  counter_type    numberOfEvents;
  JFileRecorder<JTYPELIST<JEvt, JMeta, TH1D, TH2D>::typelist>  outputFile;
  string          detectorFile;
  long long int   Tmax_s           = 600;                                // time window data collection [s]
  int             mestimator       = EM_LORENTZIAN;                      // M-estimator fit
  double          sigma_deg        = 1.0;                                // resolution [deg]
  double          stdev            = numeric_limits<double>::max();      // number of standard deviations
  int             numberOfOutliers = 0;
  string          ahrsFile;
  bool            overwriteDetector;
  int             debug;

  try {

    JProperties properties;

    properties.insert(gmake_property(Tmax_s));
    properties.insert(gmake_property(mestimator));
    properties.insert(gmake_property(sigma_deg));
    properties.insert(gmake_property(stdev));
    properties.insert(gmake_property(numberOfOutliers));

    JParser<> zap("Program to calibrate in situ AHRS.");
    
    zap['f'] = make_field(inputFile,    "output of JConvertDB -q ahrs");
    zap['n'] = make_field(numberOfEvents)                                    = JLimit::max();
    zap['a'] = make_field(detectorFile);
    zap['@'] = make_field(properties)                                        = JPARSER::initialised();
    zap['c'] = make_field(ahrsFile,     "output of JAHRSCalibration");
    zap['A'] = make_field(overwriteDetector);
    zap['o'] = make_field(outputFile)    = "compass.root";
    zap['d'] = make_field(debug)         = 2;

    zap(argc, argv);
  }
  catch(const exception &error) {
    FATAL(error.what() << endl);
  }


  JDetector detector;

  try {
    load(detectorFile, detector);
  }
  catch(const JException& error) {
    FATAL(error);
  }

  const floor_range   range = getRangeOfFloors(detector);

  const JModuleRouter router(detector);
  const JStringRouter string(detector);
  

  map<int, JAverage<JQuaternion3D> > compass;                     // output compass calibration


  const JAHRSCalibration_t calibration(ahrsFile.c_str());
  const JAHRSValidity      is_valid;


  JSimplex<JModel> simplex;

  JSimplex<JModel>::MAXIMUM_ITERATIONS = 10000;

  simplex.debug = debug;

  const JChi2 getChi2(mestimator);


  typedef JManager<int, TH1D>  JManager_t;

  JManager_t H0(new TH1D("%.twist", NULL, 100,  0.0,   5.0));
  JManager_t H1(new TH1D("%.swing", NULL, 250,  0.0,   2.5));
  JManager_t HN(new TH1D("%.count", NULL, 100, -0.5,  99.5));

  TH2D  h2("h2", NULL,
	   string.size(), -0.5, string.size() - 0.5,
	   range.getLength() + 1, range.getLowerLimit() - 0.5, range.getUpperLimit() + 0.5);

  for (Int_t i = 1; i <= h2.GetXaxis()->GetNbins(); ++i) {
    h2.GetXaxis()->SetBinLabel(i, MAKE_CSTRING(string.at(i-1)));
  }
  for (Int_t i = 1; i <= h2.GetYaxis()->GetNbins(); ++i) {
    h2.GetYaxis()->SetBinLabel(i, MAKE_CSTRING(i-1));
  }

  TH2D* h1 = (TH2D*) h2.Clone("h1");


  outputFile.open();

  outputFile.put(JMeta(argc, argv));

  counter_type counter = 0;
  
  for (JMultipleFileScanner_t::const_iterator file_name = inputFile.begin(); file_name != inputFile.end(); ++file_name) {

    STATUS("processing file " << *file_name << endl);
      
    map<int, vector<JAHRS> > data;                                // AHRS data per string
    
    for (JMultipleFileScanner<JAHRS> in(*file_name); in.hasNext() && counter != numberOfEvents; ++counter) {

      const JAHRS* parameters = in.next();

      if (is_valid(*parameters) && router.hasModule(parameters->DOMID)) {
	data[router.getModule(parameters->DOMID).getString()].push_back(*parameters);
      }
    }
    
    for (map<int, vector<JAHRS> >::iterator i = data.begin(); i != data.end(); ++i) {

      sort(i->second.begin(), i->second.end(), make_comparator(&JAHRS::UNIXTIME));
    
      for (vector<JAHRS>::const_iterator p = i->second.begin(); p != i->second.end(); ) {

	long long int t1 = p->UNIXTIME;
	long long int t2 = t1;

	vector<JHit> buffer;                                      // calibrated quaternion data

	for ( ; p != i->second.end() && p->UNIXTIME < t1 + Tmax_s * 1000; t2 = (p++)->UNIXTIME) {

	  if (calibration.has(p->DOMID)) {

	    const JModule& module = router.getModule(p->DOMID);

	    if (module.getFloor() != 0 && !module.has(COMPASS_DISABLE)) {

	      const JCompass compass(*p, calibration.get(p->DOMID));

	      const JQuaternion3D Q = module.getQuaternion() * compass.getQuaternion();

	      buffer.push_back(JHit(p->DOMID, module.getZ(), Q, sigma_deg));
	    }
	  }
	}

	if (buffer.size() > JModel::NUMBER_OF_PARAMETERS) {

	  for (vector<JHit>::const_iterator hit = buffer.begin(); hit != buffer.end(); ++hit) {

	    const JLocation& location = router.getModule(hit->getID());

	    h1->Fill((double) string.getIndex(location.getString()), (double) location.getFloor());
	  }
	  
	  JModel result(buffer.begin(), buffer.end());            // prefit

	  vector<JHit>::iterator __end = buffer.end();

	  for (int ns = 0; ns != numberOfOutliers; ++ns) {        // outlier removal

	    double                 xmax = 0.0;
	    vector<JHit>::iterator out  = __end;
	  
	    for (vector<JHit>::iterator hit = buffer.begin(); hit != __end; ++hit) {

	      const JQuaternion3D Q1 = result(hit->getZ());       // fitted
	      const JQuaternion3D Q2 = hit->getQuaternion();      // measured

	      const double x = getAngle(Q1, Q2);

	      if (x > xmax) {
		xmax = x;
		out  = hit;
	      }
	    }

	    if (xmax > stdev * sigma_deg) {

	      const JLocation& location = router.getModule(out->getID());

	      h2.Fill((double) string.getIndex(location.getString()), (double) location.getFloor());

	      if (debug >= debug_t) {
	      
		const JQuaternion3D Q1 = result(out->getZ());     // fitted
		const JQuaternion3D Q2 = out->getQuaternion();    // measured

		const JQuaternion3D::decomposition q1(Q1, JVector3Z_t);
		const JQuaternion3D::decomposition q2(Q2, JVector3Z_t);

		cout << "remove "  << location                     << ' '
		     << FIXED(5,2) << getAngle(Q1,Q2)              << ' '
		     << FIXED(5,2) << getAngle(q1.twist, q2.twist) << ' '
		     << FIXED(5,2) << getAngle(q1.swing, q2.swing) << endl;
	      }
	    
	      swap(*out, *--__end);

	      result = JModel(buffer.begin(), __end);             // refit

	    } else {

	      break;
	    }
	  }

	  simplex.value = result;                                 // start value

	  simplex.step.resize(4);

	  simplex.step[0] = JModel(JQuaternion3X(5.0e-1 * PI / 180.0), JQuaternion3D::getIdentity());
	  simplex.step[1] = JModel(JQuaternion3Y(5.0e-1 * PI / 180.0), JQuaternion3D::getIdentity());
	  simplex.step[2] = JModel(JQuaternion3Z(5.0e-1 * PI / 180.0), JQuaternion3D::getIdentity());
	  simplex.step[3] = JModel(JQuaternion3D::getIdentity(), JQuaternion3Z(5.0e-2 * PI / 180.0));

	  const double chi2 = simplex(getChi2, buffer.begin(), __end);
	  const int    ndf  = distance(buffer.begin(), __end) * 4 - simplex.step.size();

	  result = simplex.value;                                 // final value


	  outputFile.put(getEvt(JHead(t1, t2, i->first, ndf, chi2), result));


	  for (vector<JHit>::const_iterator hit = buffer.begin(); hit != buffer.end(); ++hit) {

	    const JQuaternion3D Q1 = result(hit->getZ());         // fitted
	    const JQuaternion3D Q2 = hit->getQuaternion();        // measured

	    const JQuaternion3D::decomposition q1(Q1, JVector3Z_t);
	    const JQuaternion3D::decomposition q2(Q2, JVector3Z_t);

	    compass[hit->getID()].put(Q1 * Q2.getConjugate());
	  
	    H0[hit->getID()]->Fill(getAngle(q1.twist, q2.twist));
	    H1[hit->getID()]->Fill(getAngle(q1.swing, q2.swing));
	  }

	  map<int, int> count;

	  for (vector<JHit>::const_iterator hit = buffer.begin(); hit != __end; ++hit) {
	    count[hit->getID()] += 1;
	  }

	  for (map<int, int>::const_iterator i = count.begin(); i != count.end(); ++i) {
	    HN[i->first]->Fill(i->second);
	  }
	}
      }
    }
  }

  
  h2.Divide(h1);

  outputFile.put(h2);

  for (JManager_t* p : { &H0, &H1, &HN }) {
    for (JManager_t::iterator i = p->begin(); i != p->end(); ++i) {
      outputFile.put(*(i->second));
    }
  }

  outputFile.close();


  if (overwriteDetector) {

    NOTICE("Store calibration data on file " << detectorFile << endl);

    if (detector.setToLatestVersion()) {
      NOTICE("Set detector version to " << detector.getVersion() << endl);
    }

    detector.comment.add(JMeta(argc, argv));

    for (map<int, JAverage<JQuaternion3D> >::const_iterator i = compass.begin(); i != compass.end(); ++i) {

      JModule& module = detector[router.getIndex(i->first)];

      JQuaternion3D Q(i->second * module.getQuaternion());

      module.setQuaternion(Q.normalise());
    }

    store(detectorFile, detector);
  }
}