Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVACrossValidation.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides an example of how to use TMVA for k-folds cross
5/// evaluation.
6///
7/// As input data is used a toy-MC sample consisting of two gaussian
8/// distributions.
9///
10/// The output file "TMVA.root" can be analysed with the use of dedicated
11/// macros (simply say: root -l <macro.C>), which can be conveniently
12/// invoked through a GUI that will appear at the end of the run of this macro.
13/// Launch the GUI via the command:
14///
15/// ```
16/// root -l -e 'TMVA::TMVAGui("TMVA.root")'
17/// ```
18///
19/// ## Cross Evaluation
20/// Cross evaluation is a special case of k-folds cross validation where the
21/// splitting into k folds is computed deterministically. This ensures that the
22/// a given event will always end up in the same fold.
23///
24/// In addition all resulting classifiers are saved and can be applied to new
25/// data using `MethodCrossValidation`. One requirement for this to work is a
26/// splitting function that is evaluated for each event to determine into what
27/// fold it goes (for training/evaluation) or to what classifier (for
28/// application).
29///
30/// ## Split Expression
31/// Cross evaluation uses a deterministic split to partition the data into
32/// folds called the split expression. The expression can be any valid
33/// `TFormula` as long as all parts used are defined.
34///
35/// For each event the split expression is evaluated to a number and the event
36/// is put in the fold corresponding to that number.
37///
38/// It is recommended to always use `%int([NumFolds])` at the end of the
39/// expression.
40///
41/// The split expression has access to all spectators and variables defined in
42/// the dataloader. Additionally, the number of folds in the split can be
43/// accessed with `NumFolds` (or `numFolds`).
44///
45/// ### Example
46/// ```
47/// "int(fabs([eventID]))%int([NumFolds])"
48/// ```
49///
50/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
51/// - Package : TMVA
52/// - Root Macro: TMVACrossValidation
53///
54/// \macro_output
55/// \macro_code
56/// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
57
58#include <cstdlib>
59#include <iostream>
60#include <map>
61#include <string>
62
63#include "TChain.h"
64#include "TFile.h"
65#include "TTree.h"
66#include "TString.h"
67#include "TObjString.h"
68#include "TSystem.h"
69#include "TROOT.h"
70
72#include "TMVA/DataLoader.h"
73#include "TMVA/Factory.h"
74#include "TMVA/Tools.h"
75#include "TMVA/TMVAGui.h"
76
77// Helper function to load data into TTrees.
78TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
79{
80 TRandom3 rng(seed);
81 Float_t x = 0;
82 Float_t y = 0;
83 UInt_t eventID = 0;
84
85 TTree *data = new TTree();
86 data->Branch("x", &x, "x/F");
87 data->Branch("y", &y, "y/F");
88 data->Branch("eventID", &eventID, "eventID/I");
89
90 for (Int_t n = 0; n < nPoints; ++n) {
91 x = rng.Gaus(offset, scale);
92 y = rng.Gaus(offset, scale);
93
94 // For our simple example it is enough that the id's are uniformly
95 // distributed and independent of the data.
96 ++eventID;
97
98 data->Fill();
99 }
100
101 // Important: Disconnects the tree from the memory locations of x and y.
102 data->ResetBranchAddresses();
103 return data;
104}
105
106int TMVACrossValidation(bool useRandomSplitting = false)
107{
108 // This loads the library
110
111 // --------------------------------------------------------------------------
112
113 // Load the data into TTrees. If you load data from file you can use a
114 // variant of
115 // ```
116 // TString filename = "/path/to/file";
117 // TFile * input = TFile::Open( filename );
118 // TTree * signalTree = (TTree*)input->Get("TreeName");
119 // ```
120 TTree *sigTree = genTree(1000, 1.0, 1.0, 100);
121 TTree *bkgTree = genTree(1000, -1.0, 1.0, 101);
122
123 // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
124 TString outfileName("TMVA.root");
125 TFile *outputFile = TFile::Open(outfileName, "RECREATE");
126
127 // DataLoader definitions; We declare variables in the tree so that TMVA can
128 // find them. For more information see TMVAClassification tutorial.
129 TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
130
131 // Data variables
132 dataloader->AddVariable("x", 'F');
133 dataloader->AddVariable("y", 'F');
134
135 // Spectator used for split
136 dataloader->AddSpectator("eventID", 'I');
137
138 // NOTE: Currently TMVA treats all input variables, spectators etc as
139 // floats. Thus, if the absolute value of the input is too large
140 // there can be precision loss. This can especially be a problem for
141 // cross validation with large event numbers.
142 // A workaround is to define your splitting variable as:
143 // `dataloader->AddSpectator("eventID := eventID % 4096", 'I');`
144 // where 4096 should be a number much larger than the number of folds
145 // you intend to run with.
146
147 // Attaches the trees so they can be read from
148 dataloader->AddSignalTree(sigTree, 1.0);
149 dataloader->AddBackgroundTree(bkgTree, 1.0);
150
151 // The CV mechanism of TMVA splits up the training set into several folds.
152 // The test set is currently left unused. The `nTest_ClassName=1` assigns
153 // one event to the the test set for each class and puts the rest in the
154 // training set. A value of 0 is a special value and would split the
155 // datasets 50 / 50.
156 dataloader->PrepareTrainingAndTestTree("", "",
157 "nTest_Signal=1"
158 ":nTest_Background=1"
159 ":SplitMode=Random"
160 ":NormMode=NumEvents"
161 ":!V");
162
163 // --------------------------------------------------------------------------
164
165 //
166 // This sets up a CrossValidation class (which wraps a TMVA::Factory
167 // internally) for 2-fold cross validation.
168 //
169 // The split type can be "Random", "RandomStratified" or "Deterministic".
170 // For the last option, check the comment below. Random splitting randomises
171 // the order of events and distributes events as evenly as possible.
172 // RandomStratified applies the same logic but distributes events within a
173 // class as evenly as possible over the folds.
174 //
175 UInt_t numFolds = 2;
176 TString analysisType = "Classification";
177
178 TString splitType = (useRandomSplitting) ? "Random" : "Deterministic";
179
180 //
181 // One can also use a custom splitting function for producing the folds.
182 // The example uses a dataset spectator `eventID`.
183 //
184 // The idea here is that eventID should be an event number that is integral,
185 // random and independent of the data, generated only once. This last
186 // property ensures that if a calibration is changed the same event will
187 // still be assigned the same fold.
188 //
189 // This can be used to use the cross validated classifiers in application,
190 // a technique that can simplify statistical analysis.
191 //
192 // If you want to run TMVACrossValidationApplication, make sure you have
193 // run this tutorial with Deterministic splitting type, i.e.
194 // with the option useRandomSPlitting = false
195 //
196
197 TString splitExpr = (!useRandomSplitting) ? "int(fabs([eventID]))%int([NumFolds])" : "";
198
199 TString cvOptions = Form("!V"
200 ":!Silent"
201 ":ModelPersistence"
202 ":AnalysisType=%s"
203 ":SplitType=%s"
204 ":NumFolds=%i"
205 ":SplitExpr=%s",
206 analysisType.Data(), splitType.Data(), numFolds,
207 splitExpr.Data());
208
209 TMVA::CrossValidation cv{"TMVACrossValidation", dataloader, outputFile, cvOptions};
210
211 // --------------------------------------------------------------------------
212
213 //
214 // Books a method to use for evaluation
215 //
216 cv.BookMethod(TMVA::Types::kBDT, "BDTG",
217 "!H:!V:NTrees=100:MinNodeSize=2.5%:BoostType=Grad"
218 ":NegWeightTreatment=Pray:Shrinkage=0.10:nCuts=20"
219 ":MaxDepth=2");
220
221 cv.BookMethod(TMVA::Types::kFisher, "Fisher",
222 "!H:!V:Fisher:VarTransform=None");
223
224 // --------------------------------------------------------------------------
225
226 //
227 // Train, test and evaluate the booked methods.
228 // Evaluates the booked methods once for each fold and aggregates the result
229 // in the specified output file.
230 //
231 cv.Evaluate();
232
233 // --------------------------------------------------------------------------
234
235 //
236 // Process some output programatically, printing the ROC score for each
237 // booked method.
238 //
239 size_t iMethod = 0;
240 for (auto && result : cv.GetResults()) {
241 std::cout << "Summary for method " << cv.GetMethods()[iMethod++].GetValue<TString>("MethodName")
242 << std::endl;
243 for (UInt_t iFold = 0; iFold<cv.GetNumFolds(); ++iFold) {
244 std::cout << "\tFold " << iFold << ": "
245 << "ROC int: " << result.GetROCValues()[iFold]
246 << ", "
247 << "BkgEff@SigEff=0.3: " << result.GetEff30Values()[iFold]
248 << std::endl;
249 }
250 }
251
252 // --------------------------------------------------------------------------
253
254 //
255 // Save the output
256 //
257 outputFile->Close();
258
259 std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
260 std::cout << "==> TMVACrossValidation is done!" << std::endl;
261
262 // --------------------------------------------------------------------------
263
264 //
265 // Launch the GUI for the root macros
266 //
267 if (!gROOT->IsBatch()) {
268 // Draw cv-specific graphs
269 cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for BDTG");
270 cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for Fisher");
271
272 // You can also use the classical gui
273 TMVA::TMVAGui(outfileName);
274 }
275
276 return 0;
277}
278
279//
280// This is used if the macro is compiled. If run through ROOT with
281// `root -l -b -q MACRO.C` or similar it is unused.
282//
283int main(int argc, char **argv)
284{
285 TMVACrossValidation();
286}
int main()
Definition Prototype.cxx:12
int Int_t
Definition RtypesCore.h:45
unsigned int UInt_t
Definition RtypesCore.h:46
double Double_t
Definition RtypesCore.h:59
float Float_t
Definition RtypesCore.h:57
const Bool_t kTRUE
Definition RtypesCore.h:100
#define gROOT
Definition TROOT.h:404
char * Form(const char *fmt,...)
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4025
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:899
Class to perform cross validation, splitting the dataloader into folds.
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
static Tools & Instance()
Definition Tools.cxx:71
@ kFisher
Definition Types.h:82
virtual const char * GetName() const
Returns name of object.
Definition TNamed.h:47
Random number generator class based on M.
Definition TRandom3.h:27
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
A TTree represents a columnar dataset.
Definition TTree.h:79
virtual Int_t Fill()
Fill all branches.
Definition TTree.cxx:4594
TBranch * Branch(const char *name, T *obj, Int_t bufsize=32000, Int_t splitlevel=99)
Add a new branch, and infer the data type from the type of obj being passed.
Definition TTree.h:350
virtual void ResetBranchAddresses()
Tell all of our branches to drop their current objects and allocate new ones.
Definition TTree.cxx:8054
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
void TMVAGui(const char *fName="TMVA.root", TString dataset="")