15void train(
const std::string &
filename)
20 output,
"!V:!DrawProgressBar:AnalysisType=Classification");
24 auto signal = (
TTree *)
data->Get(
"TreeS");
25 auto background = (
TTree *)
data->Get(
"TreeB");
29 const std::vector<std::string>
variables = {
"var1",
"var2",
"var3",
"var4"};
31 dataloader->AddVariable(var);
33 dataloader->AddSignalTree(signal, 1.0);
34 dataloader->AddBackgroundTree(background, 1.0);
35 dataloader->PrepareTrainingAndTestTree(
"",
"");
38 factory->BookMethod(dataloader,
TMVA::Types::kBDT,
"BDT",
"!V:!H:NTrees=300:MaxDepth=2");
39 factory->TrainAllMethods();
45 const std::string
filename =
"http://root.cern/files/tmva_class_example.root";
49 RReader model(
"tmva003_BDT/weights/tmva003_BDT.weights.xml");
53 auto variables = model.GetVariableNames();
64 auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5});
65 std::cout <<
"Single-event inference: " << prediction[0] <<
"\n\n";
72 auto df2 = df.Range(3);
73 auto x = AsTensor<float>(df2, variables);
74 auto y = model.Compute(
x);
76 std::cout <<
"RTensor input for inference on data of multiple events:\n" <<
x <<
"\n\n";
77 std::cout <<
"Prediction performed on multiple events: " <<
y <<
"\n\n";
82 auto make_histo = [&](
const std::string &treename) {
84 auto df2 = df.Define(
"y", Compute<4, float>(model), variables);
85 return df2.Histo1D({treename.c_str(),
";BDT score;N_{Events}", 30, -0.5, 0.5},
"y");
88 auto sig = make_histo(
"TreeS");
89 auto bkg = make_histo(
"TreeB");
93 auto c =
new TCanvas(
"",
"", 800, 800);
95 sig->SetLineColor(
kRed);
96 bkg->SetLineColor(
kBlue);
100 sig->Draw(
"HIST SAME");
102 TLegend legend(0.7, 0.7, 0.89, 0.89);
103 legend.SetBorderSize(0);
104 legend.AddEntry(
"TreeS",
"Signal",
"l");
105 legend.AddEntry(
"TreeB",
"Background",
"l");
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
R__EXTERN TStyle * gStyle
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This class displays a legend box (TPaveText) containing several legend entries.
A replacement for the TMVA::Reader legacy interface.
This is the main MVA steering class.
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
A TTree represents a columnar dataset.
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)