Logo ROOT   6.10/09
Reference Guide
Classification.C
Go to the documentation of this file.
1 #include <cstdlib>
2 #include <iostream>
3 #include <map>
4 #include <string>
5 
6 #include "TChain.h"
7 #include "TFile.h"
8 #include "TTree.h"
9 #include "TString.h"
10 #include "TObjString.h"
11 #include "TSystem.h"
12 #include "TROOT.h"
13 
14 
15 #include "TMVA/Factory.h"
16 #include "TMVA/Tools.h"
18 
20 {
23 
24  TString outfileName("TMVA.root");
25  TFile *outputFile = TFile::Open(outfileName, "RECREATE");
26 
27  TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
28  "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification");
29 
30  factory->AddVariable("myvar1 := var1+var2", 'F');
31  factory->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
32  factory->AddVariable("var3", "Variable 3", "units", 'F');
33  factory->AddVariable("var4", "Variable 4", "units", 'F');
34  factory->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
35  factory->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
36 
37  TFile *input(0);
38  TString fname = "./tmva_class_example.root";
39  if (!gSystem->AccessPathName( fname )) {
40  input = TFile::Open( fname ); // check if file in local directory exists
41  }
42  else {
44  input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
45  }
46  if (!input) {
47  std::cout << "ERROR: could not open data file" << std::endl;
48  exit(1);
49  }
50 
51  std::cout << "--- TMVAClassification : Using input file: " << input->GetName() << std::endl;
52 
53  // --- Register the training and test trees
54 
55  TTree *tsignal = (TTree *)input->Get("TreeS");
56  TTree *tbackground = (TTree *)input->Get("TreeB");
57 
58  // global event weights per tree (see below for setting event-wise weights)
59  Double_t signalWeight = 1.0;
60  Double_t backgroundWeight = 1.0;
61 
62  // You can add an arbitrary number of signal or background trees
63  factory->AddSignalTree(tsignal, signalWeight);
64  factory->AddBackgroundTree(tbackground, backgroundWeight);
65 
66 
67  // Set individual event weights (the variables must exist in the original TTree)
68  factory->SetBackgroundWeightExpression("weight");
69 
70 
71  // Apply additional cuts on the signal and background samples (can be different)
72  TCut mycuts = ""; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
73  TCut mycutb = ""; // for example: TCut mycutb = "abs(var1)<0.5";
74 
75  // Tell the factory how to use the training and testing events
76  factory->PrepareTrainingAndTestTree(mycuts, mycutb,
77  "nTrain_Signal=0:nTrain_Background=0:nTest_Signal=0:nTest_Background=0:SplitMode=Random:NormMode=NumEvents:!V");
78 
79 
80  ///////////////////
81  //Booking //
82  ///////////////////
83  // Boosted Decision Trees
84 
85  //PyMVA methods
86  factory->BookMethod(TMVA::Types::kPyRandomForest, "PyRandomForest",
87  "!V:NEstimators=150:Criterion=gini:MaxFeatures=auto:MaxDepth=3:MinSamplesLeaf=1:MinWeightFractionLeaf=0:Bootstrap=kTRUE");
88  factory->BookMethod(TMVA::Types::kPyAdaBoost, "PyAdaBoost",
89  "!V:BaseEstimator=None:NEstimators=100:LearningRate=1:Algorithm=SAMME.R:RandomState=None");
90  factory->BookMethod(TMVA::Types::kPyGTB, "PyGTB",
91  "!V:NEstimators=150:Loss=deviance:LearningRate=0.1:Subsample=1:MaxDepth=6:MaxFeatures='auto'");
92 
93 
94  // Train MVAs using the set of training events
95  factory->TrainAllMethods();
96 
97  // ---- Evaluate all MVAs using the set of test events
98  factory->TestAllMethods();
99 
100  // ----- Evaluate and compare performance of all configured MVAs
101  factory->EvaluateAllMethods();
102  // --------------------------------------------------------------
103 
104  // Save the output
105  outputFile->Close();
106 
107  std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
108  std::cout << "==> TMVAClassification is done!" << std::endl;
109 
110 }
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1272
static Tools & Instance()
Definition: Tools.cxx:75
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:343
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:46
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
Basic string class.
Definition: TString.h:129
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1017
static void PyInitialize()
Initialize Python interpreter.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3909
A specialized string object used for TTree selections.
Definition: TCut.h:25
void Classification()
R__EXTERN TSystem * gSystem
Definition: TSystem.h:539
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods. ...
Definition: Factory.cxx:1255
void TestAllMethods()
Definition: Factory.cxx:1153
This is the main MVA steering class.
Definition: Factory.h:81
double Double_t
Definition: RtypesCore.h:55
static Bool_t SetCacheFileDir(const char *cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Sets the directory where to locally stage/cache remote files.
Definition: TFile.cxx:4429
A TTree object has a header with a name and a title.
Definition: TTree.h:78
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:904