Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooFuncWrapper.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#include <RooFuncWrapper.h>
14
15#include <RooAbsData.h>
17#include <RooFit/Evaluator.h>
18#include <RooGlobalFunc.h>
19#include <RooHelpers.h>
20#include <RooMsgService.h>
21#include <RooRealVar.h>
22#include <RooSimultaneous.h>
23
24#include "RooEvaluatorWrapper.h"
26#include "RooFitImplHelpers.h"
27
28#include <TROOT.h>
29#include <TSystem.h>
30
31#include <Math/Util.h>
32
33#include <fstream>
34#include <set>
35
36namespace {
37
38void replaceAll(std::string &str, const std::string &from, const std::string &to)
39{
40 if (from.empty())
41 return;
42 size_t start_pos = 0;
43 while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
44 str.replace(start_pos, from.length(), to);
45 start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
46 }
47}
48
49} // namespace
50
51namespace RooFit {
52
53namespace Experimental {
54
55RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal &obj, const RooAbsData *data,
57 : RooAbsReal{name, title}, _params{"!params", "List of parameters", this}, _useEvaluator{useEvaluator}
58{
59 if (_useEvaluator) {
60 _absReal = std::make_unique<RooEvaluatorWrapper>(obj, const_cast<RooAbsData *>(data), false, "", simPdf, false);
61 }
62
63 // Get the parameters.
65 obj.getParameters(data ? data->get() : nullptr, paramSet);
66
67 // Load the parameters and observables.
69
70 // Set up the code generation context
71 std::map<RooFit::Detail::DataKey, std::size_t> nodeOutputSizes =
72 RooFit::BatchModeDataHelpers::determineOutputSizes(obj, [&spans](RooFit::Detail::DataKey key) -> int {
73 auto found = spans.find(key);
74 return found != spans.end() ? found->second.size() : -1;
75 });
76
78
79 // First update the result variable of params in the compute graph to in[<position>].
80 int idx = 0;
81 for (RooAbsArg *param : _params) {
82 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
83 idx++;
84 }
85
86 for (auto const &item : _obsInfos) {
87 const char *obsName = item.first->GetName();
88 // If the observable is scalar, set name to the start idx. else, store
89 // the start idx and later set the the name to obs[start_idx + curr_idx],
90 // here curr_idx is defined by a loop producing parent node.
91 if (item.second.size == 1) {
92 ctx.addResult(obsName, "obs[" + std::to_string(item.second.idx) + "]");
93 } else {
94 ctx.addResult(obsName, "obs");
95 ctx.addVecObs(obsName, item.second.idx);
96 }
97 }
98
99 gInterpreter->Declare("#pragma cling optimize(2)");
100
101 // Declare the function and create its derivative.
102 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
103 ROOT::Math::Util::TimingScope timingScope(print, "Function JIT time:");
104 _funcName = ctx.buildFunction(obj, nodeOutputSizes);
105 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
106
107 _xlArr = ctx.xlArr();
108 _collectedFunctions = ctx.collectedFunctions();
109}
110
113 _params("!params", this, other._params),
114 _funcName(other._funcName),
115 _func(other._func),
116 _grad(other._grad),
117 _hasGradient(other._hasGradient),
118 _gradientVarBuffer(other._gradientVarBuffer),
119 _observables(other._observables)
120{
121}
122
123std::map<RooFit::Detail::DataKey, std::span<const double>>
125{
126 // Extract observables
127 std::stack<std::vector<double>> vectorBuffers; // for data loading
128 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
129
130 if (data) {
131 spans = RooFit::BatchModeDataHelpers::getDataSpans(*data, "", simPdf, true, false, vectorBuffers);
132 }
133
134 std::size_t idx = 0;
135 for (auto const &item : spans) {
136 std::size_t n = item.second.size();
137 _obsInfos.emplace(item.first, ObsInfo{idx, n});
138 _observables.reserve(_observables.size() + n);
139 for (std::size_t i = 0; i < n; ++i) {
140 _observables.push_back(item.second[i]);
141 }
142 idx += n;
143 }
144
145 // Extract parameters
146 for (auto *param : paramSet) {
147 if (!dynamic_cast<RooAbsReal *>(param)) {
148 std::stringstream errorMsg;
149 errorMsg << "In creation of function " << GetName()
150 << " wrapper: input param expected to be of type RooAbsReal.";
151 coutE(InputArguments) << errorMsg.str() << std::endl;
152 throw std::runtime_error(errorMsg.str().c_str());
153 }
154 if (spans.find(param) == spans.end()) {
155 _params.add(*param);
156 }
157 }
158 _gradientVarBuffer.resize(_params.size());
159
160 return spans;
161}
162
163void RooFuncWrapper::createGradient()
164{
165 std::string gradName = _funcName + "_grad_0";
166 std::string requestName = _funcName + "_req";
167
168 // Calculate gradient
169 gInterpreter->Declare("#include <Math/CladDerivator.h>\n");
170 // disable clang-format for making the following code unreadable.
171 // clang-format off
172 std::stringstream requestFuncStrm;
173 requestFuncStrm << "#pragma clad ON\n"
174 "void " << requestName << "() {\n"
175 " clad::gradient(" << _funcName << ", \"params\");\n"
176 "}\n"
177 "#pragma clad OFF";
178 // clang-format on
179 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
180
181 bool cladSuccess = false;
182 {
183 ROOT::Math::Util::TimingScope timingScope(print, "Gradient generation time:");
184 cladSuccess = !gInterpreter->Declare(requestFuncStrm.str().c_str());
185 }
186 if (cladSuccess) {
187 std::stringstream errorMsg;
188 errorMsg << "Function " << GetName() << " could not be differentiated. See above for details.";
189 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
190 throw std::runtime_error(errorMsg.str().c_str());
191 }
192
193 // Clad provides different overloads for the gradient, and we need to
194 // resolve to the one that we want. Without the static_cast, getting the
195 // function pointer would be ambiguous.
196 std::stringstream ss;
197 ROOT::Math::Util::TimingScope timingScope(print, "Gradient IR to machine code time:");
198 ss << "static_cast<void (*)(double *, double const *, double const *, double *)>(" << gradName << ");";
199 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine(ss.str().c_str()));
200 _hasGradient = true;
201}
202
203void RooFuncWrapper::gradient(double *out) const
204{
205 updateGradientVarBuffer();
206 std::fill(out, out + _params.size(), 0.0);
207
208 _grad(_gradientVarBuffer.data(), _observables.data(), _xlArr.data(), out);
209}
210
211void RooFuncWrapper::updateGradientVarBuffer() const
212{
213 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(),
214 [](RooAbsArg *obj) { return static_cast<RooAbsReal *>(obj)->getVal(); });
215}
216
217double RooFuncWrapper::evaluate() const
218{
219 if (_useEvaluator)
220 return _absReal->getVal();
221 updateGradientVarBuffer();
222
223 return _func(_gradientVarBuffer.data(), _observables.data(), _xlArr.data());
224}
225
226void RooFuncWrapper::gradient(const double *x, double *g) const
227{
228 std::fill(g, g + _params.size(), 0.0);
229
230 _grad(const_cast<double *>(x), _observables.data(), _xlArr.data(), g);
231}
232
233/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
234void RooFuncWrapper::writeDebugMacro(std::string const &filename) const
235{
236 std::stringstream allCode;
237 std::set<std::string> seenFunctions;
238
239 // Remove duplicated declared functions
240 for (std::string const &name : _collectedFunctions) {
241 if (seenFunctions.count(name) > 0) {
242 continue;
243 }
244 seenFunctions.insert(name);
245 std::unique_ptr<TInterpreterValue> v = gInterpreter->MakeInterpreterValue();
246 gInterpreter->Evaluate(name.c_str(), *v);
247 std::string s = v->ToString();
248 for (int i = 0; i < 2; ++i) {
249 s = s.erase(0, s.find("\n") + 1);
250 }
251 allCode << s << std::endl;
252 }
253
254 std::ofstream outFile;
255 outFile.open(filename + ".C");
256 outFile << R"(//auto-generated test macro
257#include <RooFit/Detail/MathFuncs.h>
258#include <Math/CladDerivator.h>
259
260#pragma cling optimize(2)
261)" << allCode.str()
262 << R"(
263#pragma clad ON
264void gradient_request() {
265 clad::gradient()"
266 << _funcName << R"(, "params");
267}
268#pragma clad OFF
269)";
270
271 updateGradientVarBuffer();
272
273 auto writeVector = [&](std::string const &name, std::span<const double> vec) {
274 std::stringstream decl;
275 decl << "std::vector<double> " << name << " = {";
276 for (std::size_t i = 0; i < vec.size(); ++i) {
277 if (i % 10 == 0)
278 decl << "\n ";
279 decl << vec[i];
280 if (i < vec.size() - 1)
281 decl << ", ";
282 }
283 decl << "\n};\n";
284
285 std::string declStr = decl.str();
286
287 replaceAll(declStr, "inf", "std::numeric_limits<double>::infinity()");
288 replaceAll(declStr, "nan", "NAN");
289
290 outFile << declStr;
291 };
292
293 outFile << "// clang-format off\n" << std::endl;
294 writeVector("parametersVec", _gradientVarBuffer);
295 outFile << std::endl;
296 writeVector("observablesVec", _observables);
297 outFile << std::endl;
298 writeVector("auxConstantsVec", _xlArr);
299 outFile << std::endl;
300 outFile << "// clang-format on\n" << std::endl;
301
302 outFile << R"(
303// To run as a ROOT macro
304void )" << filename
305 << R"(()
306{
307 std::vector<double> gradientVec(parametersVec.size());
308
309 auto func = [&](std::span<double> params) {
310 return )"
311 << _funcName << R"((params.data(), observablesVec.data(), auxConstantsVec.data());
312 };
313 auto grad = [&](std::span<double> params, std::span<double> out) {
314 return )"
315 << _funcName << R"(_grad_0(parametersVec.data(), observablesVec.data(), auxConstantsVec.data(),
316 out.data());
317 };
318
319 grad(parametersVec, gradientVec);
320
321 auto numDiff = [&](int i) {
322 const double eps = 1e-6;
323 std::vector<double> p{parametersVec};
324 p[i] = parametersVec[i] - eps;
325 double funcValDown = func(p);
326 p[i] = parametersVec[i] + eps;
327 double funcValUp = func(p);
328 return (funcValUp - funcValDown) / (2 * eps);
329 };
330
331 for (std::size_t i = 0; i < parametersVec.size(); ++i) {
332 std::cout << i << ":" << std::endl;
333 std::cout << " numr : " << numDiff(i) << std::endl;
334 std::cout << " clad : " << gradientVec[i] << std::endl;
335 }
336}
337)";
338}
339
340} // namespace Experimental
341
342} // namespace RooFit
#define g(i)
Definition RSha256.hxx:105
#define oocoutE(o, a)
#define oocoutI(o, a)
#define coutE(a)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:77
RooFit::OwningPtr< RooArgSet > getParameters(const RooAbsData *data, bool stripDisconnected=true) const
Create a list of leaf nodes in the arg tree starting with ourself as top node that don't match any of...
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
A class to maintain the context for squashing of RooFit models into code.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
std::unique_ptr< RooAbsReal > _absReal
std::vector< std::string > _collectedFunctions
std::map< RooFit::Detail::DataKey, std::span< const double > > loadParamsAndData(RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
std::map< RooFit::Detail::DataKey, ObsInfo > _obsInfos
RooFuncWrapper(const char *name, const char *title, RooAbsReal &obj, const RooAbsData *data=nullptr, RooSimultaneous const *simPdf=nullptr, bool useEvaluator=false)
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
void replaceAll(std::string &inOut, std::string_view what, std::string_view with)
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition CodegenImpl.h:64
@ InputArguments