Logo ROOT   6.10/09
Reference Guide
TDFInterface.cxx
Go to the documentation of this file.
1 // Author: Enrico Guiraud, Danilo Piparo CERN 03/2017
2 
3 /*************************************************************************
4  * Copyright (C) 1995-2016, Rene Brun and Fons Rademakers. *
5  * All rights reserved. *
6  * *
7  * For the licensing terms see $ROOTSYS/LICENSE. *
8  * For the list of contributors see $ROOTSYS/README/CREDITS. *
9  *************************************************************************/
10 
11 #include "TClass.h"
12 #include "TRegexp.h"
13 
14 #include "ROOT/TDFInterface.hxx"
15 
16 #include <vector>
17 #include <string>
18 using namespace ROOT::Experimental::TDF;
19 using namespace ROOT::Internal::TDF;
20 using namespace ROOT::Detail::TDF;
21 
22 namespace ROOT {
23 namespace Experimental {
24 namespace TDF {
25 // extern templates
26 template class TInterface<TLoopManager>;
27 template class TInterface<TFilterBase>;
28 template class TInterface<TCustomColumnBase>;
29 }
30 }
31 
32 namespace Internal {
33 namespace TDF {
34 // Match expression against names of branches passed as parameter
35 // Return vector of names of the branches used in the expression
36 std::vector<std::string> GetUsedBranchesNames(const std::string expression, TObjArray *branches,
37  const std::vector<std::string> &tmpBranches)
38 {
39  // Check what branches and temporary branches are used in the expression
40  // To help matching the regex
41  std::string paddedExpr = " " + expression + " ";
42  int paddedExprLen = paddedExpr.size();
43  static const std::string regexBit("[^a-zA-Z0-9_]");
44  std::vector<std::string> usedBranches;
45  for (auto brName : tmpBranches) {
46  std::string bNameRegexContent = regexBit + brName + regexBit;
47  TRegexp bNameRegex(bNameRegexContent.c_str());
48  if (-1 != bNameRegex.Index(paddedExpr.c_str(), &paddedExprLen)) {
49  usedBranches.emplace_back(brName.c_str());
50  }
51  }
52  if (!branches) return usedBranches;
53  for (auto bro : *branches) {
54  auto brName = bro->GetName();
55  std::string bNameRegexContent = regexBit + brName + regexBit;
56  TRegexp bNameRegex(bNameRegexContent.c_str());
57  if (-1 != bNameRegex.Index(paddedExpr.c_str(), &paddedExprLen)) {
58  usedBranches.emplace_back(brName);
59  }
60  }
61  return usedBranches;
62 }
63 
64 // Jit a string filter or a string temporary column, call this->Define or this->Filter as needed
65 // Return pointer to the new functional chain node returned by the call, cast to Long_t
66 Long_t JitTransformation(void *thisPtr, const std::string &methodName, const std::string &nodeTypeName,
67  const std::string &name, const std::string &expression, TObjArray *branches,
68  const std::vector<std::string> &tmpBranches,
69  const std::map<std::string, TmpBranchBasePtr_t> &tmpBookedBranches, TTree *tree)
70 {
71  auto usedBranches = GetUsedBranchesNames(expression, branches, tmpBranches);
72  auto exprNeedsVariables = !usedBranches.empty();
73 
74  // Move to the preparation of the jitting
75  // We put all of the jitted entities in a namespace called
76  // __tdf_filter_N, where N is a monotonically increasing index.
77  std::vector<std::string> usedBranchesTypes;
78  std::stringstream ss;
79  static unsigned int iNs = 0U;
80  ss << "__tdf_" << iNs++;
81  const auto nsName = ss.str();
82  ss.str("");
83 
84  if (exprNeedsVariables) {
85  // Declare a namespace and inside it the variables in the expression
86  ss << "namespace " << nsName;
87  ss << " {\n";
88  for (auto brName : usedBranches) {
89  // The map is a const reference, so no operator[]
90  auto tmpBrIt = tmpBookedBranches.find(brName);
91  auto tmpBr = tmpBrIt == tmpBookedBranches.end() ? nullptr : tmpBrIt->second.get();
92  auto brTypeName = ColumnName2ColumnTypeName(brName, tree, tmpBr);
93  ss << brTypeName << " " << brName << ";\n";
94  usedBranchesTypes.emplace_back(brTypeName);
95  }
96  ss << "}";
97  auto variableDeclarations = ss.str();
98  ss.str("");
99  // We need ProcessLine to trigger auto{parsing,loading} where needed
100  TInterpreter::EErrorCode interpErrCode;
101  gInterpreter->ProcessLine(variableDeclarations.c_str(), &interpErrCode);
102  if (TInterpreter::EErrorCode::kNoError != interpErrCode) {
103  std::string msg = "Cannot declare these variables: ";
104  msg += variableDeclarations;
105  msg += "\nInterpreter error code is " + std::to_string(interpErrCode) + ".";
106  throw std::runtime_error(msg);
107  }
108  }
109 
110  // Declare within the same namespace, the expression to make sure it
111  // is proper C++
112  ss << "namespace " << nsName << "{ auto res = " << expression << ";}\n";
113  // Headers must have been parsed and libraries loaded: we can use Declare
114  if (!gInterpreter->Declare(ss.str().c_str())) {
115  std::string msg = "Cannot interpret this expression: ";
116  msg += " ";
117  msg += ss.str();
118  throw std::runtime_error(msg);
119  }
120 
121  // Now we build the lambda and we invoke the method with it in the jitted world
122  ss.str("");
123  ss << "[](";
124  for (unsigned int i = 0; i < usedBranchesTypes.size(); ++i) {
125  // We pass by reference to avoid expensive copies
126  ss << usedBranchesTypes[i] << "& " << usedBranches[i] << ", ";
127  }
128  if (!usedBranchesTypes.empty()) ss.seekp(-2, ss.cur);
129  ss << "){ return " << expression << ";}";
130  auto filterLambda = ss.str();
131 
132  // Here we have two cases: filter and column
133  ss.str("");
134  ss << "((" << nodeTypeName << "*)" << thisPtr << ")->" << methodName << "(";
135  if (methodName == "Define") {
136  ss << "\"" << name << "\", ";
137  }
138  ss << filterLambda << ", {";
139  for (auto brName : usedBranches) {
140  ss << "\"" << brName << "\", ";
141  }
142  if (exprNeedsVariables) ss.seekp(-2, ss.cur); // remove the last ",
143  ss << "}";
144 
145  if (methodName == "Filter") {
146  ss << ", \"" << name << "\"";
147  }
148 
149  ss << ");";
150 
151  TInterpreter::EErrorCode interpErrCode;
152  auto retVal = gInterpreter->Calc(ss.str().c_str(), &interpErrCode);
153  if (TInterpreter::EErrorCode::kNoError != interpErrCode || !retVal) {
154  std::string msg = "Cannot interpret the invocation to " + methodName + ": ";
155  msg += ss.str();
156  if (TInterpreter::EErrorCode::kNoError != interpErrCode) {
157  msg += "\nInterpreter error code is " + std::to_string(interpErrCode) + ".";
158  }
159  throw std::runtime_error(msg);
160  }
161  return retVal;
162 }
163 
164 // Jit and call something equivalent to "this->BuildAndBook<BranchTypes...>(params...)"
165 // (see comments in the body for actual jitted code)
166 void JitBuildAndBook(const ColumnNames_t &bl, const std::string &nodeTypename, void *thisPtr, const std::type_info &art,
167  const std::type_info &at, const void *r, TTree *tree, unsigned int nSlots,
168  const std::map<std::string, TmpBranchBasePtr_t> &tmpBranches)
169 {
170  gInterpreter->Declare("#include \"ROOT/TDataFrame.hxx\"");
171  auto nBranches = bl.size();
172 
173  // retrieve pointers to temporary columns (null if the column is not temporary)
174  std::vector<TCustomColumnBase *> tmpBranchPtrs(nBranches, nullptr);
175  for (auto i = 0u; i < nBranches; ++i) {
176  auto tmpBranchIt = tmpBranches.find(bl[i]);
177  if (tmpBranchIt != tmpBranches.end()) tmpBranchPtrs[i] = tmpBranchIt->second.get();
178  }
179 
180  // retrieve branch type names as strings
181  std::vector<std::string> branchTypeNames(nBranches);
182  for (auto i = 0u; i < nBranches; ++i) {
183  const auto branchTypeName = ColumnName2ColumnTypeName(bl[i], tree, tmpBranchPtrs[i]);
184  if (branchTypeName.empty()) {
185  std::string exceptionText = "The type of column ";
186  exceptionText += bl[i];
187  exceptionText += " could not be guessed. Please specify one.";
188  throw std::runtime_error(exceptionText.c_str());
189  }
190  branchTypeNames[i] = branchTypeName;
191  }
192 
193  // retrieve type of result of the action as a string
194  auto actionResultTypeClass = TClass::GetClass(art);
195  if (!actionResultTypeClass) {
196  std::string exceptionText = "An error occurred while inferring the result type of an operation.";
197  throw std::runtime_error(exceptionText.c_str());
198  }
199  const auto actionResultTypeName = actionResultTypeClass->GetName();
200 
201  // retrieve type of action as a string
202  auto actionTypeClass = TClass::GetClass(at);
203  if (!actionTypeClass) {
204  std::string exceptionText = "An error occurred while inferring the action type of the operation.";
205  throw std::runtime_error(exceptionText.c_str());
206  }
207  const auto actionTypeName = actionTypeClass->GetName();
208 
209  // createAction_str will contain the following:
210  // ROOT::Internal::TDF::CallBuildAndBook<nodeType, actionType, branchType1, branchType2...>(
211  // reinterpret_cast<nodeType*>(thisPtr), *reinterpret_cast<ROOT::ColumnNames_t*>(&bl),
212  // *reinterpret_cast<actionResultType*>(r), reinterpret_cast<ActionType*>(nullptr))
213  std::stringstream createAction_str;
214  createAction_str << "ROOT::Internal::TDF::CallBuildAndBook<" << nodeTypename << ", " << actionTypeName;
215  for (auto &branchTypeName : branchTypeNames) createAction_str << ", " << branchTypeName;
216  createAction_str << ">("
217  << "reinterpret_cast<" << nodeTypename << "*>(" << thisPtr << "), "
218  << "*reinterpret_cast<ROOT::Detail::TDF::ColumnNames_t*>(" << &bl << "), " << nSlots
219  << ", *reinterpret_cast<" << actionResultTypeName << "*>(" << r << "));";
220  auto error = TInterpreter::EErrorCode::kNoError;
221  gInterpreter->Calc(createAction_str.str().c_str(), &error);
222  if (error) {
223  std::string exceptionText = "An error occurred while jitting this action:\n";
224  exceptionText += createAction_str.str();
225  throw std::runtime_error(exceptionText.c_str());
226  }
227 }
228 } // end ns TDF
229 } // end ns Internal
230 } // end ns ROOT
An array of TObjects.
Definition: TObjArray.h:37
Namespace for new ROOT classes and functions.
Definition: StringConv.hxx:21
Regular expression class.
Definition: TRegexp.h:31
#define gInterpreter
Definition: TInterpreter.h:499
TRandom2 r(17)
std::string ColumnName2ColumnTypeName(const std::string &colName, TTree *tree, TCustomColumnBase *tmpBranch)
Return a string containing the type of the given branch.
Definition: TDFUtils.cxx:30
Long_t JitTransformation(void *thisPtr, const std::string &methodName, const std::string &nodeTypeName, const std::string &name, const std::string &expression, TObjArray *branches, const std::vector< std::string > &tmpBranches, const std::map< std::string, TmpBranchBasePtr_t > &tmpBookedBranches, TTree *tree)
long Long_t
Definition: RtypesCore.h:50
static TClass * GetClass(const char *name, Bool_t load=kTRUE, Bool_t silent=kFALSE)
Static method returning pointer to TClass of the specified class name.
Definition: TClass.cxx:2885
Definition: tree.py:1
void JitBuildAndBook(const ColumnNames_t &bl, const std::string &nodeTypename, void *thisPtr, const std::type_info &art, const std::type_info &at, const void *r, TTree *tree, unsigned int nSlots, const std::map< std::string, TmpBranchBasePtr_t > &tmpBranches)
std::vector< std::string > GetUsedBranchesNames(const std::string, TObjArray *, const std::vector< std::string > &)
The public interface to the TDataFrame federation of classes.