Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooFuncWrapper.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#include <RooFuncWrapper.h>
14
15#include <RooAbsData.h>
17#include <RooFit/Evaluator.h>
18#include <RooGlobalFunc.h>
19#include <RooHelpers.h>
20#include <RooMsgService.h>
21#include <RooRealVar.h>
22#include <RooSimultaneous.h>
23
24#include "RooEvaluatorWrapper.h"
26
27#include <TROOT.h>
28#include <TSystem.h>
29
30#include <fstream>
31#include <set>
32
33namespace {
34
35void replaceAll(std::string &str, const std::string &from, const std::string &to)
36{
37 if (from.empty())
38 return;
39 size_t start_pos = 0;
40 while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
41 str.replace(start_pos, from.length(), to);
42 start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
43 }
44}
45
46} // namespace
47
48namespace RooFit {
49
50namespace Experimental {
51
52RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal &obj, const RooAbsData *data,
54 : RooAbsReal{name, title}, _params{"!params", "List of parameters", this}, _useEvaluator{useEvaluator}
55{
56 if (_useEvaluator) {
57 _absReal = std::make_unique<RooEvaluatorWrapper>(obj, const_cast<RooAbsData *>(data), false, "", simPdf, false);
58 }
59
60 std::string func;
61
62 // Get the parameters.
64 obj.getParameters(data ? data->get() : nullptr, paramSet);
65
66 // Load the parameters and observables.
68
69 func = buildCode(obj);
70
71 gInterpreter->Declare("#pragma cling optimize(2)");
72
73 // Declare the function and create its derivative.
75 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
76}
77
80 _params("!params", this, other._params),
81 _funcName(other._funcName),
82 _func(other._func),
83 _grad(other._grad),
84 _hasGradient(other._hasGradient),
85 _gradientVarBuffer(other._gradientVarBuffer),
86 _observables(other._observables)
87{
88}
89
92{
93 // Extract observables
94 std::stack<std::vector<double>> vectorBuffers; // for data loading
95 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
96
97 if (data) {
98 spans = RooFit::BatchModeDataHelpers::getDataSpans(*data, "", simPdf, true, false, vectorBuffers);
99 }
100
101 std::size_t idx = 0;
102 for (auto const &item : spans) {
103 std::size_t n = item.second.size();
104 _obsInfos.emplace(item.first, ObsInfo{idx, n});
105 _observables.reserve(_observables.size() + n);
106 for (std::size_t i = 0; i < n; ++i) {
107 _observables.push_back(item.second[i]);
108 }
109 idx += n;
110 }
111
112 // Extract parameters
113 for (auto *param : paramSet) {
114 if (!dynamic_cast<RooAbsReal *>(param)) {
115 std::stringstream errorMsg;
116 errorMsg << "In creation of function " << GetName()
117 << " wrapper: input param expected to be of type RooAbsReal.";
118 coutE(InputArguments) << errorMsg.str() << std::endl;
119 throw std::runtime_error(errorMsg.str().c_str());
120 }
121 if (spans.find(param) == spans.end()) {
122 _params.add(*param);
123 }
124 }
125 _gradientVarBuffer.resize(_params.size());
126
127 if (head) {
128 _nodeOutputSizes = RooFit::BatchModeDataHelpers::determineOutputSizes(
129 *head, [&spans](RooFit::Detail::DataKey key) -> int {
130 auto found = spans.find(key);
131 return found != spans.end() ? found->second.size() : -1;
132 });
133 }
134}
135
136std::string RooFuncWrapper::declareFunction(std::string const &funcBody)
137{
138 static int iFuncWrapper = 0;
139 auto funcName = "roo_func_wrapper_" + std::to_string(iFuncWrapper++);
140
141 // Declare the function
142 std::stringstream bodyWithSigStrm;
143 bodyWithSigStrm << "double " << funcName << "(double* params, double const* obs, double const* xlArr) {\n"
144 << funcBody << "\n}";
145 _collectedFunctions.emplace_back(funcName);
146 if (!gInterpreter->Declare(bodyWithSigStrm.str().c_str())) {
147 std::stringstream errorMsg;
148 errorMsg << "Function " << funcName << " could not be compiled. See above for details.";
149 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
150 throw std::runtime_error(errorMsg.str().c_str());
151 }
152 return funcName;
153}
154
155void RooFuncWrapper::createGradient()
156{
157 std::string gradName = _funcName + "_grad_0";
158 std::string requestName = _funcName + "_req";
159
160 // Calculate gradient
161 gInterpreter->Declare("#include <Math/CladDerivator.h>\n");
162 // disable clang-format for making the following code unreadable.
163 // clang-format off
164 std::stringstream requestFuncStrm;
165 requestFuncStrm << "#pragma clad ON\n"
166 "void " << requestName << "() {\n"
167 " clad::gradient(" << _funcName << ", \"params\");\n"
168 "}\n"
169 "#pragma clad OFF";
170 // clang-format on
171 if (!gInterpreter->Declare(requestFuncStrm.str().c_str())) {
172 std::stringstream errorMsg;
173 errorMsg << "Function " << GetName() << " could not be differentiated. See above for details.";
174 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
175 throw std::runtime_error(errorMsg.str().c_str());
176 }
177
178 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine((gradName + ";").c_str()));
179 _hasGradient = true;
180}
181
182void RooFuncWrapper::gradient(double *out) const
183{
184 updateGradientVarBuffer();
185 std::fill(out, out + _params.size(), 0.0);
186
187 _grad(_gradientVarBuffer.data(), _observables.data(), _xlArr.data(), out);
188}
189
190void RooFuncWrapper::updateGradientVarBuffer() const
191{
192 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(),
193 [](RooAbsArg *obj) { return static_cast<RooAbsReal *>(obj)->getVal(); });
194}
195
196double RooFuncWrapper::evaluate() const
197{
198 if (_useEvaluator)
199 return _absReal->getVal();
200 updateGradientVarBuffer();
201
202 return _func(_gradientVarBuffer.data(), _observables.data(), _xlArr.data());
203}
204
205void RooFuncWrapper::gradient(const double *x, double *g) const
206{
207 std::fill(g, g + _params.size(), 0.0);
208
209 _grad(const_cast<double *>(x), _observables.data(), _xlArr.data(), g);
210}
211
212std::string RooFuncWrapper::buildCode(RooAbsReal const &head)
213{
214 RooFit::Detail::CodeSquashContext ctx(_nodeOutputSizes, _xlArr, *this);
215
216 // First update the result variable of params in the compute graph to in[<position>].
217 int idx = 0;
218 for (RooAbsArg *param : _params) {
219 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
220 idx++;
221 }
222
223 for (auto const &item : _obsInfos) {
224 const char *name = item.first->GetName();
225 // If the observable is scalar, set name to the start idx. else, store
226 // the start idx and later set the the name to obs[start_idx + curr_idx],
227 // here curr_idx is defined by a loop producing parent node.
228 if (item.second.size == 1) {
229 ctx.addResult(name, "obs[" + std::to_string(item.second.idx) + "]");
230 } else {
231 ctx.addResult(name, "obs");
232 ctx.addVecObs(name, item.second.idx);
233 }
234 }
235
236 return ctx.assembleCode(ctx.getResult(head));
237}
238
239/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
240void RooFuncWrapper::writeDebugMacro(std::string const &filename) const
241{
242 std::stringstream allCode;
243 std::set<std::string> seenFunctions;
244
245 // Remove duplicated declared functions
246 for (std::string const &name : _collectedFunctions) {
247 if (seenFunctions.count(name) > 0) {
248 continue;
249 }
250 seenFunctions.insert(name);
251 std::unique_ptr<TInterpreterValue> v = gInterpreter->MakeInterpreterValue();
252 gInterpreter->Evaluate(name.c_str(), *v);
253 std::string s = v->ToString();
254 for (int i = 0; i < 2; ++i) {
255 s = s.erase(0, s.find("\n") + 1);
256 }
257 allCode << s << std::endl;
258 }
259
260 std::ofstream outFile;
261 outFile.open(filename + ".C");
262 outFile << R"(//auto-generated test macro
263#include <RooFit/Detail/MathFuncs.h>
264#include <Math/CladDerivator.h>
265
266#pragma cling optimize(2)
267)" << allCode.str()
268 << R"(
269#pragma clad ON
270void gradient_request() {
271 clad::gradient()"
272 << _funcName << R"(, "params");
273}
274#pragma clad OFF
275)";
276
277 updateGradientVarBuffer();
278
279 auto writeVector = [&](std::string const &name, std::span<const double> vec) {
280 std::stringstream decl;
281 decl << "std::vector<double> " << name << " = {";
282 for (std::size_t i = 0; i < vec.size(); ++i) {
283 if (i % 10 == 0)
284 decl << "\n ";
285 decl << vec[i];
286 if (i < vec.size() - 1)
287 decl << ", ";
288 }
289 decl << "\n};\n";
290
291 std::string declStr = decl.str();
292
293 replaceAll(declStr, "inf", "std::numeric_limits<double>::infinity()");
294 replaceAll(declStr, "nan", "NAN");
295
296 outFile << declStr;
297 };
298
299 outFile << "// clang-format off\n" << std::endl;
300 writeVector("parametersVec", _gradientVarBuffer);
301 outFile << std::endl;
302 writeVector("observablesVec", _observables);
303 outFile << std::endl;
304 writeVector("auxConstantsVec", _xlArr);
305 outFile << std::endl;
306 outFile << "// clang-format on\n" << std::endl;
307
308 outFile << R"(
309// To run as a ROOT macro
310void )" << filename
311 << R"(()
312{
313 std::vector<double> gradientVec(parametersVec.size());
314
315 auto func = [&](std::span<double> params) {
316 return )"
317 << _funcName << R"((params.data(), observablesVec.data(), auxConstantsVec.data());
318 };
319 auto grad = [&](std::span<double> params, std::span<double> out) {
320 return )"
321 << _funcName << R"(_grad_0(parametersVec.data(), observablesVec.data(), auxConstantsVec.data(),
322 out.data());
323 };
324
325 grad(parametersVec, gradientVec);
326
327 auto numDiff = [&](int i) {
328 const double eps = 1e-6;
329 std::vector<double> p{parametersVec};
330 p[i] = parametersVec[i] - eps;
331 double funcValDown = func(p);
332 p[i] = parametersVec[i] + eps;
333 double funcValUp = func(p);
334 return (funcValUp - funcValDown) / (2 * eps);
335 };
336
337 for (std::size_t i = 0; i < parametersVec.size(); ++i) {
338 std::cout << i << ":" << std::endl;
339 std::cout << " numr : " << numDiff(i) << std::endl;
340 std::cout << " clad : " << gradientVec[i] << std::endl;
341 }
342}
343)";
344}
345
346} // namespace Experimental
347
348} // namespace RooFit
#define g(i)
Definition RSha256.hxx:105
#define oocoutE(o, a)
#define coutE(a)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:79
RooFit::OwningPtr< RooArgSet > getParameters(const RooAbsData *data, bool stripDisconnected=true) const
Create a list of leaf nodes in the arg tree starting with ourself as top node that don't match any of...
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
A class to maintain the context for squashing of RooFit models into code.
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
std::unique_ptr< RooAbsReal > _absReal
std::string buildCode(RooAbsReal const &head)
void loadParamsAndData(RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
std::map< RooFit::Detail::DataKey, ObsInfo > _obsInfos
std::string declareFunction(std::string const &funcBody)
RooFuncWrapper(const char *name, const char *title, RooAbsReal &obj, const RooAbsData *data=nullptr, RooSimultaneous const *simPdf=nullptr, bool useEvaluator=false)
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
@ InputArguments