#include <RAT/CAENBits.hh>
#include <RAT/DropoutEvaluator.hh>
#include <RAT/DU/TrigBits.hh>
#include <RAT/DB.hh>
#include <TH1D.h>
#include <TFile.h>
#include <TF1.h>
#include <TGraph.h>
#include <TSpectrum.h>
#include <algorithm>

using std::vector;
using std::map;
namespace RAT {


    DropoutEvaluator::DropoutEvaluator() : Processor("DropoutEvaluator"){
      fDBName = "Dropout.ratdb";
      fDBTableName = "DROPOUT";
      fWriteToFile = false;
      fOutputType = "root";
      fBaselineSet = false;
      fTemperatureSet = false;
      fIterationsSet = false;
    }

    void DropoutEvaluator::SetI(const std::string& param, const int value) {
        if (param == "temperature") {
            fTemperatureSet = true;
            fFitter.SetTemperature(value);
        } else if(param == "iterations" || param == "niterations" || param == "num_iterations") {
            fIterationsSet = true;
            fFitter.SetIterations(value);
        } else if(param == "baseline samples" || param == "baseline_samples") {
            fBaselineSet = true;
            fBaselineSamples = value;
        } else {
            throw ParamUnknown(param);
        }
    }

    void DropoutEvaluator::SetS( const std::string& param, const std::string& value ) {
        if (param == "filename" || param == "file_name" || param == "file name") {
            fWriteToFile = true;
            fFilename = value;
        } else if (param == "output_type" || param == "output type") {
            if(value != "json" && value != "root") {
                warn << "Given output type (" << value <<") not known."
                        "Using default output type ("<< fOutputType << ").\n";
            } else{
                fOutputType = value;
            }
        } else {
            throw ParamUnknown(param);
        }

    }

    void DropoutEvaluator::BeginOfRun(DS::Run&){
        DBLinkPtr dblink = DB::Get()->GetLink("DropoutEvaluator", "DropoutEvaluator");
        if(!fBaselineSet) {
            fBaselineSamples = static_cast<unsigned int>( dblink->GetI("baseline_samples"));
        }
        if(!fTemperatureSet) {
            fFitter.SetTemperature(static_cast<unsigned int>( dblink->GetI("starting_temperature")));
        }
        if(!fIterationsSet) {
            fFitter.SetIterations(static_cast<unsigned int>( dblink->GetI("sampling_iterations")));
        }
    }

    Processor::Result DropoutEvaluator::DSEvent(DS::Run&, DS::Entry& ds) {
        for(size_t iEV = 0; iEV < ds.GetEVCount(); iEV++) {
            Event(ds, ds.GetEV(iEV));
        }
        return OK;
    }

    Processor::Result DropoutEvaluator::Event(DS::Entry&, DS::EV& ev) {

        if(!ev.DigitiserExists()) {
            return FAIL;
        }
        // Select only PGT
        int trig_word = ev.GetTrigType();
        int pgt_bit = 1<<DU::TrigBits::PulseGT;
        if((trig_word & pgt_bit) == 0) {
            return OK;
        }
        DS::Digitiser& caen = ev.GetDigitiser();
        std::vector<UShort_t> ids = caen.GetIDs();

        for(size_t i=0; i < ids.size(); i++) {
            UShort_t id = ids[i];
            UShort_t new_id=id;
            if (id < 10) {
                id *= 10;
            }
            // Fill relevant histogram with a baseline measurement
            double baseline = caen.Average(id, 0, fBaselineSamples);
            double max = caen.Max(id, 0, fBaselineSamples);
            double min = caen.Min(id, 0, fBaselineSamples);
            if(max < 4095) {
                fBaselineValues[new_id].push_back(baseline);
                fRangeValues[new_id].push_back(max - min);
            }
        }
        return OK;
    }

