// @(#)root/tmva $Id$ // Author: Omar Zapata, Thomas James Stevenson, Pourya Vakilipourtakalou, Kim Albertsson /************************************************************************* * Copyright (C) 2018, Rene Brun and Fons Rademakers. * * All rights reserved. * * * * For the licensing terms see $ROOTSYS/LICENSE. * * For the list of contributors see $ROOTSYS/README/CREDITS. * *************************************************************************/ #ifndef ROOT_TMVA_CROSS_EVALUATION #define ROOT_TMVA_CROSS_EVALUATION #include "TGraph.h" #include "TMultiGraph.h" #include "TString.h" #include #include #include "TMVA/IMethod.h" #include "TMVA/Configurable.h" #include "TMVA/Types.h" #include "TMVA/DataSet.h" #include "TMVA/Event.h" #include #include #include #include #include /*! \class TMVA::CrossValidationResult * Class to save the results of cross validation, * the metric for the classification ins ROC and you can ROC curves * ROC integrals, ROC average and ROC standard deviation. \ingroup TMVA */ /*! \class TMVA::CrossValidation * Class to perform cross validation, splitting the dataloader into folds. \ingroup TMVA */ namespace TMVA { class CvSplitKFolds; using EventCollection_t = std::vector; using EventTypes_t = std::vector; using EventOutputs_t = std::vector; using EventOutputsMulticlass_t = std::vector>; class CrossValidationFoldResult { public: CrossValidationFoldResult() {} // For multi-proc serialisation CrossValidationFoldResult(UInt_t iFold) : fFold(iFold) {} UInt_t fFold; Float_t fROCIntegral; TGraph fROC; Double_t fSig; Double_t fSep; Double_t fEff01; Double_t fEff10; Double_t fEff30; Double_t fEffArea; Double_t fTrainEff01; Double_t fTrainEff10; Double_t fTrainEff30; }; // Used internally to keep per-fold aggregate statistics // such as ROC curves, ROC integrals and efficiencies. class CrossValidationResult { friend class CrossValidation; private: std::map fROCs; std::shared_ptr fROCCurves; std::vector fSigs; std::vector fSeps; std::vector fEff01s; std::vector fEff10s; std::vector fEff30s; std::vector fEffAreas; std::vector fTrainEff01s; std::vector fTrainEff10s; std::vector fTrainEff30s; public: CrossValidationResult(UInt_t numFolds); CrossValidationResult(const CrossValidationResult &); ~CrossValidationResult() { fROCCurves = nullptr; } std::map GetROCValues() const { return fROCs; } Float_t GetROCAverage() const; Float_t GetROCStandardDeviation() const; TMultiGraph *GetROCCurves(Bool_t fLegend = kTRUE); TGraph *GetAvgROCCurve(UInt_t numSamples = 100) const; void Print() const; TCanvas *Draw(const TString name = "CrossValidation") const; TCanvas *DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const; std::vector GetSigValues() const { return fSigs; } std::vector GetSepValues() const { return fSeps; } std::vector GetEff01Values() const { return fEff01s; } std::vector GetEff10Values() const { return fEff10s; } std::vector GetEff30Values() const { return fEff30s; } std::vector GetEffAreaValues() const { return fEffAreas; } std::vector GetTrainEff01Values() const { return fTrainEff01s; } std::vector GetTrainEff10Values() const { return fTrainEff10s; } std::vector GetTrainEff30Values() const { return fTrainEff30s; } private: void Fill(CrossValidationFoldResult const & fr); }; class CrossValidation : public Envelope { public: explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options); explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile, TString options); ~CrossValidation(); void InitOptions(); void ParseOptions(); void SetNumFolds(UInt_t i); void SetSplitExpr(TString splitExpr); UInt_t GetNumFolds() { return fNumFolds; } TString GetSplitExpr() { return fSplitExprString; } Factory &GetFactory() { return *fFactory; } const std::vector &GetResults() const; void Evaluate(); private: CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap & methodInfo); Types::EAnalysisType fAnalysisType; TString fAnalysisTypeStr; TString fSplitTypeStr; Bool_t fCorrelations; TString fCvFactoryOptions; Bool_t fDrawProgressBar; Bool_t fFoldFileOutput; /// fResults; /// fFoldFactory; std::unique_ptr fFactory; std::unique_ptr fSplit; ClassDef(CrossValidation, 0); }; } // namespace TMVA #endif // ROOT_TMVA_CROSS_EVALUATION