/* This file is part of MAUS: http://micewww.pp.rl.ac.uk:8080/projects/maus
 *
 * MAUS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * MAUS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with MAUS.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

#include <string>
#include <vector>
#include <map>

#include <sstream>

#include "src/common_cpp/Utils/JsonWrapper.hh"
#include "src/common_cpp/Utils/CppErrorHandler.hh"
#include "src/common_cpp/Utils/Exception.hh"
#include "src/legacy/Interface/dataCards.hh"
#include "src/common_cpp/API/PyWrapMapBase.hh"
#include "src/map/MapCppCuts/MapCppCuts.hh"

namespace MAUS {

    std::string class_docstring =
    std::string("MapCppCuts is the mapper for the MAUS Cuts class.\n");

    std::string birth_docstring =
    std::string("Checks if the right configuration is passed to the processor.\n");

    std::string process_docstring =
    std::string("Set up event(s) with some cut_event data and\n")+
    std::string("check if mapper sets the correct cut values.\n");

    std::string death_docstring =
    std::string("Does nothing.\n");

    PyMODINIT_FUNC init_MapCppCuts(void) {
        PyWrapMapBase<MapCppCuts>::PyWrapMapBaseModInit
        ("MapCppCuts", class_docstring, birth_docstring, process_docstring, death_docstring);
    }

    MapCppCuts::MapCppCuts()
    : MapBase<Data>("MapCppCuts") {
    }

    MapCppCuts::~MapCppCuts() {
    }

void MapCppCuts::_birth(const std::string& argJsonConfigDocument) {
// Called at the beginning of each run.
// Check if the JSON document can be parsed, else return error
// JsonCpp setup
    _configJSON = JsonWrapper::StringToJson(argJsonConfigDocument);

    _set_Cut_params =
	JsonWrapper::GetProperty(_configJSON, "recon_cuts", JsonWrapper::objectValue);

    _min_tof =
	JsonWrapper::GetProperty(_set_Cut_params, "min_tof", JsonWrapper::realValue).asDouble();

    _max_tof =
	JsonWrapper::GetProperty(_set_Cut_params, "max_tof", JsonWrapper::realValue).asDouble();

    _min_mom_us =
        JsonWrapper::GetProperty(_set_Cut_params, "min_mom_US", JsonWrapper::realValue).asDouble();

    _max_mom_us =
        JsonWrapper::GetProperty(_set_Cut_params, "max_mom_US", JsonWrapper::realValue).asDouble();

    _min_mom_ds =
        JsonWrapper::GetProperty(_set_Cut_params, "min_mom_DS", JsonWrapper::realValue).asDouble();

    _max_mom_ds =
        JsonWrapper::GetProperty(_set_Cut_params, "max_mom_DS", JsonWrapper::realValue).asDouble();

    _min_mom_loss =
        JsonWrapper::GetProperty(_set_Cut_params, "min_mom_loss",
        JsonWrapper::realValue).asDouble();

    _max_mom_loss =
        JsonWrapper::GetProperty(_set_Cut_params, "max_mom_loss",
        JsonWrapper::realValue).asDouble();

    _min_mass =
	JsonWrapper::GetProperty(_set_Cut_params, "min_mass",
        JsonWrapper::realValue).asDouble();

    _max_mass =
	JsonWrapper::GetProperty(_set_Cut_params, "max_mass",
        JsonWrapper::realValue).asDouble();

    _good_pval =
        JsonWrapper::GetProperty(_set_Cut_params, "good_pval",
        JsonWrapper::realValue).asDouble();
}

void MapCppCuts::_process(MAUS::Data *data) const {

    if (!data) {
    	throw MAUS::Exceptions::Exception(Exceptions::recoverable,
		"data was NULL", "MapCppCuts::_process");
    }

    // Get spill, break if there's no DAQ data
    Spill *spill = data->GetSpill();
    if (spill->GetDAQData() == NULL)
        return;
    if (spill->GetDaqEventType() != "physics_event")
    	return;

    ReconEventPArray *events = spill->GetReconEvents();
    int nPartEvents = events->size();
    for (int i = 0; i < nPartEvents; i++) {

		if (events->at(i)->GetCutEvent() == NULL) {
			events->at(i)->SetCutEvent(new Cuts());
		}

        // Set up empty vector
        std::vector<bool> all_cut_values(12, false);

        // Cuts 0, 1, 2
        Get_single_TOF0_hit_cut((*events)[i]);
        Get_single_TOF1_hit_cut((*events)[i]);
        Get_TimeOfFlight((*events)[i]);
        double tof = Get_TimeOfFlight((*events)[i]);

        GetTOF_hitTimes_cut(tof, _min_tof, _max_tof);
        all_cut_values[0] = Get_single_TOF0_hit_cut((*events)[i]);
        all_cut_values[1] = Get_single_TOF1_hit_cut((*events)[i]);
        all_cut_values[2] = GetTOF_hitTimes_cut(tof, _min_tof, _max_tof);

        // Cuts 3, 4
        Get_single_track_cut((*events)[i]);
        Get_station_hits_cut_US((*events)[i]);
        all_cut_values[3] = Get_single_track_cut((*events)[i]);
        all_cut_values[4] = Get_station_hits_cut_US((*events)[i]);

        // Typically we want momentum nearest to absorber (station 1) (US)
        process_US_mom((*events)[i]);
        double US_mom = process_US_mom((*events)[i]);

        // Cuts 5, 6, 7, 8
        Get_US_mom_cut(US_mom, _min_mom_us, _max_mom_us);
        Get_momentum_loss_cut(tof, _min_mom_loss, _max_mom_loss, US_mom);
        Get_p_value_cut((*events)[i], _good_pval);
        Get_mass_cut(tof, _min_mass, _max_mass, US_mom);
        all_cut_values[5] = Get_US_mom_cut(US_mom, _min_mom_us, _max_mom_us);
        all_cut_values[6] = Get_momentum_loss_cut(tof, _min_mom_loss,
        _max_mom_loss, US_mom);
        all_cut_values[7] = Get_p_value_cut((*events)[i], _good_pval);
		all_cut_values[8] = Get_mass_cut(tof, _min_mass, _max_mass,
        US_mom);

        // Cut 9
        Get_good_particle_cut(all_cut_values[0], all_cut_values[1], all_cut_values[2],
            all_cut_values[3], all_cut_values[4], all_cut_values[5], all_cut_values[6],
            all_cut_values[7], all_cut_values[8]);
        all_cut_values[9] = Get_good_particle_cut(all_cut_values[0], all_cut_values[1],
            all_cut_values[2], all_cut_values[3], all_cut_values[4], all_cut_values[5],
            all_cut_values[6], all_cut_values[7], all_cut_values[8]);

        // Downstream cuts 10, 11
        Get_station_hits_cut_DS((*events)[i]);
        all_cut_values[10] = Get_station_hits_cut_DS((*events)[i]);

        // Typically we want momentum nearest to absorber (station 5) (DS)
        process_DS_mom((*events)[i]);
        double DS_mom = process_DS_mom((*events)[i]);
        Get_DS_mom_cut(DS_mom, _min_mom_ds, _max_mom_ds);
        all_cut_values[11] = Get_DS_mom_cut(DS_mom, _min_mom_ds, _max_mom_ds);

        // Pass the vector to the event
        (*events)[i]->GetCutEvent()->SetCutStore(all_cut_values);
    } // fetching events
} // process

// GET TIME OF FLIGHT
double MapCppCuts::Get_TimeOfFlight(MAUS::ReconEvent* event) const {

    MAUS::TOFEvent * MyEvent = event->GetTOFEvent();
    if (MyEvent == NULL) {
        return 0;
    }

    MAUS::TOFEventSpacePoint SpacePoint = MyEvent->GetTOFEventSpacePoint();
    if (SpacePoint.GetTOF0SpacePointArray().size() != 1) return 0;
    if (SpacePoint.GetTOF1SpacePointArray().size() != 1) return 0;

    MAUS::TOFSpacePoint TOF0_sp = SpacePoint.GetTOF0SpacePointArray()[0];
    MAUS::TOFSpacePoint TOF1_sp = SpacePoint.GetTOF1SpacePointArray()[0];
    double TOF0_hitTime = TOF0_sp.GetTime();
    double TOF1_hitTime = TOF1_sp.GetTime();
    double dt = TOF1_hitTime - TOF0_hitTime;

    return dt;
}

/* GET MOMENTUM
std::vector<double> MapCppCuts::Process_momentum(MAUS::ReconEvent* event) const {

     int tracker, station;
     double TKU_plane1_px, TKU_plane1_py, TKU_plane1_pz, TKU_plane1_p;
     double TKU_plane2_px, TKU_plane2_py, TKU_plane2_pz, TKU_plane2_p;
     double TKU_plane3_px, TKU_plane3_py, TKU_plane3_pz, TKU_plane3_p;
     double TKU_plane4_px, TKU_plane4_py, TKU_plane4_pz, TKU_plane4_p;
     double TKU_plane5_px, TKU_plane5_py, TKU_plane5_pz, TKU_plane5_p;

     MAUS::SciFiEvent * MyEvent = event->GetSciFiEvent();
     std::vector<double> all_station_mom(5, 0);
     if (MyEvent == NULL) return all_station_mom;

     MAUS::ThreeVector momentum;
     std::vector<MAUS::SciFiTrack*> tracks = MyEvent->scifitracks();
     std::vector<MAUS::SciFiTrack*>::iterator tr_iter;
     std::vector<MAUS::SciFiTrackPoint*>::iterator tp_iter;

     // for (tr_iter = tracks.begin(); tr_iter != tracks.end(); tr_iter++) {
     if (tracks.size() != 0) {
         tr_iter = tracks.begin();
         std::vector<MAUS::SciFiTrackPoint*> tr_points = (*tr_iter)->scifitrackpoints();
         if (tr_points.size() != 0) {
         // for (tp_iter = tr_points.begin(); tp_iter != tr_points.end(); tp_iter++) {
             tp_iter = tr_points.begin();
	     MAUS::SciFiTrackPoint* point = (*tp_iter);
	     tracker = point->tracker();
             station = point->station();
	     momentum = point->mom();
	     if (tracker == 0) {
	         if (station == 1) {
		     TKU_plane1_px = momentum.x();
		     TKU_plane1_py = momentum.y();
		     TKU_plane1_pz = momentum.z();
		     TKU_plane1_p = sqrt(TKU_plane1_px*TKU_plane1_px
		         + TKU_plane1_py*TKU_plane1_py
                	 + TKU_plane1_pz*TKU_plane1_pz);
                     all_station_mom[0] = TKU_plane1_p;
                 } else if (station == 2) {
        	     TKU_plane2_px = momentum.x();
	             TKU_plane2_py = momentum.y();
	             TKU_plane2_pz = momentum.z();
		     TKU_plane2_p = sqrt(TKU_plane2_px*TKU_plane2_px
        	         + TKU_plane2_py*TKU_plane2_py
                	 + TKU_plane2_pz*TKU_plane2_pz);
	             all_station_mom[1] = TKU_plane2_p;
	         } else if (station == 3) {
	             TKU_plane3_px = momentum.x();
        	     TKU_plane3_py = momentum.y();
	             TKU_plane3_pz = momentum.z();
        	     TKU_plane3_p = sqrt(TKU_plane3_px*TKU_plane3_px
                         + TKU_plane3_py*TKU_plane3_py
                         + TKU_plane3_pz*TKU_plane3_pz);
		     all_station_mom[2] = TKU_plane3_p;
		 } else if (station == 4) {
		     TKU_plane4_px = momentum.x();
		     TKU_plane4_py = momentum.y();
	             TKU_plane4_pz = momentum.z();
	             TKU_plane4_p = sqrt(TKU_plane4_px*TKU_plane4_px
        	         + TKU_plane4_py*TKU_plane4_py
                	 + TKU_plane4_pz*TKU_plane4_pz);
	             all_station_mom[3] = TKU_plane4_p;
        	 } else if (station == 5) {
		     TKU_plane5_px = momentum.x();
	             TKU_plane5_py = momentum.y();
		     TKU_plane5_pz = momentum.z();
		     TKU_plane5_p = sqrt(TKU_plane5_px*TKU_plane5_px
        	         + TKU_plane5_py*TKU_plane5_py
                         + TKU_plane5_pz*TKU_plane5_pz);
	             all_station_mom[4] = TKU_plane5_p;
                 }
	     } // tracker station
         } // trackpoint iterator
     } // track iterator
     return all_station_mom;
}
*/

