#ifndef __JFIT__JMESTIMATOR__
#define __JFIT__JMESTIMATOR__

#include <cmath>


/**
 * \file
 * Maximum likelihood estimator (M-estimators).
 * \author mdejong
 */
namespace JFIT {}
namespace JPP { using namespace JFIT; }

namespace JFIT {

  /**
   * Interface for maximum likelihood estimator (M-estimator).
   */
  struct JMEstimator {
    /**
     * Virtual destructor.
     */
    virtual ~JMEstimator()
    {}


    /**
     * Get maximum likelihood estimate.
     *
     * \param  z           deviation
     * \return             likelihood
     */
    virtual double getRho(const double z) const = 0;


    /**
     * Get derivative of maximum likelihood estimate.
     *
     * \param  z           deviation
     * \return             derivative
     */
    virtual double getPsi(const double z) const = 0;
  };


  /**
   * Normal M-estimator.
   *
   * This estimator is based on a Gaussian PDF and produces the standard chi2.
   */
  struct JMEstimatorNormal :
    public JMEstimator
  {
    virtual double getRho(const double z) const { return 0.5*z*z; }
    virtual double getPsi(const double z) const { return z; }
  };


  /**
   * Lorentzian M-estimator.
   *
   * This estimator prouces a logarithmic dependence for large deviations.
   */
  struct JMEstimatorLorentzian :
    public JMEstimator
  {
    virtual double getRho(const double z) const { return log (1.0 + 0.5*z*z); }
    virtual double getPsi(const double z) const { return z / (1.0 + 0.5*z*z); }
  };


  /**
   * Linear M-estimator.
   *
   * This estimator produces a linear dependence for large deviations.
   */
  struct JMEstimatorLinear :
    public JMEstimator
  {
    virtual double getRho(const double z) const { return sqrt(1.0 + 0.5*z*z) - 1.0; }
    virtual double getPsi(const double z) const { return 0.5 * z / sqrt(1.0 + 0.5*z*z); }
  };


  /**
   * Null M-estimator.
   *
   * This is not an estimator at all, but an object that just returns whatever it is given.\n
   * It is introduced so that the user can directly access the likelihood calculated by JRegressor<JEnergy>.
   */
  struct JMEstimatorNull :
    public JMEstimator
  {
    virtual double getRho(const double z) const { return z; }
    virtual double getPsi(const double z) const { return 1.0; }
  };


  /**
   * Tukey's biweight M-estimator.
   *
   * This estimator produces a redescending dependence for large deviations.
   */
  struct JMEstimatorTukey :
    public JMEstimator
  {
    /**
     * Constructor.
     *
     * \param  k           standard deviation
     */
    JMEstimatorTukey(const double k) :
      k(k)
    {}

    virtual double getRho(const double z) const override 
    { 
      const double w = 0.5 * k*k / 3.0;

      if (fabs(z) < k) {

	const double u = z/k;
	const double v = 1.0 - u*u;
	
	return w * (1.0 - v*v*v);
      }

      return w;
    }

    virtual double getPsi(const double z) const override 
    { 
      if (fabs(z) < k) {

	const double u = z/k;
	const double v = 1.0 - u*u;
	
	return z * v*v;
      }

      return 0.0;
    }

    double k;
  };


  /**
   * Normal M-estimator with background.
   */
  struct JMEstimatorNormalWithBackground :
    public JMEstimator
  {
    /**
     * Constructor.
     *
     * \param  p           background probability
     */
    JMEstimatorNormalWithBackground(const double p) :
      p(p)
    {}

    virtual double getRho(const double z) const override 
    { 
      const double w = exp(-0.5*z*z);

      return -log(w + p);
    }

    virtual double getPsi(const double z) const override 
    {
      const double w = exp(-0.5*z*z);

      return z * w / (w + p);
    }

    double p;
  };


  /**
   * Definition of the various M-Estimators available to use.
   */
  enum JMEstimator_t {
    EM_NORMAL                 = 0,
    EM_LORENTZIAN             = 1,
    EM_LINEAR                 = 2,
    EM_NULL                   = 3,
    EM_TUKEY                  = 4,
    EM_NORMALWITHBACKGROUND   = 5
  };


  /**
   * Get M-Estimator.
   *
   * Note that for M-estimators with an additional parameter, defaults are used.
   *
   * \param  type        type
   * \return             pointer to newly created M-Estimator (may be NULL)
   */
  inline JMEstimator* getMEstimator(const int type)
  {
    switch (type) {

    case EM_NORMAL:
      return new JMEstimatorNormal();

    case EM_LORENTZIAN:
      return new JMEstimatorLorentzian();

    case EM_LINEAR:
      return new JMEstimatorLinear();

    case EM_NULL:
      return new JMEstimatorNull();

    case EM_TUKEY:
      return new JMEstimatorTukey(5.0);

    case EM_NORMALWITHBACKGROUND:
      return new JMEstimatorNormalWithBackground(1.0e-5);

    default:
      return NULL;
    }
  }
}

#endif