Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CodeSquashContext.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2023
5 * Jonas Rembser, CERN 2023
6 *
7 * Copyright (c) 2023, CERN
8 *
9 * Redistribution and use in source and binary forms,
10 * with or without modification, are permitted according to the terms
11 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
12 */
13
15
16#include "RooFitImplHelpers.h"
17
18#include <algorithm>
19#include <cctype>
20
21namespace RooFit {
22
23namespace Detail {
24
25/// @brief Adds (or overwrites) the string representing the result of a node.
26/// @param key The name of the node to add the result for.
27/// @param value The new name to assign/overwrite.
28void CodeSquashContext::addResult(const char *key, std::string const &value)
29{
30 const TNamed *namePtr = RooNameReg::known(key);
31 if (namePtr)
32 addResult(namePtr, value);
33}
34
35void CodeSquashContext::addResult(TNamed const *key, std::string const &value)
36{
37 _nodeNames[key] = value;
38}
39
40/// @brief Gets the result for the given node using the node name. This node also performs the necessary
41/// code generation through recursive calls to 'translate'. A call to this function modifies the already
42/// existing code body.
43/// @param key The node to get the result string for.
44/// @return String representing the result of this node.
45std::string const &CodeSquashContext::getResult(RooAbsArg const &arg)
46{
47 // If the result has already been recorded, just return the result.
48 // It is usually the responsibility of each translate function to assign
49 // the proper result to its class. Hence, if a result has already been recorded
50 // for a particular node, it means the node has already been 'translate'd and we
51 // dont need to visit it again.
52 auto found = _nodeNames.find(arg.namePtr());
53 if (found != _nodeNames.end())
54 return found->second;
55
56 // The result for vector observables should already be in the map if you
57 // opened the loop scope. This is just to check if we did not request the
58 // result of a vector-valued observable outside of the scope of a loop.
59 auto foundVecObs = _vecObsIndices.find(arg.namePtr());
60 if (foundVecObs != _vecObsIndices.end()) {
61 throw std::runtime_error("You requested the result of a vector observable outside a loop scope for it!");
62 }
63
64 // Now, recursively call translate into the current argument to load the correct result.
65 arg.translate(*this);
66
67 return _nodeNames.at(arg.namePtr());
68}
69
70/// @brief Adds the given string to the string block that will be emitted at the top of the squashed function. Useful
71/// for variable declarations.
72/// @param str The string to add to the global scope.
73void CodeSquashContext::addToGlobalScope(std::string const &str)
74{
75 _globalScope += str;
76}
77
78/// @brief Assemble and return the final code with the return expression and global statements.
79/// @param returnExpr The string representation of what the squashed function should return, usually the head node.
80/// @return The final body of the function.
81std::string CodeSquashContext::assembleCode(std::string const &returnExpr)
82{
83 std::string arrDecl;
84 if(!_xlArr.empty()) {
85 arrDecl += "double auxArr[" + std::to_string(_xlArr.size()) + "];\n";
86 arrDecl += "for (int i = 0; i < " + std::to_string(_xlArr.size()) + "; i++) auxArr[i] = xlArr[i];\n";
87 }
88 return arrDecl + _globalScope + _code + "\n return " + returnExpr + ";\n";
89}
90
91/// @brief Since the squashed code represents all observables as a single flattened array, it is important
92/// to keep track of the start index for a vector valued observable which can later be expanded to access the correct
93/// element. For example, a vector valued variable x with 10 entries will be squashed to obs[start_idx + i].
94/// @param key The name of the node representing the vector valued observable.
95/// @param idx The start index (or relative position of the observable in the set of all observables).
96void CodeSquashContext::addVecObs(const char *key, int idx)
97{
98 const TNamed *namePtr = RooNameReg::known(key);
99 if (namePtr)
100 _vecObsIndices[namePtr] = idx;
101}
102
103/// @brief Adds the input string to the squashed code body. If a class implements a translate function that wants to
104/// emit something to the squashed code body, it must call this function with the code it wants to emit. In case of
105/// loops, automatically determines if code needs to be stored inside or outside loop scope.
106/// @param klass The class requesting this addition, usually 'this'.
107/// @param in String to add to the squashed code.
108void CodeSquashContext::addToCodeBody(RooAbsArg const *klass, std::string const &in)
109{
110 // If we are in a loop and the value is scope independent, save it at the top of the loop.
111 // else, just save it in the current scope.
113}
114
115/// @brief A variation of the previous addToCodeBody that takes in a bool value that determines
116/// if input is independent. This overload exists because there might other ways to determine if
117/// a value/collection of values is scope independent.
118/// @param in String to add to the squashed code.
119/// @param isScopeIndep The value determining if the input is scope dependent.
120void CodeSquashContext::addToCodeBody(std::string const &in, bool isScopeIndep /* = false */)
121{
122 // If we are in a loop and the value is scope independent, save it at the top of the loop.
123 // else, just save it in the current scope.
124 if (_scopePtr != -1 && isScopeIndep) {
125 _tempScope += in;
126 } else {
127 _code += in;
128 }
129}
130
131/// @brief Create a RAII scope for iterating over vector observables. You can't use the result of vector observables
132/// outside these loop scopes.
133/// @param in A pointer to the calling class, used to determine the loop dependent variables.
134std::unique_ptr<CodeSquashContext::LoopScope> CodeSquashContext::beginLoop(RooAbsArg const *in)
135{
136 std::string idx = "loopIdx" + std::to_string(_loopLevel);
137
138 std::vector<TNamed const *> vars;
139 // set the results of the vector observables
140 for (auto const &it : _vecObsIndices) {
141 if (!in->dependsOn(it.first))
142 continue;
143
144 vars.push_back(it.first);
145 _nodeNames[it.first] = "obs[" + std::to_string(it.second) + " + " + idx + "]";
146 }
147
148 // TODO: we are using the size of the first loop variable to the the number
149 // of iterations, but it should be made sure that all loop vars are either
150 // scalar or have the same size.
151 std::size_t numEntries = 1;
152 for (auto &it : vars) {
153 std::size_t n = outputSize(it);
154 if (n > 1 && numEntries > 1 && n != numEntries) {
155 throw std::runtime_error("Trying to loop over variables with different sizes!");
156 }
157 numEntries = std::max(n, numEntries);
158 }
159
160 // Save the current size of the code array so that we can insert the code at the right position.
161 _scopePtr = _code.size();
162
163 // Make sure that the name of this variable doesn't clash with other stuff
164 addToCodeBody(in, "for(int " + idx + " = 0; " + idx + " < " + std::to_string(numEntries) + "; " + idx + "++) {\n");
165
166 ++_loopLevel;
167 return std::make_unique<LoopScope>(*this, std::move(vars));
168}
169
171{
172 _code += "}\n";
173
174 // Insert the temporary code into the correct code position.
175 _code.insert(_scopePtr, _tempScope);
176 _tempScope.erase();
177 _scopePtr = -1;
178
179 // clear the results of the loop variables if they were vector observables
180 for (auto const &ptr : scope.vars()) {
181 if (_vecObsIndices.find(ptr) != _vecObsIndices.end())
182 _nodeNames.erase(ptr);
183 }
184 --_loopLevel;
185}
186
187/// @brief Get a unique variable name to be used in the generated code.
189{
190 return "t" + std::to_string(_tmpVarIdx++);
191}
192
193/// @brief A function to save an expression that includes/depends on the result of the input node.
194/// @param in The node on which the valueToSave depends on/belongs to.
195/// @param valueToSave The actual string value to save as a temporary.
196void CodeSquashContext::addResult(RooAbsArg const *in, std::string const &valueToSave)
197{
198 //std::string savedName = RooFit::Detail::makeValidVarName(in->GetName());
199 std::string savedName = getTmpVarName();
200
201 // Only save values if they contain operations.
202 bool hasOperations = valueToSave.find_first_of(":-+/*") != std::string::npos;
203
204 // If the name is not empty and this value is worth saving, save it to the correct scope.
205 // otherwise, just return the actual value itself
206 if (hasOperations) {
207 // If this is a scalar result, it will go just outside the loop because
208 // it doesn't need to be recomputed inside loops.
209 std::string outVarDecl = "const double " + savedName + " = " + valueToSave + ";\n";
210 addToCodeBody(in, outVarDecl);
211 } else {
212 savedName = valueToSave;
213 }
214
215 addResult(in->namePtr(), savedName);
216}
217
218/// @brief Function to save a RooListProxy as an array in the squashed code.
219/// @param in The list to convert to array.
220/// @return Name of the array that stores the input list in the squashed code.
222{
223 auto it = listNames.find(in.uniqueId().value());
224 if (it != listNames.end())
225 return it->second;
226
227 std::string savedName = getTmpVarName();
228 bool canSaveOutside = true;
229
230 std::stringstream declStrm;
231 declStrm << "double " << savedName << "[] = {";
232 for (const auto arg : in) {
233 declStrm << getResult(*arg) << ",";
234 canSaveOutside = canSaveOutside && isScopeIndependent(arg);
235 }
236 declStrm.seekp(-1, declStrm.cur);
237 declStrm << "};\n";
238
239 addToCodeBody(declStrm.str(), canSaveOutside);
240
241 listNames.insert({in.uniqueId().value(), savedName});
242 return savedName;
243}
244
245std::string CodeSquashContext::buildArg(std::span<const double> arr)
246{
247 unsigned int n = arr.size();
248 std::string offset = std::to_string(_xlArr.size());
249 _xlArr.reserve(_xlArr.size() + n);
250 for (unsigned int i = 0; i < n; i++) {
251 _xlArr.push_back(arr[i]);
252 }
253 return "auxArr + " + offset;
254}
255
257{
258 return !in->isReducerNode() && outputSize(in->namePtr()) == 1;
259}
260
261} // namespace Detail
262} // namespace RooFit
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:77
bool dependsOn(const RooAbsCollection &serverList, const RooAbsArg *ignoreArg=nullptr, bool valueOnly=false) const
Test whether we depend on (ie, are served by) any object in the specified collection.
const TNamed * namePtr() const
De-duplicated pointer to this object's name.
Definition RooAbsArg.h:561
virtual void translate(RooFit::Detail::CodeSquashContext &ctx) const
This function defines a translation for each RooAbsReal based object that can be used to express the ...
virtual bool isReducerNode() const
Definition RooAbsArg.h:575
Abstract container object that can hold multiple RooAbsArg objects.
RooFit::UniqueId< RooAbsCollection > const & uniqueId() const
Returns a unique ID that is different for every instantiated RooAbsCollection.
A class to manage loop scopes using the RAII technique.
std::vector< TNamed const * > const & vars() const
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
std::string _tempScope
Stores code that eventually gets injected into main code body.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void endLoop(LoopScope const &scope)
std::unordered_map< const TNamed *, int > _vecObsIndices
A map to keep track of the observable indices if they are non scalar.
int _loopLevel
The current number of for loops the started.
int _tmpVarIdx
Index to get unique names for temporary variables.
std::size_t outputSize(RooFit::Detail::DataKey key) const
Figure out the output size of a node.
std::unordered_map< const TNamed *, std::string > _nodeNames
Map of node names to their result strings.
void addToCodeBody(RooAbsArg const *klass, std::string const &in)
Adds the input string to the squashed code body.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
bool isScopeIndependent(RooAbsArg const *in) const
std::string getTmpVarName() const
Get a unique variable name to be used in the generated code.
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
std::string _code
Stores the squashed code body.
void addToGlobalScope(std::string const &str)
Adds the given string to the string block that will be emitted at the top of the squashed function.
std::unordered_map< RooFit::UniqueId< RooAbsCollection >::Value_t, std::string > listNames
A map to keep track of list names as assigned by addResult.
std::string buildArg(RooAbsCollection const &x)
Function to save a RooListProxy as an array in the squashed code.
int _scopePtr
Keeps track of the position to go back and insert code to.
std::string _globalScope
Block of code that is placed before the rest of the function body.
std::unique_ptr< LoopScope > beginLoop(RooAbsArg const *in)
Create a RAII scope for iterating over vector observables.
static const TNamed * known(const char *stringPtr)
If the name is already known, return its TNamed pointer. Otherwise return 0 (don't register the name)...
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
const Int_t n
Definition legend1.C:16
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
constexpr Value_t value() const
Return numerical value of ID.
Definition UniqueId.h:59