// GET US MOMENTUM
double MapCppCuts::process_US_mom(MAUS::ReconEvent* event) const {

    int tracker, station, plane;
    double mom_US = 0;
    double TKU_plane2_px, TKU_plane2_py, TKU_plane2_pz, TKU_plane2_p;

    MAUS::SciFiEvent * MyEvent = event->GetSciFiEvent();
    if (MyEvent == NULL) return mom_US;

    MAUS::ThreeVector momentum;
    std::vector<MAUS::SciFiTrack*> tracks = MyEvent->scifitracks();
    std::vector<MAUS::SciFiTrack*>::iterator track_it;

    if (tracks.size() != 0) {
        for (track_it = tracks.begin(); track_it != tracks.end(); track_it++) {
            std::vector<MAUS::SciFiTrackPoint*> tr_point = (*track_it)->scifitrackpoints();
            std::vector<MAUS::SciFiTrackPoint*>::iterator tp_iter;
            for (tp_iter = tr_point.begin(); tp_iter != tr_point.end(); tp_iter++) {
                MAUS::SciFiTrackPoint* point = (*tp_iter);
                tracker = point->tracker();
                plane = point->plane();
                station = point->station();
                momentum = point->mom();
                TKU_plane2_px = momentum.x();
                TKU_plane2_py = momentum.y();
                TKU_plane2_pz = momentum.z();
                TKU_plane2_p = sqrt(TKU_plane2_px*TKU_plane2_px
                + TKU_plane2_py*TKU_plane2_py
                + TKU_plane2_pz*TKU_plane2_pz);

                if (tracker == 0 and plane == 2 and station == 5) {
                    mom_US = TKU_plane2_p;
                }
            }
        }
    }
    return mom_US;
}

