Logo ROOT   6.10/09
Reference Guide
testPyRandomForestClassification.C
Go to the documentation of this file.
1 #include <iostream>
2 
3 #include "TString.h"
4 #include "TFile.h"
5 #include "TTree.h"
6 #include "TSystem.h"
7 #include "TMVA/Factory.h"
8 #include "TMVA/Reader.h"
9 #include "TMVA/DataLoader.h"
10 #include "TMVA/PyMethodBase.h"
11 
13  // Get data file
14  std::cout << "Get test data..." << std::endl;
15  TString fname = "./tmva_class_example.root";
16  if (gSystem->AccessPathName(fname)) // file does not exist in local directory
17  gSystem->Exec("curl -O http://root.cern.ch/files/tmva_class_example.root");
18  TFile *input = TFile::Open(fname);
19 
20  // Setup PyMVA and factory
21  std::cout << "Setup TMVA..." << std::endl;
23  TFile* outputFile = TFile::Open("ResultsTestPyRandomForestClassification.root", "RECREATE");
24  TMVA::Factory *factory = new TMVA::Factory("testPyRandomForestClassification", outputFile,
25  "!V:Silent:Color:!DrawProgressBar:AnalysisType=Classification");
26 
27  // Load data
28  TMVA::DataLoader *dataloader = new TMVA::DataLoader("datasetTestPyRandomForestClassification");
29 
30  TTree *signal = (TTree*)input->Get("TreeS");
31  TTree *background = (TTree*)input->Get("TreeB");
32  dataloader->AddSignalTree(signal);
33  dataloader->AddBackgroundTree(background);
34 
35  dataloader->AddVariable("var1");
36  dataloader->AddVariable("var2");
37  dataloader->AddVariable("var3");
38  dataloader->AddVariable("var4");
39 
40  dataloader->PrepareTrainingAndTestTree("",
41  "SplitMode=Random:NormMode=NumEvents:!V");
42 
43  // Book and train method
44  factory->BookMethod(dataloader, TMVA::Types::kPyRandomForest, "PyRandomForest",
45  "H:V:NEstimators=10");
46  std::cout << "Train classifier..." << std::endl;
47  factory->TrainAllMethods();
48 
49  // Clean-up
50  delete factory;
51  delete dataloader;
52  delete outputFile;
53 
54  // Setup reader
55  UInt_t numEvents = 100;
56  std::cout << "Run reader and classify " << numEvents << " events..." << std::endl;
57  TMVA::Reader *reader = new TMVA::Reader("Color:Silent");
58  Float_t vars[4];
59  reader->AddVariable("var1", vars+0);
60  reader->AddVariable("var2", vars+1);
61  reader->AddVariable("var3", vars+2);
62  reader->AddVariable("var4", vars+3);
63  reader->BookMVA("PyRandomForest", "datasetTestPyRandomForestClassification/weights/testPyRandomForestClassification_PyRandomForest.weights.xml");
64 
65  // Get mean response of method on signal and background events
66  signal->SetBranchAddress("var1", vars+0);
67  signal->SetBranchAddress("var2", vars+1);
68  signal->SetBranchAddress("var3", vars+2);
69  signal->SetBranchAddress("var4", vars+3);
70 
71  background->SetBranchAddress("var1", vars+0);
72  background->SetBranchAddress("var2", vars+1);
73  background->SetBranchAddress("var3", vars+2);
74  background->SetBranchAddress("var4", vars+3);
75 
76  Float_t meanMvaSignal = 0;
77  Float_t meanMvaBackground = 0;
78  for(UInt_t i=0; i<numEvents; i++){
79  signal->GetEntry(i);
80  meanMvaSignal += reader->EvaluateMVA("PyRandomForest");
81  background->GetEntry(i);
82  meanMvaBackground += reader->EvaluateMVA("PyRandomForest");
83  }
84  meanMvaSignal = meanMvaSignal/float(numEvents);
85  meanMvaBackground = meanMvaBackground/float(numEvents);
86 
87  // Check whether the response is obviously better than guessing
88  std::cout << "Mean MVA response on signal: " << meanMvaSignal << std::endl;
89  if(meanMvaSignal < 0.6){
90  std::cout << "[ERROR] Mean response on signal is " << meanMvaSignal << " (<0.6)" << std::endl;
91  return 1;
92  }
93  std::cout << "Mean MVA response on background: " << meanMvaBackground << std::endl;
94  if(meanMvaBackground > 0.4){
95  std::cout << "[ERROR] Mean response on background is " << meanMvaBackground << " (>0.4)" << std::endl;
96  return 1;
97  }
98 
99  return 0;
100 }
101 
102 int main(){
104  return err;
105 }
int testPyRandomForestClassification()
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:408
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1272
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:343
float Float_t
Definition: RtypesCore.h:53
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:308
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:46
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5321
Basic string class.
Definition: TString.h:129
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1017
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
static void PyInitialize()
Initialize Python interpreter.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3909
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:7873
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:377
R__EXTERN TSystem * gSystem
Definition: TSystem.h:539
unsigned int UInt_t
Definition: RtypesCore.h:42
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition: TSystem.cxx:660
This is the main MVA steering class.
Definition: Factory.h:81
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:629
Double_t EvaluateMVA(const std::vector< Float_t > &, const TString &methodTag, Double_t aux=0)
Evaluate a std::vector<float> of input data for a given method The parameter aux is obligatory for th...
Definition: Reader.cxx:485
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
A TTree object has a header with a name and a title.
Definition: TTree.h:78
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:377