    void DropoutEvaluator::EndOfRun(DS::Run& run) {
        // Do a fit and store results
        // Assumes there will be two NHit caen traces available, N100L and N20L
        // TODO Generalize this to work with potentially many traces.

        if(fBaselineValues.count(NH100Lo) == 0 || fBaselineValues.count(NH20Lo) == 0) {
            warn << "Data not available for performing dropout fit.\n";
            return;
        }

        map<int, vector<double> > fit_results;
        map<int, vector<double> > fit_yvals;
        map<int, vector<double> > data_yvals;
        map<int, vector<double> > data_xvals;
        DBTable table(fDBTableName);
        table.SetI("version", 2);
        table.SetPassNumber(-1);
        table.SetRunRange(run.GetRunID(), run.GetRunID());

        for(int iHist=0; iHist<2; iHist++) {
            const char* trigger_type = "N100";
            int id = NH100Lo;
            if(iHist==1) {
                id = NH20Lo;
                trigger_type = "N20";
            }
            vector<double> bounds = ExtractWidthBounds(id);
            if(bounds.size() != 2) {
                // An error message should already have been emitted
                continue;
            }
            double min = bounds[0];
            double max = bounds[1];
            PrepFitter(id, min, max);
            fFitter.PerformFit();
            vector<double> params = fFitter.GetParams();
            vector <double> best_fit = fFitter.DropoutModel(fFitter.xax, params[0],
                                                                         params[1],
                                                                         params[2],
                                                                         params[3]);
            char buffer[64];
            sprintf(buffer,"%sRate", trigger_type);
            table.SetD(buffer, params[0]);
            sprintf(buffer,"%sLocation", trigger_type);
            table.SetD(buffer, params[1]);
            sprintf(buffer,"%sSeparation", trigger_type);
            table.SetD(buffer, params[2]);
            sprintf(buffer,"%sSigma", trigger_type);
            table.SetD(buffer, params[3]);
            sprintf(buffer,"%sNormalization", trigger_type);
            table.SetD(buffer, fFitter.fNormalization);
            sprintf(buffer,"%sChi2", trigger_type);
            table.SetD(buffer, fFitter.Chi2(fFitter.yvals, best_fit));

            fit_results[id] = params;
            data_yvals[id] = fFitter.yvals;
            data_xvals[id] = fFitter.xax;
            fit_yvals[id] = best_fit;
        }

        // Now store the results in a ratdb file
        table.SaveAs(fDBName);

        // If desired write the histograms to a file
        if (fWriteToFile) {
            if(fOutputType == "root") {
                WriteHistogramsToROOTFile(data_xvals, data_yvals, fit_yvals);
            }
            else if (fOutputType == "json") {
                WriteHistogramsToJSONFile(data_xvals, data_yvals, fit_yvals);
            }
        }
    }

    vector<double> DropoutEvaluator::ExtractWidthBounds(int id) {
        TSpectrum s;
        TF1 f("range_hist_gaus","gaus", 0, 100);
        TH1D h_range("baseline_range", "baseline_range", 100, 0, 100);
        vector<double> ret;
        unsigned int n_peaks = 0;
        double first_peak = -1;
        double peak_height;
        for(size_t i=0; i < fRangeValues[id].size(); i++) {
            vector<unsigned int>* vals = &(fRangeValues[id]);
            h_range.Fill(vals->at(i));
        }

        // n_peaks should be somewhere between 2 and say 5
        // we need to find the smallest one, that will correspond to
        // events with no nhit bumps in them
        n_peaks = s.Search(&h_range);
        if(n_peaks == 0) {
            warn << "DropoutEvaluator::Failed to find peak in CAEN width distribution, failing.\n";
            return ret;
        }
        int peak_num=0;
        for(unsigned int iPeak=0; iPeak  < n_peaks; iPeak++) {
            double posx = s.GetPositionX()[iPeak];
            if( posx < first_peak || first_peak < 0) {
                first_peak = posx;
                peak_num = iPeak;
            }
        }
        peak_height = s.GetPositionY()[peak_num];
        f.FixParameter(0, peak_height);
        f.FixParameter(1, first_peak);
        h_range.Fit(&f, "BQN0", "", 0, 0);
        // TODO check convergence
        double range_sigma = f.GetParameter(2);
        double min = first_peak - 2.0*range_sigma;
        double max = first_peak + 2.0*range_sigma;
        ret.push_back(min);
        ret.push_back(max);
        return ret;
    }