// GET DS MOMENTUM
double MapCppCuts::process_DS_mom(MAUS::ReconEvent* event) const {

    int tracker, station, plane;
    double mom_DS = 0;
    double TKD_plane2_px, TKD_plane2_py, TKD_plane2_pz, TKD_plane2_p;

    MAUS::SciFiEvent * MyEvent = event->GetSciFiEvent();
    if (MyEvent == NULL) return mom_DS;

    MAUS::ThreeVector momentum;
    std::vector<MAUS::SciFiTrack*> tracks = MyEvent->scifitracks();
    std::vector<MAUS::SciFiTrack*>::iterator track_it;

    if (tracks.size() != 0) {
        for (track_it = tracks.begin(); track_it != tracks.end(); track_it++) {
            std::vector<MAUS::SciFiTrackPoint*> tr_point = (*track_it)->scifitrackpoints();
            std::vector<MAUS::SciFiTrackPoint*>::iterator tp_iter;
            for (tp_iter = tr_point.begin(); tp_iter != tr_point.end(); tp_iter++) {
                MAUS::SciFiTrackPoint* point = (*tp_iter);
                tracker = point->tracker();
                plane = point->plane();
                station = point->station();
                momentum = point->mom();
                TKD_plane2_px = momentum.x();
                TKD_plane2_py = momentum.y();
                TKD_plane2_pz = momentum.z();
                TKD_plane2_p = sqrt(TKD_plane2_px*TKD_plane2_px
                + TKD_plane2_py*TKD_plane2_py
                + TKD_plane2_pz*TKD_plane2_pz);

                if (tracker == 1 and plane == 2 and station == 5) {
                    mom_DS = TKD_plane2_p;
                }
            }
        }
    }

    return mom_DS;
}

