Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVAMinimalClassification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// Minimal self-contained example for setting up TMVA with binary
5/// classification.
6///
7/// This is intended as a simple foundation to build on. It assumes you are
8/// familiar with TMVA already. As such concepts like the Factory, the DataLoader
9/// and others are not explained. For descriptions and tutorials use the TMVA
10/// User's Guide (https://root.cern.ch/root-user-guides-and-manuals under TMVA)
11/// or the more detailed examples provided with TMVA e.g. TMVAClassification.C.
12///
13/// Sets up a minimal binary classification example with two slightly overlapping
14/// 2-D gaussian distributions and trains a BDT classifier to discriminate the
15/// data.
16///
17/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
18/// - Package : TMVA
19/// - Root Macro: TMVAMinimalClassification.C
20///
21/// \macro_output
22/// \macro_code
23/// \author Kim Albertsson
24
25#include "TMVA/DataLoader.h"
26#include "TMVA/Factory.h"
27
28#include "TFile.h"
29#include "TString.h"
30#include "TTree.h"
31
32//
33// Helper function to generate 2-D gaussian data points and fill to a ROOT
34// TTree.
35//
36// Arguments:
37// nPoints Number of points to generate.
38// offset Mean of the generated numbers
39// scale Standard deviation of the generated numbers.
40// seed Seed for random number generator. Use `seed=0` for random
41// seed.
42// Returns a TTree ready to be used as input to TMVA.
43//
44TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
45{
46 TRandom rng(seed);
47 Double_t x = 0;
48 Double_t y = 0;
49
50 TTree *data = new TTree();
51 data->Branch("x", &x, "x/D");
52 data->Branch("y", &y, "y/D");
53
54 for (Int_t n = 0; n < nPoints; ++n) {
55 x = rng.Rndm() * scale;
56 y = offset + rng.Rndm() * scale;
57 data->Fill();
58 }
59
60 // Important: Disconnects the tree from the memory locations of x and y.
62 return data;
63}
64
65//
66// Minimal setup for performing binary classification in TMVA.
67//
68// Modify the setup to your liking and run with
69// `root -l -b -q TMVAMinimalClassification.C`.
70// This will generate an output file "out.root" that can be viewed with
71// `root -l -e 'TMVA::TMVAGui("out.root")'`.
72//
73void TMVAMinimalClassification()
74{
75 TString outputFilename = "out.root";
76 TFile *outFile = new TFile(outputFilename, "RECREATE");
77
78 // Data generation
79 TTree *signalTree = genTree(1000, 0.0, 2.0, 100);
80 TTree *backgroundTree = genTree(1000, 1.0, 2.0, 101);
81
82 TString factoryOptions = "AnalysisType=Classification";
83 TMVA::Factory factory{"", outFile, factoryOptions};
84
85 TMVA::DataLoader dataloader{"dataset"};
86
87 // Data specification
88 dataloader.AddVariable("x", 'D');
89 dataloader.AddVariable("y", 'D');
90
91 dataloader.AddSignalTree(signalTree, 1.0);
92 dataloader.AddBackgroundTree(backgroundTree, 1.0);
93
94 TCut signalCut = "";
95 TCut backgroundCut = "";
96 TString datasetOptions = "SplitMode=Random";
97 dataloader.PrepareTrainingAndTestTree(signalCut, backgroundCut, datasetOptions);
98
99 // Method specification
100 TString methodOptions = "";
101 factory.BookMethod(&dataloader, TMVA::Types::kBDT, "BDT", methodOptions);
102
103 // Training and Evaluation
104 factory.TrainAllMethods();
105 factory.TestAllMethods();
106 factory.EvaluateAllMethods();
107
108 // Clean up
109 outFile->Close();
110
111 delete outFile;
112 delete signalTree;
113 delete backgroundTree;
114}
int Int_t
Definition RtypesCore.h:45
unsigned int UInt_t
Definition RtypesCore.h:46
double Double_t
Definition RtypesCore.h:59
A specialized string object used for TTree selections.
Definition TCut.h:25
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:899
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
This is the main MVA steering class.
Definition Factory.h:80
This is the base class for the ROOT Random number generators.
Definition TRandom.h:27
Basic string class.
Definition TString.h:136
A TTree represents a columnar dataset.
Definition TTree.h:79
virtual Int_t Fill()
Fill all branches.
Definition TTree.cxx:4594
TBranch * Branch(const char *name, T *obj, Int_t bufsize=32000, Int_t splitlevel=99)
Add a new branch, and infer the data type from the type of obj being passed.
Definition TTree.h:350
virtual void ResetBranchAddresses()
Tell all of our branches to drop their current objects and allocate new ones.
Definition TTree.cxx:8054
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16