Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
tmva003_RReader.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This tutorial shows how to apply with the modern interfaces models saved in
5/// TMVA XML files.
6///
7/// \macro_code
8/// \macro_output
9///
10/// \date July 2019
11/// \author Stefan Wunsch
12
13using namespace TMVA::Experimental;
14
15void train(const std::string &filename)
16{
17 // Create factory
18 auto output = TFile::Open("TMVARR.root", "RECREATE");
19 auto factory = new TMVA::Factory("tmva003",
20 output, "!V:!DrawProgressBar:AnalysisType=Classification");
21
22 // Open trees with signal and background events
23 auto data = TFile::Open(filename.c_str());
24 auto signal = (TTree *)data->Get("TreeS");
25 auto background = (TTree *)data->Get("TreeB");
26
27 // Add variables and register the trees with the dataloader
28 auto dataloader = new TMVA::DataLoader("tmva003_BDT");
29 const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"};
30 for (const auto &var : variables) {
31 dataloader->AddVariable(var);
32 }
33 dataloader->AddSignalTree(signal, 1.0);
34 dataloader->AddBackgroundTree(background, 1.0);
35 dataloader->PrepareTrainingAndTestTree("", "");
36
37 // Train a TMVA method
38 factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=300:MaxDepth=2");
39 factory->TrainAllMethods();
40}
41
42void tmva003_RReader()
43{
44 // First, let's train a model with TMVA.
45 const std::string filename = "http://root.cern/files/tmva_class_example.root";
46 train(filename);
47
48 // Next, we load the model from the TMVA XML file.
49 RReader model("tmva003_BDT/weights/tmva003_BDT.weights.xml");
50
51 // In case you need a reminder of the names and order of the variables during
52 // training, you can ask the model for it.
53 auto variables = model.GetVariableNames();
54
55 // The model can now be applied in different scenarios:
56 // 1) Event-by-event inference
57 // 2) Batch inference on data of multiple events
58 // 3) Inference as part of an RDataFrame graph
59
60 // 1) Event-by-event inference
61 // The event-by-event inference takes the values of the variables as a std::vector<float>.
62 // Note that the return value is as well a std::vector<float> since the reader
63 // is also capable to process models with multiple outputs.
64 auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5});
65 std::cout << "Single-event inference: " << prediction[0] << "\n\n";
66
67 // 2) Batch inference on data of multiple events
68 // For batch inference, the data needs to be structured as a matrix. For this
69 // purpose, TMVA makes use of the RTensor class. For convenience, we use RDataFrame
70 // and the AsTensor utility to make the read-out from the ROOT file.
71 ROOT::RDataFrame df("TreeS", filename);
72 auto df2 = df.Range(3); // Read only a small subset of the dataset
73 auto x = AsTensor<float>(df2, variables);
74 auto y = model.Compute(x);
75
76 std::cout << "RTensor input for inference on data of multiple events:\n" << x << "\n\n";
77 std::cout << "Prediction performed on multiple events: " << y << "\n\n";
78
79 // 3) Perform inference as part of an RDataFrame graph
80 // We write a small lambda function that performs for us the inference on
81 // a dataframe to omit code duplication.
82 auto make_histo = [&](const std::string &treename) {
83 ROOT::RDataFrame df(treename, filename);
84 auto df2 = df.Define("y", Compute<4, float>(model), variables);
85 return df2.Histo1D({treename.c_str(), ";BDT score;N_{Events}", 30, -0.5, 0.5}, "y");
86 };
87
88 auto sig = make_histo("TreeS");
89 auto bkg = make_histo("TreeB");
90
91 // Make plot
93 auto c = new TCanvas("", "", 800, 800);
94
95 sig->SetLineColor(kRed);
96 bkg->SetLineColor(kBlue);
97 sig->SetLineWidth(2);
98 bkg->SetLineWidth(2);
99 bkg->Draw("HIST");
100 sig->Draw("HIST SAME");
101
102 TLegend legend(0.7, 0.7, 0.89, 0.89);
103 legend.SetBorderSize(0);
104 legend.AddEntry("TreeS", "Signal", "l");
105 legend.AddEntry("TreeB", "Background", "l");
106 legend.Draw();
107
108 c->DrawClone();
109}
#define c(i)
Definition RSha256.hxx:101
@ kRed
Definition Rtypes.h:66
@ kBlue
Definition Rtypes.h:66
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
R__EXTERN TStyle * gStyle
Definition TStyle.h:433
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
The Canvas class.
Definition TCanvas.h:23
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4089
This class displays a legend box (TPaveText) containing several legend entries.
Definition TLegend.h:23
A replacement for the TMVA::Reader legacy interface.
Definition RReader.hxx:136
This is the main MVA steering class.
Definition Factory.h:80
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition TStyle.cxx:1636
A TTree represents a columnar dataset.
Definition TTree.h:79
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
static void output()