Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooEvaluatorWrapper.cxx
Go to the documentation of this file.
1/// \cond ROOFIT_INTERNAL
2
3/*
4 * Project: RooFit
5 * Authors:
6 * Jonas Rembser, CERN 2023
7 *
8 * Copyright (c) 2023, CERN
9 *
10 * Redistribution and use in source and binary forms,
11 * with or without modification, are permitted according to the terms
12 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
13 */
14
15/**
16\internal
17\file RooEvaluatorWrapper.cxx
18\class RooEvaluatorWrapper
19\ingroup Roofitcore
20
21Wraps a RooFit::Evaluator that evaluates a RooAbsReal back into a RooAbsReal.
22**/
23
24#include "RooEvaluatorWrapper.h"
25
26#include <RooAbsData.h>
27#include <RooAbsPdf.h>
28#include <RooMsgService.h>
29#include <RooRealVar.h>
30#include <RooSimultaneous.h>
31
33
34#include <TInterpreter.h>
35
36#include <fstream>
37
38namespace RooFit::Experimental {
39
40RooEvaluatorWrapper::RooEvaluatorWrapper(RooAbsReal &topNode, RooAbsData *data, bool useGPU,
41 std::string const &rangeName, RooAbsPdf const *pdf,
43 : RooAbsReal{"RooEvaluatorWrapper", "RooEvaluatorWrapper"},
44 _evaluator{std::make_unique<RooFit::Evaluator>(topNode, useGPU)},
45 _topNode("topNode", "top node", this, topNode, false, false),
46 _data{data},
47 _paramSet("paramSet", "Set of parameters", this),
48 _rangeName{rangeName},
49 _pdf{pdf},
50 _takeGlobalObservablesFromData{takeGlobalObservablesFromData}
51{
52 if (data) {
53 setData(*data, false);
54 }
55 _paramSet.add(_evaluator->getParameters());
56 for (auto const &item : _dataSpans) {
57 _paramSet.remove(*_paramSet.find(item.first->GetName()));
58 }
59}
60
61RooEvaluatorWrapper::RooEvaluatorWrapper(const RooEvaluatorWrapper &other, const char *name)
63 _evaluator{other._evaluator},
64 _topNode("topNode", this, other._topNode),
65 _data{other._data},
66 _paramSet("paramSet", "Set of parameters", this),
67 _rangeName{other._rangeName},
68 _pdf{other._pdf},
69 _takeGlobalObservablesFromData{other._takeGlobalObservablesFromData},
71{
72 _paramSet.add(other._paramSet);
73}
74
75RooEvaluatorWrapper::~RooEvaluatorWrapper() = default;
76
77bool RooEvaluatorWrapper::getParameters(const RooArgSet *observables, RooArgSet &outputSet,
78 bool stripDisconnected) const
79{
80 outputSet.add(_evaluator->getParameters());
81 if (observables) {
82 outputSet.remove(*observables, /*silent*/ false, /*matchByNameOnly*/ true);
83 }
84 // Exclude the data variables from the parameters which are not global observables
85 for (auto const &item : _dataSpans) {
86 if (_data->getGlobalObservables() && _data->getGlobalObservables()->find(item.first->GetName())) {
87 continue;
88 }
89 RooAbsArg *found = outputSet.find(item.first->GetName());
90 if (found) {
91 outputSet.remove(*found);
92 }
93 }
94 // If we take the global observables as data, we have to return these as
95 // parameters instead of the parameters in the model. Otherwise, the
96 // constant parameters in the fit result that are global observables will
97 // not have the right values.
98 if (_takeGlobalObservablesFromData && _data->getGlobalObservables()) {
99 outputSet.replace(*_data->getGlobalObservables());
100 }
101
102 // The disconnected parameters are stripped away in
103 // RooAbsArg::getParametersHook(), that is only called in the original
104 // RooAbsArg::getParameters() implementation. So he have to call it to
105 // identify disconnected parameters to remove.
106 if (stripDisconnected) {
108 _topNode->getParameters(observables, paramsStripped, true);
110 for (RooAbsArg *param : outputSet) {
111 if (!paramsStripped.find(param->GetName())) {
112 toRemove.add(*param);
113 }
114 }
115 outputSet.remove(toRemove, /*silent*/ false, /*matchByNameOnly*/ true);
116 }
117
118 return false;
119}
120
121bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
122{
123 // To make things easiear for RooFit, we only support resetting with
124 // datasets that have the same structure, e.g. the same columns and global
125 // observables. This is anyway the usecase: resetting same-structured data
126 // when iterating over toys.
127 constexpr auto errMsg = "Error in RooAbsReal::setData(): only resetting with same-structured data is supported.";
128
129 _data = &data;
130 bool isInitializing = _paramSet.empty();
131 const std::size_t oldSize = _dataSpans.size();
132
133 std::stack<std::vector<double>>{}.swap(_vectorBuffers);
134 bool skipZeroWeights = !_pdf || !_pdf->getAttribute("BinnedLikelihoodActive");
135 _dataSpans =
136 RooFit::BatchModeDataHelpers::getDataSpans(*_data, _rangeName, dynamic_cast<RooSimultaneous const *>(_pdf),
137 skipZeroWeights, _takeGlobalObservablesFromData, _vectorBuffers);
138 if (!isInitializing && _dataSpans.size() != oldSize) {
139 coutE(DataHandling) << errMsg << std::endl;
140 throw std::runtime_error(errMsg);
141 }
142 for (auto const &item : _dataSpans) {
143 const char *name = item.first->GetName();
144 _evaluator->setInput(name, item.second, false);
145 if (_paramSet.find(name)) {
146 coutE(DataHandling) << errMsg << std::endl;
147 throw std::runtime_error(errMsg);
148 }
149 }
150 return true;
151}
152
153/// @brief A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
154/// The parameters can be accessed as params[<relative position of param in paramSet>] in the function body.
155/// The observables can be accessed as obs[i + j], where i represents the observable position and j
156/// represents the data entry.
157class RooFuncWrapper {
158public:
160
161 bool hasGradient() const { return _hasGradient; }
162 void gradient(double *out) const
163 {
165 std::fill(out, out + _params.size(), 0.0);
166
167 _grad(_gradientVarBuffer.data(), _observables.data(), _xlArr.data(), out);
168 }
169
170 void createGradient();
171
172 void writeDebugMacro(std::string const &) const;
173
174 std::vector<std::string> const &collectedFunctions() { return _collectedFunctions; }
175
176 double evaluate() const
177 {
179 return _func(_gradientVarBuffer.data(), _observables.data(), _xlArr.data());
180 }
181
182private:
183 void updateGradientVarBuffer() const;
184
185 std::map<RooFit::Detail::DataKey, std::span<const double>>
187
189
190 using Func = double (*)(double *, double const *, double const *);
191 using Grad = void (*)(double *, double const *, double const *, double *);
192
193 struct ObsInfo {
194 ObsInfo(std::size_t i, std::size_t n) : idx{i}, size{n} {}
195 std::size_t idx = 0;
196 std::size_t size = 0;
197 };
198
199 RooArgList _params;
200 std::string _funcName;
201 Func _func;
202 Grad _grad;
203 bool _hasGradient = false;
204 mutable std::vector<double> _gradientVarBuffer;
205 std::vector<double> _observables;
206 std::map<RooFit::Detail::DataKey, ObsInfo> _obsInfos;
207 std::vector<double> _xlArr;
208 std::vector<std::string> _collectedFunctions;
209};
210
211namespace {
212
213void replaceAll(std::string &str, const std::string &from, const std::string &to)
214{
215 if (from.empty())
216 return;
217 size_t start_pos = 0;
218 while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
219 str.replace(start_pos, from.length(), to);
220 start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
221 }
222}
223
224} // namespace
225
226RooFuncWrapper::RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf,
227 RooArgSet const &paramSet)
228{
229 // Load the parameters and observables.
231
232 // Set up the code generation context
233 std::map<RooFit::Detail::DataKey, std::size_t> nodeOutputSizes =
234 RooFit::BatchModeDataHelpers::determineOutputSizes(obj, [&spans](RooFit::Detail::DataKey key) -> int {
235 auto found = spans.find(key);
236 return found != spans.end() ? found->second.size() : -1;
237 });
238
240
241 // First update the result variable of params in the compute graph to in[<position>].
242 int idx = 0;
243 for (RooAbsArg *param : _params) {
244 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
245 idx++;
246 }
247
248 for (auto const &item : _obsInfos) {
249 const char *obsName = item.first->GetName();
250 // If the observable is scalar, set name to the start idx. else, store
251 // the start idx and later set the the name to obs[start_idx + curr_idx],
252 // here curr_idx is defined by a loop producing parent node.
253 if (item.second.size == 1) {
254 ctx.addResult(obsName, "obs[" + std::to_string(item.second.idx) + "]");
255 } else {
256 ctx.addResult(obsName, "obs");
257 ctx.addVecObs(obsName, item.second.idx);
258 }
259 }
260
261 // Declare the function and create its derivative.
262 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
263 ROOT::Math::Util::TimingScope timingScope(print, "Function JIT time:");
264 _funcName = ctx.buildFunction(obj, nodeOutputSizes);
265
266 // Make sure the codegen implementations are known to the interpreter
267 gInterpreter->Declare("#include <RooFit/CodegenImpl.h>\n");
268
269 if (!gInterpreter->Declare(ctx.collectedCode().c_str())) {
270 std::stringstream errorMsg;
271 std::string debugFileName = "_codegen_" + _funcName + ".cxx";
272 errorMsg << "Function " << _funcName << " could not be compiled. See above for details. Full code dumped to file "
273 << debugFileName << "for debugging";
274 {
275 std::ofstream outFile;
276 outFile.open(debugFileName.c_str());
277 outFile << ctx.collectedCode();
278 }
279 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
280 throw std::runtime_error(errorMsg.str().c_str());
281 }
282
283 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
284
285 _xlArr = ctx.xlArr();
286 _collectedFunctions = ctx.collectedFunctions();
287}
288
289std::map<RooFit::Detail::DataKey, std::span<const double>>
290RooFuncWrapper::loadParamsAndData(RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
291{
292 // Extract observables
293 std::stack<std::vector<double>> vectorBuffers; // for data loading
294 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
295
296 if (data) {
297 spans = RooFit::BatchModeDataHelpers::getDataSpans(*data, "", simPdf, true, false, vectorBuffers);
298 }
299
300 std::size_t idx = 0;
301 for (auto const &item : spans) {
302 std::size_t n = item.second.size();
303 _obsInfos.emplace(item.first, ObsInfo{idx, n});
304 _observables.reserve(_observables.size() + n);
305 for (std::size_t i = 0; i < n; ++i) {
306 _observables.push_back(item.second[i]);
307 }
308 idx += n;
309 }
310
311 for (auto *param : paramSet) {
312 if (spans.find(param) == spans.end()) {
313 _params.add(*param);
314 }
315 }
316 _gradientVarBuffer.resize(_params.size());
317
318 return spans;
319}
320
321void RooFuncWrapper::createGradient()
322{
323#ifdef ROOFIT_CLAD
324 std::string gradName = _funcName + "_grad_0";
325 std::string requestName = _funcName + "_req";
326
327 // Calculate gradient
328 gInterpreter->Declare("#include <Math/CladDerivator.h>\n");
329 // disable clang-format for making the following code unreadable.
330 // clang-format off
331 std::stringstream requestFuncStrm;
332 requestFuncStrm << "#pragma clad ON\n"
333 "void " << requestName << "() {\n"
334 " clad::gradient(" << _funcName << ", \"params\");\n"
335 "}\n"
336 "#pragma clad OFF";
337 // clang-format on
338 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
339
340 bool cladSuccess = false;
341 {
342 ROOT::Math::Util::TimingScope timingScope(print, "Gradient generation time:");
343 cladSuccess = !gInterpreter->Declare(requestFuncStrm.str().c_str());
344 }
345 if (cladSuccess) {
346 std::stringstream errorMsg;
347 errorMsg << "Function could not be differentiated. See above for details.";
348 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
349 throw std::runtime_error(errorMsg.str().c_str());
350 }
351
352 // Clad provides different overloads for the gradient, and we need to
353 // resolve to the one that we want. Without the static_cast, getting the
354 // function pointer would be ambiguous.
355 std::stringstream ss;
356 ROOT::Math::Util::TimingScope timingScope(print, "Gradient IR to machine code time:");
357 ss << "static_cast<void (*)(double *, double const *, double const *, double *)>(" << gradName << ");";
358 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine(ss.str().c_str()));
359 _hasGradient = true;
360#else
361 _hasGradient = false;
362 std::stringstream errorMsg;
363 errorMsg << "Function could not be differentiated since ROOT was built without Clad support.";
364 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
365 throw std::runtime_error(errorMsg.str().c_str());
366#endif
367}
368
369void RooFuncWrapper::updateGradientVarBuffer() const
370{
371 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(), [](RooAbsArg *obj) {
372 return obj->isCategory() ? static_cast<RooAbsCategory *>(obj)->getCurrentIndex()
373 : static_cast<RooAbsReal *>(obj)->getVal();
374 });
375}
376
377/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
378void RooFuncWrapper::writeDebugMacro(std::string const &filename) const
379{
380 std::stringstream allCode;
381 std::set<std::string> seenFunctions;
382
383 // Remove duplicated declared functions
384 for (std::string const &name : _collectedFunctions) {
385 if (seenFunctions.count(name) > 0) {
386 continue;
387 }
388 seenFunctions.insert(name);
389 std::unique_ptr<TInterpreterValue> v = gInterpreter->MakeInterpreterValue();
390 gInterpreter->Evaluate(name.c_str(), *v);
391 std::string s = v->ToString();
392 for (int i = 0; i < 2; ++i) {
393 s = s.erase(0, s.find("\n") + 1);
394 }
395 allCode << s << std::endl;
396 }
397
398 std::ofstream outFile;
399 outFile.open(filename + ".C");
400 outFile << R"(//auto-generated test macro
401#include <RooFit/Detail/MathFuncs.h>
402#include <Math/CladDerivator.h>
403
404)" << allCode.str()
405 << R"(
406#pragma clad ON
407void gradient_request() {
408 clad::gradient()"
409 << _funcName << R"(, "params");
410}
411#pragma clad OFF
412)";
413
415
416 auto writeVector = [&](std::string const &name, std::span<const double> vec) {
417 std::stringstream decl;
418 decl << "std::vector<double> " << name << " = {";
419 for (std::size_t i = 0; i < vec.size(); ++i) {
420 if (i % 10 == 0)
421 decl << "\n ";
422 decl << vec[i];
423 if (i < vec.size() - 1)
424 decl << ", ";
425 }
426 decl << "\n};\n";
427
428 std::string declStr = decl.str();
429
430 replaceAll(declStr, "inf", "std::numeric_limits<double>::infinity()");
431 replaceAll(declStr, "nan", "NAN");
432
433 outFile << declStr;
434 };
435
436 outFile << "// clang-format off\n" << std::endl;
437 writeVector("parametersVec", _gradientVarBuffer);
438 outFile << std::endl;
439 writeVector("observablesVec", _observables);
440 outFile << std::endl;
441 writeVector("auxConstantsVec", _xlArr);
442 outFile << std::endl;
443 outFile << "// clang-format on\n" << std::endl;
444
445 outFile << R"(
446// To run as a ROOT macro
447void )" << filename
448 << R"(()
449{
450 std::vector<double> gradientVec(parametersVec.size());
451
452 auto func = [&](std::span<double> params) {
453 return )"
454 << _funcName << R"((params.data(), observablesVec.data(), auxConstantsVec.data());
455 };
456 auto grad = [&](std::span<double> params, std::span<double> out) {
457 return )"
458 << _funcName << R"(_grad_0(parametersVec.data(), observablesVec.data(), auxConstantsVec.data(),
459 out.data());
460 };
461
462 grad(parametersVec, gradientVec);
463
464 auto numDiff = [&](int i) {
465 const double eps = 1e-6;
466 std::vector<double> p{parametersVec};
467 p[i] = parametersVec[i] - eps;
468 double funcValDown = func(p);
469 p[i] = parametersVec[i] + eps;
470 double funcValUp = func(p);
471 return (funcValUp - funcValDown) / (2 * eps);
472 };
473
474 for (std::size_t i = 0; i < parametersVec.size(); ++i) {
475 std::cout << i << ":" << std::endl;
476 std::cout << " numr : " << numDiff(i) << std::endl;
477 std::cout << " clad : " << gradientVec[i] << std::endl;
478 }
479}
480)";
481}
482
483double RooEvaluatorWrapper::evaluate() const
484{
486 return _funcWrapper->evaluate();
487
488 if (!_evaluator)
489 return 0.0;
490
491 _evaluator->setOffsetMode(hideOffset() ? RooFit::EvalContext::OffsetMode::WithoutOffset
492 : RooFit::EvalContext::OffsetMode::WithOffset);
493
494 return _evaluator->run()[0];
495}
496
497void RooEvaluatorWrapper::createFuncWrapper()
498{
499 // Get the parameters.
501 this->getParameters(_data ? _data->get() : nullptr, paramSet, /*sripDisconnectedParams=*/false);
502
504 std::make_unique<RooFuncWrapper>(*_topNode, _data, dynamic_cast<RooSimultaneous const *>(_pdf), paramSet);
505}
506
507void RooEvaluatorWrapper::generateGradient()
508{
509 if (!_funcWrapper)
511 _funcWrapper->createGradient();
512}
513
514void RooEvaluatorWrapper::setUseGeneratedFunctionCode(bool flag)
515{
519}
520
521void RooEvaluatorWrapper::gradient(double *out) const
522{
523 _funcWrapper->gradient(out);
524}
525
526bool RooEvaluatorWrapper::hasGradient() const
527{
528 if (!_funcWrapper)
529 return false;
530 return _funcWrapper->hasGradient();
531}
532
533void RooEvaluatorWrapper::writeDebugMacro(std::string const &filename) const
534{
535 if (_funcWrapper)
536 return _funcWrapper->writeDebugMacro(filename);
537}
538
539} // namespace RooFit::Experimental
540
541/// \endcond
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
#define oocoutE(o, a)
#define oocoutI(o, a)
#define coutE(a)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
const_iterator begin() const
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:76
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract interface for all probability density functions.
Definition RooAbsPdf.h:32
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:63
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
A class to maintain the context for squashing of RooFit models into code.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
const Int_t n
Definition legend1.C:16
void replaceAll(std::string &inOut, std::string_view what, std::string_view with)
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition CodegenImpl.h:67
void evaluate(typename Architecture_t::Tensor_t &A, EActivationFunction f)
Apply the given activation function to each value in the given tensor A.
Definition Functions.h:98