Logo ROOT   6.14/05
Reference Guide
TMVACrossValidation.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// This macro provides an example of how to use TMVA for k-folds cross
5 /// evaluation.
6 ///
7 /// As input data is used a toy-MC sample consisting of two guassian
8 /// distributions.
9 ///
10 /// The output file "TMVA.root" can be analysed with the use of dedicated
11 /// macros (simply say: root -l <macro.C>), which can be conveniently
12 /// invoked through a GUI that will appear at the end of the run of this macro.
13 /// Launch the GUI via the command:
14 ///
15 /// ```
16 /// root -l -e 'TMVA::TMVAGui("TMVA.root")'
17 /// ```
18 ///
19 /// ## Cross Evaluation
20 /// Cross evaluation is a special case of k-folds cross validation where the
21 /// splitting into k folds is computed deterministically. This ensures that the
22 /// a given event will always end up in the same fold.
23 ///
24 /// In addition all resulting classifiers are saved and can be applied to new
25 /// data using `MethodCrossValidation`. One requirement for this to work is a
26 /// splitting function that is evaluated for each event to determine into what
27 /// fold it goes (for training/evaluation) or to what classifier (for
28 /// application).
29 ///
30 /// ## Split Expression
31 /// Cross evaluation uses a deterministic split to partition the data into
32 /// folds called the split expression. The expression can be any valid
33 /// `TFormula` as long as all parts used are defined.
34 ///
35 /// For each event the split expression is evaluated to a number and the event
36 /// is put in the fold corresponding to that number.
37 ///
38 /// It is recommended to always use `%int([NumFolds])` at the end of the
39 /// expression.
40 ///
41 /// The split expression has access to all spectators and variables defined in
42 /// the dataloader. Additionally, the number of folds in the split can be
43 /// accessed with `NumFolds` (or `numFolds`).
44 ///
45 /// ### Example
46 /// ```
47 /// "int(fabs([eventID]))%int([NumFolds])"
48 /// ```
49 ///
50 /// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
51 /// - Package : TMVA
52 /// - Root Macro: TMVACrossValidation
53 ///
54 /// \macro_output
55 /// \macro_code
56 /// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
57 
58 #include <cstdlib>
59 #include <iostream>
60 #include <map>
61 #include <string>
62 
63 #include "TChain.h"
64 #include "TFile.h"
65 #include "TTree.h"
66 #include "TString.h"
67 #include "TObjString.h"
68 #include "TSystem.h"
69 #include "TROOT.h"
70 
71 #include "TMVA/CrossValidation.h"
72 #include "TMVA/DataLoader.h"
73 #include "TMVA/Factory.h"
74 #include "TMVA/Tools.h"
75 #include "TMVA/TMVAGui.h"
76 
77 // Helper function to load data into TTrees.
78 TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
79 {
80  TRandom3 rng(seed);
81  Float_t x = 0;
82  Float_t y = 0;
83  UInt_t eventID = 0;
84 
85  TTree *data = new TTree();
86  data->Branch("x", &x, "x/F");
87  data->Branch("y", &y, "y/F");
88  data->Branch("eventID", &eventID, "eventID/I");
89 
90  for (Int_t n = 0; n < nPoints; ++n) {
91  x = rng.Gaus(offset, scale);
92  y = rng.Gaus(offset, scale);
93 
94  // For our simple example it is enough that the id's are uniformly
95  // distributed and independent of the data.
96  ++eventID;
97 
98  data->Fill();
99  }
100 
101  // Important: Disconnects the tree from the memory locations of x and y.
102  data->ResetBranchAddresses();
103  return data;
104 }
105 
106 int TMVACrossValidation()
107 {
108  // This loads the library
110 
111  // --------------------------------------------------------------------------
112 
113  // Load the data into TTrees. If you load data from file you can use a
114  // variant of
115  // ```
116  // TString filename = "/path/to/file";
117  // TFile * input = TFile::Open( filename );
118  // TTree * signalTree = (TTree*)input->Get("TreeName");
119  // ```
120  TTree *sigTree = genTree(1000, 1.0, 1.0, 100);
121  TTree *bkgTree = genTree(1000, -1.0, 1.0, 101);
122 
123  // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
124  TString outfileName("TMVA.root");
125  TFile *outputFile = TFile::Open(outfileName, "RECREATE");
126 
127  // DataLoader definitions; We declare variables in the tree so that TMVA can
128  // find them. For more information see TMVAClassification tutorial.
129  TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
130 
131  // Data variables
132  dataloader->AddVariable("x", 'F');
133  dataloader->AddVariable("y", 'F');
134 
135  // Spectator used for split
136  dataloader->AddSpectator("eventID", 'I');
137 
138  // Attaches the trees so they can be read from
139  dataloader->AddSignalTree(sigTree, 1.0);
140  dataloader->AddBackgroundTree(bkgTree, 1.0);
141 
142  // The CV mechanism of TMVA splits up the training set into several folds.
143  // The test set is currently left unused. The `nTest_ClassName=1` assigns
144  // one event to the the test set for each class and puts the rest in the
145  // training set. A value of 0 is a special value and would split the
146  // datasets 50 / 50.
147  dataloader->PrepareTrainingAndTestTree("", "",
148  "nTest_Signal=1"
149  ":nTest_Background=1"
150  ":SplitMode=Random"
151  ":NormMode=NumEvents"
152  ":!V");
153 
154  // --------------------------------------------------------------------------
155 
156  //
157  // This sets up a CrossValidation class (which wraps a TMVA::Factory
158  // internally) for 2-fold cross validation.
159  //
160  UInt_t numFolds = 2;
161  TString analysisType = "Classification";
162  TString splitExpr = "";
163 
164  //
165  // One can also use a custom splitting function for producing the folds.
166  // The example uses a dataset spectator `eventID`.
167  //
168  // The idea here is that eventID should be an event number that is integral,
169  // random and independent of the data, generated only once. This last
170  // property ensures that if a calibration is changed the same event will
171  // still be assigned the same fold.
172  //
173  // This can be used to use the cross validated classifiers in application,
174  // a technique that can simplify statistical analysis.
175  //
176  // If you want to run TMVACrossValidationApplication, make sure you have
177  // run this tutorial with the below line uncommented first.
178  //
179 
180  // TString splitExpr = "int(fabs([eventID]))%int([NumFolds])";
181 
182  TString cvOptions = Form("!V"
183  ":!Silent"
184  ":ModelPersistence"
185  ":AnalysisType=%s"
186  ":NumFolds=%i"
187  ":SplitExpr=%s",
188  analysisType.Data(), numFolds, splitExpr.Data());
189 
190  TMVA::CrossValidation cv{"TMVACrossValidation", dataloader, outputFile, cvOptions};
191 
192  // --------------------------------------------------------------------------
193 
194  //
195  // Books a method to use for evaluation
196  //
197  cv.BookMethod(TMVA::Types::kBDT, "BDTG",
198  "!H:!V:NTrees=100:MinNodeSize=2.5%:BoostType=Grad"
199  ":NegWeightTreatment=Pray:Shrinkage=0.10:nCuts=20"
200  ":MaxDepth=2");
201 
202  cv.BookMethod(TMVA::Types::kFisher, "Fisher",
203  "!H:!V:Fisher:VarTransform=None");
204 
205  // --------------------------------------------------------------------------
206 
207  //
208  // Train, test and evaluate the booked methods.
209  // Evaluates the booked methods once for each fold and aggregates the result
210  // in the specified output file.
211  //
212  cv.Evaluate();
213 
214  // --------------------------------------------------------------------------
215 
216  //
217  // Process some output programatically, printing the ROC score for each
218  // booked method.
219  //
220  size_t iMethod = 0;
221  for (auto && result : cv.GetResults()) {
222  std::cout << "Summary for method " << cv.GetMethods()[iMethod++].GetValue<TString>("MethodName")
223  << std::endl;
224  for (UInt_t iFold = 0; iFold<cv.GetNumFolds(); ++iFold) {
225  std::cout << "\tFold " << iFold << ": "
226  << "ROC int: " << result.GetROCValues()[iFold]
227  << ", "
228  << "BkgEff@SigEff=0.3: " << result.GetEff30Values()[iFold]
229  << std::endl;
230  }
231  }
232 
233  // --------------------------------------------------------------------------
234 
235  //
236  // Save the output
237  //
238  outputFile->Close();
239 
240  std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
241  std::cout << "==> TMVACrossValidation is done!" << std::endl;
242 
243  // --------------------------------------------------------------------------
244 
245  //
246  // Launch the GUI for the root macros
247  //
248  if (!gROOT->IsBatch()) {
249  TMVA::TMVAGui(outfileName);
250  }
251 
252  return 0;
253 }
254 
255 //
256 // This is used if the macro is compiled. If run through ROOT with
257 // `root -l -b -q MACRO.C` or similar it is unused.
258 //
259 int main(int argc, char **argv)
260 {
261  TMVACrossValidation();
262 }
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:408
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
static Tools & Instance()
Definition: Tools.cxx:75
Random number generator class based on M.
Definition: TRandom3.h:27
float Float_t
Definition: RtypesCore.h:53
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4374
void TMVAGui(const char *fName="TMVA.root", TString dataset="")
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:47
#define gROOT
Definition: TROOT.h:410
Basic string class.
Definition: TString.h:131
int Int_t
Definition: RtypesCore.h:41
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3976
Double_t x[n]
Definition: legend1.C:17
int main(int argc, char **argv)
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:629
double Double_t
Definition: RtypesCore.h:55
Class to perform cross validation, splitting the dataloader into folds.
virtual void ResetBranchAddresses()
Tell all of our branches to drop their current objects and allocate new ones.
Definition: TTree.cxx:7714
Double_t y[n]
Definition: legend1.C:17
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1711
A TTree object has a header with a name and a title.
Definition: TTree.h:70
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:377
const Int_t n
Definition: legend1.C:16
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: DataLoader.cxx:521
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:917
const char * Data() const
Definition: TString.h:364