Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooEvaluatorWrapper.cxx
Go to the documentation of this file.
1/// \cond ROOFIT_INTERNAL
2
3/*
4 * Project: RooFit
5 * Authors:
6 * Jonas Rembser, CERN 2023
7 *
8 * Copyright (c) 2023, CERN
9 *
10 * Redistribution and use in source and binary forms,
11 * with or without modification, are permitted according to the terms
12 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
13 */
14
15/**
16\internal
17\file RooEvaluatorWrapper.cxx
18\class RooEvaluatorWrapper
19\ingroup Roofitcore
20
21Wraps a RooFit::Evaluator that evaluates a RooAbsReal back into a RooAbsReal.
22**/
23
24#include "RooEvaluatorWrapper.h"
25
26#include <RooAbsData.h>
27#include <RooAbsPdf.h>
28#include <RooConstVar.h>
29#include <RooHelpers.h>
30#include <RooMsgService.h>
31#include <RooRealVar.h>
32#include <RooSimultaneous.h>
33
34#include <TInterpreter.h>
35
36#include <fstream>
37
38RooEvaluatorWrapper::RooEvaluatorWrapper(RooAbsReal &topNode, RooAbsData *data, bool useGPU,
39 std::string const &rangeName, RooAbsPdf const *pdf,
41 : RooAbsReal{"RooEvaluatorWrapper", "RooEvaluatorWrapper"},
42 _evaluator{std::make_unique<RooFit::Evaluator>(topNode, useGPU)},
43 _topNode("topNode", "top node", this, topNode, false, false),
44 _data{data},
45 _paramSet("paramSet", "Set of parameters", this),
46 _rangeName{rangeName},
47 _pdf{pdf},
48 _takeGlobalObservablesFromData{takeGlobalObservablesFromData}
49{
50 if (data) {
51 setData(*data, false);
52 }
53 _paramSet.add(_evaluator->getParameters());
54 for (auto const &item : _dataSpans) {
55 _paramSet.remove(*_paramSet.find(item.first->GetName()));
56 }
57}
58
59RooEvaluatorWrapper::RooEvaluatorWrapper(const RooEvaluatorWrapper &other, const char *name)
61 _evaluator{other._evaluator},
62 _topNode("topNode", this, other._topNode),
63 _data{other._data},
64 _paramSet("paramSet", "Set of parameters", this),
65 _rangeName{other._rangeName},
66 _pdf{other._pdf},
67 _takeGlobalObservablesFromData{other._takeGlobalObservablesFromData},
69{
70 _paramSet.add(other._paramSet);
71}
72
73RooEvaluatorWrapper::~RooEvaluatorWrapper() = default;
74
75bool RooEvaluatorWrapper::getParameters(const RooArgSet *observables, RooArgSet &outputSet,
76 bool stripDisconnected) const
77{
78 outputSet.add(_evaluator->getParameters());
79 if (observables) {
80 outputSet.remove(*observables, /*silent*/ false, /*matchByNameOnly*/ true);
81 }
82 // Exclude the data variables from the parameters which are not global observables
83 for (auto const &item : _dataSpans) {
84 if (_data->getGlobalObservables() && _data->getGlobalObservables()->find(item.first->GetName())) {
85 continue;
86 }
87 RooAbsArg *found = outputSet.find(item.first->GetName());
88 if (found) {
89 outputSet.remove(*found);
90 }
91 }
92 // If we take the global observables as data, we have to return these as
93 // parameters instead of the parameters in the model. Otherwise, the
94 // constant parameters in the fit result that are global observables will
95 // not have the right values.
96 if (_takeGlobalObservablesFromData && _data->getGlobalObservables()) {
97 outputSet.replace(*_data->getGlobalObservables());
98 }
99
100 // The disconnected parameters are stripped away in
101 // RooAbsArg::getParametersHook(), that is only called in the original
102 // RooAbsArg::getParameters() implementation. So he have to call it to
103 // identify disconnected parameters to remove.
104 if (stripDisconnected) {
106 _topNode->getParameters(observables, paramsStripped, true);
108 for (RooAbsArg *param : outputSet) {
109 if (!paramsStripped.find(param->GetName())) {
110 toRemove.add(*param);
111 }
112 }
113 outputSet.remove(toRemove, /*silent*/ false, /*matchByNameOnly*/ true);
114 }
115
116 return false;
117}
118
119bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
120{
121 // To make things easiear for RooFit, we only support resetting with
122 // datasets that have the same structure, e.g. the same columns and global
123 // observables. This is anyway the usecase: resetting same-structured data
124 // when iterating over toys.
125 constexpr auto errMsg = "Error in RooAbsReal::setData(): only resetting with same-structured data is supported.";
126
127 _data = &data;
128 bool isInitializing = _paramSet.empty();
129 const std::size_t oldSize = _dataSpans.size();
130
131 std::stack<std::vector<double>>{}.swap(_vectorBuffers);
132 bool skipZeroWeights = !_pdf || !_pdf->getAttribute("BinnedLikelihoodActive");
133 _dataSpans =
134 RooFit::BatchModeDataHelpers::getDataSpans(*_data, _rangeName, dynamic_cast<RooSimultaneous const *>(_pdf),
135 skipZeroWeights, _takeGlobalObservablesFromData, _vectorBuffers);
136 if (!isInitializing && _dataSpans.size() != oldSize) {
137 coutE(DataHandling) << errMsg << std::endl;
138 throw std::runtime_error(errMsg);
139 }
140 for (auto const &item : _dataSpans) {
141 const char *name = item.first->GetName();
142 _evaluator->setInput(name, item.second, false);
143 if (_paramSet.find(name)) {
144 coutE(DataHandling) << errMsg << std::endl;
145 throw std::runtime_error(errMsg);
146 }
147 }
148 return true;
149}
150
151/// @brief A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
152/// The parameters can be accessed as params[<relative position of param in paramSet>] in the function body.
153/// The observables can be accessed as obs[i + j], where i represents the observable position and j
154/// represents the data entry.
155class RooFuncWrapper {
156public:
158
159 bool hasGradient() const { return _hasGradient; }
160 void gradient(double *out) const
161 {
163 std::fill(out, out + _params.size(), 0.0);
164
165 _grad(_gradientVarBuffer.data(), _observables.data(), _xlArr.data(), out);
166 }
167
168 void createGradient();
169
170 void writeDebugMacro(std::string const &) const;
171
172 std::vector<std::string> const &collectedFunctions() { return _collectedFunctions; }
173
174 double evaluate() const
175 {
177 return _func(_gradientVarBuffer.data(), _observables.data(), _xlArr.data());
178 }
179
180private:
181 void updateGradientVarBuffer() const;
182
183 std::map<RooFit::Detail::DataKey, std::span<const double>>
185
187
188 using Func = double (*)(double *, double const *, double const *);
189 using Grad = void (*)(double *, double const *, double const *, double *);
190
191 struct ObsInfo {
192 ObsInfo(std::size_t i, std::size_t n) : idx{i}, size{n} {}
193 std::size_t idx = 0;
194 std::size_t size = 0;
195 };
196
197 RooArgList _params;
198 std::string _funcName;
199 Func _func;
200 Grad _grad;
201 bool _hasGradient = false;
202 mutable std::vector<double> _gradientVarBuffer;
203 std::vector<double> _observables;
204 std::map<RooFit::Detail::DataKey, ObsInfo> _obsInfos;
205 std::vector<double> _xlArr;
206 std::vector<std::string> _collectedFunctions;
207};
208
209namespace {
210
211void replaceAll(std::string &str, const std::string &from, const std::string &to)
212{
213 if (from.empty())
214 return;
215 size_t start_pos = 0;
216 while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
217 str.replace(start_pos, from.length(), to);
218 start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
219 }
220}
221
222} // namespace
223
224RooFuncWrapper::RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf,
225 RooArgSet const &paramSet)
226{
227 // Load the parameters and observables.
229
230 // Set up the code generation context
231 std::map<RooFit::Detail::DataKey, std::size_t> nodeOutputSizes =
232 RooFit::BatchModeDataHelpers::determineOutputSizes(obj, [&spans](RooFit::Detail::DataKey key) -> int {
233 auto found = spans.find(key);
234 return found != spans.end() ? found->second.size() : -1;
235 });
236
238
239 // First update the result variable of params in the compute graph to in[<position>].
240 int idx = 0;
241 for (RooAbsArg *param : _params) {
242 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
243 idx++;
244 }
245
246 for (auto const &item : _obsInfos) {
247 const char *obsName = item.first->GetName();
248 // If the observable is scalar, set name to the start idx. else, store
249 // the start idx and later set the the name to obs[start_idx + curr_idx],
250 // here curr_idx is defined by a loop producing parent node.
251 if (item.second.size == 1) {
252 ctx.addResult(obsName, "obs[" + std::to_string(item.second.idx) + "]");
253 } else {
254 ctx.addResult(obsName, "obs");
255 ctx.addVecObs(obsName, item.second.idx);
256 }
257 }
258
259 gInterpreter->Declare("#pragma cling optimize(2)");
260
261 // Declare the function and create its derivative.
262 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
263 ROOT::Math::Util::TimingScope timingScope(print, "Function JIT time:");
264 _funcName = ctx.buildFunction(obj, nodeOutputSizes);
265 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
266
267 _xlArr = ctx.xlArr();
268 _collectedFunctions = ctx.collectedFunctions();
269}
270
271std::map<RooFit::Detail::DataKey, std::span<const double>>
272RooFuncWrapper::loadParamsAndData(RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
273{
274 // Extract observables
275 std::stack<std::vector<double>> vectorBuffers; // for data loading
276 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
277
278 if (data) {
279 spans = RooFit::BatchModeDataHelpers::getDataSpans(*data, "", simPdf, true, false, vectorBuffers);
280 }
281
282 std::size_t idx = 0;
283 for (auto const &item : spans) {
284 std::size_t n = item.second.size();
285 _obsInfos.emplace(item.first, ObsInfo{idx, n});
286 _observables.reserve(_observables.size() + n);
287 for (std::size_t i = 0; i < n; ++i) {
288 _observables.push_back(item.second[i]);
289 }
290 idx += n;
291 }
292
293 for (auto *param : paramSet) {
294 if (spans.find(param) == spans.end()) {
295 _params.add(*param);
296 }
297 }
298 _gradientVarBuffer.resize(_params.size());
299
300 return spans;
301}
302
303void RooFuncWrapper::createGradient()
304{
305#ifdef ROOFIT_CLAD
306 std::string gradName = _funcName + "_grad_0";
307 std::string requestName = _funcName + "_req";
308
309 // Calculate gradient
310 gInterpreter->Declare("#include <Math/CladDerivator.h>\n");
311 // disable clang-format for making the following code unreadable.
312 // clang-format off
313 std::stringstream requestFuncStrm;
314 requestFuncStrm << "#pragma clad ON\n"
315 "void " << requestName << "() {\n"
316 " clad::gradient(" << _funcName << ", \"params\");\n"
317 "}\n"
318 "#pragma clad OFF";
319 // clang-format on
320 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
321
322 bool cladSuccess = false;
323 {
324 ROOT::Math::Util::TimingScope timingScope(print, "Gradient generation time:");
325 cladSuccess = !gInterpreter->Declare(requestFuncStrm.str().c_str());
326 }
327 if (cladSuccess) {
328 std::stringstream errorMsg;
329 errorMsg << "Function could not be differentiated. See above for details.";
330 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
331 throw std::runtime_error(errorMsg.str().c_str());
332 }
333
334 // Clad provides different overloads for the gradient, and we need to
335 // resolve to the one that we want. Without the static_cast, getting the
336 // function pointer would be ambiguous.
337 std::stringstream ss;
338 ROOT::Math::Util::TimingScope timingScope(print, "Gradient IR to machine code time:");
339 ss << "static_cast<void (*)(double *, double const *, double const *, double *)>(" << gradName << ");";
340 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine(ss.str().c_str()));
341 _hasGradient = true;
342#else
343 _hasGradient = false;
344 std::stringstream errorMsg;
345 errorMsg << "Function could not be differentiated since ROOT was built without Clad support.";
346 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
347 throw std::runtime_error(errorMsg.str().c_str());
348#endif
349}
350
351void RooFuncWrapper::updateGradientVarBuffer() const
352{
353 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(), [](RooAbsArg *obj) {
354 return obj->isCategory() ? static_cast<RooAbsCategory *>(obj)->getCurrentIndex()
355 : static_cast<RooAbsReal *>(obj)->getVal();
356 });
357}
358
359/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
360void RooFuncWrapper::writeDebugMacro(std::string const &filename) const
361{
362 std::stringstream allCode;
363 std::set<std::string> seenFunctions;
364
365 // Remove duplicated declared functions
366 for (std::string const &name : _collectedFunctions) {
367 if (seenFunctions.count(name) > 0) {
368 continue;
369 }
370 seenFunctions.insert(name);
371 std::unique_ptr<TInterpreterValue> v = gInterpreter->MakeInterpreterValue();
372 gInterpreter->Evaluate(name.c_str(), *v);
373 std::string s = v->ToString();
374 for (int i = 0; i < 2; ++i) {
375 s = s.erase(0, s.find("\n") + 1);
376 }
377 allCode << s << std::endl;
378 }
379
380 std::ofstream outFile;
381 outFile.open(filename + ".C");
382 outFile << R"(//auto-generated test macro
383#include <RooFit/Detail/MathFuncs.h>
384#include <Math/CladDerivator.h>
385
386#pragma cling optimize(2)
387)" << allCode.str()
388 << R"(
389#pragma clad ON
390void gradient_request() {
391 clad::gradient()"
392 << _funcName << R"(, "params");
393}
394#pragma clad OFF
395)";
396
398
399 auto writeVector = [&](std::string const &name, std::span<const double> vec) {
400 std::stringstream decl;
401 decl << "std::vector<double> " << name << " = {";
402 for (std::size_t i = 0; i < vec.size(); ++i) {
403 if (i % 10 == 0)
404 decl << "\n ";
405 decl << vec[i];
406 if (i < vec.size() - 1)
407 decl << ", ";
408 }
409 decl << "\n};\n";
410
411 std::string declStr = decl.str();
412
413 replaceAll(declStr, "inf", "std::numeric_limits<double>::infinity()");
414 replaceAll(declStr, "nan", "NAN");
415
416 outFile << declStr;
417 };
418
419 outFile << "// clang-format off\n" << std::endl;
420 writeVector("parametersVec", _gradientVarBuffer);
421 outFile << std::endl;
422 writeVector("observablesVec", _observables);
423 outFile << std::endl;
424 writeVector("auxConstantsVec", _xlArr);
425 outFile << std::endl;
426 outFile << "// clang-format on\n" << std::endl;
427
428 outFile << R"(
429// To run as a ROOT macro
430void )" << filename
431 << R"(()
432{
433 std::vector<double> gradientVec(parametersVec.size());
434
435 auto func = [&](std::span<double> params) {
436 return )"
437 << _funcName << R"((params.data(), observablesVec.data(), auxConstantsVec.data());
438 };
439 auto grad = [&](std::span<double> params, std::span<double> out) {
440 return )"
441 << _funcName << R"(_grad_0(parametersVec.data(), observablesVec.data(), auxConstantsVec.data(),
442 out.data());
443 };
444
445 grad(parametersVec, gradientVec);
446
447 auto numDiff = [&](int i) {
448 const double eps = 1e-6;
449 std::vector<double> p{parametersVec};
450 p[i] = parametersVec[i] - eps;
451 double funcValDown = func(p);
452 p[i] = parametersVec[i] + eps;
453 double funcValUp = func(p);
454 return (funcValUp - funcValDown) / (2 * eps);
455 };
456
457 for (std::size_t i = 0; i < parametersVec.size(); ++i) {
458 std::cout << i << ":" << std::endl;
459 std::cout << " numr : " << numDiff(i) << std::endl;
460 std::cout << " clad : " << gradientVec[i] << std::endl;
461 }
462}
463)";
464}
465
466double RooEvaluatorWrapper::evaluate() const
467{
469 return _funcWrapper->evaluate();
470
471 if (!_evaluator)
472 return 0.0;
473
474 _evaluator->setOffsetMode(hideOffset() ? RooFit::EvalContext::OffsetMode::WithoutOffset
475 : RooFit::EvalContext::OffsetMode::WithOffset);
476
477 return _evaluator->run()[0];
478}
479
480void RooEvaluatorWrapper::createFuncWrapper()
481{
482 // Get the parameters.
484 this->getParameters(_data ? _data->get() : nullptr, paramSet, /*sripDisconnectedParams=*/false);
485
487 std::make_unique<RooFuncWrapper>(*_topNode, _data, dynamic_cast<RooSimultaneous const *>(_pdf), paramSet);
488}
489
490void RooEvaluatorWrapper::generateGradient()
491{
492 if (!_funcWrapper)
494 _funcWrapper->createGradient();
495}
496
497void RooEvaluatorWrapper::setUseGeneratedFunctionCode(bool flag)
498{
502}
503
504void RooEvaluatorWrapper::gradient(double *out) const
505{
506 _funcWrapper->gradient(out);
507}
508
509bool RooEvaluatorWrapper::hasGradient() const
510{
511 if (!_funcWrapper)
512 return false;
513 return _funcWrapper->hasGradient();
514}
515
516/// \endcond
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
#define oocoutE(o, a)
#define oocoutI(o, a)
#define coutE(a)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
const_iterator begin() const
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:76
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract interface for all probability density functions.
Definition RooAbsPdf.h:32
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:63
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
A class to maintain the context for squashing of RooFit models into code.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
const Int_t n
Definition legend1.C:16
void replaceAll(std::string &inOut, std::string_view what, std::string_view with)
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition CodegenImpl.h:67
void evaluate(typename Architecture_t::Tensor_t &A, EActivationFunction f)
Apply the given activation function to each value in the given tensor A.
Definition Functions.h:98