Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CodegenContext.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2023
5 * Jonas Rembser, CERN 2023
6 *
7 * Copyright (c) 2023, CERN
8 *
9 * Redistribution and use in source and binary forms,
10 * with or without modification, are permitted according to the terms
11 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
12 */
13
15#include <RooAbsArg.h>
16
17#include "RooFitImplHelpers.h"
18
19#include <TInterpreter.h>
20
21#include <algorithm>
22#include <cctype>
23#include <fstream>
24#include <type_traits>
25#include <unordered_map>
26
27namespace {
28
29bool startsWith(std::string_view str, std::string_view prefix)
30{
31 return str.size() >= prefix.size() && 0 == str.compare(0, prefix.size(), prefix);
32}
33
34} // namespace
35
36namespace RooFit {
37namespace Experimental {
38
39/// @brief Adds (or overwrites) the string representing the result of a node.
40/// @param key The name of the node to add the result for.
41/// @param value The new name to assign/overwrite.
42void CodegenContext::addResult(const char *key, std::string const &value)
43{
44 const TNamed *namePtr = RooNameReg::known(key);
45 if (namePtr)
46 addResult(namePtr, value);
47}
48
49void CodegenContext::addResult(TNamed const *key, std::string const &value)
50{
51 _nodeNames[key] = value;
52}
53
54/// @brief Gets the result for the given node using the node name. This node also performs the necessary
55/// code generation through recursive calls to 'translate'. A call to this function modifies the already
56/// existing code body.
57/// @param key The node to get the result string for.
58/// @return String representing the result of this node.
59std::string const &CodegenContext::getResult(RooAbsArg const &arg)
60{
61 // If the result has already been recorded, just return the result.
62 // It is usually the responsibility of each translate function to assign
63 // the proper result to its class. Hence, if a result has already been recorded
64 // for a particular node, it means the node has already been 'translate'd and we
65 // dont need to visit it again.
66 auto found = _nodeNames.find(arg.namePtr());
67 if (found != _nodeNames.end())
68 return found->second;
69
70 // The result for vector observables should already be in the map if you
71 // opened the loop scope. This is just to check if we did not request the
72 // result of a vector-valued observable outside of the scope of a loop.
73 auto foundVecObs = _vecObsIndices.find(arg.namePtr());
74 if (foundVecObs != _vecObsIndices.end()) {
75 throw std::runtime_error("You requested the result of a vector observable outside a loop scope for it!");
76 }
77
78 auto RAII(OutputScopeRangeComment(&arg));
79
80 // Now, recursively call translate into the current argument to load the correct result.
81 codegen(const_cast<RooAbsArg &>(arg), *this);
82
83 return _nodeNames.at(arg.namePtr());
84}
85
86/// @brief Adds the given string to the string block that will be emitted at the top of the squashed function. Useful
87/// for variable declarations.
88/// @param str The string to add to the global scope.
89void CodegenContext::addToGlobalScope(std::string const &str)
90{
91 // Introduce proper indentation for multiline strings.
92 _code[0] += str;
93}
94
95/// @brief Since the squashed code represents all observables as a single flattened array, it is important
96/// to keep track of the start index for a vector valued observable which can later be expanded to access the correct
97/// element. For example, a vector valued variable x with 10 entries will be squashed to obs[start_idx + i].
98/// @param key The name of the node representing the vector valued observable.
99/// @param idx The start index (or relative position of the observable in the set of all observables).
100void CodegenContext::addVecObs(const char *key, int idx)
101{
102 const TNamed *namePtr = RooNameReg::known(key);
103 if (namePtr)
104 _vecObsIndices[namePtr] = idx;
105}
106
108{
109 auto it = _vecObsIndices.find(arg.namePtr());
110 if (it != _vecObsIndices.end()) {
111 return it->second;
112 }
113
114 return -1; // Not found
115}
116/// @brief Adds the input string to the squashed code body. If a class implements a translate function that wants to
117/// emit something to the squashed code body, it must call this function with the code it wants to emit. In case of
118/// loops, automatically determines if code needs to be stored inside or outside loop scope.
119/// @param klass The class requesting this addition, usually 'this'.
120/// @param in String to add to the squashed code.
121void CodegenContext::addToCodeBody(RooAbsArg const *klass, std::string const &in)
122{
123 // If we are in a loop and the value is scope independent, save it at the top of the loop.
124 // else, just save it in the current scope.
126}
127
128/// @brief A variation of the previous addToCodeBody that takes in a bool value that determines
129/// if input is independent. This overload exists because there might other ways to determine if
130/// a value/collection of values is scope independent.
131/// @param in String to add to the squashed code.
132/// @param isScopeIndep The value determining if the input is scope dependent.
133void CodegenContext::addToCodeBody(std::string const &in, bool isScopeIndep /* = false */)
134{
135 TString indented = in;
136 indented = indented.Strip(TString::kBoth); // trim
137
138 std::string indent_str = "";
139 for (unsigned i = 0; i < _indent; ++i)
140 indent_str += " ";
141 indented = indented.Prepend(indent_str);
142
143 // FIXME: Multiline input.
144 // indent_str += "\n";
145 // indented = indented.ReplaceAll("\n", indent_str);
146
147 // If we are in a loop and the value is scope independent, save it at the top of the loop.
148 // else, just save it in the current scope.
149 if (_code.size() > 2 && isScopeIndep) {
150 _code[_code.size() - 2] += indented;
151 } else {
152 _code.back() += indented;
153 }
154}
155
156/// @brief Create a RAII scope for iterating over vector observables. You can't use the result of vector observables
157/// outside these loop scopes.
158/// @param in A pointer to the calling class, used to determine the loop dependent variables.
159std::unique_ptr<CodegenContext::LoopScope> CodegenContext::beginLoop(RooAbsArg const *in)
160{
161 pushScope();
162 unsigned loopLevel = _code.size() - 2; // subtract global + function scope.
163 std::string idx = "loopIdx" + std::to_string(loopLevel);
164
165 std::vector<TNamed const *> vars;
166 // set the results of the vector observables
167 for (auto const &it : _vecObsIndices) {
168 if (!in->dependsOn(it.first))
169 continue;
170
171 vars.push_back(it.first);
172 _nodeNames[it.first] = "obs[" + std::to_string(it.second) + " + " + idx + "]";
173 }
174
175 // TODO: we are using the size of the first loop variable to the the number
176 // of iterations, but it should be made sure that all loop vars are either
177 // scalar or have the same size.
178 std::size_t numEntries = 1;
179 for (auto &it : vars) {
180 std::size_t n = outputSize(it);
181 if (n > 1 && numEntries > 1 && n != numEntries) {
182 throw std::runtime_error("Trying to loop over variables with different sizes!");
183 }
184 numEntries = std::max(n, numEntries);
185 }
186
187 // Make sure that the name of this variable doesn't clash with other stuff
188 addToCodeBody(in, "for(int " + idx + " = 0; " + idx + " < " + std::to_string(numEntries) + "; " + idx + "++) {\n");
189
190 return std::make_unique<LoopScope>(*this, std::move(vars));
191}
192
194{
195 addToCodeBody("}\n");
196
197 // clear the results of the loop variables if they were vector observables
198 for (auto const &ptr : scope.vars()) {
199 if (_vecObsIndices.find(ptr) != _vecObsIndices.end())
200 _nodeNames.erase(ptr);
201 }
202 popScope();
203}
204
205/// @brief Get a unique variable name to be used in the generated code.
207{
208 return "t" + std::to_string(_tmpVarIdx++);
209}
210
211/// @brief A function to save an expression that includes/depends on the result of the input node.
212/// @param in The node on which the valueToSave depends on/belongs to.
213/// @param valueToSave The actual string value to save as a temporary.
214void CodegenContext::addResult(RooAbsArg const *in, std::string const &valueToSave)
215{
216 // std::string savedName = RooFit::Detail::makeValidVarName(in->GetName());
217 std::string savedName = getTmpVarName();
218
219 // Only save values if they contain operations.
220 bool hasOperations = valueToSave.find_first_of(":-+/*") != std::string::npos;
221
222 // If the name is not empty and this value is worth saving, save it to the correct scope.
223 // otherwise, just return the actual value itself
224 if (hasOperations) {
225 // If this is a scalar result, it will go just outside the loop because
226 // it doesn't need to be recomputed inside loops.
227 std::string outVarDecl = "const double " + savedName + " = " + valueToSave + ";\n";
229 } else {
231 }
232
234}
235
236/// @brief Function to save a RooListProxy as an array in the squashed code.
237/// @param in The list to convert to array.
238/// @return Name of the array that stores the input list in the squashed code.
240{
241 if (in.empty()) {
242 return "nullptr";
243 }
244
245 auto it = _listNames.find(in.uniqueId().value());
246 if (it != _listNames.end())
247 return it->second;
248
249 std::string savedName = getTmpVarName();
250 bool canSaveOutside = true;
251
252 std::stringstream declStrm;
253 declStrm << "double " << savedName << "[] = {";
254 for (const auto arg : in) {
255 declStrm << getResult(*arg) << ",";
257 }
258 declStrm.seekp(-1, declStrm.cur);
259 declStrm << "};\n";
260
262
263 _listNames.insert({in.uniqueId().value(), savedName});
264 return savedName;
265}
266
267std::string CodegenContext::buildArg(std::span<const double> arr)
268{
269 unsigned int n = arr.size();
270 std::string offset = std::to_string(_xlArr.size());
271 _xlArr.reserve(_xlArr.size() + n);
272 for (unsigned int i = 0; i < n; i++) {
273 _xlArr.push_back(arr[i]);
274 }
275 return "xlArr + " + offset;
276}
277
278CodegenContext::ScopeRAII::ScopeRAII(RooAbsArg const *arg, CodegenContext &ctx) : _ctx(ctx), _arg(arg)
279{
280 std::ostringstream os;
281 Option_t *opts = nullptr;
283 _fn = os.str();
284 const std::string info = "// Begin -- " + _fn;
285 _ctx._indent++;
287}
288
290{
291 const std::string info = "// End -- " + _fn + "\n";
292 _ctx.addToCodeBody(_arg, info);
293 _ctx._indent--;
294}
295
297{
298 _code.push_back("");
299}
300
302{
303 std::string active_scope = _code.back();
304 _code.pop_back();
305 _code.back() += active_scope;
306}
307
309{
310 return !in->isReducerNode() && outputSize(in->namePtr()) == 1;
311}
312
313/// @brief Register a function that is only know to the interpreter to the context.
314/// This is useful to dump the standalone C++ code for the computation graph.
315void CodegenContext::collectFunction(std::string const &name)
316{
317 _collectedFunctions.emplace_back(name);
318}
319
320/// @brief Assemble and return the final code with the return expression and global statements.
321/// @param returnExpr The string representation of what the squashed function should return, usually the head node.
322/// @return The name of the declared function.
323std::string
324CodegenContext::buildFunction(RooAbsArg const &arg, std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes)
325{
326 CodegenContext ctx;
327 ctx.pushScope(); // push our global scope.
330 // We only want to take over parameters and observables
331 for (auto const &item : _nodeNames) {
332 if (startsWith(item.second, "params[") || startsWith(item.second, "obs[")) {
333 ctx._nodeNames.insert(item);
334 }
335 }
336 ctx._xlArr = _xlArr;
338
339 static int iCodegen = 0;
340 auto funcName = "roo_codegen_" + std::to_string(iCodegen++);
341
342 // Make sure the codegen implementations are known to the interpreter
343 gInterpreter->Declare("#include <RooFit/CodegenImpl.h>\n");
344
345 ctx.pushScope();
346 std::string funcBody = ctx.getResult(arg);
347 ctx.popScope();
348 funcBody = ctx._code[0] + "\n return " + funcBody + ";\n";
349
350 // Declare the function
351 std::stringstream bodyWithSigStrm;
352 bodyWithSigStrm << "double " << funcName << "(double* params, double const* obs, double const* xlArr) {\n"
353 << "constexpr double inf = std::numeric_limits<double>::infinity();\n"
354 << funcBody << "\n}";
355 ctx._collectedFunctions.emplace_back(funcName);
356 if (!gInterpreter->Declare(bodyWithSigStrm.str().c_str())) {
357 std::stringstream errorMsg;
358 std::string debugFileName = "_codegen_" + funcName + ".cxx";
359 errorMsg << "Function " << funcName << " could not be compiled. See above for details. Full code dumped to file "
360 << debugFileName << "for debugging";
361 {
362 std::ofstream outFile;
363 outFile.open(debugFileName.c_str());
364 outFile << bodyWithSigStrm.str();
365 }
366 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
367 throw std::runtime_error(errorMsg.str().c_str());
368 }
369
370 _xlArr = ctx._xlArr;
372
373 return funcName;
374}
375
376void declareDispatcherCode(std::string const &funcName)
377{
378 std::string dispatcherCode = R"(
379namespace RooFit {
380namespace Experimental {
381
382template <class Arg_t, int P>
383auto FUNC_NAME(Arg_t &arg, CodegenContext &ctx, Prio<P> p)
384{
385 if constexpr (std::is_same<Prio<P>, PrioLowest>::value) {
386 return FUNC_NAME(arg, ctx);
387 } else {
388 return FUNC_NAME(arg, ctx, p.next());
389 }
390}
391
392template <class Arg_t>
393struct Caller_FUNC_NAME {
394
395 static auto call(RooAbsArg &arg, CodegenContext &ctx)
396 {
397 return FUNC_NAME(static_cast<Arg_t &>(arg), ctx, PrioHighest{});
398 }
399};
400
401} // namespace Experimental
402} // namespace RooFit
403 )";
404
406 gInterpreter->Declare(dispatcherCode.c_str());
407}
408
410{
411 static bool codeDeclared = false;
412 if (!codeDeclared) {
413 declareDispatcherCode("codegenImpl");
414 codeDeclared = true;
415 }
416
417 using Func = void (*)(RooAbsArg &, CodegenContext &);
418
419 Func func;
420
421 TClass *tclass = arg.IsA();
422
423 // Cache the overload resolutions
424 static std::unordered_map<TClass *, Func> dispatchMap;
425
426 auto found = dispatchMap.find(tclass);
427
428 if (found != dispatchMap.end()) {
429 func = found->second;
430 } else {
431 // Can probably done with CppInterop in the future to avoid string manipulation.
432 std::stringstream cmd;
433 cmd << "&RooFit::Experimental::Caller_codegenImpl<" << tclass->GetName() << ">::call;";
434 func = reinterpret_cast<Func>(gInterpreter->ProcessLine(cmd.str().c_str()));
435 dispatchMap[tclass] = func;
436 }
437
438 return func(arg, ctx);
439}
440
441} // namespace Experimental
442} // namespace RooFit
bool startsWith(std::string_view str, std::string_view prefix)
#define oocoutE(o, a)
const char Option_t
Option string (const char)
Definition RtypesCore.h:80
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:76
bool dependsOn(const RooAbsCollection &serverList, const RooAbsArg *ignoreArg=nullptr, bool valueOnly=false) const
Test whether we depend on (ie, are served by) any object in the specified collection.
TClass * IsA() const override
Definition RooAbsArg.h:678
const TNamed * namePtr() const
De-duplicated pointer to this object's name.
Definition RooAbsArg.h:502
Int_t defaultPrintContents(Option_t *opt) const override
Define default contents to print.
virtual bool isReducerNode() const
Definition RooAbsArg.h:514
Abstract container object that can hold multiple RooAbsArg objects.
RooFit::UniqueId< RooAbsCollection > const & uniqueId() const
Returns a unique ID that is different for every instantiated RooAbsCollection.
A class to manage loop scopes using the RAII technique.
A class to maintain the context for squashing of RooFit models into code.
std::unordered_map< RooFit::UniqueId< RooAbsCollection >::Value_t, std::string > _listNames
A map to keep track of list names as assigned by addResult.
void addToGlobalScope(std::string const &str)
Adds the given string to the string block that will be emitted at the top of the squashed function.
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
std::string getTmpVarName() const
Get a unique variable name to be used in the generated code.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void addToCodeBody(RooAbsArg const *klass, std::string const &in)
Adds the input string to the squashed code body.
std::unique_ptr< LoopScope > beginLoop(RooAbsArg const *in)
Create a RAII scope for iterating over vector observables.
void collectFunction(std::string const &name)
Register a function that is only know to the interpreter to the context.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
std::unordered_map< const TNamed *, int > _vecObsIndices
A map to keep track of the observable indices if they are non scalar.
int observableIndexOf(const RooAbsArg &arg) const
std::map< RooFit::Detail::DataKey, std::size_t > _nodeOutputSizes
Map of node output sizes.
std::string buildFunction(RooAbsArg const &arg, std::map< RooFit::Detail::DataKey, std::size_t > const &outputSizes={})
Assemble and return the final code with the return expression and global statements.
void endLoop(LoopScope const &scope)
std::vector< std::string > _collectedFunctions
bool isScopeIndependent(RooAbsArg const *in) const
std::vector< std::string > _code
The code layered by lexical scopes used as a stack.
unsigned _indent
The indentation level for pretty-printing.
std::unordered_map< const TNamed *, std::string > _nodeNames
Map of node names to their result strings.
std::size_t outputSize(RooFit::Detail::DataKey key) const
Figure out the output size of a node.
ScopeRAII OutputScopeRangeComment(RooAbsArg const *arg)
std::string buildArg(RooAbsCollection const &x)
Function to save a RooListProxy as an array in the squashed code.
int _tmpVarIdx
Index to get unique names for temporary variables.
static const TNamed * known(const char *stringPtr)
If the name is already known, return its TNamed pointer. Otherwise return 0 (don't register the name)...
virtual StyleOption defaultPrintStyle(Option_t *opt) const
virtual void printStream(std::ostream &os, Int_t contents, StyleOption style, TString indent="") const
Print description of object on ostream, printing contents set by contents integer,...
TClass instances represent classes, structs and namespaces in the ROOT type system.
Definition TClass.h:84
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
const char * GetName() const override
Returns name of object.
Definition TNamed.h:49
Basic string class.
Definition TString.h:138
@ kBoth
Definition TString.h:284
const Int_t n
Definition legend1.C:16
void replaceAll(std::string &inOut, std::string_view what, std::string_view with)
void declareDispatcherCode(std::string const &funcName)
void codegen(RooAbsArg &arg, CodegenContext &ctx)
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition CodegenImpl.h:67
@ InputArguments
ScopeRAII(RooAbsArg const *arg, CodegenContext &ctx)
constexpr Value_t value() const
Return numerical value of ID.
Definition UniqueId.h:59