Logo ROOT   6.10/09
Reference Guide
TMVAMulticlass.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// This macro provides a simple example for the training and testing of the TMVA
5 /// multiclass classification
6 /// - Project : TMVA - a Root-integrated toolkit for multivariate data analysis
7 /// - Package : TMVA
8 /// - Root Macro: TMVAMulticlass
9 ///
10 /// \macro_output
11 /// \macro_code
12 /// \author Andreas Hoecker
13 
14 #include <cstdlib>
15 #include <iostream>
16 #include <map>
17 #include <string>
18 
19 #include "TFile.h"
20 #include "TTree.h"
21 #include "TString.h"
22 #include "TSystem.h"
23 #include "TROOT.h"
24 
25 
26 #include "TMVA/Tools.h"
27 #include "TMVA/Factory.h"
28 #include "TMVA/DataLoader.h"
29 #include "TMVA/TMVAMultiClassGui.h"
30 
31 
32 using namespace TMVA;
33 
34 void TMVAMulticlass( TString myMethodList = "" )
35 {
36 
37  // This loads the library
39 
40  // to get access to the GUI and all tmva macros
41  //
42  // TString tmva_dir(TString(gRootDir) + "/tmva");
43  // if(gSystem->Getenv("TMVASYS"))
44  // tmva_dir = TString(gSystem->Getenv("TMVASYS"));
45  // gROOT->SetMacroPath(tmva_dir + "/test/:" + gROOT->GetMacroPath() );
46  // gROOT->ProcessLine(".L TMVAMultiClassGui.C");
47 
48 
49  //---------------------------------------------------------------
50  // Default MVA methods to be trained + tested
51  std::map<std::string,int> Use;
52  Use["MLP"] = 1;
53  Use["BDTG"] = 1;
54  Use["DNN_CPU"] = 0;
55  Use["FDA_GA"] = 0;
56  Use["PDEFoam"] = 0;
57  //---------------------------------------------------------------
58 
59  std::cout << std::endl;
60  std::cout << "==> Start TMVAMulticlass" << std::endl;
61 
62  if (myMethodList != "") {
63  for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
64 
65  std::vector<TString> mlist = TMVA::gTools().SplitString( myMethodList, ',' );
66  for (UInt_t i=0; i<mlist.size(); i++) {
67  std::string regMethod(mlist[i]);
68 
69  if (Use.find(regMethod) == Use.end()) {
70  std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
71  for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " ";
72  std::cout << std::endl;
73  return;
74  }
75  Use[regMethod] = 1;
76  }
77  }
78 
79  // Create a new root output file.
80  TString outfileName = "TMVAMulticlass.root";
81  TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
82 
83  TMVA::Factory *factory = new TMVA::Factory( "TMVAMulticlass", outputFile,
84  "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=multiclass" );
86 
87  dataloader->AddVariable( "var1", 'F' );
88  dataloader->AddVariable( "var2", "Variable 2", "", 'F' );
89  dataloader->AddVariable( "var3", "Variable 3", "units", 'F' );
90  dataloader->AddVariable( "var4", "Variable 4", "units", 'F' );
91 
92  TFile *input(0);
93  TString fname = "./tmva_example_multiple_background.root";
94  if (!gSystem->AccessPathName( fname )) {
95  // first we try to find the file in the local directory
96  std::cout << "--- TMVAMulticlass : Accessing " << fname << std::endl;
97  input = TFile::Open( fname );
98  }
99  else {
100  std::cout << "Creating testdata...." << std::endl;
101  TString createDataMacro = gROOT->GetTutorialDir() + "/tmva/createData.C";
102  gROOT->ProcessLine(TString::Format(".L %s",createDataMacro.Data()));
103  gROOT->ProcessLine("create_MultipleBackground(2000)");
104  std::cout << " created tmva_example_multiple_background.root for tests of the multiclass features"<<std::endl;
105  input = TFile::Open( fname );
106  }
107  if (!input) {
108  std::cout << "ERROR: could not open data file" << std::endl;
109  exit(1);
110  }
111 
112  TTree *signalTree = (TTree*)input->Get("TreeS");
113  TTree *background0 = (TTree*)input->Get("TreeB0");
114  TTree *background1 = (TTree*)input->Get("TreeB1");
115  TTree *background2 = (TTree*)input->Get("TreeB2");
116 
117  gROOT->cd( outfileName+TString(":/") );
118  dataloader->AddTree (signalTree,"Signal");
119  dataloader->AddTree (background0,"bg0");
120  dataloader->AddTree (background1,"bg1");
121  dataloader->AddTree (background2,"bg2");
122 
123  dataloader->PrepareTrainingAndTestTree( "", "SplitMode=Random:NormMode=NumEvents:!V" );
124 
125  if (Use["BDTG"]) // gradient boosted decision trees
126  factory->BookMethod( dataloader, TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:BaggedSampleFraction=0.50:nCuts=20:MaxDepth=2");
127  if (Use["MLP"]) // neural network
128  factory->BookMethod( dataloader, TMVA::Types::kMLP, "MLP", "!H:!V:NeuronType=tanh:NCycles=1000:HiddenLayers=N+5,5:TestRate=5:EstimatorType=MSE");
129  if (Use["FDA_GA"]) // functional discriminant with GA minimizer
130  factory->BookMethod( dataloader, TMVA::Types::kFDA, "FDA_GA", "H:!V:Formula=(0)+(1)*x0+(2)*x1+(3)*x2+(4)*x3:ParRanges=(-1,1);(-10,10);(-10,10);(-10,10);(-10,10):FitMethod=GA:PopSize=300:Cycles=3:Steps=20:Trim=True:SaveBestGen=1" );
131  if (Use["PDEFoam"]) // PDE-Foam approach
132  factory->BookMethod( dataloader, TMVA::Types::kPDEFoam, "PDEFoam", "!H:!V:TailCut=0.001:VolFrac=0.0666:nActiveCells=500:nSampl=2000:nBin=5:Nmin=100:Kernel=None:Compress=T" );
133 
134  if (Use["DNN_CPU"]) {
135  TString layoutString("Layout=TANH|100,TANH|50,TANH|10,LINEAR");
136  TString training0("LearningRate=1e-1, Momentum=0.5, Repetitions=1, ConvergenceSteps=10,"
137  " BatchSize=256, TestRepetitions=10, Multithreading=True");
138  TString training1("LearningRate=1e-2, Momentum=0.0, Repetitions=1, ConvergenceSteps=10,"
139  " BatchSize=256, TestRepetitions=7, Multithreading=True");
140  TString trainingStrategyString("TrainingStrategy=");
141  trainingStrategyString += training0 + "|" + training1;
142  TString nnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=N:"
143  "WeightInitialization=XAVIERUNIFORM:Architecture=CPU");
144  nnOptions.Append(":");
145  nnOptions.Append(layoutString);
146  nnOptions.Append(":");
147  nnOptions.Append(trainingStrategyString);
148  factory->BookMethod(dataloader, TMVA::Types::kDNN, "DNN_CPU", nnOptions);
149  }
150 
151  // Train MVAs using the set of training events
152  factory->TrainAllMethods();
153 
154  // Evaluate all MVAs using the set of test events
155  factory->TestAllMethods();
156 
157  // Evaluate and compare performance of all configured MVAs
158  factory->EvaluateAllMethods();
159 
160  // --------------------------------------------------------------
161 
162  // Save the output
163  outputFile->Close();
164 
165  std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
166  std::cout << "==> TMVAMulticlass is done!" << std::endl;
167 
168  delete factory;
169  delete dataloader;
170 
171  // Launch the GUI for the root macros
172  if (!gROOT->IsBatch()) TMVAMultiClassGui( outfileName );
173 
174 
175 }
176 
177 int main( int argc, char** argv )
178 {
179  // Select methods (don't look at this code - not of interest)
180  TString methodList;
181  for (int i=1; i<argc; i++) {
182  TString regMethod(argv[i]);
183  if(regMethod=="-b" || regMethod=="--batch") continue;
184  if (!methodList.IsNull()) methodList += TString(",");
185  methodList += regMethod;
186  }
187  TMVAMulticlass(methodList);
188  return 0;
189 }
190 
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1272
static Tools & Instance()
Definition: Tools.cxx:75
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:343
#define gROOT
Definition: TROOT.h:375
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1017
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3909
void TMVAMultiClassGui(const char *fName="TMVAMulticlass.root", TString dataset="")
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString...
Definition: TString.cxx:2345
R__EXTERN TSystem * gSystem
Definition: TSystem.h:539
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods. ...
Definition: Factory.cxx:1255
void TestAllMethods()
Definition: Factory.cxx:1153
unsigned int UInt_t
Definition: RtypesCore.h:42
This is the main MVA steering class.
Definition: Factory.h:81
Tools & gTools()
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
Definition: DataLoader.cxx:357
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:629
Abstract ClassifierFactory template that handles arbitrary types.
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at &#39;separator&#39; and fills the list &#39;splitV&#39; with the primitive strings ...
Definition: Tools.cxx:1210
int main(int argc, char **argv)