Logo ROOT   6.14/05
Reference Guide
TMVACrossValidationRegression.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// This macro provides an example of how to use TMVA for k-folds cross
5 /// evaluation.
6 ///
7 /// As input data is used a toy-MC sample consisting of two guassian
8 /// distributions.
9 ///
10 /// The output file "TMVA.root" can be analysed with the use of dedicated
11 /// macros (simply say: root -l <macro.C>), which can be conveniently
12 /// invoked through a GUI that will appear at the end of the run of this macro.
13 /// Launch the GUI via the command:
14 ///
15 /// ```
16 /// root -l -e 'TMVA::TMVAGui("TMVA.root")'
17 /// ```
18 ///
19 /// ## Cross Evaluation
20 /// Cross evaluation is a special case of k-folds cross validation where the
21 /// splitting into k folds is computed deterministically. This ensures that the
22 /// a given event will always end up in the same fold.
23 ///
24 /// In addition all resulting classifiers are saved and can be applied to new
25 /// data using `MethodCrossValidation`. One requirement for this to work is a
26 /// splitting function that is evaluated for each event to determine into what
27 /// fold it goes (for training/evaluation) or to what classifier (for
28 /// application).
29 ///
30 /// ## Split Expression
31 /// Cross evaluation uses a deterministic split to partition the data into
32 /// folds called the split expression. The expression can be any valid
33 /// `TFormula` as long as all parts used are defined.
34 ///
35 /// For each event the split expression is evaluated to a number and the event
36 /// is put in the fold corresponding to that number.
37 ///
38 /// It is recommended to always use `%int([NumFolds])` at the end of the
39 /// expression.
40 ///
41 /// The split expression has access to all spectators and variables defined in
42 /// the dataloader. Additionally, the number of folds in the split can be
43 /// accessed with `NumFolds` (or `numFolds`).
44 ///
45 /// ### Example
46 /// ```
47 /// "int(fabs([eventID]))%int([NumFolds])"
48 /// ```
49 ///
50 /// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
51 /// - Package : TMVA
52 /// - Root Macro: TMVACrossValidationRegression
53 ///
54 /// \macro_output
55 /// \macro_code
56 /// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
57 
58 #include <cstdlib>
59 #include <iostream>
60 #include <map>
61 #include <string>
62 
63 #include "TChain.h"
64 #include "TFile.h"
65 #include "TTree.h"
66 #include "TString.h"
67 #include "TObjString.h"
68 #include "TSystem.h"
69 #include "TROOT.h"
70 
71 #include "TMVA/Factory.h"
72 #include "TMVA/DataLoader.h"
73 #include "TMVA/Tools.h"
74 #include "TMVA/TMVAGui.h"
75 
76 TFile * getDataFile(TString fname) {
77  TFile *input(0);
78 
79  if (!gSystem->AccessPathName(fname)) {
80  input = TFile::Open(fname); // check if file in local directory exists
81  } else {
82  // if not: download from ROOT server
84  input = TFile::Open("http://root.cern.ch/files/tmva_reg_example.root", "CACHEREAD");
85  }
86 
87  if (!input) {
88  std::cout << "ERROR: could not open data file " << fname << std::endl;
89  exit(1);
90  }
91 
92  return input;
93 }
94 
95 int TMVACrossValidationRegression()
96 {
97  // This loads the library
99 
100  // --------------------------------------------------------------------------
101 
102  // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
103  TString outfileName("TMVAReg.root");
104  TFile * outputFile = TFile::Open(outfileName, "RECREATE");
105 
106  TString infileName("./files/tmva_reg_example.root");
107  TFile * inputFile = getDataFile(infileName);
108 
110 
111  dataloader->AddVariable("var1", "Variable 1", "units", 'F');
112  dataloader->AddVariable("var2", "Variable 2", "units", 'F');
113 
114  dataloader->AddSpectator("spec1 := var1*100 + var2*100", 'F');
115 
116  // Add the variable carrying the regression target
117  dataloader->AddTarget("fvalue");
118 
119  TTree * regTree = (TTree*)inputFile->Get("TreeR");
120  dataloader->AddRegressionTree(regTree, 1.0);
121  dataloader->SetWeightExpression("var1", "Regression");
122 
123  std::cout << "--- TMVACrossValidationRegression: Using input file: " << inputFile->GetName() << std::endl;
124 
125  // Bypasses the normal splitting mechanism. Unfortunately we must set the
126  // number of events in the training and test sets to 1, otherwise the non-CV
127  // part of TMVA is unhappy.
128  dataloader->PrepareTrainingAndTestTree("", "",
129  ":nTest_Regression=0"
130  ":SplitMode=Random"
131  ":NormMode=NumEvents"
132  ":!V");
133 
134  // --------------------------------------------------------------------------
135 
136  //
137  // This sets up a CrossValidation class (which wraps a TMVA::Factory
138  // internally) for 2-fold cross validation. The data will be split into the
139  // two folds randomly if `splitExpr` is `""`.
140  //
141  // One can also give a deterministic split using spectator variables. An
142  // example would be e.g. `"int(fabs([spec1]))%int([NumFolds])"`.
143  //
144  UInt_t numFolds = 2;
145  TString analysisType = "Regression";
146  TString splitExpr = "";
147 
148  TString cvOptions = Form("!V"
149  ":!Silent"
150  ":ModelPersistence"
151  ":!FoldFileOutput"
152  ":AnalysisType=%s"
153  ":NumFolds=%i"
154  ":SplitExpr=%s",
155  analysisType.Data(), numFolds, splitExpr.Data());
156 
157  TMVA::CrossValidation ce{"TMVACrossValidationRegression", dataloader, outputFile, cvOptions};
158 
159  // --------------------------------------------------------------------------
160 
161  //
162  // Books a method to use for evaluation
163  //
164  ce.BookMethod(TMVA::Types::kBDT, "BDTG",
165  "!H:!V:NTrees=2000::BoostType=Grad:Shrinkage=0.1:"
166  "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=3:"
167  "MaxDepth=4");
168 
169  ce.BookMethod(TMVA::Types::kMLP, "MLP",
170  "!H:!V:VarTransform=Norm:NeuronType=tanh:NCycles=200:"
171  "HiddenLayers=N+20:TestRate=6:TrainingMethod=BFGS:"
172  "Sampling=0.3:SamplingEpoch=0.8:ConvergenceImprove=1e-6:"
173  "ConvergenceTests=15:!UseRegulator" );
174 
175  // --------------------------------------------------------------------------
176 
177  //
178  // Train, test and evaluate the booked methods.
179  // Evaluates the booked methods once for each fold and aggregates the result
180  // in the specified output file.
181  //
182  ce.Evaluate();
183 
184  // --------------------------------------------------------------------------
185 
186  //
187  // Save the output
188  //
189  outputFile->Close();
190 
191  std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
192  std::cout << "==> TMVACrossValidationRegression is done!" << std::endl;
193 
194  // --------------------------------------------------------------------------
195 
196  //
197  // Launch the GUI for the root macros
198  //
199  if (!gROOT->IsBatch()) {
200  TMVA::TMVAGui(outfileName);
201  }
202 
203  return 0;
204 }
205 
206 //
207 // This is used if the macro is compiled. If run through ROOT with
208 // `root -l -b -q MACRO.C` or similar it is unused.
209 //
210 int main(int argc, char **argv)
211 {
212  TMVACrossValidationRegression();
213 }
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1276
static Tools & Instance()
Definition: Tools.cxx:75
void TMVAGui(const char *fName="TMVA.root", TString dataset="")
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:47
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition: TFile.h:315
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
#define gROOT
Definition: TROOT.h:410
Basic string class.
Definition: TString.h:131
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3976
int main(int argc, char **argv)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
void AddRegressionTree(TTree *tree, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition: DataLoader.h:113
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:629
Class to perform cross validation, splitting the dataloader into folds.
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: DataLoader.cxx:509
void SetWeightExpression(const TString &variable, const TString &className="")
Definition: DataLoader.cxx:560
A TTree object has a header with a name and a title.
Definition: TTree.h:70
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: DataLoader.cxx:521
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:917
const char * Data() const
Definition: TString.h:364