20#ifndef TMVA_TREEINFERENCE_FOREST 
   21#define TMVA_TREEINFERENCE_FOREST 
   40namespace Experimental {
 
   46   auto v = 
reinterpret_cast<T *
>(
f->Get(
m.c_str()));
 
   48      throw std::runtime_error(
"Failed to read " + 
m + 
" from file " + 
n + 
".");
 
   55   if (
a.fInputs[0] == 
b.fInputs[0])
 
   56      return a.fThresholds[0] < 
b.fThresholds[0];
 
   58      return a.fInputs[0] < 
b.fInputs[0];
 
   66template <
typename T, 
typename ForestType>
 
   73   void Inference(
const T *inputs, 
const int rows, 
bool layout, 
T *predictions);
 
   82template <
typename T, 
typename ForestType>
 
   85   const auto strideTree = layout ? 1 : rows;
 
   86   const auto strideBatch = layout ? fNumInputs : 1;
 
   87   for (
int i = 0; i < rows; i++) {
 
   89      for (
auto &
tree : fTrees) {
 
   90         predictions[i] += 
tree.Inference(inputs + i * strideBatch, strideTree);
 
   92      predictions[i] = fObjectiveFunc(predictions[i]);
 
  101   void Load(
const std::string &key, 
const std::string &filename, 
const int output = 0, 
const bool sortTrees = 
true);
 
  118   auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/max_depth");
 
  119   auto numTrees = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/num_trees");
 
  120   auto numInputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/num_inputs");
 
  121   auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/num_outputs");
 
  122   auto objective = Internal::GetObjectSafe<std::string>(
file, filename, key + 
"/objective");
 
  123   auto inputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/inputs");
 
  124   auto outputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/outputs");
 
  125   auto thresholds = Internal::GetObjectSafe<std::vector<T>>(
file, filename, key + 
"/thresholds");
 
  127   this->fNumInputs = numInputs->at(0);
 
  128   this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
 
  129   const auto lenInputs = 
std::pow(2, maxDepth->at(0)) - 1;
 
  130   const auto lenThresholds = 
std::pow(2, maxDepth->at(0) + 1) - 1;
 
  133   if (
output > numOutputs->at(0))
 
  134      throw std::runtime_error(
"Given output node of the forest is larger or equal to number of output nodes.");
 
  136   for (
int i = 0; i < numTrees->at(0); i++)
 
  137      if (outputs->at(i) == 
output)
 
  140      std::runtime_error(
"No trees found for given output node of the forest.");
 
  141   this->fTrees.resize(
c);
 
  145   for (
int i = 0; i < numTrees->at(0); i++) {
 
  147      if (outputs->at(i) != 
output)
 
  151      this->fTrees[
c].fTreeDepth = maxDepth->at(0);
 
  154      this->fTrees[
c].fInputs.resize(lenInputs);
 
  155      for (
int j = 0; j < lenInputs; j++)
 
  156         this->fTrees[
c].fInputs[j] = inputs->at(i * lenInputs + j);
 
  159      this->fTrees[
c].fThresholds.resize(lenThresholds);
 
  160      for (
int j = 0; j < lenThresholds; j++)
 
  161         this->fTrees[
c].fThresholds[j] = thresholds->at(i * lenThresholds + j);
 
  164      this->fTrees[
c].FillSparse();
 
  171      std::sort(this->fTrees.begin(), this->fTrees.end(), Internal::CompareTree<T>);
 
  188    std::string 
Load(
const std::string &key, 
const std::string &filename, 
const int output = 0, 
const bool sortTrees = 
true);
 
  189   void Inference(
const T *inputs, 
const int rows, 
bool layout, 
T *predictions);
 
  207   auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/max_depth");
 
  208   auto numTrees = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/num_trees");
 
  209   auto numInputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/num_inputs");
 
  210   auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/num_outputs");
 
  211   auto objective = Internal::GetObjectSafe<std::string>(
file, filename, key + 
"/objective");
 
  212   auto inputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/inputs");
 
  213   auto outputs = Internal::GetObjectSafe<std::vector<int>>(
file, filename, key + 
"/outputs");
 
  214   auto thresholds = Internal::GetObjectSafe<std::vector<T>>(
file, filename, key + 
"/thresholds");
 
  216   this->fNumInputs = numInputs->at(0);
 
  217   this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
 
  218   const auto lenInputs = 
std::pow(2, maxDepth->at(0)) - 1;
 
  219   const auto lenThresholds = 
std::pow(2, maxDepth->at(0) + 1) - 1;
 
  222   if (
output > numOutputs->at(0))
 
  223      throw std::runtime_error(
"Given output node of the forest is larger or equal to number of output nodes.");
 
  225   for (
int i = 0; i < numTrees->at(0); i++)
 
  226      if (outputs->at(i) == 
output)
 
  229      std::runtime_error(
"No trees found for given output node of the forest.");
 
  233   if (typeName.compare(
"") == 0) {
 
  234      throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (typename as string)");
 
  238   std::vector<T> firstThreshold(
c);
 
  239   std::vector<int> firstInput(
c, -1);
 
  240   std::vector<std::string> codes(
c);
 
  242   for (
int i = 0; i < numTrees->at(0); i++) {
 
  244      if (outputs->at(i) != 
output)
 
  249      tree.fTreeDepth = maxDepth->at(0);
 
  252      tree.fInputs.resize(lenInputs);
 
  253      for (
int j = 0; j < lenInputs; j++)
 
  254         tree.fInputs[j] = inputs->at(i * lenInputs + j);
 
  257      tree.fThresholds.resize(lenThresholds);
 
  258      for (
int j = 0; j < lenThresholds; j++)
 
  259         tree.fThresholds[j] = thresholds->at(i * lenThresholds + j);
 
  265      firstThreshold[
c] = 
tree.fThresholds[0];
 
  267          firstInput[
c] = 
tree.fInputs[0];
 
  270      std::stringstream ss;
 
  272      codes[
c] = 
tree.GetInferenceCode(ss.str(), typeName);
 
  278   std::vector<int> treeIndices(codes.size());
 
  279   for(
int i = 0; i < 
c; i++) treeIndices[i] = i;
 
  281      auto compareIndices = [&firstInput, &firstThreshold](
int i, 
int j)
 
  283                 if (firstInput[i] == firstInput[j])
 
  284                    return firstThreshold[i] < firstThreshold[j];
 
  286                    return firstInput[i] < firstInput[j];
 
  288      std::sort(treeIndices.begin(), treeIndices.end(), compareIndices);
 
  293   std::string nameSpace = uuid.
AsString();
 
  294   for (
auto& 
v : nameSpace) {
 
  295      if (
v == 
'-') 
v = 
'_';
 
  297   nameSpace = 
"ns_" + nameSpace;
 
  300   std::stringstream jitForest;
 
  301   jitForest << 
"#pragma cling optimize(3)\n" 
  302             << 
"namespace " << nameSpace << 
" {\n";
 
  303   for (
int i = 0; i < static_cast<int>(codes.size()); i++) {
 
  304      jitForest << codes[treeIndices[i]] << 
"\n\n";
 
  306   jitForest << 
"void Inference(const " 
  307             << typeName << 
"* inputs, const int rows, bool layout, " 
  308             << typeName << 
"* predictions)" 
  310             << 
"   const auto strideTree = layout ? 1 : rows;\n" 
  311             << 
"   const auto strideBatch = layout ? " << this->fNumInputs << 
" : 1;\n" 
  312             << 
"   for (int i = 0; i < rows; i++) {\n" 
  313             << 
"      predictions[i] = 0.0;\n";
 
  314   for (
int i = 0; i < static_cast<int>(codes.size()); i++) {
 
  315      std::stringstream ss;
 
  317      const std::string funcName = ss.str();
 
  318      jitForest << 
"      predictions[i] += " << funcName << 
"(inputs + i * strideBatch, strideTree);\n";
 
  322             << 
"} // end namespace " << nameSpace;
 
  323   const std::string jitForestStr = jitForest.str();
 
  324   const auto err = 
gInterpreter->Declare(jitForestStr.c_str());
 
  326      throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (declare function)");
 
  330   std::stringstream treesFunc;
 
  331   treesFunc << 
"#pragma cling optimize(3)\n" << nameSpace << 
"::Inference";
 
  332   const std::string treesFuncStr = treesFunc.str();
 
  335      throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (compile function)");
 
  337   this->fTrees = 
reinterpret_cast<void (*)(
const T *, 
int, 
bool, 
float*)
>(ptr);
 
  360   this->fTrees(inputs, rows, layout, predictions);
 
  361   for (
int i = 0; i < rows; i++)
 
  362      predictions[i] = this->fObjectiveFunc(predictions[i]);
 
double pow(double, double)
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This class defines a UUID (Universally Unique IDentifier), also known as GUIDs (Globally Unique IDent...
const char * AsString() const
Return UUID as string. Copy string immediately since it will be reused.
std::string GetDemangledTypeName(const std::type_info &)
Returns a string with the demangled and normalized name for the given type.
void function(const Char_t *name_, T fun, const Char_t *docstring=0)
T * GetObjectSafe(TFile *f, const std::string &n, const std::string &m)
bool CompareTree(const BranchlessTree< T > &a, const BranchlessTree< T > &b)
create variable transformations
Forest using branchless trees.
void Load(const std::string &key, const std::string &filename, const int output=0, const bool sortTrees=true)
Load parameters from a ROOT file to the branchless trees.
Forest using branchless jitted trees.
void Inference(const T *inputs, const int rows, bool layout, T *predictions)
Perform inference of the forest with the jitted branchless implementation on a batch of inputs.
std::string Load(const std::string &key, const std::string &filename, const int output=0, const bool sortTrees=true)
Load parameters from a ROOT file to the branchless trees.
Branchless representation of a decision tree using topological ordering.
std::function< T(T)> fObjectiveFunc
Objective function.
int fNumInputs
Number of input variables.
ForestType fTrees
Store the forest, either as vector or jitted function.
void Inference(const T *inputs, const int rows, bool layout, T *predictions)
Perform inference of the forest on a batch of inputs.
static void output(int code)