// @(#)root/tmva/pymva $Id$ // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015 /********************************************************************************** * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * * Package: TMVA * * Class : MethodPyGTB * * Web : http://oproject.org * * * * Description: * * scikit-learn Package GradientBoostingClassifier method based on python * * * **********************************************************************************/ #ifndef ROOT_TMVA_MethodPyGTB #define ROOT_TMVA_MethodPyGTB ////////////////////////////////////////////////////////////////////////// // // // MethodPyGTB // // // ////////////////////////////////////////////////////////////////////////// #include "TMVA/PyMethodBase.h" namespace TMVA { class Factory; class Reader; class DataSetManager; class Types; class MethodPyGTB : public PyMethodBase { public : MethodPyGTB(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption = ""); MethodPyGTB(DataSetInfo &dsi, const TString &theWeightFile); ~MethodPyGTB(void); void Train(); void Init(); void DeclareOptions(); void ProcessOptions(); const Ranking *CreateRanking(); Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets); virtual void TestClassification(); Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0); std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false); std::vector<Float_t>& GetMulticlassValues(); virtual void ReadModelFromFile(); using MethodBase::ReadWeightsFromStream; // the actual "weights" virtual void AddWeightsXMLTo(void * /* parent */ ) const {} // = 0; virtual void ReadWeightsFromXML(void * /*wghtnode*/) {} // = 0; virtual void ReadWeightsFromStream(std::istream &) {} //= 0; backward compatibility private : DataSetManager *fDataSetManager; friend class Factory; friend class Reader; protected: std::vector<Double_t> mvaValues; std::vector<Float_t> classValues; UInt_t fNvars; // number of variables UInt_t fNoutputs; // number of outputs TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder) //GTB options PyObject* pLoss; TString fLoss; // {'deviance', 'exponential'}, optional (default='deviance') //loss function to be optimized. 'deviance' refers to //deviance (= logistic regression) for classification //with probabilistic outputs. For loss 'exponential' gradient //boosting recovers the AdaBoost algorithm. PyObject* pLearningRate; Double_t fLearningRate; //float, optional (default=0.1) //learning rate shrinks the contribution of each tree by `learning_rate`. //There is a trade-off between learning_rate and n_estimators. PyObject* pNestimators; Int_t fNestimators; //integer, optional (default=10) //The number of trees in the forest. PyObject* pSubsample; Double_t fSubsample; //float, optional (default=1.0) //The fraction of samples to be used for fitting the individual base //learners. If smaller than 1.0 this results in Stochastic Gradient //Boosting. `subsample` interacts with the parameter `n_estimators`. //Choosing `subsample < 1.0` leads to a reduction of variance //and an increase in bias. PyObject* pMinSamplesSplit; Int_t fMinSamplesSplit; // integer, optional (default=2) //The minimum number of samples required to split an internal node. PyObject* pMinSamplesLeaf; Int_t fMinSamplesLeaf; //integer, optional (default=1) //The minimum number of samples required to be at a leaf node. PyObject* pMinWeightFractionLeaf; Double_t fMinWeightFractionLeaf; //float, optional (default=0.) //The minimum weighted fraction of the input samples required to be at a leaf node. PyObject* pMaxDepth; Int_t fMaxDepth; //integer, optional (default=3) //maximum depth of the individual regression estimators. The maximum //depth limits the number of nodes in the tree. Tune this parameter //for best performance; the best value depends on the interaction //of the input variables. //Ignored if ``max_leaf_nodes`` is not None. PyObject* pInit; TString fInit; //BaseEstimator, None, optional (default=None) //An estimator object that is used to compute the initial //predictions. ``init`` has to provide ``fit`` and ``predict``. //If None it uses ``loss.init_estimator``. PyObject* pRandomState; TString fRandomState; //int, RandomState instance or None, optional (default=None) //If int, random_state is the seed used by the random number generator; //If RandomState instance, random_state is the random number generator; //If None, the random number generator is the RandomState instance used //by `np.random`. PyObject* pMaxFeatures; TString fMaxFeatures; //int, float, string or None, optional (default="auto") //The number of features to consider when looking for the best split: //- If int, then consider `max_features` features at each split. //- If float, then `max_features` is a percentage and //`int(max_features * n_features)` features are considered at each split. //- If "auto", then `max_features=sqrt(n_features)`. //- If "sqrt", then `max_features=sqrt(n_features)`. //- If "log2", then `max_features=log2(n_features)`. //- If None, then `max_features=n_features`. // Note: the search for a split does not stop until at least one // valid partition of the node samples is found, even if it requires to // effectively inspect more than ``max_features`` features. // Note: this parameter is tree-specific. PyObject* pVerbose; Int_t fVerbose; //Controls the verbosity of the tree building process. PyObject* pMaxLeafNodes; TString fMaxLeafNodes; //int or None, optional (default=None) //Grow trees with ``max_leaf_nodes`` in best-first fashion. //Best nodes are defined as relative reduction in impurity. //If None then unlimited number of leaf nodes. //If not None then ``max_depth`` will be ignored. PyObject* pWarmStart; Bool_t fWarmStart; //bool, optional (default=False) //When set to ``True``, reuse the solution of the previous call to fit //and add more estimators to the ensemble, otherwise, just fit a whole //new forest. // get help message text void GetHelpMessage() const; ClassDef(MethodPyGTB, 0) }; } // namespace TMVA #endif // ROOT_TMVA_PyMethodGTB