#ifndef __JRECONSTRUCTION__JMUONPREFIT__
#define __JRECONSTRUCTION__JMUONPREFIT__

#include <iostream>
#include <iomanip>
#include <vector>
#include <algorithm>

#include "km3net-dataformat/online/JDAQEvent.hh"
#include "km3net-dataformat/online/JDAQTimeslice.hh"
#include "km3net-dataformat/definitions/fitparameters.hh"

#include "JTrigger/JHitR1.hh"
#include "JTrigger/JBuildL0.hh"
#include "JTrigger/JBuildL2.hh"
#include "JTrigger/JAlgorithm.hh"
#include "JTrigger/JMatch1D.hh"
#include "JTrigger/JMatch3B.hh"

#include "JFit/JLine1Z.hh"
#include "JFit/JLine1ZEstimator.hh"
#include "JFit/JMatrixNZ.hh"
#include "JFit/JVectorNZ.hh"
#include "JFit/JFitToolkit.hh"

#include "JReconstruction/JEvt.hh"
#include "JReconstruction/JEvtToolkit.hh"
#include "JReconstruction/JMuonPrefitParameters_t.hh"

#include "JMath/JConstants.hh"
#include "JTools/JPermutation.hh"

#include "JLang/JPredicate.hh"

#include "JDetector/JModuleRouter.hh"
#include "JDetector/JDetectorToolkit.hh"

#include "JGeometry3D/JOmega3D.hh"
#include "JGeometry3D/JRotation3D.hh"

#include "Jeep/JMessage.hh"


/**
 * \author mdejong, gmaggi, azegarelli
 */

namespace JRECONSTRUCTION {}
namespace JPP { using namespace JRECONSTRUCTION; }

namespace JRECONSTRUCTION {

  using JDETECTOR::JModuleRouter;
  using JFIT::JLine1Z;
  using JFIT::JEstimator;
  using JFIT::JMatrixNZ;
  using JFIT::JVectorNZ;
  using JTRIGGER::JMatch3B;
  using JTRIGGER::JMatch1D;
  using JGEOMETRY3D::JOmega3D;