    void DropoutEvaluator::PrepFitter(const int& id, const double& min, const double& max) {
        const int BINNING_FACTOR = 2; // How many CAEN ADC values should be in a single histogram bin
        const int SEED_PEAK_WIDTH = 10;
        int first_peak = -1;
        unsigned int n_peaks;
        vector<double> peaks;
        double avg_diff = 0;
        TSpectrum s;
        TH1D baseline_histogram("n100_baseline", "n100_baseline", 1096/BINNING_FACTOR, 3000, 4096);

        for(size_t i =0; i < fBaselineValues[id].size(); i++) {
            double val = fRangeValues[id][i];
            if(val <= max && val >= min) {
                baseline_histogram.Fill(fBaselineValues[id][i]);
            }
        }
        fFitter.fNormalization = baseline_histogram.Integral("width");
        baseline_histogram.Scale(1.0/fFitter.fNormalization);

        n_peaks = s.Search(&baseline_histogram, SEED_PEAK_WIDTH/BINNING_FACTOR);
        peaks.resize(n_peaks);
        for(unsigned int iPeak=0; iPeak  < n_peaks; iPeak++) {
            double posx = s.GetPositionX()[iPeak];
            peaks[iPeak] = posx;
            if( posx > first_peak || first_peak < 0) {
                first_peak = posx;
            }
        }

        std::sort(peaks.begin(), peaks.end());
        for(unsigned int iPeak=1; iPeak  < n_peaks; iPeak++) {
            avg_diff += peaks[iPeak] - peaks[iPeak - 1];
        }
        avg_diff /= (n_peaks-1);

        // TODO have parameter positions not be hardcoded like this
        // 0 = rate, 1 = position, 2 = separation, 3 = sigma
        if(avg_diff > 0) {
            fFitter.SetParam(0, (n_peaks > 1) ? n_peaks/2 : 1);

            fFitter.SetParam(1, first_peak);
            fFitter.SetError(1, n_peaks*avg_diff);

            fFitter.SetParam(2, avg_diff);
            fFitter.SetError(2, 0);

            fFitter.SetParam(3, avg_diff/2.0);
        }

        // Give the fitter the x,y vals for the histogram
        vector<double> xax;
        vector<double> yax;
        for(int i=1; i<=baseline_histogram.GetNbinsX(); i++) {
            double y = baseline_histogram.GetBinContent(i);
            if(y != 0) {
                yax.push_back(y);
                xax.push_back(baseline_histogram.GetBinCenter(i));
            }

        }
        fFitter.xax = xax;
        fFitter.yvals = yax;
    }


    void DropoutEvaluator::WriteHistogramsToROOTFile(map<int, vector<double> >& x,
                                                     map<int, vector<double> >& y,
                                                     map<int, vector<double> >& yp) {
        char buffer[64];
        TFile f(fFilename.c_str(),"RECREATE");
        for(map<int, vector<double> >::iterator it=x.begin(); it!=x.end(); ++it) {
            int id = it->first;
            const char* name = "N100";
            TGraph* g_original = new TGraph();
            TGraph* g_fit = new TGraph();
            if(id == NH20Lo) {
                name = "N20";
            }
            for(size_t i=0; i < x[id].size(); i++) {
                g_fit->SetPoint(i, x[id][i], yp[id][i]);
                g_original->SetPoint(i, x[id][i], y[id][i]);
            }
            sprintf(buffer, "%s_fit_graph", name);
            g_fit->SetName(buffer);
            g_fit->SetTitle(buffer);
            g_fit->Write();
            sprintf(buffer, "%s_input_graph", name);
            g_original->SetName(buffer);
            g_original->SetTitle(buffer);
            g_original->Write();
            delete g_original;
            delete g_fit;
        }
        f.Close();
    }

    void DropoutEvaluator::WriteHistogramsToJSONFile(map<int, vector<double> >& x,
                                                     map<int, vector<double> >& y,
                                                     map<int, vector<double> >& yp) {
        DBTable output_table("DropoutValues");
        for(map<int, vector<double> >::iterator it=x.begin(); it!=x.end(); ++it) {
            int id = it->first;
            const char* name = "N100";
            if(id == NH20Lo) {
                name = "N20";
            }
            char buffer[64];
            sprintf(buffer,"%s_xax", name);
            output_table.SetDArray(buffer, x[id]);
            sprintf(buffer,"%s_yvals", name);
            output_table.SetDArray(buffer, y[id]);
            sprintf(buffer,"%s_fit_vals", name);
            output_table.SetDArray(buffer, yp[id]);
        }
        output_table.SaveAs(fFilename);
    }
} // namespace RAT