Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVAMulticlass.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_ml
3/// \notebook -nodraw
4/// This macro provides a simple example for the training and testing of the TMVA
5/// multiclass classification
6/// - Project : TMVA - a Root-integrated toolkit for multivariate data analysis
7/// - Package : TMVA
8/// - Root Macro: TMVAMulticlass
9///
10/// \macro_output
11/// \macro_code
12/// \author Andreas Hoecker
13
14#include <cstdlib>
15#include <iostream>
16#include <map>
17#include <string>
18
19#include "TFile.h"
20#include "TTree.h"
21#include "TString.h"
22#include "TSystem.h"
23#include "TROOT.h"
24
25
26#include "TMVA/Tools.h"
27#include "TMVA/Factory.h"
28#include "TMVA/DataLoader.h"
30
31
32using namespace TMVA;
33
34void TMVAMulticlass( TString myMethodList = "" )
35{
36
37 // This loads the library
39
40 // to get access to the GUI and all tmva macros
41 //
42 // TString tmva_dir(TString(gRootDir) + "/tmva");
43 // if(gSystem->Getenv("TMVASYS"))
44 // tmva_dir = TString(gSystem->Getenv("TMVASYS"));
45 // gROOT->SetMacroPath(tmva_dir + "/test/:" + gROOT->GetMacroPath() );
46 // gROOT->ProcessLine(".L TMVAMultiClassGui.C");
47
48
49 //---------------------------------------------------------------
50 // Default MVA methods to be trained + tested
51 std::map<std::string,int> Use;
52 Use["MLP"] = 1;
53 Use["BDTG"] = 1;
54#ifdef R__HAS_TMVAGPU
55 Use["DL_CPU"] = 1;
56 Use["DL_GPU"] = 1;
57#else
58 Use["DL_CPU"] = 1;
59 Use["DL_GPU"] = 0;
60#endif
61 Use["FDA_GA"] = 0;
62 Use["PDEFoam"] = 1;
63
64 //---------------------------------------------------------------
65
66 std::cout << std::endl;
67 std::cout << "==> Start TMVAMulticlass" << std::endl;
68
69 if (myMethodList != "") {
70 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
71
72 std::vector<TString> mlist = TMVA::gTools().SplitString( myMethodList, ',' );
73 for (UInt_t i=0; i<mlist.size(); i++) {
74 std::string regMethod(mlist[i]);
75
76 if (Use.find(regMethod) == Use.end()) {
77 std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
78 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " ";
79 std::cout << std::endl;
80 return;
81 }
82 Use[regMethod] = 1;
83 }
84 }
85
86 // Create a new root output file.
87 TString outfileName = "TMVAMulticlass.root";
88 TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
89
90 TMVA::Factory *factory = new TMVA::Factory( "TMVAMulticlass", outputFile,
91 "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=multiclass" );
92 TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset");
93
94 dataloader->AddVariable( "var1", 'F' );
95 dataloader->AddVariable( "var2", "Variable 2", "", 'F' );
96 dataloader->AddVariable( "var3", "Variable 3", "units", 'F' );
97 dataloader->AddVariable( "var4", "Variable 4", "units", 'F' );
98
99 TFile *input(0);
100 TString fname = "./tmva_example_multiclass.root";
101 if (!gSystem->AccessPathName( fname )) {
102 input = TFile::Open( fname ); // check if file in local directory exists
103 }
104 else {
106 input = TFile::Open("http://root.cern/files/tmva_multiclass_example.root", "CACHEREAD");
107 }
108 if (!input) {
109 std::cout << "ERROR: could not open data file" << std::endl;
110 exit(1);
111 }
112 std::cout << "--- TMVAMulticlass: Using input file: " << input->GetName() << std::endl;
113
114 TTree *signalTree = (TTree*)input->Get("TreeS");
115 TTree *background0 = (TTree*)input->Get("TreeB0");
116 TTree *background1 = (TTree*)input->Get("TreeB1");
117 TTree *background2 = (TTree*)input->Get("TreeB2");
118
119 gROOT->cd( outfileName+TString(":/") );
120 dataloader->AddTree (signalTree,"Signal");
121 dataloader->AddTree (background0,"bg0");
122 dataloader->AddTree (background1,"bg1");
123 dataloader->AddTree (background2,"bg2");
124
125 dataloader->PrepareTrainingAndTestTree( "", "SplitMode=Random:NormMode=NumEvents:!V" );
126
127 if (Use["BDTG"]) // gradient boosted decision trees
128 factory->BookMethod( dataloader, TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:BaggedSampleFraction=0.50:nCuts=20:MaxDepth=2");
129 if (Use["MLP"]) // neural network
130 factory->BookMethod( dataloader, TMVA::Types::kMLP, "MLP", "!H:!V:NeuronType=tanh:NCycles=1000:HiddenLayers=N+5,5:TestRate=5:EstimatorType=MSE");
131 if (Use["FDA_GA"]) // functional discriminant with GA minimizer
132 factory->BookMethod( dataloader, TMVA::Types::kFDA, "FDA_GA", "H:!V:Formula=(0)+(1)*x0+(2)*x1+(3)*x2+(4)*x3:ParRanges=(-1,1);(-10,10);(-10,10);(-10,10);(-10,10):FitMethod=GA:PopSize=300:Cycles=3:Steps=20:Trim=True:SaveBestGen=1" );
133 if (Use["PDEFoam"]) // PDE-Foam approach
134 factory->BookMethod( dataloader, TMVA::Types::kPDEFoam, "PDEFoam", "!H:!V:TailCut=0.001:VolFrac=0.0666:nActiveCells=500:nSampl=2000:nBin=5:Nmin=100:Kernel=None:Compress=T" );
135
136
137 if (Use["DL_CPU"]) {
138 TString layoutString("Layout=TANH|100,TANH|50,TANH|10,LINEAR");
139 TString trainingStrategyString("TrainingStrategy=Optimizer=ADAM,LearningRate=1e-3,"
140 "TestRepetitions=1,ConvergenceSteps=10,BatchSize=100,MaxEpochs=20");
141 TString nnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=N:"
142 "WeightInitialization=XAVIERUNIFORM:Architecture=GPU");
143 nnOptions.Append(":");
144 nnOptions.Append(layoutString);
145 nnOptions.Append(":");
146 nnOptions.Append(trainingStrategyString);
147 factory->BookMethod(dataloader, TMVA::Types::kDL, "DL_CPU", nnOptions);
148 }
149 if (Use["DL_GPU"]) {
150 TString layoutString("Layout=TANH|100,TANH|50,TANH|10,LINEAR");
151 TString trainingStrategyString("TrainingStrategy=Optimizer=ADAM,LearningRate=1e-3,"
152 "TestRepetitions=1,ConvergenceSteps=10,BatchSize=100,MaxEpochs=20");
153 TString nnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=N:"
154 "WeightInitialization=XAVIERUNIFORM:Architecture=GPU");
155 nnOptions.Append(":");
156 nnOptions.Append(layoutString);
157 nnOptions.Append(":");
158 nnOptions.Append(trainingStrategyString);
159 factory->BookMethod(dataloader, TMVA::Types::kDL, "DL_GPU", nnOptions);
160 }
161
162
163 // Train MVAs using the set of training events
164 factory->TrainAllMethods();
165
166 // Evaluate all MVAs using the set of test events
167 factory->TestAllMethods();
168
169 // Evaluate and compare performance of all configured MVAs
170 factory->EvaluateAllMethods();
171
172 // --------------------------------------------------------------
173
174 // Save the output
175 outputFile->Close();
176
177 std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
178 std::cout << "==> TMVAMulticlass is done!" << std::endl;
179
180 delete factory;
181 delete dataloader;
182
183 // Launch the GUI for the root macros
184 if (!gROOT->IsBatch()) TMVAMultiClassGui( outfileName );
185
186
187}
188
189int main( int argc, char** argv )
190{
191 // Select methods (don't look at this code - not of interest)
192 TString methodList;
193 for (int i=1; i<argc; i++) {
194 TString regMethod(argv[i]);
195 if(regMethod=="-b" || regMethod=="--batch") continue;
196 if (!methodList.IsNull()) methodList += TString(",");
197 methodList += regMethod;
198 }
199 TMVAMulticlass(methodList);
200 return 0;
201}
202
int main()
Definition Prototype.cxx:12
unsigned int UInt_t
Definition RtypesCore.h:46
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
#define gROOT
Definition TROOT.h:406
R__EXTERN TSystem * gSystem
Definition TSystem.h:561
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4086
static Bool_t SetCacheFileDir(std::string_view cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Sets the directory where to locally stage/cache remote files.
Definition TFile.cxx:4623
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:947
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
This is the main MVA steering class.
Definition Factory.h:80
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition Factory.cxx:1114
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition Factory.cxx:352
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition Factory.cxx:1271
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition Factory.cxx:1376
static Tools & Instance()
Definition Tools.cxx:71
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1199
@ kPDEFoam
Definition Types.h:94
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
Basic string class.
Definition TString.h:139
Bool_t IsNull() const
Definition TString.h:414
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1296
A TTree represents a columnar dataset.
Definition TTree.h:79
create variable transformations
Tools & gTools()
void TMVAMultiClassGui(const char *fName="TMVAMulticlass.root", TString dataset="")