Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBDT.hxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Web : http://tmva.sourceforge.net *
5 * *
6 * Description: *
7 * *
8 * Authors: *
9 * Stefan Wunsch (stefan.wunsch@cern.ch) *
10 * *
11 * Copyright (c) 2019: *
12 * CERN, Switzerland *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (http://tmva.sourceforge.net/LICENSE) *
17 **********************************************************************************/
18
19#ifndef TMVA_RBDT
20#define TMVA_RBDT
21
22#include "TMVA/RTensor.hxx"
24#include "TFile.h"
25
26#include <vector>
27#include <string>
28#include <sstream> // std::stringstream
29#include <memory>
30
31namespace TMVA {
32namespace Experimental {
33
34/// Fast boosted decision tree inference
35template <typename Backend = BranchlessJittedForest<float>>
36class RBDT {
37public:
38 using Value_t = typename Backend::Value_t;
39 using Backend_t = Backend;
40
41private:
44 std::vector<Backend_t> fBackends;
45
46public:
47 /// Construct backends from model in ROOT file
48 RBDT(const std::string &key, const std::string &filename)
49 {
50 // Get number of output nodes of the forest
51 std::unique_ptr<TFile> file{TFile::Open(filename.c_str(),"READ")};
52 if (!file || file->IsZombie()) {
53 throw std::runtime_error("Failed to open input file " + filename);
54 }
55 auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file.get(), filename, key + "/num_outputs");
56 fNumOutputs = numOutputs->at(0);
57 delete numOutputs;
58
59 // Get objective and decide whether to normalize output nodes for example in the multiclass case
60 auto objective = Internal::GetObjectSafe<std::string>(file.get(), filename, key + "/objective");
61 if (objective->compare("softmax") == 0)
62 fNormalizeOutputs = true;
63 else
64 fNormalizeOutputs = false;
65 delete objective;
66 file->Close();
67
68 // Initialize backends
69 fBackends = std::vector<Backend_t>(fNumOutputs);
70 for (int i = 0; i < fNumOutputs; i++)
71 fBackends[i].Load(key, filename, i);
72 }
73
74 /// Compute model prediction on a single event
75 ///
76 /// The method is intended to be used with std::vectors-like containers,
77 /// for example RVecs.
78 template <typename Vector>
79 Vector Compute(const Vector &x)
80 {
81 Vector y;
82 y.resize(fNumOutputs);
83 for (int i = 0; i < fNumOutputs; i++)
84 fBackends[i].Inference(&x[0], 1, true, &y[i]);
86 Value_t s = 0.0;
87 for (int i = 0; i < fNumOutputs; i++)
88 s += y[i];
89 for (int i = 0; i < fNumOutputs; i++)
90 y[i] /= s;
91 }
92 return y;
93 }
94
95 /// Compute model prediction on a single event
96 std::vector<Value_t> Compute(const std::vector<Value_t> &x) { return this->Compute<std::vector<Value_t>>(x); }
97
98 /// Compute model prediction on input RTensor
100 {
101 const auto rows = x.GetShape()[0];
102 RTensor<Value_t> y({rows, static_cast<std::size_t>(fNumOutputs)}, MemoryLayout::ColumnMajor);
103 const bool layout = x.GetMemoryLayout() == MemoryLayout::ColumnMajor ? false : true;
104 for (int i = 0; i < fNumOutputs; i++)
105 fBackends[i].Inference(x.GetData(), rows, layout, &y(0, i));
106 if (fNormalizeOutputs) {
107 Value_t s;
108 for (int i = 0; i < static_cast<int>(rows); i++) {
109 s = 0.0;
110 for (int j = 0; j < fNumOutputs; j++)
111 s += y(i, j);
112 for (int j = 0; j < fNumOutputs; j++)
113 y(i, j) /= s;
114 }
115 }
116 return y;
117 }
118};
119
122
123} // namespace Experimental
124} // namespace TMVA
125
126#endif // TMVA_RBDT
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4075
Fast boosted decision tree inference.
Definition RBDT.hxx:36
RTensor< Value_t > Compute(const RTensor< Value_t > &x)
Compute model prediction on input RTensor.
Definition RBDT.hxx:99
std::vector< Value_t > Compute(const std::vector< Value_t > &x)
Compute model prediction on a single event.
Definition RBDT.hxx:96
typename Backend::Value_t Value_t
Definition RBDT.hxx:38
RBDT(const std::string &key, const std::string &filename)
Construct backends from model in ROOT file.
Definition RBDT.hxx:48
Vector Compute(const Vector &x)
Compute model prediction on a single event.
Definition RBDT.hxx:79
std::vector< Backend_t > fBackends
Definition RBDT.hxx:44
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
MemoryLayout GetMemoryLayout() const
Definition RTensor.hxx:248
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
create variable transformations
Definition file.py:1