ROOT   Reference Guide
Searching...
No Matches
BranchlessTree.hxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Web : http://tmva.sourceforge.net *
5 * *
6 * Description: *
7 * *
8 * Authors: *
9 * Stefan Wunsch (stefan.wunsch@cern.ch) *
10 * Luca Zampieri (luca.zampieri@alumni.epfl.ch) *
11 * *
12 * Copyright (c) 2019: *
13 * CERN, Switzerland *
14 * *
15 * Redistribution and use in source and binary forms, with or without *
16 * modification, are permitted according to the terms listed in LICENSE *
18 **********************************************************************************/
19
20#ifndef TMVA_TREEINFERENCE_BRANCHLESSTREE
21#define TMVA_TREEINFERENCE_BRANCHLESSTREE
22
23#include <vector>
24#include <algorithm>
25#include <string>
26#include <sstream>
27
28namespace TMVA {
29namespace Experimental {
30
31namespace Internal {
32
33/// Fill the empty nodes of a sparse tree recursively
34template <typename T>
35void RecursiveFill(int thisIndex, int lastIndex, int treeDepth, int maxTreeDepth, std::vector<T> &thresholds,
36 std::vector<int> &inputs)
37{
38 // If we are upstream of a leaf in a sparse branch, copy the last threshold value
39 // and mark this node as a leaf again
40 if (inputs[lastIndex] == -1) {
41 thresholds.at(thisIndex) = thresholds.at(lastIndex);
42 // Don't access the feature vector in the last layer of the tree since we
43 // don't store these values in the inputs vector
44 if (treeDepth < maxTreeDepth)
45 inputs.at(thisIndex) = -1;
46 }
47
48 // Fill the children of this node if we are not in the final layer of the tree
49 if (treeDepth < maxTreeDepth) {
50 Internal::RecursiveFill<T>(2 * thisIndex + 1, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
51 Internal::RecursiveFill<T>(2 * thisIndex + 2, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
52 }
53}
54
55} // namespace Internal
56
57/// \class BranchlessTree
58/// \brief Branchless representation of a decision tree using topological ordering
59///
60/// \tparam T Value type for the computation (usually floating point type)
61template <typename T>
63 int fTreeDepth; ///< Depth of the tree
64 std::vector<T> fThresholds; ///< Cut thresholds or scores if corresponding node is a leaf
65 std::vector<int> fInputs; ///< Cut variables / inputs
66
67 inline T Inference(const T *input, const int stride);
68 inline void FillSparse();
69 inline std::string GetInferenceCode(const std::string& funcName, const std::string& typeName);
70};
71
72/// Perform inference on a single input vector
73/// \param[in] input Pointer to data containing the input values
74/// \param[in] stride Stride to go from one input variable to the next one
75/// \return Tree score, result of the inference
76template <typename T>
77inline T BranchlessTree<T>::Inference(const T *input, const int stride)
78{
79 int index = 0;
80 for (int level = 0; level < fTreeDepth; ++level) {
81 index = 2 * index + 1 + (input[fInputs[index] * stride] > fThresholds[index]);
82 }
83 return fThresholds[index];
84}
85
86/// Fill nodes of a sparse tree forming a full tree
87///
88/// Sparse parts of the tree are marked with -1 values in the feature vector. The
89/// algorithm fills these parts up with the last threshold value so that the result
90/// of the inference stays the same but the computation always traverses the full tree,
91/// which is needed to avoid branching logic.
92template <typename T>
94{
95 // Fill threshold / leaf values recursively
96 Internal::RecursiveFill<T>(1, 0, 1, fTreeDepth, fThresholds, fInputs);
97 Internal::RecursiveFill<T>(2, 0, 1, fTreeDepth, fThresholds, fInputs);
98
99 // Replace feature indices of -1 with 0
100 std::replace(fInputs.begin(), fInputs.end(), -1.0, 0.0);
101}
102
103/// Get code for compiling the inference function of the branchless tree with
104/// the current thresholds and cut variables
105///
106/// \param[in] funcName Name of the function
107/// \param[in] typeName Name of the type used for the computation
108/// \return Code of the inference function as string
109template <typename T>
110inline std::string BranchlessTree<T>::GetInferenceCode(const std::string& funcName, const std::string& typeName)
111{
112 std::stringstream ss;
113
114 // Build signature
115 ss << "inline " << typeName << " " << funcName << "(const " << typeName << "* input, const int stride)";
116
117 // Function body
118 ss << "\n{\n";
119
120 // Hard-code thresholds and cut variables
121 ss << " const int inputs[" << fInputs.size() << "] = {";
122 int last = static_cast<int>(fInputs.size() - 1);
123 for (int i = 0; i < last + 1; i++) {
124 ss << fInputs[i];
125 if (i != last) ss << ", ";
126 }
127 ss << "};\n";
128
129 ss << " const " << typeName << " thresholds[" << fThresholds.size() << "] = {";
130 last = static_cast<int>(fThresholds.size() - 1);
131 for (int i = 0; i < last + 1; i++) {
132 ss << fThresholds[i];
133 if (i != last) ss << ", ";
134 }
135 ss << "};\n";
136
138 ss << " int index = 0;\n";
139 for (int level = 0; level < fTreeDepth; ++level) {
140 ss << " index = 2 * index + 1 + (input[inputs[index] * stride] > thresholds[index]);\n";
141 }
142 ss << " return thresholds[index];\n";
143 ss << "}";
144
145 return ss.str();
146}
147
148} // namespace Experimental
149} // namespace TMVA
150
151#endif // TMVA_TREEINFERENCE_BRANCHLESSTREE
void RecursiveFill(int thisIndex, int lastIndex, int treeDepth, int maxTreeDepth, std::vector< T > &thresholds, std::vector< int > &inputs)
Fill the empty nodes of a sparse tree recursively.
create variable transformations
Branchless representation of a decision tree using topological ordering.
std::vector< int > fInputs
Cut variables / inputs.
std::vector< T > fThresholds
Cut thresholds or scores if corresponding node is a leaf.
void FillSparse()
Fill nodes of a sparse tree forming a full tree.
T Inference(const T *input, const int stride)
Perform inference on a single input vector.
std::string GetInferenceCode(const std::string &funcName, const std::string &typeName)
Get code for compiling the inference function of the branchless tree with the current thresholds and ...