// @(#)root/tmva/pymva $Id$ // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015 /********************************************************************************** * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * * Package: TMVA * * Class : MethodPyAdaBoost * * Web : http://oproject.org * * * * Description: * * scikit-learn package AdaBoostClassifier method based on python * * * **********************************************************************************/ #ifndef ROOT_TMVA_MethodPyAdaBoost #define ROOT_TMVA_MethodPyAdaBoost ////////////////////////////////////////////////////////////////////////// // // // MethodPyAdaBoost // // // ////////////////////////////////////////////////////////////////////////// #include "TMVA/PyMethodBase.h" #include "TString.h" namespace TMVA { class Factory; class Reader; class DataSetManager; class Types; class MethodPyAdaBoost : public PyMethodBase { public : MethodPyAdaBoost(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption = ""); MethodPyAdaBoost(DataSetInfo &dsi, const TString &theWeightFile); ~MethodPyAdaBoost(); void Train(); void Init(); void DeclareOptions(); void ProcessOptions(); // create ranking const Ranking *CreateRanking(); Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets); // performs classifier testing virtual void TestClassification(); Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0); std::vector GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false); std::vector& GetMulticlassValues(); virtual void ReadModelFromFile(); using MethodBase::ReadWeightsFromStream; // the actual "weights" virtual void AddWeightsXMLTo(void * /*parent */ ) const {} // = 0; virtual void ReadWeightsFromXML(void * /*wghtnode*/ ) {} // = 0; virtual void ReadWeightsFromStream(std::istream &) {} //= 0; backward compatibility private : DataSetManager *fDataSetManager; friend class Factory; friend class Reader; protected: std::vector mvaValues; std::vector classValues; UInt_t fNvars; // number of variables UInt_t fNoutputs; // number of outputs TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder) //AdaBoost options PyObject* pBaseEstimator; TString fBaseEstimator; //object, optional (default=DecisionTreeClassifier) //The base estimator from which the boosted ensemble is built. //Support for sample weighting is required, as well as proper `classes_` //and `n_classes_` attributes. PyObject* pNestimators; Int_t fNestimators; //integer, optional (default=10) //The number of trees in the forest. PyObject* pLearningRate; Double_t fLearningRate; //loat, optional (default=1.) //Learning rate shrinks the contribution of each classifier by //``learning_rate``. There is a trade-off between ``learning_rate`` and ``n_estimators``. PyObject* pAlgorithm; TString fAlgorithm; //{'SAMME', 'SAMME.R'}, optional (default='SAMME.R') //If 'SAMME.R' then use the SAMME.R real boosting algorithm. //``base_estimator`` must support calculation of class probabilities. //If 'SAMME' then use the SAMME discrete boosting algorithm. //The SAMME.R algorithm typically converges faster than SAMME, //achieving a lower test error with fewer boosting iterations. PyObject* pRandomState; TString fRandomState; //int, RandomState instance or None, optional (default=None) //If int, random_state is the seed used by the random number generator; //If RandomState instance, random_state is the random number generator; //If None, the random number generator is the RandomState instance used by `np.random`. // get help message text void GetHelpMessage() const; ClassDef(MethodPyAdaBoost, 0) }; } // namespace TMVA #endif // ROOT_TMVA_MethodPyAdaBoost