20#ifndef TMVA_TREEINFERENCE_FOREST
21#define TMVA_TREEINFERENCE_FOREST
40namespace Experimental {
46 auto v =
reinterpret_cast<T *
>(
f->Get(
m.c_str()));
48 throw std::runtime_error(
"Failed to read " +
m +
" from file " +
n +
".");
55 if (
a.fInputs[0] ==
b.fInputs[0])
56 return a.fThresholds[0] <
b.fThresholds[0];
58 return a.fInputs[0] <
b.fInputs[0];
66template <
typename T,
typename ForestType>
73 void Inference(
const T *inputs,
const int rows,
bool layout,
T *predictions);
82template <
typename T,
typename ForestType>
85 const auto strideTree = layout ? 1 : rows;
86 const auto strideBatch = layout ? fNumInputs : 1;
87 for (
int i = 0; i < rows; i++) {
89 for (
auto &
tree : fTrees) {
90 predictions[i] +=
tree.Inference(inputs + i * strideBatch, strideTree);
92 predictions[i] = fObjectiveFunc(predictions[i]);
101 void Load(
const std::string &key,
const std::string &filename,
const int output = 0,
const bool sortTrees =
true);
118 auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/max_depth");
119 auto numTrees = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/num_trees");
120 auto numInputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/num_inputs");
121 auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/num_outputs");
122 auto objective = Internal::GetObjectSafe<std::string>(
file, filename, key +
"/objective");
123 auto inputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/inputs");
124 auto outputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/outputs");
125 auto thresholds = Internal::GetObjectSafe<std::vector<T>>(
file, filename, key +
"/thresholds");
127 this->fNumInputs = numInputs->at(0);
128 this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
129 const auto lenInputs =
std::pow(2, maxDepth->at(0)) - 1;
130 const auto lenThresholds =
std::pow(2, maxDepth->at(0) + 1) - 1;
133 if (
output > numOutputs->at(0))
134 throw std::runtime_error(
"Given output node of the forest is larger or equal to number of output nodes.");
136 for (
int i = 0; i < numTrees->at(0); i++)
137 if (outputs->at(i) ==
output)
140 std::runtime_error(
"No trees found for given output node of the forest.");
141 this->fTrees.resize(
c);
145 for (
int i = 0; i < numTrees->at(0); i++) {
147 if (outputs->at(i) !=
output)
151 this->fTrees[
c].fTreeDepth = maxDepth->at(0);
154 this->fTrees[
c].fInputs.resize(lenInputs);
155 for (
int j = 0; j < lenInputs; j++)
156 this->fTrees[
c].fInputs[j] = inputs->at(i * lenInputs + j);
159 this->fTrees[
c].fThresholds.resize(lenThresholds);
160 for (
int j = 0; j < lenThresholds; j++)
161 this->fTrees[
c].fThresholds[j] = thresholds->at(i * lenThresholds + j);
164 this->fTrees[
c].FillSparse();
171 std::sort(this->fTrees.begin(), this->fTrees.end(), Internal::CompareTree<T>);
188 std::string
Load(
const std::string &key,
const std::string &filename,
const int output = 0,
const bool sortTrees =
true);
189 void Inference(
const T *inputs,
const int rows,
bool layout,
T *predictions);
207 auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/max_depth");
208 auto numTrees = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/num_trees");
209 auto numInputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/num_inputs");
210 auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/num_outputs");
211 auto objective = Internal::GetObjectSafe<std::string>(
file, filename, key +
"/objective");
212 auto inputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/inputs");
213 auto outputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key +
"/outputs");
214 auto thresholds = Internal::GetObjectSafe<std::vector<T>>(
file, filename, key +
"/thresholds");
216 this->fNumInputs = numInputs->at(0);
217 this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
218 const auto lenInputs =
std::pow(2, maxDepth->at(0)) - 1;
219 const auto lenThresholds =
std::pow(2, maxDepth->at(0) + 1) - 1;
222 if (
output > numOutputs->at(0))
223 throw std::runtime_error(
"Given output node of the forest is larger or equal to number of output nodes.");
225 for (
int i = 0; i < numTrees->at(0); i++)
226 if (outputs->at(i) ==
output)
229 std::runtime_error(
"No trees found for given output node of the forest.");
233 if (typeName.compare(
"") == 0) {
234 throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (typename as string)");
238 std::vector<T> firstThreshold(
c);
239 std::vector<int> firstInput(
c, -1);
240 std::vector<std::string> codes(
c);
242 for (
int i = 0; i < numTrees->at(0); i++) {
244 if (outputs->at(i) !=
output)
249 tree.fTreeDepth = maxDepth->at(0);
252 tree.fInputs.resize(lenInputs);
253 for (
int j = 0; j < lenInputs; j++)
254 tree.fInputs[j] = inputs->at(i * lenInputs + j);
257 tree.fThresholds.resize(lenThresholds);
258 for (
int j = 0; j < lenThresholds; j++)
259 tree.fThresholds[j] = thresholds->at(i * lenThresholds + j);
265 firstThreshold[
c] =
tree.fThresholds[0];
267 firstInput[
c] =
tree.fInputs[0];
270 std::stringstream ss;
272 codes[
c] =
tree.GetInferenceCode(ss.str(), typeName);
278 std::vector<int> treeIndices(codes.size());
279 for(
int i = 0; i <
c; i++) treeIndices[i] = i;
281 auto compareIndices = [&firstInput, &firstThreshold](
int i,
int j)
283 if (firstInput[i] == firstInput[j])
284 return firstThreshold[i] < firstThreshold[j];
286 return firstInput[i] < firstInput[j];
288 std::sort(treeIndices.begin(), treeIndices.end(), compareIndices);
293 std::string nameSpace = uuid.
AsString();
294 for (
auto&
v : nameSpace) {
295 if (
v ==
'-')
v =
'_';
297 nameSpace =
"ns_" + nameSpace;
300 std::stringstream jitForest;
301 jitForest <<
"#pragma cling optimize(3)\n"
302 <<
"namespace " << nameSpace <<
" {\n";
303 for (
int i = 0; i < static_cast<int>(codes.size()); i++) {
304 jitForest << codes[treeIndices[i]] <<
"\n\n";
306 jitForest <<
"void Inference(const "
307 << typeName <<
"* inputs, const int rows, bool layout, "
308 << typeName <<
"* predictions)"
310 <<
" const auto strideTree = layout ? 1 : rows;\n"
311 <<
" const auto strideBatch = layout ? " << this->fNumInputs <<
" : 1;\n"
312 <<
" for (int i = 0; i < rows; i++) {\n"
313 <<
" predictions[i] = 0.0;\n";
314 for (
int i = 0; i < static_cast<int>(codes.size()); i++) {
315 std::stringstream ss;
317 const std::string funcName = ss.str();
318 jitForest <<
" predictions[i] += " << funcName <<
"(inputs + i * strideBatch, strideTree);\n";
322 <<
"} // end namespace " << nameSpace;
323 const std::string jitForestStr = jitForest.str();
324 const auto err =
gInterpreter->Declare(jitForestStr.c_str());
326 throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (declare function)");
330 std::stringstream treesFunc;
331 treesFunc <<
"#pragma cling optimize(3)\n" << nameSpace <<
"::Inference";
332 const std::string treesFuncStr = treesFunc.str();
335 throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (compile function)");
337 this->fTrees =
reinterpret_cast<void (*)(
const T *,
int,
bool,
float*)
>(ptr);
360 this->fTrees(inputs, rows, layout, predictions);
361 for (
int i = 0; i < rows; i++)
362 predictions[i] = this->fObjectiveFunc(predictions[i]);
double pow(double, double)
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This class defines a UUID (Universally Unique IDentifier), also known as GUIDs (Globally Unique IDent...
const char * AsString() const
Return UUID as string. Copy string immediately since it will be reused.
std::string GetDemangledTypeName(const std::type_info &)
Returns a string with the demangled and normalized name for the given type.
void function(const Char_t *name_, T fun, const Char_t *docstring=0)
T * GetObjectSafe(TFile *f, const std::string &n, const std::string &m)
bool CompareTree(const BranchlessTree< T > &a, const BranchlessTree< T > &b)
create variable transformations
Forest using branchless trees.
void Load(const std::string &key, const std::string &filename, const int output=0, const bool sortTrees=true)
Load parameters from a ROOT file to the branchless trees.
Forest using branchless jitted trees.
void Inference(const T *inputs, const int rows, bool layout, T *predictions)
Perform inference of the forest with the jitted branchless implementation on a batch of inputs.
std::string Load(const std::string &key, const std::string &filename, const int output=0, const bool sortTrees=true)
Load parameters from a ROOT file to the branchless trees.
Branchless representation of a decision tree using topological ordering.
std::function< T(T)> fObjectiveFunc
Objective function.
int fNumInputs
Number of input variables.
ForestType fTrees
Store the forest, either as vector or jitted function.
void Inference(const T *inputs, const int rows, bool layout, T *predictions)
Perform inference of the forest on a batch of inputs.
static void output(int code)