// CUT CHECK FUNCTIONS
bool MapCppCuts::Get_single_TOF0_hit_cut(MAUS::ReconEvent* event) const {

    bool willCut = false;

    MAUS::TOFEvent * MyEvent = event->GetTOFEvent();
    if (MyEvent == NULL) {
         return willCut;
    }

    MAUS::TOFEventSpacePoint SpacePoint = MyEvent->GetTOFEventSpacePoint();
    if (SpacePoint.GetTOF0SpacePointArray().size() == 1) {
         willCut = true;
    }
    return willCut;
}

bool MapCppCuts::Get_single_TOF1_hit_cut(MAUS::ReconEvent* event) const {

    bool willCut = false;

    MAUS::TOFEvent * MyEvent = event->GetTOFEvent();
    if (MyEvent == NULL) {
         return willCut;
    }

    MAUS::TOFEventSpacePoint SpacePoint = MyEvent->GetTOFEventSpacePoint();
    if (SpacePoint.GetTOF1SpacePointArray().size() == 1) {
         willCut = true;
    }
    return willCut;
}

bool MapCppCuts::GetTOF_hitTimes_cut(double dt, double min_t,
    double max_t) const {

    bool willCut = false;

    if (dt >= min_t && dt <= max_t) {
        willCut = true;
    }
    return willCut;
}

