Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
classification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva_envelope
3/// \notebook -nodraw
4
5/// \macro_output
6/// \macro_code
7
8
9#include "TMVA/Factory.h"
10#include "TMVA/DataLoader.h"
11#include "TMVA/Tools.h"
12#include "TROOT.h"
13#include "TMVA/Classification.h"
14
16{
18
19 TFile *input(nullptr);
20 TString fname = gROOT->GetTutorialDir() + "/machine_learning/data/tmva_class_example.root";
22 input = TFile::Open(fname); // check if file in local directory exists
23 }
24 if (!input) {
25 std::cout << "ERROR: could not open data file" << fname << std::endl;
26 exit(1);
27 }
28
29 // Register the training and test trees
30
31 TTree *signalTree = (TTree *)input->Get("TreeS");
32 TTree *background = (TTree *)input->Get("TreeB");
33
35 // If you wish to modify default settings
36 // (please check "src/Config.h" to see all available global options)
37 //
38 // (TMVA::gConfig().GetVariablePlotting()).fTimesRMS = 8.0;
39 // (TMVA::gConfig().GetIONames()).fWeightFileDir = "myWeightDirectory";
40
41 // Define the input variables that shall be used for the MVA training
42 // note that you may also use variable expressions, such as: "3*var1/var2*abs(var3)"
43 // [all types of expressions that can also be parsed by TTree::Draw( "expression" )]
44 dataloader->AddVariable("myvar1 := var1+var2", 'F');
45 dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
46 dataloader->AddVariable("var3", "Variable 3", "units", 'F');
47 dataloader->AddVariable("var4", "Variable 4", "units", 'F');
48
49 // You can add so-called "Spectator variables", which are not used in the MVA training,
50 // but will appear in the final "TestTree" produced by TMVA. This TestTree will contain the
51 // input variables, the response values of all trained MVAs, and the spectator variables
52
53 dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
54 dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
55
56 // global event weights per tree (see below for setting event-wise weights)
59
60 // You can add an arbitrary number of signal or background trees
61 dataloader->AddSignalTree(signalTree, signalWeight);
62 dataloader->AddBackgroundTree(background, backgroundWeight);
63
64 // Set individual event weights (the variables must exist in the original TTree)
65 // - for signal : `dataloader->SetSignalWeightExpression ("weight1*weight2");`
66 // - for background: `dataloader->SetBackgroundWeightExpression("weight1*weight2");`
67 dataloader->SetBackgroundWeightExpression("weight");
68 dataloader->PrepareTrainingAndTestTree(
69 "", "", "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V");
70
71 TFile *outputFile = TFile::Open("TMVAClass.root", "RECREATE");
72
74
75 cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
76 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
77 cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
78
79 cl->BookMethod(TMVA::Types::kBDT, "BDTB", "!H:!V:NTrees=2000:BoostType=Bagging:SeparationType=GiniIndex:nCuts=20");
80
81 cl->BookMethod(TMVA::Types::kCuts, "Cuts", "!H:!V:FitMethod=MC:EffSel:SampleSize=200000:VarProp=FSmart");
82
83 cl->Evaluate(); // Train and Test all methods
84
85 auto &results = cl->GetResults();
86
87 TCanvas *c = new TCanvas(Form("ROC"));
88 c->SetTitle("ROC-Integral Curve");
89
90 auto mg = new TMultiGraph();
91 for (UInt_t i = 0; i < results.size(); i++) {
92 if (!results[i].IsCutsMethod()) {
93 auto roc = results[i].GetROCGraph();
94 roc->SetLineColorAlpha(i + 1, 0.1);
95 mg->Add(roc);
96 }
97 }
98 mg->Draw("AL");
99 mg->GetXaxis()->SetTitle(" Signal Efficiency ");
100 mg->GetYaxis()->SetTitle(" Background Rejection ");
101 c->BuildLegend(0.15, 0.15, 0.3, 0.3);
102 c->Draw();
103
104 outputFile->Close();
105 delete cl;
106}
#define c(i)
Definition RSha256.hxx:101
unsigned int UInt_t
Definition RtypesCore.h:46
double Double_t
Definition RtypesCore.h:59
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
#define gROOT
Definition TROOT.h:406
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
R__EXTERN TSystem * gSystem
Definition TSystem.h:572
The Canvas class.
Definition TCanvas.h:23
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4130
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition Envelope.cxx:163
std::vector< ClassificationResult > & GetResults()
Return the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
static Tools & Instance()
Definition Tools.cxx:71
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:34
Basic string class.
Definition TString.h:139
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1307
A TTree represents a columnar dataset.
Definition TTree.h:79
void classification(UInt_t jobs=4)