Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CodeSquashContext.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2023
5 * Jonas Rembser, CERN 2023
6 *
7 * Copyright (c) 2023, CERN
8 *
9 * Redistribution and use in source and binary forms,
10 * with or without modification, are permitted according to the terms
11 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
12 */
13
15
16#include "RooFitImplHelpers.h"
17
18#include <algorithm>
19#include <cctype>
20
21namespace RooFit {
22
23namespace Detail {
24
25CodeSquashContext::CodeSquashContext(std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes,
26 std::vector<double> &xlarr, Experimental::RooFuncWrapper &wrapper)
27 : _wrapper{&wrapper}, _nodeOutputSizes(outputSizes), _xlArr(xlarr)
28{
29}
30
31/// @brief Adds (or overwrites) the string representing the result of a node.
32/// @param key The name of the node to add the result for.
33/// @param value The new name to assign/overwrite.
34void CodeSquashContext::addResult(const char *key, std::string const &value)
35{
36 const TNamed *namePtr = RooNameReg::known(key);
37 if (namePtr)
38 addResult(namePtr, value);
39}
40
41void CodeSquashContext::addResult(TNamed const *key, std::string const &value)
42{
43 _nodeNames[key] = value;
44}
45
46/// @brief Gets the result for the given node using the node name. This node also performs the necessary
47/// code generation through recursive calls to 'translate'. A call to this function modifies the already
48/// existing code body.
49/// @param key The node to get the result string for.
50/// @return String representing the result of this node.
51std::string const &CodeSquashContext::getResult(RooAbsArg const &arg)
52{
53 // If the result has already been recorded, just return the result.
54 // It is usually the responsibility of each translate function to assign
55 // the proper result to its class. Hence, if a result has already been recorded
56 // for a particular node, it means the node has already been 'translate'd and we
57 // dont need to visit it again.
58 auto found = _nodeNames.find(arg.namePtr());
59 if (found != _nodeNames.end())
60 return found->second;
61
62 // The result for vector observables should already be in the map if you
63 // opened the loop scope. This is just to check if we did not request the
64 // result of a vector-valued observable outside of the scope of a loop.
65 auto foundVecObs = _vecObsIndices.find(arg.namePtr());
66 if (foundVecObs != _vecObsIndices.end()) {
67 throw std::runtime_error("You requested the result of a vector observable outside a loop scope for it!");
68 }
69
70 // Now, recursively call translate into the current argument to load the correct result.
71 arg.translate(*this);
72
73 return _nodeNames.at(arg.namePtr());
74}
75
76/// @brief Adds the given string to the string block that will be emitted at the top of the squashed function. Useful
77/// for variable declarations.
78/// @param str The string to add to the global scope.
79void CodeSquashContext::addToGlobalScope(std::string const &str)
80{
81 _globalScope += str;
82}
83
84/// @brief Assemble and return the final code with the return expression and global statements.
85/// @param returnExpr The string representation of what the squashed function should return, usually the head node.
86/// @return The final body of the function.
87std::string CodeSquashContext::assembleCode(std::string const &returnExpr)
88{
89 return _globalScope + _code + "\n return " + returnExpr + ";\n";
90}
91
92/// @brief Since the squashed code represents all observables as a single flattened array, it is important
93/// to keep track of the start index for a vector valued observable which can later be expanded to access the correct
94/// element. For example, a vector valued variable x with 10 entries will be squashed to obs[start_idx + i].
95/// @param key The name of the node representing the vector valued observable.
96/// @param idx The start index (or relative position of the observable in the set of all observables).
97void CodeSquashContext::addVecObs(const char *key, int idx)
98{
99 const TNamed *namePtr = RooNameReg::known(key);
100 if (namePtr)
101 _vecObsIndices[namePtr] = idx;
102}
103
104/// @brief Adds the input string to the squashed code body. If a class implements a translate function that wants to
105/// emit something to the squashed code body, it must call this function with the code it wants to emit. In case of
106/// loops, automatically determines if code needs to be stored inside or outside loop scope.
107/// @param klass The class requesting this addition, usually 'this'.
108/// @param in String to add to the squashed code.
109void CodeSquashContext::addToCodeBody(RooAbsArg const *klass, std::string const &in)
110{
111 // If we are in a loop and the value is scope independent, save it at the top of the loop.
112 // else, just save it in the current scope.
114}
115
116/// @brief A variation of the previous addToCodeBody that takes in a bool value that determines
117/// if input is independent. This overload exists because there might other ways to determine if
118/// a value/collection of values is scope independent.
119/// @param in String to add to the squashed code.
120/// @param isScopeIndep The value determining if the input is scope dependent.
121void CodeSquashContext::addToCodeBody(std::string const &in, bool isScopeIndep /* = false */)
122{
123 // If we are in a loop and the value is scope independent, save it at the top of the loop.
124 // else, just save it in the current scope.
125 if (_scopePtr != -1 && isScopeIndep) {
126 _tempScope += in;
127 } else {
128 _code += in;
129 }
130}
131
132/// @brief Create a RAII scope for iterating over vector observables. You can't use the result of vector observables
133/// outside these loop scopes.
134/// @param in A pointer to the calling class, used to determine the loop dependent variables.
135std::unique_ptr<CodeSquashContext::LoopScope> CodeSquashContext::beginLoop(RooAbsArg const *in)
136{
137 std::string idx = "loopIdx" + std::to_string(_loopLevel);
138
139 std::vector<TNamed const *> vars;
140 // set the results of the vector observables
141 for (auto const &it : _vecObsIndices) {
142 if (!in->dependsOn(it.first))
143 continue;
144
145 vars.push_back(it.first);
146 _nodeNames[it.first] = "obs[" + std::to_string(it.second) + " + " + idx + "]";
147 }
148
149 // TODO: we are using the size of the first loop variable to the the number
150 // of iterations, but it should be made sure that all loop vars are either
151 // scalar or have the same size.
152 std::size_t numEntries = 1;
153 for (auto &it : vars) {
154 std::size_t n = outputSize(it);
155 if (n > 1 && numEntries > 1 && n != numEntries) {
156 throw std::runtime_error("Trying to loop over variables with different sizes!");
157 }
158 numEntries = std::max(n, numEntries);
159 }
160
161 // Save the current size of the code array so that we can insert the code at the right position.
162 _scopePtr = _code.size();
163
164 // Make sure that the name of this variable doesn't clash with other stuff
165 addToCodeBody(in, "for(int " + idx + " = 0; " + idx + " < " + std::to_string(numEntries) + "; " + idx + "++) {\n");
166
167 ++_loopLevel;
168 return std::make_unique<LoopScope>(*this, std::move(vars));
169}
170
172{
173 _code += "}\n";
174
175 // Insert the temporary code into the correct code position.
176 _code.insert(_scopePtr, _tempScope);
177 _tempScope.erase();
178 _scopePtr = -1;
179
180 // clear the results of the loop variables if they were vector observables
181 for (auto const &ptr : scope.vars()) {
182 if (_vecObsIndices.find(ptr) != _vecObsIndices.end())
183 _nodeNames.erase(ptr);
184 }
185 --_loopLevel;
186}
187
188/// @brief Get a unique variable name to be used in the generated code.
190{
191 return "t" + std::to_string(_tmpVarIdx++);
192}
193
194/// @brief A function to save an expression that includes/depends on the result of the input node.
195/// @param in The node on which the valueToSave depends on/belongs to.
196/// @param valueToSave The actual string value to save as a temporary.
197void CodeSquashContext::addResult(RooAbsArg const *in, std::string const &valueToSave)
198{
199 //std::string savedName = RooFit::Detail::makeValidVarName(in->GetName());
200 std::string savedName = getTmpVarName();
201
202 // Only save values if they contain operations.
203 bool hasOperations = valueToSave.find_first_of(":-+/*") != std::string::npos;
204
205 // If the name is not empty and this value is worth saving, save it to the correct scope.
206 // otherwise, just return the actual value itself
207 if (hasOperations) {
208 // If this is a scalar result, it will go just outside the loop because
209 // it doesn't need to be recomputed inside loops.
210 std::string outVarDecl = "const double " + savedName + " = " + valueToSave + ";\n";
211 addToCodeBody(in, outVarDecl);
212 } else {
213 savedName = valueToSave;
214 }
215
216 addResult(in->namePtr(), savedName);
217}
218
219/// @brief Function to save a RooListProxy as an array in the squashed code.
220/// @param in The list to convert to array.
221/// @return Name of the array that stores the input list in the squashed code.
223{
224 if (in.empty()) {
225 return "nullptr";
226 }
227
228 auto it = listNames.find(in.uniqueId().value());
229 if (it != listNames.end())
230 return it->second;
231
232 std::string savedName = getTmpVarName();
233 bool canSaveOutside = true;
234
235 std::stringstream declStrm;
236 declStrm << "double " << savedName << "[] = {";
237 for (const auto arg : in) {
238 declStrm << getResult(*arg) << ",";
239 canSaveOutside = canSaveOutside && isScopeIndependent(arg);
240 }
241 declStrm.seekp(-1, declStrm.cur);
242 declStrm << "};\n";
243
244 addToCodeBody(declStrm.str(), canSaveOutside);
245
246 listNames.insert({in.uniqueId().value(), savedName});
247 return savedName;
248}
249
250std::string CodeSquashContext::buildArg(std::span<const double> arr)
251{
252 unsigned int n = arr.size();
253 std::string offset = std::to_string(_xlArr.size());
254 _xlArr.reserve(_xlArr.size() + n);
255 for (unsigned int i = 0; i < n; i++) {
256 _xlArr.push_back(arr[i]);
257 }
258 return "xlArr + " + offset;
259}
260
262{
263 return !in->isReducerNode() && outputSize(in->namePtr()) == 1;
264}
265
266} // namespace Detail
267} // namespace RooFit
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:77
bool dependsOn(const RooAbsCollection &serverList, const RooAbsArg *ignoreArg=nullptr, bool valueOnly=false) const
Test whether we depend on (ie, are served by) any object in the specified collection.
const TNamed * namePtr() const
De-duplicated pointer to this object's name.
Definition RooAbsArg.h:561
virtual void translate(RooFit::Detail::CodeSquashContext &ctx) const
This function defines a translation for each RooAbsReal based object that can be used to express the ...
virtual bool isReducerNode() const
Definition RooAbsArg.h:575
Abstract container object that can hold multiple RooAbsArg objects.
RooFit::UniqueId< RooAbsCollection > const & uniqueId() const
Returns a unique ID that is different for every instantiated RooAbsCollection.
A class to manage loop scopes using the RAII technique.
std::vector< TNamed const * > const & vars() const
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
std::string _tempScope
Stores code that eventually gets injected into main code body.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void endLoop(LoopScope const &scope)
std::unordered_map< const TNamed *, int > _vecObsIndices
A map to keep track of the observable indices if they are non scalar.
int _loopLevel
The current number of for loops the started.
int _tmpVarIdx
Index to get unique names for temporary variables.
std::size_t outputSize(RooFit::Detail::DataKey key) const
Figure out the output size of a node.
std::unordered_map< const TNamed *, std::string > _nodeNames
Map of node names to their result strings.
void addToCodeBody(RooAbsArg const *klass, std::string const &in)
Adds the input string to the squashed code body.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
bool isScopeIndependent(RooAbsArg const *in) const
std::string getTmpVarName() const
Get a unique variable name to be used in the generated code.
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
std::string _code
Stores the squashed code body.
void addToGlobalScope(std::string const &str)
Adds the given string to the string block that will be emitted at the top of the squashed function.
std::unordered_map< RooFit::UniqueId< RooAbsCollection >::Value_t, std::string > listNames
A map to keep track of list names as assigned by addResult.
std::string buildArg(RooAbsCollection const &x)
Function to save a RooListProxy as an array in the squashed code.
CodeSquashContext(std::map< RooFit::Detail::DataKey, std::size_t > const &outputSizes, std::vector< double > &xlarr, Experimental::RooFuncWrapper &wrapper)
int _scopePtr
Keeps track of the position to go back and insert code to.
std::string _globalScope
Block of code that is placed before the rest of the function body.
std::unique_ptr< LoopScope > beginLoop(RooAbsArg const *in)
Create a RAII scope for iterating over vector observables.
A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
static const TNamed * known(const char *stringPtr)
If the name is already known, return its TNamed pointer. Otherwise return 0 (don't register the name)...
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
const Int_t n
Definition legend1.C:16
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
constexpr Value_t value() const
Return numerical value of ID.
Definition UniqueId.h:59