bool MapCppCuts::Get_single_track_cut(MAUS::ReconEvent* event) const {

     bool willCut = false;

     MAUS::SciFiEvent * MyEvent = event->GetSciFiEvent();
     if (MyEvent == NULL) return willCut;
     std::vector<MAUS::SciFiTrack*> tracks = MyEvent->scifitracks();
     if (tracks.size() == 1) {
          willCut = true;
     }
     return willCut;
}

bool MapCppCuts::Get_station_hits_cut_DS(MAUS::ReconEvent* event) const {

    bool willCut = false;
    int tracker, station_hits;

    MAUS::SciFiEvent * MyEvent = event->GetSciFiEvent();
    if (MyEvent == NULL) return willCut;

    std::vector<MAUS::SciFiHelicalPRTrack*> pr_tracks = MyEvent->helicalprtracks();
    std::vector<MAUS::SciFiHelicalPRTrack*>::iterator pr_track_it;

    if (pr_tracks.size() == 1) {
        for (pr_track_it = pr_tracks.begin(); pr_track_it != pr_tracks.end(); pr_track_it++) {
            MAUS::SciFiHelicalPRTrack* track = (*pr_track_it);
            tracker = track->get_tracker();
            if (tracker == 1) {
                station_hits = track->get_num_points();
                if (station_hits == 5) {
                    willCut = true;
                }
            }
        } // track loop
    }

    return willCut;
}

bool MapCppCuts::Get_station_hits_cut_US(MAUS::ReconEvent* event) const {

    bool willCut = false;
    int tracker, station_hits;

    MAUS::SciFiEvent * MyEvent = event->GetSciFiEvent();
    if (MyEvent == NULL) return willCut;

    std::vector<MAUS::SciFiHelicalPRTrack*> pr_tracks = MyEvent->helicalprtracks();
    std::vector<MAUS::SciFiHelicalPRTrack*>::iterator pr_track_it;

    if (pr_tracks.size() == 1) {
        for (pr_track_it = pr_tracks.begin(); pr_track_it != pr_tracks.end(); pr_track_it++) {
            MAUS::SciFiHelicalPRTrack* track = (*pr_track_it);
            tracker = track->get_tracker();
            if (tracker == 0) {
                station_hits = track->get_num_points();
                if (station_hits == 5) {
                    willCut = true;
                }
            }
        } // track loop
    }

    return willCut;
}

