Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Forest.hxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Web : http://tmva.sourceforge.net *
5 * *
6 * Description: *
7 * *
8 * Authors: *
9 * Stefan Wunsch (stefan.wunsch@cern.ch) *
10 * Luca Zampieri (luca.zampieri@alumni.epfl.ch) *
11 * *
12 * Copyright (c) 2019: *
13 * CERN, Switzerland *
14 * *
15 * Redistribution and use in source and binary forms, with or without *
16 * modification, are permitted according to the terms listed in LICENSE *
17 * (http://tmva.sourceforge.net/LICENSE) *
18 **********************************************************************************/
19
20#ifndef TMVA_TREEINFERENCE_FOREST
21#define TMVA_TREEINFERENCE_FOREST
22
23#include <functional>
24#include <string>
25#include <vector>
26#include <stdexcept>
27#include <cmath>
28#include <algorithm>
29
30#include "TFile.h"
31#include "TDirectory.h"
32#include "TInterpreter.h"
33#include "TUUID.h"
34#include "TGenericClassInfo.h" // ROOT::Internal::GetDemangledTypeName
35
36#include "BranchlessTree.hxx"
37#include "Objectives.hxx"
38
39namespace TMVA {
40namespace Experimental {
41
42namespace Internal {
43template <typename T>
44T *GetObjectSafe(TFile *f, const std::string &n, const std::string &m)
45{
46 auto *v = f->Get<T>(m.c_str());
47 if (v == nullptr)
48 throw std::runtime_error("Failed to read " + m + " from file " + n + ".");
49 return v;
50}
51
52template <typename T>
54{
55 if (a.fInputs[0] == b.fInputs[0])
56 return a.fThresholds[0] < b.fThresholds[0];
57 else
58 return a.fInputs[0] < b.fInputs[0];
59}
60} // namespace Internal
61
62/// Forest base class
63///
64/// \tparam T Value type for the computation (usually floating point type)
65/// \tparam ForestType Type of the collection of trees
66template <typename T, typename ForestType>
67struct ForestBase {
68 using Value_t = T;
69 std::function<T(T)> fObjectiveFunc; ///< Objective function
70 ForestType fTrees; ///< Store the forest, either as vector or jitted function
71 int fNumInputs; ///< Number of input variables
72
73 void Inference(const T *inputs, const int rows, bool layout, T *predictions);
74};
75
76/// Perform inference of the forest on a batch of inputs
77///
78/// \param[in] inputs Pointer to data containing the inputs
79/// \param[in] rows Number of events in inputs vector
80/// \param[in] layout Row major (true) or column major (false) memory layout
81/// \param[in] predictions Pointer to the buffer to be filled with the predictions
82template <typename T, typename ForestType>
83inline void ForestBase<T, ForestType>::Inference(const T *inputs, const int rows, bool layout, T *predictions)
84{
85 const auto strideTree = layout ? 1 : rows;
86 const auto strideBatch = layout ? fNumInputs : 1;
87 for (int i = 0; i < rows; i++) {
88 predictions[i] = 0.0;
89 for (auto &tree : fTrees) {
90 predictions[i] += tree.Inference(inputs + i * strideBatch, strideTree);
91 }
92 predictions[i] = fObjectiveFunc(predictions[i]);
93 }
94}
95
96/// Forest using branchless trees
97///
98/// \tparam T Value type for the computation (usually floating point type)
99template <typename T>
100struct BranchlessForest : public ForestBase<T, std::vector<BranchlessTree<T>>> {
101 void Load(const std::string &key, const std::string &filename, const int output = 0, const bool sortTrees = true);
102};
103
104/// Load parameters from a ROOT file to the branchless trees
105///
106/// \param[in] key Name of folder in the ROOT file containing the model parameters
107/// \param[in] filename Filename of the ROOT file
108/// \param[in] output Load trees corresponding to the given output node of the forest
109/// \param[in] sortTrees Flag to indicate sorting the input trees by the cut value of the first node of each tree
110template <typename T>
111inline void
112BranchlessForest<T>::Load(const std::string &key, const std::string &filename, const int output, const bool sortTrees)
113{
114 // Open input file and get folder from key
115 auto file = TFile::Open(filename.c_str(), "READ");
116
117 // Load parameters from file
118 auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/max_depth");
119 auto numTrees = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_trees");
120 auto numInputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_inputs");
121 auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_outputs");
122 auto objective = Internal::GetObjectSafe<std::string>(file, filename, key + "/objective");
123 auto inputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/inputs");
124 auto outputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/outputs");
125 auto thresholds = Internal::GetObjectSafe<std::vector<T>>(file, filename, key + "/thresholds");
126
127 this->fNumInputs = numInputs->at(0);
128 this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
129 const auto lenInputs = std::pow(2, maxDepth->at(0)) - 1;
130 const auto lenThresholds = std::pow(2, maxDepth->at(0) + 1) - 1;
131
132 // Find number of trees corresponding to given output node
133 if (output > numOutputs->at(0))
134 throw std::runtime_error("Given output node of the forest is larger or equal to number of output nodes.");
135 int c = 0;
136 for (int i = 0; i < numTrees->at(0); i++)
137 if (outputs->at(i) == output)
138 c++;
139 if (c == 0)
140 std::runtime_error("No trees found for given output node of the forest.");
141 this->fTrees.resize(c);
142
143 // Load parameters in trees
144 c = 0;
145 for (int i = 0; i < numTrees->at(0); i++) {
146 // Select only trees for the given output node of the forest
147 if (outputs->at(i) != output)
148 continue;
149
150 // Set tree depth
151 this->fTrees[c].fTreeDepth = maxDepth->at(0);
152
153 // Set feature indices
154 this->fTrees[c].fInputs.resize(lenInputs);
155 for (int j = 0; j < lenInputs; j++)
156 this->fTrees[c].fInputs[j] = inputs->at(i * lenInputs + j);
157
158 // Set threshold values
159 this->fTrees[c].fThresholds.resize(lenThresholds);
160 for (int j = 0; j < lenThresholds; j++)
161 this->fTrees[c].fThresholds[j] = thresholds->at(i * lenThresholds + j);
162
163 // Fill sparse trees fully
164 this->fTrees[c].FillSparse();
165
166 c++;
167 }
168
169 // Sort trees by first cut variable and threshold value
170 if (sortTrees)
171 std::sort(this->fTrees.begin(), this->fTrees.end(), Internal::CompareTree<T>);
172
173 // Clean-up
174 delete maxDepth;
175 delete numTrees;
176 delete numInputs;
177 delete objective;
178 delete inputs;
179 delete thresholds;
180 file->Close();
181}
182
183/// Forest using branchless jitted trees
184///
185/// \tparam T Value type for the computation (usually floating point type)
186template <typename T>
187struct BranchlessJittedForest : public ForestBase<T, std::function<void (const T *, const int, bool, T*)>> {
188 std::string Load(const std::string &key, const std::string &filename, const int output = 0, const bool sortTrees = true);
189 void Inference(const T *inputs, const int rows, bool layout, T *predictions);
190};
191
192/// Load parameters from a ROOT file to the branchless trees
193///
194/// \param[in] key Name of folder in the ROOT file containing the model parameters
195/// \param[in] filename Filename of the ROOT file
196/// \param[in] output Load trees corresponding to the given output node of the forest
197/// \param[in] sortTrees Flag to indicate sorting the input trees by the cut value of the first node of each tree
198/// \return Return jitted code as string
199template <typename T>
200inline std::string
201BranchlessJittedForest<T>::Load(const std::string &key, const std::string &filename, const int output, const bool sortTrees)
202{
203 // Open input file and get folder from key
204 auto file = TFile::Open(filename.c_str(), "READ");
205
206 // Load parameters from file
207 auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/max_depth");
208 auto numTrees = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_trees");
209 auto numInputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_inputs");
210 auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_outputs");
211 auto objective = Internal::GetObjectSafe<std::string>(file, filename, key + "/objective");
212 auto inputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/inputs");
213 auto outputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/outputs");
214 auto thresholds = Internal::GetObjectSafe<std::vector<T>>(file, filename, key + "/thresholds");
215
216 this->fNumInputs = numInputs->at(0);
217 this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
218 const auto lenInputs = std::pow(2, maxDepth->at(0)) - 1;
219 const auto lenThresholds = std::pow(2, maxDepth->at(0) + 1) - 1;
220
221 // Find number of trees corresponding to given output node
222 if (output > numOutputs->at(0))
223 throw std::runtime_error("Given output node of the forest is larger or equal to number of output nodes.");
224 int c = 0;
225 for (int i = 0; i < numTrees->at(0); i++)
226 if (outputs->at(i) == output)
227 c++;
228 if (c == 0)
229 std::runtime_error("No trees found for given output node of the forest.");
230
231 // Get typename of template argument as string
232 std::string typeName = ROOT::Internal::GetDemangledTypeName(typeid(T));
233 if (typeName.compare("") == 0) {
234 throw std::runtime_error("Failed to just-in-time compile inference code for branchless forest (typename as string)");
235 }
236
237 // Load parameters in trees
238 std::vector<T> firstThreshold(c);
239 std::vector<int> firstInput(c, -1);
240 std::vector<std::string> codes(c);
241 c = 0;
242 for (int i = 0; i < numTrees->at(0); i++) {
243 // Select only trees for the given output node of the forest
244 if (outputs->at(i) != output)
245 continue;
246
247 // Set tree depth
249 tree.fTreeDepth = maxDepth->at(0);
250
251 // Set feature indices
252 tree.fInputs.resize(lenInputs);
253 for (int j = 0; j < lenInputs; j++)
254 tree.fInputs[j] = inputs->at(i * lenInputs + j);
255
256 // Set threshold values
257 tree.fThresholds.resize(lenThresholds);
258 for (int j = 0; j < lenThresholds; j++)
259 tree.fThresholds[j] = thresholds->at(i * lenThresholds + j);
260
261 // Fill sparse trees fully
262 tree.FillSparse();
263
264 // Save first threshold and input index for ordering the trees later
265 firstThreshold[c] = tree.fThresholds[0];
266 if (lenInputs != 0)
267 firstInput[c] = tree.fInputs[0];
268
269 // Save code for jitting
270 std::stringstream ss;
271 ss << "tree" << c;
272 codes[c] = tree.GetInferenceCode(ss.str(), typeName);
273
274 c++;
275 }
276
277 // Sort trees by first cut variable and threshold value
278 std::vector<int> treeIndices(codes.size());
279 for(int i = 0; i < c; i++) treeIndices[i] = i;
280 if (sortTrees) {
281 auto compareIndices = [&firstInput, &firstThreshold](int i, int j)
282 {
283 if (firstInput[i] == firstInput[j])
284 return firstThreshold[i] < firstThreshold[j];
285 else
286 return firstInput[i] < firstInput[j];
287 };
288 std::sort(treeIndices.begin(), treeIndices.end(), compareIndices);
289 }
290
291 // Get unique ID for a private namespace
292 TUUID uuid;
293 std::string nameSpace = uuid.AsString();
294 for (auto& v : nameSpace) {
295 if (v == '-') v = '_';
296 }
297 nameSpace = "ns_" + nameSpace;
298
299 // JIT the forest
300 std::stringstream jitForest;
301 jitForest << "#pragma cling optimize(3)\n"
302 << "namespace " << nameSpace << " {\n";
303 for (int i = 0; i < static_cast<int>(codes.size()); i++) {
304 jitForest << codes[treeIndices[i]] << "\n\n";
305 }
306 jitForest << "void Inference(const "
307 << typeName << "* inputs, const int rows, bool layout, "
308 << typeName << "* predictions)"
309 << "\n{\n"
310 << " const auto strideTree = layout ? 1 : rows;\n"
311 << " const auto strideBatch = layout ? " << this->fNumInputs << " : 1;\n"
312 << " for (int i = 0; i < rows; i++) {\n"
313 << " predictions[i] = 0.0;\n";
314 for (int i = 0; i < static_cast<int>(codes.size()); i++) {
315 std::stringstream ss;
316 ss << "tree" << i;
317 const std::string funcName = ss.str();
318 jitForest << " predictions[i] += " << funcName << "(inputs + i * strideBatch, strideTree);\n";
319 }
320 jitForest << " }\n"
321 << "}\n"
322 << "} // end namespace " << nameSpace;
323 const std::string jitForestStr = jitForest.str();
324 const auto err = gInterpreter->Declare(jitForestStr.c_str());
325 if (err == 0) {
326 throw std::runtime_error("Failed to just-in-time compile inference code for branchless forest (declare function)");
327 }
328
329 // Get function pointer and attach pointer to the forest
330 std::stringstream treesFunc;
331 treesFunc << "#pragma cling optimize(3)\n" << nameSpace << "::Inference";
332 const std::string treesFuncStr = treesFunc.str();
333 auto ptr = gInterpreter->Calc(treesFuncStr.c_str());
334 if (ptr == 0) {
335 throw std::runtime_error("Failed to just-in-time compile inference code for branchless forest (compile function)");
336 }
337 this->fTrees = reinterpret_cast<void (*)(const T *, int, bool, float*)>(ptr);
338
339 // Clean-up
340 delete maxDepth;
341 delete numTrees;
342 delete numInputs;
343 delete objective;
344 delete inputs;
345 delete thresholds;
346 file->Close();
347
348 return jitForestStr;
349}
350
351/// Perform inference of the forest with the jitted branchless implementation on a batch of inputs
352///
353/// \param[in] inputs Pointer to data containing the inputs
354/// \param[in] rows Number of events in inputs vector
355/// \param[in] layout Row major (true) or column major (false) memory layout
356/// \param[in] predictions Pointer to the buffer to be filled with the predictions
357template <typename T>
358void BranchlessJittedForest<T>::Inference(const T *inputs, const int rows, bool layout, T *predictions)
359{
360 this->fTrees(inputs, rows, layout, predictions);
361 for (int i = 0; i < rows; i++)
362 predictions[i] = this->fObjectiveFunc(predictions[i]);
363}
364
365} // namespace Experimental
366} // namespace TMVA
367
368#endif // TMVA_TREEINFERENCE_FOREST
#define b(i)
Definition RSha256.hxx:100
#define f(i)
Definition RSha256.hxx:104
#define c(i)
Definition RSha256.hxx:101
#define a(i)
Definition RSha256.hxx:99
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
#define gInterpreter
A ROOT file is composed of a header, followed by consecutive data records (TKey instances) with a wel...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4075
This class defines a UUID (Universally Unique IDentifier), also known as GUIDs (Globally Unique IDent...
Definition TUUID.h:42
const char * AsString() const
Return UUID as string. Copy string immediately since it will be reused.
Definition TUUID.cxx:571
const Int_t n
Definition legend1.C:16
std::string GetDemangledTypeName(const std::type_info &t)
T * GetObjectSafe(TFile *f, const std::string &n, const std::string &m)
Definition Forest.hxx:44
bool CompareTree(const BranchlessTree< T > &a, const BranchlessTree< T > &b)
Definition Forest.hxx:53
create variable transformations
Definition file.py:1
Definition tree.py:1
Forest using branchless trees.
Definition Forest.hxx:100
void Load(const std::string &key, const std::string &filename, const int output=0, const bool sortTrees=true)
Load parameters from a ROOT file to the branchless trees.
Definition Forest.hxx:112
Forest using branchless jitted trees.
Definition Forest.hxx:187
void Inference(const T *inputs, const int rows, bool layout, T *predictions)
Perform inference of the forest with the jitted branchless implementation on a batch of inputs.
Definition Forest.hxx:358
std::string Load(const std::string &key, const std::string &filename, const int output=0, const bool sortTrees=true)
Load parameters from a ROOT file to the branchless trees.
Definition Forest.hxx:201
Branchless representation of a decision tree using topological ordering.
Forest base class.
Definition Forest.hxx:67
std::function< T(T)> fObjectiveFunc
Objective function.
Definition Forest.hxx:69
int fNumInputs
Number of input variables.
Definition Forest.hxx:71
ForestType fTrees
Store the forest, either as vector or jitted function.
Definition Forest.hxx:70
void Inference(const T *inputs, const int rows, bool layout, T *predictions)
Perform inference of the forest on a batch of inputs.
Definition Forest.hxx:83
TMarker m
Definition textangle.C:8
static void output()