Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
TMVAMultipleBackgroundExample.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This example shows the training of signal with three different backgrounds
5/// Then in the application a tree is created with all signal and background
6/// events where the true class ID and the three classifier outputs are added
7/// finally with the application tree, the significance is maximized with the
8/// help of the TMVA genetic algorithm.
9/// - Project : TMVA - a Root-integrated toolkit for multivariate data analysis
10/// - Package : TMVA
11/// - Executable: TMVAGAexample
12///
13/// \macro_output
14/// \macro_code
15/// \author Andreas Hoecker
16
17
18#include <iostream> // Stream declarations
19#include <vector>
20#include <limits>
21
22#include "TChain.h"
23#include "TCut.h"
24#include "TDirectory.h"
25#include "TH1F.h"
26#include "TH1.h"
27#include "TMath.h"
28#include "TFile.h"
29#include "TStopwatch.h"
30#include "TROOT.h"
31#include "TSystem.h"
32
34#include "TMVA/GeneticFitter.h"
35#include "TMVA/IFitterTarget.h"
36#include "TMVA/Factory.h"
37#include "TMVA/DataLoader.h"//required to load dataset
38#include "TMVA/Reader.h"
39
40using std::vector, std::cout, std::endl;
41
42using namespace TMVA;
43
44// ----------------------------------------------------------------------------------------------
45// Training
46// ----------------------------------------------------------------------------------------------
47//
48void Training(){
49 std::string factoryOptions( "!V:!Silent:Transformations=I;D;P;G,D:AnalysisType=Classification" );
50 TString fname = "./tmva_example_multiple_background.root";
51
52 TFile *input(nullptr);
54 if (!input) {
55 std::cout << "ERROR: could not open data file" << std::endl;
56 exit(1);
57 }
58
59 TTree *signal = (TTree*)input->Get("TreeS");
60 TTree *background0 = (TTree*)input->Get("TreeB0");
61 TTree *background1 = (TTree*)input->Get("TreeB1");
62 TTree *background2 = (TTree*)input->Get("TreeB2");
63
64 /// global event weights per tree (see below for setting event-wise weights)
69
70 // Create a new root output file.
71 TString outfileName( "TMVASignalBackground0.root" );
72 TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
73
74
75
76 // background 0
77 // ____________
78 TMVA::Factory *factory = new TMVA::Factory( "TMVAMultiBkg0", outputFile, factoryOptions );
80
81 dataloader->AddVariable( "var1", "Variable 1", "", 'F' );
82 dataloader->AddVariable( "var2", "Variable 2", "", 'F' );
83 dataloader->AddVariable( "var3", "Variable 3", "units", 'F' );
84 dataloader->AddVariable( "var4", "Variable 4", "units", 'F' );
85
86 dataloader->AddSignalTree ( signal, signalWeight );
87 dataloader->AddBackgroundTree( background0, background0Weight );
88
89 // factory->SetBackgroundWeightExpression("weight");
90 TCut mycuts = ""; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
91 TCut mycutb = ""; // for example: TCut mycutb = "abs(var1)<0.5";
92
93 // tell the factory to use all remaining events in the trees after training for testing:
94 dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,
95 "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" );
96
97 // Boosted Decision Trees
98 factory->BookMethod( dataloader, TMVA::Types::kBDT, "BDTG",
99 "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.30:UseBaggedBoost:BaggedSampleFraction=0.6:SeparationType=GiniIndex:nCuts=20:MaxDepth=2" );
100 factory->TrainAllMethods();
101 factory->TestAllMethods();
102 factory->EvaluateAllMethods();
103
104 outputFile->Close();
105
106 delete factory;
107 delete dataloader;
108
109
110
111 // background 1
112 // ____________
113
114 outfileName = "TMVASignalBackground1.root";
115 outputFile = TFile::Open( outfileName, "RECREATE" );
116 dataloader=new TMVA::DataLoader("datasetBkg1");
117
118 factory = new TMVA::Factory( "TMVAMultiBkg1", outputFile, factoryOptions );
119 dataloader->AddVariable( "var1", "Variable 1", "", 'F' );
120 dataloader->AddVariable( "var2", "Variable 2", "", 'F' );
121 dataloader->AddVariable( "var3", "Variable 3", "units", 'F' );
122 dataloader->AddVariable( "var4", "Variable 4", "units", 'F' );
123
124 dataloader->AddSignalTree ( signal, signalWeight );
125 dataloader->AddBackgroundTree( background1, background1Weight );
126
127 // dataloader->SetBackgroundWeightExpression("weight");
128
129 // tell the factory to use all remaining events in the trees after training for testing:
130 dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,
131 "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" );
132
133 // Boosted Decision Trees
134 factory->BookMethod( dataloader, TMVA::Types::kBDT, "BDTG",
135 "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.30:UseBaggedBoost:BaggedSampleFraction=0.6:SeparationType=GiniIndex:nCuts=20:MaxDepth=2" );
136 factory->TrainAllMethods();
137 factory->TestAllMethods();
138 factory->EvaluateAllMethods();
139
140 outputFile->Close();
141
142 delete factory;
143 delete dataloader;
144
145
146 // background 2
147 // ____________
148
149 outfileName = "TMVASignalBackground2.root";
150 outputFile = TFile::Open( outfileName, "RECREATE" );
151
152 factory = new TMVA::Factory( "TMVAMultiBkg2", outputFile, factoryOptions );
153 dataloader=new TMVA::DataLoader("datasetBkg2");
154
155 dataloader->AddVariable( "var1", "Variable 1", "", 'F' );
156 dataloader->AddVariable( "var2", "Variable 2", "", 'F' );
157 dataloader->AddVariable( "var3", "Variable 3", "units", 'F' );
158 dataloader->AddVariable( "var4", "Variable 4", "units", 'F' );
159
160 dataloader->AddSignalTree ( signal, signalWeight );
161 dataloader->AddBackgroundTree( background2, background2Weight );
162
163 // dataloader->SetBackgroundWeightExpression("weight");
164
165 // tell the dataloader to use all remaining events in the trees after training for testing:
166 dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,
167 "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" );
168
169 // Boosted Decision Trees
170 factory->BookMethod( dataloader, TMVA::Types::kBDT, "BDTG",
171 "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.30:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20:MaxDepth=2" );
172 factory->TrainAllMethods();
173 factory->TestAllMethods();
174 factory->EvaluateAllMethods();
175
176 outputFile->Close();
177
178 delete factory;
179 delete dataloader;
180
181}
182
183
184
185
186
187// ----------------------------------------------------------------------------------------------
188// Application
189// ----------------------------------------------------------------------------------------------
190//
191// create a summary tree with all signal and background events and for each event the three classifier values and the true classID
193
194 // Create a new root output file.
195 TString outfileName( "tmva_example_multiple_backgrounds__applied.root" );
196 TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
197 TTree* outputTree = new TTree("multiBkg","multiple backgrounds tree");
198
201 Int_t classID = 0;
202 Float_t weight = 1.f;
203
205
206 outputTree->Branch("classID", &classID, "classID/I");
207 outputTree->Branch("var1", &var1, "var1/F");
208 outputTree->Branch("var2", &var2, "var2/F");
209 outputTree->Branch("var3", &var3, "var3/F");
210 outputTree->Branch("var4", &var4, "var4/F");
211 outputTree->Branch("weight", &weight, "weight/F");
212 outputTree->Branch("cls0", &classifier0, "cls0/F");
213 outputTree->Branch("cls1", &classifier1, "cls1/F");
214 outputTree->Branch("cls2", &classifier2, "cls2/F");
215
216
217 // create three readers for the three different signal/background classifications, .. one for each background
218 TMVA::Reader *reader0 = new TMVA::Reader( "!Color:!Silent" );
219 TMVA::Reader *reader1 = new TMVA::Reader( "!Color:!Silent" );
220 TMVA::Reader *reader2 = new TMVA::Reader( "!Color:!Silent" );
221
222 reader0->AddVariable( "var1", &var1 );
223 reader0->AddVariable( "var2", &var2 );
224 reader0->AddVariable( "var3", &var3 );
225 reader0->AddVariable( "var4", &var4 );
226
227 reader1->AddVariable( "var1", &var1 );
228 reader1->AddVariable( "var2", &var2 );
229 reader1->AddVariable( "var3", &var3 );
230 reader1->AddVariable( "var4", &var4 );
231
232 reader2->AddVariable( "var1", &var1 );
233 reader2->AddVariable( "var2", &var2 );
234 reader2->AddVariable( "var3", &var3 );
235 reader2->AddVariable( "var4", &var4 );
236
237 // load the weight files for the readers
238 TString method = "BDT method";
239 reader0->BookMVA( "BDT method", "datasetBkg0/weights/TMVAMultiBkg0_BDTG.weights.xml" );
240 reader1->BookMVA( "BDT method", "datasetBkg1/weights/TMVAMultiBkg1_BDTG.weights.xml" );
241 reader2->BookMVA( "BDT method", "datasetBkg2/weights/TMVAMultiBkg2_BDTG.weights.xml" );
242
243 // load the input file
244 TFile *input(0);
245 TString fname = "./tmva_example_multiple_background.root";
247
248 TTree* theTree = NULL;
249
250 // loop through signal and all background trees
251 for( int treeNumber = 0; treeNumber < 4; ++treeNumber ) {
252 if( treeNumber == 0 ){
253 theTree = (TTree*)input->Get("TreeS");
254 std::cout << "--- Select signal sample" << std::endl;
255// theTree->SetBranchAddress( "weight", &weight );
256 weight = 1;
257 classID = 0;
258 }else if( treeNumber == 1 ){
259 theTree = (TTree*)input->Get("TreeB0");
260 std::cout << "--- Select background 0 sample" << std::endl;
261// theTree->SetBranchAddress( "weight", &weight );
262 weight = 1;
263 classID = 1;
264 }else if( treeNumber == 2 ){
265 theTree = (TTree*)input->Get("TreeB1");
266 std::cout << "--- Select background 1 sample" << std::endl;
267// theTree->SetBranchAddress( "weight", &weight );
268 weight = 1;
269 classID = 2;
270 }else if( treeNumber == 3 ){
271 theTree = (TTree*)input->Get("TreeB2");
272 std::cout << "--- Select background 2 sample" << std::endl;
273// theTree->SetBranchAddress( "weight", &weight );
274 weight = 1;
275 classID = 3;
276 }
277
278
279 theTree->SetBranchAddress( "var1", &var1 );
280 theTree->SetBranchAddress( "var2", &var2 );
281 theTree->SetBranchAddress( "var3", &var3 );
282 theTree->SetBranchAddress( "var4", &var4 );
283
284
285 std::cout << "--- Processing: " << theTree->GetEntries() << " events" << std::endl;
287 sw.Start();
288 Int_t nEvent = theTree->GetEntries();
289// Int_t nEvent = 100;
290 for (Long64_t ievt=0; ievt<nEvent; ievt++) {
291
292 if (ievt%1000 == 0){
293 std::cout << "--- ... Processing event: " << ievt << std::endl;
294 }
295
296 theTree->GetEntry(ievt);
297
298 // get the classifiers for each of the signal/background classifications
299 classifier0 = reader0->EvaluateMVA( method );
300 classifier1 = reader1->EvaluateMVA( method );
301 classifier2 = reader2->EvaluateMVA( method );
302
303 outputTree->Fill();
304 }
305
306
307 // get elapsed time
308 sw.Stop();
309 std::cout << "--- End of event loop: "; sw.Print();
310 }
311 input->Close();
312
313
314 // write output tree
315/* outputTree->SetDirectory(outputFile);
316 outputTree->Write(); */
317 outputFile->Write();
318
319 outputFile->Close();
320
321 std::cout << "--- Created root file: \"" << outfileName.Data() << "\" containing the MVA output histograms" << std::endl;
322
323 delete reader0;
324 delete reader1;
325 delete reader2;
326
327 std::cout << "==> Application of readers is done! combined tree created" << std::endl << std::endl;
328
329}
330
331
332
333
334// -----------------------------------------------------------------------------------------
335// Genetic Algorithm Fitness definition
336// -----------------------------------------------------------------------------------------
337//
338class MyFitness : public IFitterTarget {
339public:
340 // constructor
342 chain = _chain;
343
344 hSignal = new TH1F("hsignal","hsignal",100,-1,1);
345 hFP = new TH1F("hfp","hfp",100,-1,1);
346 hTP = new TH1F("htp","htp",100,-1,1);
347
348 TString cutsAndWeightSignal = "weight*(classID==0)";
349 nSignal = chain->Draw("Entry$/Entries$>>hsignal",cutsAndWeightSignal,"goff");
350 weightsSignal = hSignal->Integral();
351
352 }
353
354 // the output of this function will be minimized
355 Double_t EstimatorFunction( std::vector<Double_t> & factors ){
356
357 TString cutsAndWeightTruePositive = Form("weight*((classID==0) && cls0>%f && cls1>%f && cls2>%f )",factors.at(0), factors.at(1), factors.at(2));
358 TString cutsAndWeightFalsePositive = Form("weight*((classID >0) && cls0>%f && cls1>%f && cls2>%f )",factors.at(0), factors.at(1), factors.at(2));
359
360 // Entry$/Entries$ just draws something reasonable. Could in principle anything
361 Float_t nTP = chain->Draw("Entry$/Entries$>>htp",cutsAndWeightTruePositive,"goff");
362 Float_t nFP = chain->Draw("Entry$/Entries$>>hfp",cutsAndWeightFalsePositive,"goff");
363
364 weightsTruePositive = hTP->Integral();
365 weightsFalsePositive = hFP->Integral();
366
367 efficiency = 0;
368 if( weightsSignal > 0 )
370
371 purity = 0;
374
376
377 Float_t toMinimize = std::numeric_limits<float>::max(); // set to the highest existing number
378 if( effTimesPur > 0 ) // if larger than 0, take 1/x. This is the value to minimize
379 toMinimize = 1./(effTimesPur); // we want to minimize 1/efficiency*purity
380
381 // Print();
382
383 return toMinimize;
384 }
385
386
387 void Print(){
388 std::cout << std::endl;
389 std::cout << "======================" << std::endl
390 << "Efficiency : " << efficiency << std::endl
391 << "Purity : " << purity << std::endl << std::endl
392 << "True positive weights : " << weightsTruePositive << std::endl
393 << "False positive weights: " << weightsFalsePositive << std::endl
394 << "Signal weights : " << weightsSignal << std::endl;
395 }
396
398
404
405
406private:
407 TChain* chain;
408 TH1F* hSignal;
409 TH1F* hFP;
410 TH1F* hTP;
411
412};
413
414
415
416
417
418
419
420
421// ----------------------------------------------------------------------------------------------
422// Call of Genetic algorithm
423// ----------------------------------------------------------------------------------------------
424//
426
427 // define all the parameters by their minimum and maximum value
428 // in this example 3 parameters (=cuts on the classifiers) are defined.
429 vector<Interval*> ranges;
430 ranges.push_back( new Interval(-1,1) ); // for some classifiers (especially LD) the ranges have to be taken larger
431 ranges.push_back( new Interval(-1,1) );
432 ranges.push_back( new Interval(-1,1) );
433
434 std::cout << "Classifier ranges (defined by the user)" << std::endl;
435 for( std::vector<Interval*>::iterator it = ranges.begin(); it != ranges.end(); it++ ){
436 std::cout << " range: " << (*it)->GetMin() << " " << (*it)->GetMax() << std::endl;
437 }
438
439 TChain* chain = new TChain("multiBkg");
440 chain->Add("tmva_example_multiple_backgrounds__applied.root");
441
443
444 // prepare the genetic algorithm with an initial population size of 20
445 // mind: big population sizes will help in searching the domain space of the solution
446 // but you have to weight this out to the number of generations
447 // the extreme case of 1 generation and populationsize n is equal to
448 // a Monte Carlo calculation with n tries
449
450 const TString name( "multipleBackgroundGA" );
451 const TString opts( "PopSize=100:Steps=30" );
452
453 GeneticFitter mg( *myFitness, name, ranges, opts);
454 // mg.SetParameters( 4, 30, 200, 10,5, 0.95, 0.001 );
455
456 std::vector<Double_t> result;
457 Double_t estimator = mg.Run(result);
458
459 dynamic_cast<MyFitness*>(myFitness)->Print();
460 std::cout << std::endl;
461
462 int n = 0;
463 for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); it++ ){
464 std::cout << " cutValue[" << n << "] = " << (*it) << ";"<< std::endl;
465 n++;
466 }
467
468
469}
470
471
472
473
475{
476 // ----------------------------------------------------------------------------------------
477 // Run all
478 // ----------------------------------------------------------------------------------------
479 cout << "Start Test TMVAGAexample" << endl
480 << "========================" << endl
481 << endl;
482
483 TString createDataMacro = gROOT->GetTutorialDir() + "/tmva/createData.C";
484 gROOT->ProcessLine(TString::Format(".L %s",createDataMacro.Data()));
485 gROOT->ProcessLine("create_MultipleBackground(200)");
486
487
488 cout << endl;
489 cout << "========================" << endl;
490 cout << "--- Training" << endl;
491 Training();
492
493 cout << endl;
494 cout << "========================" << endl;
495 cout << "--- Application & create combined tree" << endl;
497
498 cout << endl;
499 cout << "========================" << endl;
500 cout << "--- maximize significance" << endl;
502}
503
504int main( int argc, char** argv ) {
506}
int main()
Definition Prototype.cxx:12
int Int_t
Definition RtypesCore.h:45
float Float_t
Definition RtypesCore.h:57
double Double_t
Definition RtypesCore.h:59
long long Long64_t
Definition RtypesCore.h:80
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
char name[80]
Definition TGX11.cxx:110
void Print(GNN_Data &d, std::string txt="")
#define gROOT
Definition TROOT.h:406
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
const_iterator begin() const
const_iterator end() const
A chain is a collection of files containing TTree objects.
Definition TChain.h:33
A specialized string object used for TTree selections.
Definition TCut.h:25
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4094
1-D histogram with a float per channel (see TH1 documentation)
Definition TH1.h:621
This is the main MVA steering class.
Definition Factory.h:80
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition Factory.cxx:1114
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition Factory.cxx:352
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition Factory.cxx:1271
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition Factory.cxx:1376
Fitter using a Genetic Algorithm.
Interface for a fitter 'target'.
The TMVA::Interval Class.
Definition Interval.h:61
The Reader class serves to use the MVAs in a specific analysis context.
Definition Reader.h:64
Stopwatch class.
Definition TStopwatch.h:28
Basic string class.
Definition TString.h:139
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
A TTree represents a columnar dataset.
Definition TTree.h:79
const Int_t n
Definition legend1.C:16
double efficiency(double effFuncVal, int catIndex, int sigCatIndex)
Definition MathFuncs.h:117
create variable transformations