Logo ROOT   6.16/01
Reference Guide
TMVACrossValidation.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides an example of how to use TMVA for k-folds cross
5/// evaluation.
6///
7/// As input data is used a toy-MC sample consisting of two guassian
8/// distributions.
9///
10/// The output file "TMVA.root" can be analysed with the use of dedicated
11/// macros (simply say: root -l <macro.C>), which can be conveniently
12/// invoked through a GUI that will appear at the end of the run of this macro.
13/// Launch the GUI via the command:
14///
15/// ```
16/// root -l -e 'TMVA::TMVAGui("TMVA.root")'
17/// ```
18///
19/// ## Cross Evaluation
20/// Cross evaluation is a special case of k-folds cross validation where the
21/// splitting into k folds is computed deterministically. This ensures that the
22/// a given event will always end up in the same fold.
23///
24/// In addition all resulting classifiers are saved and can be applied to new
25/// data using `MethodCrossValidation`. One requirement for this to work is a
26/// splitting function that is evaluated for each event to determine into what
27/// fold it goes (for training/evaluation) or to what classifier (for
28/// application).
29///
30/// ## Split Expression
31/// Cross evaluation uses a deterministic split to partition the data into
32/// folds called the split expression. The expression can be any valid
33/// `TFormula` as long as all parts used are defined.
34///
35/// For each event the split expression is evaluated to a number and the event
36/// is put in the fold corresponding to that number.
37///
38/// It is recommended to always use `%int([NumFolds])` at the end of the
39/// expression.
40///
41/// The split expression has access to all spectators and variables defined in
42/// the dataloader. Additionally, the number of folds in the split can be
43/// accessed with `NumFolds` (or `numFolds`).
44///
45/// ### Example
46/// ```
47/// "int(fabs([eventID]))%int([NumFolds])"
48/// ```
49///
50/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
51/// - Package : TMVA
52/// - Root Macro: TMVACrossValidation
53///
54/// \macro_output
55/// \macro_code
56/// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
57
58#include <cstdlib>
59#include <iostream>
60#include <map>
61#include <string>
62
63#include "TChain.h"
64#include "TFile.h"
65#include "TTree.h"
66#include "TString.h"
67#include "TObjString.h"
68#include "TSystem.h"
69#include "TROOT.h"
70
72#include "TMVA/DataLoader.h"
73#include "TMVA/Factory.h"
74#include "TMVA/Tools.h"
75#include "TMVA/TMVAGui.h"
76
77// Helper function to load data into TTrees.
78TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
79{
80 TRandom3 rng(seed);
81 Float_t x = 0;
82 Float_t y = 0;
83 UInt_t eventID = 0;
84
85 TTree *data = new TTree();
86 data->Branch("x", &x, "x/F");
87 data->Branch("y", &y, "y/F");
88 data->Branch("eventID", &eventID, "eventID/I");
89
90 for (Int_t n = 0; n < nPoints; ++n) {
91 x = rng.Gaus(offset, scale);
92 y = rng.Gaus(offset, scale);
93
94 // For our simple example it is enough that the id's are uniformly
95 // distributed and independent of the data.
96 ++eventID;
97
98 data->Fill();
99 }
100
101 // Important: Disconnects the tree from the memory locations of x and y.
102 data->ResetBranchAddresses();
103 return data;
104}
105
106int TMVACrossValidation()
107{
108 // This loads the library
110
111 // --------------------------------------------------------------------------
112
113 // Load the data into TTrees. If you load data from file you can use a
114 // variant of
115 // ```
116 // TString filename = "/path/to/file";
117 // TFile * input = TFile::Open( filename );
118 // TTree * signalTree = (TTree*)input->Get("TreeName");
119 // ```
120 TTree *sigTree = genTree(1000, 1.0, 1.0, 100);
121 TTree *bkgTree = genTree(1000, -1.0, 1.0, 101);
122
123 // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
124 TString outfileName("TMVA.root");
125 TFile *outputFile = TFile::Open(outfileName, "RECREATE");
126
127 // DataLoader definitions; We declare variables in the tree so that TMVA can
128 // find them. For more information see TMVAClassification tutorial.
130
131 // Data variables
132 dataloader->AddVariable("x", 'F');
133 dataloader->AddVariable("y", 'F');
134
135 // Spectator used for split
136 dataloader->AddSpectator("eventID", 'I');
137
138 // Attaches the trees so they can be read from
139 dataloader->AddSignalTree(sigTree, 1.0);
140 dataloader->AddBackgroundTree(bkgTree, 1.0);
141
142 // The CV mechanism of TMVA splits up the training set into several folds.
143 // The test set is currently left unused. The `nTest_ClassName=1` assigns
144 // one event to the the test set for each class and puts the rest in the
145 // training set. A value of 0 is a special value and would split the
146 // datasets 50 / 50.
147 dataloader->PrepareTrainingAndTestTree("", "",
148 "nTest_Signal=1"
149 ":nTest_Background=1"
150 ":SplitMode=Random"
151 ":NormMode=NumEvents"
152 ":!V");
153
154 // --------------------------------------------------------------------------
155
156 //
157 // This sets up a CrossValidation class (which wraps a TMVA::Factory
158 // internally) for 2-fold cross validation.
159 //
160 // The split type can be "Random", "RandomStratified" or "Deterministic".
161 // For the last option, check the comment below. Random splitting randomises
162 // the order of events and distributes events as evenly as possible.
163 // RandomStratified applies the same logic but distributes events within a
164 // class as evenly as possible over the folds.
165 //
166 UInt_t numFolds = 2;
167 TString analysisType = "Classification";
168 TString splitType = "Random";
169 TString splitExpr = "";
170
171 //
172 // One can also use a custom splitting function for producing the folds.
173 // The example uses a dataset spectator `eventID`.
174 //
175 // The idea here is that eventID should be an event number that is integral,
176 // random and independent of the data, generated only once. This last
177 // property ensures that if a calibration is changed the same event will
178 // still be assigned the same fold.
179 //
180 // This can be used to use the cross validated classifiers in application,
181 // a technique that can simplify statistical analysis.
182 //
183 // If you want to run TMVACrossValidationApplication, make sure you have
184 // run this tutorial with the below line uncommented first.
185 //
186
187 // TString splitExpr = "int(fabs([eventID]))%int([NumFolds])";
188
189 TString cvOptions = Form("!V"
190 ":!Silent"
191 ":ModelPersistence"
192 ":AnalysisType=%s"
193 ":SplitType=%s"
194 ":NumFolds=%i"
195 ":SplitExpr=%s",
196 analysisType.Data(), splitType.Data(), numFolds,
197 splitExpr.Data());
198
199 TMVA::CrossValidation cv{"TMVACrossValidation", dataloader, outputFile, cvOptions};
200
201 // --------------------------------------------------------------------------
202
203 //
204 // Books a method to use for evaluation
205 //
206 cv.BookMethod(TMVA::Types::kBDT, "BDTG",
207 "!H:!V:NTrees=100:MinNodeSize=2.5%:BoostType=Grad"
208 ":NegWeightTreatment=Pray:Shrinkage=0.10:nCuts=20"
209 ":MaxDepth=2");
210
211 cv.BookMethod(TMVA::Types::kFisher, "Fisher",
212 "!H:!V:Fisher:VarTransform=None");
213
214 // --------------------------------------------------------------------------
215
216 //
217 // Train, test and evaluate the booked methods.
218 // Evaluates the booked methods once for each fold and aggregates the result
219 // in the specified output file.
220 //
221 cv.Evaluate();
222
223 // --------------------------------------------------------------------------
224
225 //
226 // Process some output programatically, printing the ROC score for each
227 // booked method.
228 //
229 size_t iMethod = 0;
230 for (auto && result : cv.GetResults()) {
231 std::cout << "Summary for method " << cv.GetMethods()[iMethod++].GetValue<TString>("MethodName")
232 << std::endl;
233 for (UInt_t iFold = 0; iFold<cv.GetNumFolds(); ++iFold) {
234 std::cout << "\tFold " << iFold << ": "
235 << "ROC int: " << result.GetROCValues()[iFold]
236 << ", "
237 << "BkgEff@SigEff=0.3: " << result.GetEff30Values()[iFold]
238 << std::endl;
239 }
240 }
241
242 // --------------------------------------------------------------------------
243
244 //
245 // Save the output
246 //
247 outputFile->Close();
248
249 std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
250 std::cout << "==> TMVACrossValidation is done!" << std::endl;
251
252 // --------------------------------------------------------------------------
253
254 //
255 // Launch the GUI for the root macros
256 //
257 if (!gROOT->IsBatch()) {
258 // Draw cv-specific graphs
259 cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for BDTG");
260 cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for Fisher");
261
262 // You can also use the classical gui
263 TMVA::TMVAGui(outfileName);
264 }
265
266 return 0;
267}
268
269//
270// This is used if the macro is compiled. If run through ROOT with
271// `root -l -b -q MACRO.C` or similar it is unused.
272//
273int main(int argc, char **argv)
274{
275 TMVACrossValidation();
276}
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define gROOT
Definition: TROOT.h:410
char * Form(const char *fmt,...)
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:912
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3975
Class to perform cross validation, splitting the dataloader into folds.
static Tools & Instance()
Definition: Tools.cxx:75
@ kFisher
Definition: Types.h:84
@ kBDT
Definition: Types.h:88
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Random number generator class based on M.
Definition: TRandom3.h:27
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
A TTree object has a header with a name and a title.
Definition: TTree.h:71
int main(int argc, char **argv)
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
const Int_t n
Definition: legend1.C:16
void TMVAGui(const char *fName="TMVA.root", TString dataset="")