Logo ROOT   6.14/05
Reference Guide
TMVACrossValidationApplication.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// This macro provides an example of how to use TMVA for k-folds cross
5 /// evaluation in application.
6 ///
7 /// This requires that CrossValidation was run with a deterministic split, such
8 /// as `"...:splitExpr=int([eventID])%int([numFolds]):..."`.
9 ///
10 /// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
11 /// - Package : TMVA
12 /// - Root Macro: TMVACrossValidationApplication
13 ///
14 /// \macro_output
15 /// \macro_code
16 /// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
17 
18 #include <cstdlib>
19 #include <iostream>
20 #include <map>
21 #include <string>
22 
23 #include "TChain.h"
24 #include "TFile.h"
25 #include "TTree.h"
26 #include "TString.h"
27 #include "TObjString.h"
28 #include "TSystem.h"
29 #include "TROOT.h"
30 
31 #include "TMVA/Factory.h"
32 #include "TMVA/DataLoader.h"
33 #include "TMVA/Tools.h"
34 #include "TMVA/TMVAGui.h"
35 
36 // Helper function to load data into TTrees.
37 TTree *fillTree(TTree * tree, Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
38 {
39  TRandom3 rng(seed);
40  Float_t x = 0;
41  Float_t y = 0;
42  Int_t eventID = 0;
43 
44  tree->SetBranchAddress("x", &x);
45  tree->SetBranchAddress("y", &y);
46  tree->SetBranchAddress("eventID", &eventID);
47 
48  for (Int_t n = 0; n < nPoints; ++n) {
49  x = rng.Gaus(offset, scale);
50  y = rng.Gaus(offset, scale);
51 
52  // For our simple example it is enough that the id's are uniformly
53  // distributed and independent of the data.
54  ++eventID;
55 
56  tree->Fill();
57  }
58 
59  // Important: Disconnects the tree from the memory locations of x and y.
60  tree->ResetBranchAddresses();
61  return tree;
62 }
63 
64 int TMVACrossValidationApplication()
65 {
66  // This loads the library
68 
69  // Set up the TMVA::Reader
70  TMVA::Reader *reader = new TMVA::Reader("!Color:!Silent:!V");
71 
72  Float_t x;
73  Float_t y;
74  Int_t eventID;
75 
76  reader->AddVariable("x", &x);
77  reader->AddVariable("y", &y);
78  reader->AddSpectator("eventID", &eventID);
79 
80  // Book the serialised methods
81  TString jobname("TMVACrossEvaluation");
82  {
83  TString methodName = "BDTG";
84  TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
85 
86  Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
87  if (weightfileExists) {
88  reader->BookMVA(methodName, weightfile);
89  } else {
90  std::cout << "Weightfile for method " << methodName << " not found."
91  " Did you run TMVACrossValidation with a specified"
92  " splitExpr?" << std::endl;
93  exit(0);
94  }
95 
96  }
97  {
98  TString methodName = "Fisher";
99  TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
100 
101  Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
102  if (weightfileExists) {
103  reader->BookMVA(methodName, weightfile);
104  } else {
105  std::cout << "Weightfile for method " << methodName << " not found."
106  " Did you run TMVACrossValidation with a specified"
107  " splitExpr?" << std::endl;
108  exit(0);
109  }
110  }
111 
112  // Load data
113  TTree *tree = new TTree();
114  tree->Branch("x", &x, "x/F");
115  tree->Branch("y", &y, "y/F");
116  tree->Branch("eventID", &eventID, "eventID/I");
117 
118  fillTree(tree, 1000, 1.0, 1.0, 100);
119  fillTree(tree, 1000, -1.0, 1.0, 101);
120  tree->SetBranchAddress("x", &x);
121  tree->SetBranchAddress("y", &y);
122  tree->SetBranchAddress("eventID", &eventID);
123 
124  // Prepare histograms
125  Int_t nbin = 100;
126  TH1F histBDTG{"BDTG", "BDTG", nbin, -1, 1};
127  TH1F histFisher{"Fisher", "Fisher", nbin, -1, 1};
128 
129  // Evaluate classifiers
130  for (Long64_t ievt = 0; ievt < tree->GetEntries(); ievt++) {
131  tree->GetEntry(ievt);
132 
133  Double_t valBDTG = reader->EvaluateMVA("BDTG");
134  Double_t valFisher = reader->EvaluateMVA("Fisher");
135 
136  histBDTG.Fill(valBDTG);
137  histFisher.Fill(valFisher);
138  }
139 
140  tree->ResetBranchAddresses();
141  delete tree;
142 
143  { // Write histograms to output file
144  TFile *target = new TFile("TMVACrossEvaluationApp.root", "RECREATE");
145  histBDTG.Write();
146  histFisher.Write();
147  target->Close();
148  delete target;
149  }
150 
151  delete reader;
152 
153  return 0;
154 }
155 
156 //
157 // This is used if the macro is compiled. If run through ROOT with
158 // `root -l -b -q MACRO.C` or similar it is unused.
159 //
160 int main(int argc, char **argv)
161 {
162  TMVACrossValidationApplication();
163 }
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1276
static Tools & Instance()
Definition: Tools.cxx:75
Random number generator class based on M.
Definition: TRandom3.h:27
long long Long64_t
Definition: RtypesCore.h:69
float Float_t
Definition: RtypesCore.h:53
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:308
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4374
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:47
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5363
Basic string class.
Definition: TString.h:131
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:567
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
Double_t x[n]
Definition: legend1.C:17
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:7982
int main(int argc, char **argv)
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:377
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
void AddSpectator(const TString &expression, Float_t *)
Add a float spectator or expression to the reader.
Definition: Reader.cxx:326
double Double_t
Definition: RtypesCore.h:55
virtual void ResetBranchAddresses()
Tell all of our branches to drop their current objects and allocate new ones.
Definition: TTree.cxx:7714
Double_t y[n]
Definition: legend1.C:17
Double_t EvaluateMVA(const std::vector< Float_t > &, const TString &methodTag, Double_t aux=0)
Evaluate a std::vector<float> of input data for a given method The parameter aux is obligatory for th...
Definition: Reader.cxx:485
virtual Long64_t GetEntries() const
Definition: TTree.h:384
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1711
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
Definition: tree.py:1
A TTree object has a header with a name and a title.
Definition: TTree.h:70
const Int_t n
Definition: legend1.C:16
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:917