Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
TMVAMinimalClassification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_ml
3/// \notebook -nodraw
4/// Minimal self-contained example for setting up TMVA with binary
5/// classification.
6///
7/// This is intended as a simple foundation to build on. It assumes you are
8/// familiar with TMVA already. As such concepts like the Factory, the DataLoader
9/// and others are not explained. For descriptions and tutorials use the TMVA online manual
10/// https://root.cern/manual/tmva/ or the more detailed examples provided with TMVA
11/// e.g. TMVAClassification.C. or the TMVA Users Guide
12/// https://github.com/root-project/root/blob/master/documentation/tmva/UsersGuide/TMVAUsersGuide.pdf
13///
14/// Sets up a minimal binary classification example with two slightly overlapping
15/// 2-D gaussian distributions and trains a BDT classifier to discriminate the
16/// data.
17///
18/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
19/// - Package : TMVA
20/// - Root Macro: TMVAMinimalClassification.C
21///
22/// \macro_code
23/// \macro_output
24/// \author Kim Albertsson
25
26#include "TMVA/DataLoader.h"
27#include "TMVA/Factory.h"
28
29#include "TFile.h"
30#include "TString.h"
31#include "TTree.h"
32
33//
34// Helper function to generate 2-D gaussian data points and fill to a ROOT
35// TTree.
36//
37// Arguments:
38// nPoints Number of points to generate.
39// offset Mean of the generated numbers
40// scale Standard deviation of the generated numbers.
41// seed Seed for random number generator. Use `seed=0` for random
42// seed.
43// Returns a TTree ready to be used as input to TMVA.
44//
45TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
46{
47 TRandom rng(seed);
48 Double_t x = 0;
49 Double_t y = 0;
50
51 TTree *data = new TTree();
52 data->Branch("x", &x, "x/D");
53 data->Branch("y", &y, "y/D");
54
55 for (Int_t n = 0; n < nPoints; ++n) {
56 x = rng.Rndm() * scale;
57 y = offset + rng.Rndm() * scale;
58 data->Fill();
59 }
60
61 // Important: Disconnects the tree from the memory locations of x and y.
62 data->ResetBranchAddresses();
63 return data;
64}
65
66//
67// Minimal setup for performing binary classification in TMVA.
68//
69// Modify the setup to your liking and run with
70// `root -b -q TMVAMinimalClassification.C`.
71// This will generate an output file "out.root" that can be viewed with
72// `root -l -e 'TMVA::TMVAGui("out.root")'`.
73//
74void TMVAMinimalClassification()
75{
76 TString outputFilename = "out.root";
77 TFile *outFile = new TFile(outputFilename, "RECREATE");
78
79 // Data generation
80 TTree *signalTree = genTree(1000, 0.0, 2.0, 100);
81 TTree *backgroundTree = genTree(1000, 1.0, 2.0, 101);
82
83 TString factoryOptions = "AnalysisType=Classification";
84 TMVA::Factory factory{"", outFile, factoryOptions};
85
86 TMVA::DataLoader dataloader{"dataset"};
87
88 // Data specification
89 dataloader.AddVariable("x", 'D');
90 dataloader.AddVariable("y", 'D');
91
92 dataloader.AddSignalTree(signalTree, 1.0);
93 dataloader.AddBackgroundTree(backgroundTree, 1.0);
94
95 TCut signalCut = "";
96 TCut backgroundCut = "";
97 TString datasetOptions = "SplitMode=Random";
98 dataloader.PrepareTrainingAndTestTree(signalCut, backgroundCut, datasetOptions);
99
100 // Method specification
101 TString methodOptions = "";
102 factory.BookMethod(&dataloader, TMVA::Types::kBDT, "BDT", methodOptions);
103
104 // Training and Evaluation
105 factory.TrainAllMethods();
106 factory.TestAllMethods();
107 factory.EvaluateAllMethods();
108
109 // Clean up
110 outFile->Close();
111
112 delete outFile;
113 delete signalTree;
114 delete backgroundTree;
115}
int Int_t
Signed integer 4 bytes (int).
Definition RtypesCore.h:59
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
Definition RtypesCore.h:60
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
A specialized string object used for TTree selections.
Definition TCut.h:25
A file, usually with extension .root, that stores data and code in the form of serialized objects in ...
Definition TFile.h:130
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:981
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
This is the main MVA steering class.
Definition Factory.h:80
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition Factory.cxx:1108
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition Factory.cxx:1265
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition Factory.cxx:1370
MethodBase * BookMethod(DataLoader *loader, MethodName theMethodName, TString methodTitle, TString theOption="")
Books an MVA classifier or regression method.
Definition Factory.cxx:357
This is the base class for the ROOT Random number generators.
Definition TRandom.h:27
Basic string class.
Definition TString.h:138
A TTree represents a columnar dataset.
Definition TTree.h:89
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16