20#ifndef TMVA_TREEINFERENCE_BRANCHLESSTREE
21#define TMVA_TREEINFERENCE_BRANCHLESSTREE
29namespace Experimental {
35void RecursiveFill(
int thisIndex,
int lastIndex,
int treeDepth,
int maxTreeDepth, std::vector<T> &thresholds,
36 std::vector<int> &inputs)
40 if (inputs[lastIndex] == -1) {
41 thresholds.at(thisIndex) = thresholds.at(lastIndex);
44 if (treeDepth < maxTreeDepth)
45 inputs.at(thisIndex) = -1;
49 if (treeDepth < maxTreeDepth) {
50 Internal::RecursiveFill<T>(2 * thisIndex + 1, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
51 Internal::RecursiveFill<T>(2 * thisIndex + 2, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
69 inline std::string
GetInferenceCode(
const std::string& funcName,
const std::string& typeName);
80 for (
int level = 0; level < fTreeDepth; ++level) {
83 return fThresholds[
index];
96 Internal::RecursiveFill<T>(1, 0, 1, fTreeDepth, fThresholds, fInputs);
97 Internal::RecursiveFill<T>(2, 0, 1, fTreeDepth, fThresholds, fInputs);
100 std::replace(fInputs.begin(), fInputs.end(), -1.0, 0.0);
112 std::stringstream ss;
115 ss <<
"inline " << typeName <<
" " << funcName <<
"(const " << typeName <<
"* input, const int stride)";
121 ss <<
" const int inputs[" << fInputs.size() <<
"] = {";
122 int last =
static_cast<int>(fInputs.size() - 1);
123 for (
int i = 0; i < last + 1; i++) {
125 if (i != last) ss <<
", ";
129 ss <<
" const " << typeName <<
" thresholds[" << fThresholds.size() <<
"] = {";
130 last =
static_cast<int>(fThresholds.size() - 1);
131 for (
int i = 0; i < last + 1; i++) {
132 ss << fThresholds[i];
133 if (i != last) ss <<
", ";
138 ss <<
" int index = 0;\n";
139 for (
int level = 0; level < fTreeDepth; ++level) {
140 ss <<
" index = 2 * index + 1 + (input[inputs[index] * stride] > thresholds[index]);\n";
142 ss <<
" return thresholds[index];\n";
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
void RecursiveFill(int thisIndex, int lastIndex, int treeDepth, int maxTreeDepth, std::vector< T > &thresholds, std::vector< int > &inputs)
Fill the empty nodes of a sparse tree recursively.
create variable transformations
Branchless representation of a decision tree using topological ordering.
std::vector< int > fInputs
Cut variables / inputs.
std::vector< T > fThresholds
Cut thresholds or scores if corresponding node is a leaf.
void FillSparse()
Fill nodes of a sparse tree forming a full tree.
int fTreeDepth
Depth of the tree.
T Inference(const T *input, const int stride)
Perform inference on a single input vector.
std::string GetInferenceCode(const std::string &funcName, const std::string &typeName)
Get code for compiling the inference function of the branchless tree with the current thresholds and ...