bool MapCppCuts::Get_US_mom_cut(double mom, double min_mom, double max_mom) const {

    bool willCut = false;
    if (mom >= min_mom && mom <= max_mom) {
        willCut = true;
    }
    return willCut;
}

bool MapCppCuts::Get_DS_mom_cut(double mom, double min_mom, double max_mom) const {

    bool willCut = false;
    if (mom >= min_mom && mom <= max_mom) {
        willCut = true;
    }
    return willCut;
}

bool MapCppCuts::Get_momentum_loss_cut(double dt, double min_mom_loss,
    double max_mom_loss, double mom) const {

    bool willCut = false;
    if (dt == 0) return willCut;

    double m = 105.6583715;
    double dt_e = 25.48;
    double beta_tof = dt_e/dt;

    if (beta_tof > 1.0) return willCut;
    if (beta_tof < 0) return willCut;

    double min_mom_cut = (mom + min_mom_loss) / m;
    double max_mom_cut = (mom + max_mom_loss) / m;
    double gamma_tof = 1.0/(sqrt(1.0 - beta_tof*beta_tof));
    double beta_gamma_tof = beta_tof*gamma_tof;

    if (beta_gamma_tof >= min_mom_cut && beta_gamma_tof <= max_mom_cut) {
         willCut = true;
     }
    return willCut;
}

bool MapCppCuts::Get_p_value_cut(MAUS::ReconEvent* event, double good_pval) const {

    bool willCut;
    double tku_pval;

    MAUS::SciFiEvent * MyEvent = event->GetSciFiEvent();
    if (MyEvent == NULL) return willCut;

    std::vector<MAUS::SciFiTrack*> tracks = MyEvent->scifitracks();
    std::vector<MAUS::SciFiTrack*>::iterator track_it;

    if (tracks.size() == 1) {
        for (track_it = tracks.begin(); track_it != tracks.end(); track_it++) {
            MAUS::SciFiTrack* track = (*track_it);
            tku_pval = track->P_value();
            if (tku_pval > good_pval) {
                willCut = true;
            }
        } // track loop
    }

    return willCut;
}

bool MapCppCuts::Get_mass_cut(double dt, double min_mass, double max_mass,
    double mom) const {

    bool willCut = false;
    if (dt == 0) return willCut;

    double mom_corr = 18.82;
    double dt_e = 25.48;
    double beta_tof = dt_e/dt;

    if (beta_tof > 1.0) return willCut;
    if (beta_tof < 0) return willCut;

    double gamma_tof = 1.0/(sqrt(1.0 - beta_tof*beta_tof));
    double beta_gamma_tof = beta_tof*gamma_tof;
    double mass = (mom + mom_corr)/beta_gamma_tof;

    if (mass >= min_mass && mass <= max_mass) {
        willCut = true;
    }
    return willCut;
}

bool MapCppCuts::Get_good_particle_cut(bool single_tof0, bool single_tof1, bool good_tof,
    bool single_track, bool station_hits, bool good_mom, bool good_momLoss,
    bool good_pval, bool good_massval) const {

    bool willCut = false;
    if (single_tof0 == 1 && single_tof1 == 1 && good_tof == 1 &&
    single_track == 1 && station_hits == 1 && good_mom == 1 &&
    good_momLoss == 1 && good_pval == 1 && good_massval == 1) {
        willCut = true;
    }
    return willCut;
}

void MapCppCuts::_death() {
	}
} // namespace MAUS