Logo ROOT   6.16/01
Reference Guide
classification.C
Go to the documentation of this file.
1#include "TMVA/Factory.h"
2#include "TMVA/DataLoader.h"
3#include "TMVA/Tools.h"
5
6void classification(UInt_t jobs = 4)
7{
9
10 TFile *input(0);
11 TString fname = "./tmva_class_example.root";
12 if (!gSystem->AccessPathName(fname)) {
13 input = TFile::Open(fname); // check if file in local directory exists
14 } else {
16 input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
17 }
18 if (!input) {
19 std::cout << "ERROR: could not open data file" << std::endl;
20 exit(1);
21 }
22
23 // Register the training and test trees
24
25 TTree *signalTree = (TTree *)input->Get("TreeS");
26 TTree *background = (TTree *)input->Get("TreeB");
27
29 // If you wish to modify default settings
30 // (please check "src/Config.h" to see all available global options)
31 //
32 // (TMVA::gConfig().GetVariablePlotting()).fTimesRMS = 8.0;
33 // (TMVA::gConfig().GetIONames()).fWeightFileDir = "myWeightDirectory";
34
35 // Define the input variables that shall be used for the MVA training
36 // note that you may also use variable expressions, such as: "3*var1/var2*abs(var3)"
37 // [all types of expressions that can also be parsed by TTree::Draw( "expression" )]
38 dataloader->AddVariable("myvar1 := var1+var2", 'F');
39 dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
40 dataloader->AddVariable("var3", "Variable 3", "units", 'F');
41 dataloader->AddVariable("var4", "Variable 4", "units", 'F');
42
43 // You can add so-called "Spectator variables", which are not used in the MVA training,
44 // but will appear in the final "TestTree" produced by TMVA. This TestTree will contain the
45 // input variables, the response values of all trained MVAs, and the spectator variables
46
47 dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
48 dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
49
50 // global event weights per tree (see below for setting event-wise weights)
51 Double_t signalWeight = 1.0;
52 Double_t backgroundWeight = 1.0;
53
54 // You can add an arbitrary number of signal or background trees
55 dataloader->AddSignalTree(signalTree, signalWeight);
56 dataloader->AddBackgroundTree(background, backgroundWeight);
57
58 // Set individual event weights (the variables must exist in the original TTree)
59 // - for signal : `dataloader->SetSignalWeightExpression ("weight1*weight2");`
60 // - for background: `dataloader->SetBackgroundWeightExpression("weight1*weight2");`
61 dataloader->SetBackgroundWeightExpression("weight");
62 dataloader->PrepareTrainingAndTestTree(
63 "", "", "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V");
64
65 TFile *outputFile = TFile::Open("TMVAClass.root", "RECREATE");
66
68
69 cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
70 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
71 cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
72
73 cl->BookMethod(TMVA::Types::kBDT, "BDTB", "!H:!V:NTrees=2000:BoostType=Bagging:SeparationType=GiniIndex:nCuts=20");
74
75 cl->BookMethod(TMVA::Types::kCuts, "Cuts", "!H:!V:FitMethod=MC:EffSel:SampleSize=200000:VarProp=FSmart");
76
77 cl->Evaluate(); // Train and Test all methods
78
79 auto &results = cl->GetResults();
80
81 TCanvas *c = new TCanvas(Form("ROC"));
82 c->SetTitle("ROC-Integral Curve");
83
84 auto mg = new TMultiGraph();
85 for (UInt_t i = 0; i < results.size(); i++) {
86 if (!results[i].IsCutsMethod()) {
87 auto roc = results[i].GetROCGraph();
88 roc->SetLineColorAlpha(i + 1, 0.1);
89 mg->Add(roc);
90 }
91 }
92 mg->Draw("AL");
93 mg->GetXaxis()->SetTitle(" Signal Efficiency ");
94 mg->GetYaxis()->SetTitle(" Background Rejection ");
95 c->BuildLegend(0.15, 0.15, 0.3, 0.3);
96 c->Draw();
97
98 outputFile->Close();
99 delete cl;
100}
#define c(i)
Definition: RSha256.hxx:101
unsigned int UInt_t
Definition: RtypesCore.h:42
double Double_t
Definition: RtypesCore.h:55
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
The Canvas class.
Definition: TCanvas.h:31
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:912
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition: TFile.h:316
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3975
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition: Envelope.cxx:168
std::vector< ClassificationResult > & GetResults()
return the the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
static Tools & Instance()
Definition: Tools.cxx:75
@ kBDT
Definition: Types.h:88
@ kCuts
Definition: Types.h:80
@ kSVM
Definition: Types.h:91
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Basic string class.
Definition: TString.h:131
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1286
A TTree object has a header with a name and a title.
Definition: TTree.h:71
void classification(UInt_t jobs=4)
Definition: classification.C:6
static constexpr double mg