  /**
   * Wrapper class to make pre-fit of muon trajectory.
   *
   * The JMuonPrefit fit is used to generate start values for subsequent fits (usually JMuonSimplex and JMuonGandalf).\n
   * To this end, a scan of directions is made and the time and transverse positions of the track are fitted for each direction (JFIT::JEstimator<JLine1Z>).\n
   * The directions are spaced by the parameters JMuonPrefitParameters_t::gridAngle_deg.\n
   * This angle corresponds to the envisaged angular accuracy of the result.\n
   * The probability that one of the results is less than this angle away from the correct value,
   * multiple start values should be considered (JMuonPrefitParameters_t::numberOfPrefits).\n
   * Note that the CPU time scales with the inverse of the square of this angle.\n
   * The chi-squared is based on the time residuals.\n
   */
  struct JMuonPrefit :
    public JMuonPrefitParameters_t,
    public JEstimator<JLine1Z>
  {
    typedef JEstimator<JLine1Z>                  JEstimator_t;
    typedef JTRIGGER::JHitR1                     hit_type;
    typedef std::vector<hit_type>                buffer_type;

    using JEstimator_t::operator();

    /**
     * Constructor 
     * 
     * \param  parameters    parameters
     * \param  router        module router
     * \param  debug         debug
     */
    JMuonPrefit(const JMuonPrefitParameters_t& parameters,
		const JModuleRouter&           router,
		const int                      debug = 0) :
      JMuonPrefitParameters_t(parameters),
      router(router),
      omega (parameters.gridAngle_deg * JMATH::PI/180.0),
      debug (debug)
    {
      configure();
    }


    /**
     * Constructor 
     * 
     * \param  router        module router
     * \param  parameters    parameters
     * \param  omega         directions
     * \param  debug         debug
     */
    JMuonPrefit(const JMuonPrefitParameters_t& parameters,
		const JModuleRouter&           router,
		const JOmega3D&                omega,
		const int                      debug = 0) :
      JMuonPrefitParameters_t(parameters),
      router(router),
      omega (omega),
      debug (debug)
    {
      configure();
    }


    /**
     * Fit function.
     *
     * \param  event         event
     * \return               fit results
     */
    JEvt operator()(const KM3NETDAQ::JDAQEvent& event)
    {
      using namespace std;
      using namespace JPP;

      const JBuildL0<hit_type> buildL0;
      const JBuildL2<hit_type> buildL2(JL2Parameters(2, TMaxLocal_ns, ctMin));

      buffer_type dataL0;
      buffer_type dataL1;

      buildL2(event, router, !useL0, back_inserter(dataL1));

      // 3D cluster of unique optical modules

      const JMatch3B<hit_type> match3B(roadWidth_m, TMaxLocal_ns);

      sort(dataL1.begin(), dataL1.end(), hit_type::compare);

      buffer_type::iterator __end = dataL1.end();

      __end = unique(dataL1.begin(), __end, equal_to<JDAQModuleIdentifier>());
    
      __end = clusterizeWeight(dataL1.begin(), __end, match3B);

      dataL1.erase(__end, dataL1.end());


      if (useL0) {

	buildL0(event, router, true, back_inserter(dataL0));

	__end = dataL0.end();

	for (buffer_type::iterator i = dataL0.begin(); i != __end; ) {

	  if (match3B.count(*i, dataL1.begin(), dataL1.end()) != 0)
	    ++i;
	  else
	    swap(*i, *--__end);
	}

	dataL0.erase(__end, dataL0.end());
      }

      return (*this)(dataL0, dataL1);
    }


    /**
     * Fit function.
     *
     * \param  dataL0        L0 hit data
     * \param  dataL1        L1 hit data
     * \return               fit results
     */
    JEvt operator()(const buffer_type& dataL0,
		    const buffer_type& dataL1)
    {
      using namespace std;
      using namespace JPP;

      const double STANDARD_DEVIATIONS    =   3.0;                              // [unit]
      const double HIT_OFF                =   1.0e3 * sigma_ns * sigma_ns;      // [ns^2]

      const JMatch1D<hit_type> match1D(roadWidth_m, TMaxLocal_ns);

      data.reserve(dataL0.size() + 
		   dataL1.size());

      JEvt out;

      for (JOmega3D_t::const_iterator dir = omega.begin(); dir != omega.end(); ++dir) {

	const JRotation3D R(*dir);


	buffer_type::iterator __end = copy(dataL1.begin(), dataL1.end(), data.begin());

	for (buffer_type::iterator i = data.begin(); i != __end; ++i) {
	  i->rotate(R);
	}


	// reduce data

	if (distance(data.begin(), __end) > NMaxHits) {

	  advance(__end = data.begin(), NMaxHits);

	  partial_sort(data.begin(), __end, data.end(), cmz);
	}


	// 1D cluster

	__end = clusterizeWeight(data.begin(), __end, match1D);

	if (useL0) {

	  buffer_type::iterator p = __end;                                     // begin L0 data
	  buffer_type::iterator q = copy(dataL0.begin(), dataL0.end(), p);     // end   L0 data

	  for (buffer_type::iterator i = p; i != q; ++i) {

	    if (find_if(data.begin(), __end, make_predicate(&hit_type::getModuleID, i->getModuleID())) == __end) {

	      i->rotate(R);

	      if (match1D.count(*i, data.begin(), __end) != 0) {
		*p = *i;
		++p;
	      }
	    }
	  }

	  __end = clusterize(__end, p, match1D);
	}

	  
	if (distance(data.begin(), __end) <= NUMBER_OF_PARAMETERS) {
	  continue;
	}


	// 1D fit
	  
	JLine1Z  tz;
	double   chi2 = numeric_limits<double>::max();
	int      NDF  = distance(data.begin(), __end) - NUMBER_OF_PARAMETERS;
	int      N    = getCount(data.begin(), __end);
	  

	if (distance(data.begin(), __end) <= factoryLimit) {

	  int number_of_outliers = numberOfOutliers;

	  if (number_of_outliers > NDF - 1) {
	    number_of_outliers = NDF - 1;
	  }

	  double ymin = numeric_limits<double>::max();
	    
	  buffer_type::iterator __end1 = __end;
	      
	  for (int n = 0; n <= number_of_outliers; ++n, --__end1) {
	      
	    sort(data.begin(), __end, hit_type::compare);
	      
	    do {
	      /*
	      if (getNumberOfStrings(router, data.begin(), __end1) < 2) {
		continue;
	      }
	      */
	      try {
		
		(*this)(data.begin(), __end1);
		
		V.set(*this, data.begin(), __end1, gridAngle_deg, sigma_ns);
		Y.set(*this, data.begin(), __end1);
		
		V.invert();
		
		double y = getChi2(Y, V);
		
		if (y <= -(STANDARD_DEVIATIONS * STANDARD_DEVIATIONS)) {

		  WARNING(endl << "chi2(1) " << y << endl);

		} else {

		  if (y < 0.0) {
		    y = 0.0;
		  }

		  if (y < ymin) {
		    ymin = y;
		    tz   = *this;
		    chi2 = ymin;
		    NDF  = distance(data.begin(), __end1) - NUMBER_OF_PARAMETERS;
		    N    = getCount(data.begin(), __end1);
		  }
		}
	      }
	      catch(const JException& error) {}

	    } while (next_permutation(data.begin(), __end1, __end, hit_type::compare));
	      
	    ymin -= STANDARD_DEVIATIONS * STANDARD_DEVIATIONS;
	  }

	} else {

	  const int number_of_outliers = NDF - 1;
	      
	  try {
	      
	    (*this)(data.begin(), __end);
	      
	    V.set(*this, data.begin(), __end, gridAngle_deg, sigma_ns);
	    Y.set(*this, data.begin(), __end);
	    
	    V.invert();

	    for (int n = 0; n <= number_of_outliers; ++n) {
	      
	      double ymax =  0.0;
	      int    k    = -1;
		
	      for (size_t i = 0; i != Y.size(); ++i) {
		  
		double y = getChi2(Y, V, i);
		  
		if (y > ymax) {
		  ymax = y;
		  k    = i;
		}
	      }
		
	      if (ymax < STANDARD_DEVIATIONS * STANDARD_DEVIATIONS) {
		break;
	      }

	      V.update(k, HIT_OFF);
	      
	      this->update(data.begin(), __end, V);
	      
	      Y.set(*this, data.begin(), __end);

	      NDF -= 1;
	      N   -= getCount(data[k]);
	    }

	    chi2 = getChi2(Y, V);
	    tz   = *this;
	  }
	  catch(const JException& error) {}
	}

	if (chi2 != numeric_limits<double>::max()) {

	  tz.rotate_back(R);

	  out.push_back(getFit(JHistory(JMUONPREFIT), tz, *dir, getQuality(chi2, N, NDF), NDF));
	}
      }

      JMuonPrefitParameters_t parameters = static_cast<const JMuonPrefitParameters_t&>(*this);

      if (parameters.numberOfPrefits > 0) {

        // apply default sorter

        JEvt::iterator __end = out.end();

        if (parameters.numberOfPrefits < out.size()) {

          advance(__end = out.begin(), parameters.numberOfPrefits);

          partial_sort(out.begin(), __end, out.end(), qualitySorter);

        } else {

          sort(out.begin(), __end, qualitySorter);
        }

        // add downward pointing solutions if available but not yet sufficient

        int nz = parameters.numberOfDZMax - count_if(out.begin(), __end, make_predicate(&JFit::getDZ, parameters.DZMax, JComparison::le()));

        if (nz > 0) {

          JEvt::iterator __p = __end;
          JEvt::iterator __q = __end = partition(__p, out.end(), make_predicate(&JFit::getDZ, parameters.DZMax, JComparison::le()));

          if (nz < distance(__p, __q)) {

            advance(__end = __p, nz);

            partial_sort(__p, __end, __q, qualitySorter);

          } else {

            sort(__p, __end, qualitySorter);
          }
        }

        out.erase(__end, out.end());

      } else {

        sort(out.begin(), out.end(), qualitySorter);
      }

      return out;
    }
 

    /**
     * Auxiliary data structure for sorting of hits.
     */
    static const struct cmz {
      /**
       * Sort hits according times corrected for position along z-axis.
       *
       * \param  first             first  hit
       * \param  second            second hit
       * \return                   true if first hit earlier than second hit; else false
       */
      template<class T>
      inline bool operator()(const T& first, const T& second) const
      {
	using namespace JPP;

	return (first .getT() * getSpeedOfLight()  -  first .getZ()  <
		second.getT() * getSpeedOfLight()  -  second.getZ());
      }
    } cmz;


    const JModuleRouter& router;
    JOmega3D             omega;
    int                  debug;

  private:
    /**
     * Configure internal buffer(s).
     */
    void configure()
    {
      using namespace JPP;

      data.reserve(getNumberOfPMTs   (router.getReference())  +
		   getNumberOfModules(router.getReference()));
    }

    buffer_type data;
    JMatrixNZ   V;
    JVectorNZ   Y;
  };
}

#endif