/// \file /// \ingroup tutorial_tmva /// \notebook -nodraw /// This macro provides a simple example for the training and testing of the TMVA /// multiclass classification /// - Project : TMVA - a Root-integrated toolkit for multivariate data analysis /// - Package : TMVA /// - Root Macro: TMVAMulticlass /// /// \macro_output /// \macro_code /// \author Andreas Hoecker #include #include #include #include #include "TFile.h" #include "TTree.h" #include "TString.h" #include "TSystem.h" #include "TROOT.h" #include "TMVA/Tools.h" #include "TMVA/Factory.h" #include "TMVA/DataLoader.h" #include "TMVA/TMVAMultiClassGui.h" using namespace TMVA; void TMVAMulticlass( TString myMethodList = "" ) { // This loads the library TMVA::Tools::Instance(); // to get access to the GUI and all tmva macros // // TString tmva_dir(TString(gRootDir) + "/tmva"); // if(gSystem->Getenv("TMVASYS")) // tmva_dir = TString(gSystem->Getenv("TMVASYS")); // gROOT->SetMacroPath(tmva_dir + "/test/:" + gROOT->GetMacroPath() ); // gROOT->ProcessLine(".L TMVAMultiClassGui.C"); //--------------------------------------------------------------- // Default MVA methods to be trained + tested std::map Use; Use["MLP"] = 1; Use["BDTG"] = 1; #ifdef R__HAS_TMVAGPU Use["DL_CPU"] = 1; Use["DL_GPU"] = 1; #else Use["DL_CPU"] = 1; Use["DL_GPU"] = 0; #endif Use["FDA_GA"] = 0; Use["PDEFoam"] = 1; //--------------------------------------------------------------- std::cout << std::endl; std::cout << "==> Start TMVAMulticlass" << std::endl; if (myMethodList != "") { for (std::map::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0; std::vector mlist = TMVA::gTools().SplitString( myMethodList, ',' ); for (UInt_t i=0; i::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " "; std::cout << std::endl; return; } Use[regMethod] = 1; } } // Create a new root output file. TString outfileName = "TMVAMulticlass.root"; TFile* outputFile = TFile::Open( outfileName, "RECREATE" ); TMVA::Factory *factory = new TMVA::Factory( "TMVAMulticlass", outputFile, "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=multiclass" ); TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset"); dataloader->AddVariable( "var1", 'F' ); dataloader->AddVariable( "var2", "Variable 2", "", 'F' ); dataloader->AddVariable( "var3", "Variable 3", "units", 'F' ); dataloader->AddVariable( "var4", "Variable 4", "units", 'F' ); TFile *input(0); TString fname = "./tmva_example_multiclass.root"; if (!gSystem->AccessPathName( fname )) { input = TFile::Open( fname ); // check if file in local directory exists } else { TFile::SetCacheFileDir("."); input = TFile::Open("http://root.cern.ch/files/tmva_multiclass_example.root", "CACHEREAD"); } if (!input) { std::cout << "ERROR: could not open data file" << std::endl; exit(1); } std::cout << "--- TMVAMulticlass: Using input file: " << input->GetName() << std::endl; TTree *signalTree = (TTree*)input->Get("TreeS"); TTree *background0 = (TTree*)input->Get("TreeB0"); TTree *background1 = (TTree*)input->Get("TreeB1"); TTree *background2 = (TTree*)input->Get("TreeB2"); gROOT->cd( outfileName+TString(":/") ); dataloader->AddTree (signalTree,"Signal"); dataloader->AddTree (background0,"bg0"); dataloader->AddTree (background1,"bg1"); dataloader->AddTree (background2,"bg2"); dataloader->PrepareTrainingAndTestTree( "", "SplitMode=Random:NormMode=NumEvents:!V" ); if (Use["BDTG"]) // gradient boosted decision trees factory->BookMethod( dataloader, TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:BaggedSampleFraction=0.50:nCuts=20:MaxDepth=2"); if (Use["MLP"]) // neural network factory->BookMethod( dataloader, TMVA::Types::kMLP, "MLP", "!H:!V:NeuronType=tanh:NCycles=1000:HiddenLayers=N+5,5:TestRate=5:EstimatorType=MSE"); if (Use["FDA_GA"]) // functional discriminant with GA minimizer factory->BookMethod( dataloader, TMVA::Types::kFDA, "FDA_GA", "H:!V:Formula=(0)+(1)*x0+(2)*x1+(3)*x2+(4)*x3:ParRanges=(-1,1);(-10,10);(-10,10);(-10,10);(-10,10):FitMethod=GA:PopSize=300:Cycles=3:Steps=20:Trim=True:SaveBestGen=1" ); if (Use["PDEFoam"]) // PDE-Foam approach factory->BookMethod( dataloader, TMVA::Types::kPDEFoam, "PDEFoam", "!H:!V:TailCut=0.001:VolFrac=0.0666:nActiveCells=500:nSampl=2000:nBin=5:Nmin=100:Kernel=None:Compress=T" ); if (Use["DL_CPU"]) { TString layoutString("Layout=TANH|100,TANH|50,TANH|10,LINEAR"); TString trainingStrategyString("TrainingStrategy=Optimizer=ADAM,LearningRate=1e-3," "TestRepetitions=1,ConvergenceSteps=10,BatchSize=100,MaxEpochs=20"); TString nnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=N:" "WeightInitialization=XAVIERUNIFORM:Architecture=GPU"); nnOptions.Append(":"); nnOptions.Append(layoutString); nnOptions.Append(":"); nnOptions.Append(trainingStrategyString); factory->BookMethod(dataloader, TMVA::Types::kDL, "DL_CPU", nnOptions); } if (Use["DL_GPU"]) { TString layoutString("Layout=TANH|100,TANH|50,TANH|10,LINEAR"); TString trainingStrategyString("TrainingStrategy=Optimizer=ADAM,LearningRate=1e-3," "TestRepetitions=1,ConvergenceSteps=10,BatchSize=100,MaxEpochs=20"); TString nnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=N:" "WeightInitialization=XAVIERUNIFORM:Architecture=GPU"); nnOptions.Append(":"); nnOptions.Append(layoutString); nnOptions.Append(":"); nnOptions.Append(trainingStrategyString); factory->BookMethod(dataloader, TMVA::Types::kDL, "DL_GPU", nnOptions); } // Train MVAs using the set of training events factory->TrainAllMethods(); // Evaluate all MVAs using the set of test events factory->TestAllMethods(); // Evaluate and compare performance of all configured MVAs factory->EvaluateAllMethods(); // -------------------------------------------------------------- // Save the output outputFile->Close(); std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl; std::cout << "==> TMVAMulticlass is done!" << std::endl; delete factory; delete dataloader; // Launch the GUI for the root macros if (!gROOT->IsBatch()) TMVAMultiClassGui( outfileName ); } int main( int argc, char** argv ) { // Select methods (don't look at this code - not of interest) TString methodList; for (int i=1; i