Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooEvaluatorWrapper.cxx
Go to the documentation of this file.
1/// \cond ROOFIT_INTERNAL
2
3/*
4 * Project: RooFit
5 * Authors:
6 * Jonas Rembser, CERN 2023
7 *
8 * Copyright (c) 2023, CERN
9 *
10 * Redistribution and use in source and binary forms,
11 * with or without modification, are permitted according to the terms
12 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
13 */
14
15/**
16\internal
17\file RooEvaluatorWrapper.cxx
18\class RooEvaluatorWrapper
19\ingroup Roofitcore
20
21Wraps a RooFit::Evaluator that evaluates a RooAbsReal back into a RooAbsReal.
22**/
23
24#include "RooEvaluatorWrapper.h"
25
26#include <RooAbsData.h>
27#include <RooAbsPdf.h>
28#include <RooMsgService.h>
29#include <RooRealVar.h>
30#include <RooSimultaneous.h>
31
33#include "RooFitImplHelpers.h"
34
35#include <TInterpreter.h>
36
37#include <fstream>
38
39namespace RooFit::Experimental {
40
41RooEvaluatorWrapper::RooEvaluatorWrapper(RooAbsReal &topNode, RooAbsData *data, bool useGPU,
42 std::string const &rangeName, RooAbsPdf const *pdf,
44 : RooAbsReal{"RooEvaluatorWrapper", "RooEvaluatorWrapper"},
45 _evaluator{std::make_unique<RooFit::Evaluator>(topNode, useGPU)},
46 _topNode("topNode", "top node", this, topNode, false, false),
47 _data{data},
48 _paramSet("paramSet", "Set of parameters", this),
49 _rangeName{rangeName},
50 _pdf{pdf},
51 _takeGlobalObservablesFromData{takeGlobalObservablesFromData}
52{
53 if (data) {
54 setData(*data, false);
55 }
56 _paramSet.add(_evaluator->getParameters());
57 for (auto const &item : _dataSpans) {
58 _paramSet.remove(*_paramSet.find(item.first->GetName()));
59 }
60}
61
62RooEvaluatorWrapper::RooEvaluatorWrapper(const RooEvaluatorWrapper &other, const char *name)
64 _evaluator{other._evaluator},
65 _topNode("topNode", this, other._topNode),
66 _data{other._data},
67 _paramSet("paramSet", "Set of parameters", this),
68 _rangeName{other._rangeName},
69 _pdf{other._pdf},
70 _takeGlobalObservablesFromData{other._takeGlobalObservablesFromData},
72{
73 _paramSet.add(other._paramSet);
74}
75
76RooEvaluatorWrapper::~RooEvaluatorWrapper() = default;
77
78bool RooEvaluatorWrapper::getParameters(const RooArgSet *observables, RooArgSet &outputSet,
79 bool stripDisconnected) const
80{
81 outputSet.add(_evaluator->getParameters());
82 if (observables) {
83 outputSet.remove(*observables, /*silent*/ false, /*matchByNameOnly*/ true);
84 }
85 // Exclude the data variables from the parameters which are not global observables
86 for (auto const &item : _dataSpans) {
87 if (_data->getGlobalObservables() && _data->getGlobalObservables()->find(item.first->GetName())) {
88 continue;
89 }
90 RooAbsArg *found = outputSet.find(item.first->GetName());
91 if (found) {
92 outputSet.remove(*found);
93 }
94 }
95 // If we take the global observables as data, we have to return these as
96 // parameters instead of the parameters in the model. Otherwise, the
97 // constant parameters in the fit result that are global observables will
98 // not have the right values.
99 if (_takeGlobalObservablesFromData && _data->getGlobalObservables()) {
100 outputSet.replace(*_data->getGlobalObservables());
101 }
102
103 // The disconnected parameters are stripped away in
104 // RooAbsArg::getParametersHook(), that is only called in the original
105 // RooAbsArg::getParameters() implementation. So he have to call it to
106 // identify disconnected parameters to remove.
107 if (stripDisconnected) {
109 _topNode->getParameters(observables, paramsStripped, true);
111 for (RooAbsArg *param : outputSet) {
112 if (!paramsStripped.find(param->GetName())) {
113 toRemove.add(*param);
114 }
115 }
116 outputSet.remove(toRemove, /*silent*/ false, /*matchByNameOnly*/ true);
117 }
118
119 return false;
120}
121
122/// @brief A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
123/// The parameters can be accessed as params[<relative position of param in paramSet>] in the function body.
124/// The observables can be accessed as obs[i + j], where i represents the observable position and j
125/// represents the data entry.
126class RooFuncWrapper {
127public:
129 std::string const &rangeName, bool skipZeroWeights);
130
131 bool hasGradient() const { return _hasGradient; }
132 bool hasHessian() const { return _hasHessian; }
133 void gradient(double *out) const
134 {
136 std::fill(out, out + _params.size(), 0.0);
137 _grad(_varBuffer.data(), _observables.data(), _xlArr.data(), out);
138 }
139 void hessian(double *out) const
140 {
142 std::fill(out, out + _params.size() * _params.size(), 0.0);
143 _hessian(_varBuffer.data(), _observables.data(), _xlArr.data(), out);
144 }
145
146 void createGradient();
147 void createHessian();
148
149 void writeDebugMacro(std::string const &) const;
150
151 std::vector<std::string> const &collectedFunctions() { return _collectedFunctions; }
152
153 double evaluate() const
154 {
156 return _func(_varBuffer.data(), _observables.data(), _xlArr.data());
157 }
158
159 void
160 loadData(RooAbsData const &data, RooSimultaneous const *simPdf, std::string const &rangeName, bool skipZeroWeights);
161
162private:
163 void updateGradientVarBuffer() const;
164
166
167 using Func = double (*)(double *, double const *, double const *);
168 using Grad = void (*)(double *, double const *, double const *, double *);
169 using Hessian = void (*)(double *, double const *, double const *, double *);
170
171 RooArgList _params;
172 std::string _funcName;
173 Func _func;
174 Grad _grad;
175 Hessian _hessian;
176 bool _hasGradient = false;
177 bool _hasHessian = false;
178 mutable std::vector<double> _varBuffer;
179 std::vector<double> _observables;
180 std::unordered_map<RooFit::Detail::DataKey, std::size_t> _obsInfos;
181 std::vector<double> _xlArr;
182 std::vector<std::string> _collectedFunctions;
183};
184
185namespace {
186
187void replaceAll(std::string &str, const std::string &from, const std::string &to)
188{
189 if (from.empty())
190 return;
191 size_t start_pos = 0;
192 while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
193 str.replace(start_pos, from.length(), to);
194 start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
195 }
196}
197
199{
202
203 std::unordered_set<RooFit::Detail::DataKey> dependsOnData;
204 for (RooAbsArg *arg : dataObs) {
205 dependsOnData.insert(arg);
206 }
207
208 for (RooAbsArg *arg : serverSet) {
209 if (arg->getAttribute("__obs__")) {
210 dependsOnData.insert(arg);
211 }
212 for (RooAbsArg *server : arg->servers()) {
213 if (server->isValueServer(*arg)) {
214 if (dependsOnData.find(server) != dependsOnData.end() && !arg->isReducerNode()) {
215 dependsOnData.insert(arg);
216 break;
217 }
218 }
219 }
220 }
221
222 return dependsOnData;
223}
224
225} // namespace
226
227RooFuncWrapper::RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf,
228 RooArgSet const &paramSet, std::string const &rangeName, bool skipZeroWeights)
229{
230 // Load the observables from the dataset
231 if (data) {
233 }
234
235 // Define the parameters
236 for (auto *param : paramSet) {
237 if (_obsInfos.find(param) == _obsInfos.end()) {
238 _params.add(*param);
239 }
240 }
241 _varBuffer.resize(_params.size());
242
243 // Figure out which part of the computation graph depends on data
244 std::unordered_set<RooFit::Detail::DataKey> dependsOnData;
245 if (data) {
246 dependsOnData = getDependsOnData(obj, *data->get());
247 }
248
249 // Set up the code generation context
251
252 // First update the result variable of params in the compute graph to in[<position>].
253 int idx = 0;
254 for (RooAbsArg *param : _params) {
255 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
256 idx++;
257 }
258
259 for (auto const &item : _obsInfos) {
260 const char *obsName = item.first->GetName();
261 ctx.addResult(obsName, "obs");
262 ctx.addVecObs(obsName, item.second);
263 }
264
265 // Declare the function and create its derivative.
266 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
267 ROOT::Math::Util::TimingScope timingScope(print, "Function JIT time:");
268 _funcName = ctx.buildFunction(obj, dependsOnData);
269
270 // Make sure the codegen implementations are known to the interpreter
271 gInterpreter->Declare("#include <RooFit/CodegenImpl.h>\n");
272
273 if (!gInterpreter->Declare(ctx.collectedCode().c_str())) {
274 std::stringstream errorMsg;
275 std::string debugFileName = "_codegen_" + _funcName + ".cxx";
276 errorMsg << "Function " << _funcName << " could not be compiled. See above for details. Full code dumped to file "
277 << debugFileName << " for debugging";
278 {
279 std::ofstream outFile;
280 outFile.open(debugFileName.c_str());
281 outFile << ctx.collectedCode();
282 }
283 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
284 throw std::runtime_error(errorMsg.str().c_str());
285 }
286
287 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
288
289 _xlArr = ctx.xlArr();
290 _collectedFunctions = ctx.collectedFunctions();
291}
292
293void RooFuncWrapper::loadData(RooAbsData const &data, RooSimultaneous const *simPdf, std::string const &rangeName,
294 bool skipZeroWeights)
295{
296 // Extract observables
297 std::stack<std::vector<double>> vectorBuffers; // for data loading
298 auto spans =
299 RooFit::BatchModeDataHelpers::getDataSpans(data, rangeName, simPdf, skipZeroWeights, false, vectorBuffers);
300
301 _observables.clear();
302 // The first elements contain the sizes of the packed observable arrays
303 std::size_t total = 0;
304 _observables.reserve(2 * spans.size());
305 std::size_t idx = 0;
306 for (auto const &item : spans) {
307 _obsInfos.emplace(item.first, idx);
308 _observables.push_back(total + 2 * spans.size());
309 _observables.push_back(item.second.size());
310 total += item.second.size();
311 idx += 1;
312 }
313 idx = 0;
314 for (auto const &item : spans) {
315 std::size_t n = item.second.size();
316 _observables.reserve(_observables.size() + n);
317 for (std::size_t i = 0; i < n; ++i) {
318 _observables.push_back(item.second[i]);
319 }
320 idx += n;
321 }
322}
323
324void RooFuncWrapper::createGradient()
325{
326#ifdef ROOFIT_CLAD
327 std::string gradName = _funcName + "_grad_0";
328 std::string requestName = _funcName + "_req";
329
330 // Calculate gradient
331 gInterpreter->Declare("#include <Math/CladDerivator.h>\n");
332 // disable clang-format for making the following code unreadable.
333 // clang-format off
334 std::stringstream requestFuncStrm;
335 requestFuncStrm << "#pragma clad ON\n"
336 "void " << requestName << "() {\n"
337 " clad::gradient(" << _funcName << ", \"params\");\n"
338 "}\n"
339 "#pragma clad OFF";
340 // clang-format on
341 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
342
343 bool cladSuccess = false;
344 {
345 ROOT::Math::Util::TimingScope timingScope(print, "Gradient generation time:");
346 cladSuccess = !gInterpreter->Declare(requestFuncStrm.str().c_str());
347 }
348 if (cladSuccess) {
349 std::stringstream errorMsg;
350 errorMsg << "Function could not be differentiated. See above for details.";
351 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
352 throw std::runtime_error(errorMsg.str().c_str());
353 }
354
355 // Clad provides different overloads for the gradient, and we need to
356 // resolve to the one that we want. Without the static_cast, getting the
357 // function pointer would be ambiguous.
358 std::stringstream ss;
359 ROOT::Math::Util::TimingScope timingScope(print, "Gradient IR to machine code time:");
360 ss << "static_cast<void (*)(double *, double const *, double const *, double *)>(" << gradName << ");";
361 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine(ss.str().c_str()));
362 _hasGradient = true;
363#else
364 _hasGradient = false;
365 std::stringstream errorMsg;
366 errorMsg << "Function could not be differentiated since ROOT was built without Clad support.";
367 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
368 throw std::runtime_error(errorMsg.str().c_str());
369#endif
370}
371
372void RooFuncWrapper::createHessian()
373{
374#ifdef ROOFIT_CLAD
375 std::string hessianName = _funcName + "_hessian_0";
376 std::string requestName = _funcName + "_hessian_req";
377
378 // Calculate Hessian
379 gInterpreter->Declare("#include <Math/CladDerivator.h>\n");
380 // disable clang-format for making the following code unreadable.
381 // clang-format off
382 std::stringstream requestFuncStrm;
383 std::string paramsStr =
384 _params.size() == 1 ? "\"params[0]\"" : ("\"params[0:" + std::to_string(_params.size() - 1) + "]\"");
385 requestFuncStrm << "#pragma clad ON\n"
386 "void " << requestName << "() {\n"
387 " clad::hessian(" << _funcName << ", " << paramsStr << ");\n"
388 "}\n"
389 "#pragma clad OFF";
390 // clang-format on
391 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
392
393 bool cladSuccess = false;
394 {
395 ROOT::Math::Util::TimingScope timingScope(print, "Hessian generation time:");
396 cladSuccess = !gInterpreter->Declare(requestFuncStrm.str().c_str());
397 }
398 if (cladSuccess) {
399 std::stringstream errorMsg;
400 errorMsg << "Function could not be differentiated. See above for details.";
401 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
402 throw std::runtime_error(errorMsg.str().c_str());
403 }
404
405 // Clad provides different overloads for the Hessian, and we need to
406 // resolve to the one that we want. Without the static_cast, getting the
407 // function pointer would be ambiguous.
408 std::stringstream ss;
409 ROOT::Math::Util::TimingScope timingScope(print, "Hessian IR to machine code time:");
410 ss << "static_cast<void (*)(double *, double const *, double const *, double *)>(" << hessianName << ");";
411 _hessian = reinterpret_cast<Hessian>(gInterpreter->ProcessLine(ss.str().c_str()));
412 _hasHessian = true;
413#else
414 _hasHessian = false;
415 std::stringstream errorMsg;
416 errorMsg << "Function could not be differentiated since ROOT was built without Clad support.";
417 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
418 throw std::runtime_error(errorMsg.str().c_str());
419#endif
420}
421
422void RooFuncWrapper::updateGradientVarBuffer() const
423{
424 std::transform(_params.begin(), _params.end(), _varBuffer.begin(), [](RooAbsArg *obj) {
425 return obj->isCategory() ? static_cast<RooAbsCategory *>(obj)->getCurrentIndex()
426 : static_cast<RooAbsReal *>(obj)->getVal();
427 });
428}
429
430/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
431void RooFuncWrapper::writeDebugMacro(std::string const &filename) const
432{
433 std::stringstream allCode;
434 std::set<std::string> seenFunctions;
435
436 // Remove duplicated declared functions
437 for (std::string const &name : _collectedFunctions) {
438 if (seenFunctions.count(name) > 0) {
439 continue;
440 }
441 seenFunctions.insert(name);
442 std::unique_ptr<TInterpreterValue> v = gInterpreter->MakeInterpreterValue();
443 gInterpreter->Evaluate(name.c_str(), *v);
444 std::string s = v->ToString();
445 for (int i = 0; i < 2; ++i) {
446 s = s.erase(0, s.find("\n") + 1);
447 }
448 allCode << s << std::endl;
449 }
450
451 std::ofstream outFile;
452 std::string paramsStr =
453 _params.size() == 1 ? "\"params[0]\"" : ("\"params[0:" + std::to_string(_params.size() - 1) + "]\"");
454 outFile.open(filename + ".C");
455 outFile << R"(//auto-generated test macro
456#include <RooFit/Detail/MathFuncs.h>
457#include <Math/CladDerivator.h>
458
459//#define DO_HESSIAN
460
461)" << allCode.str()
462 << R"(
463#pragma clad ON
464void gradient_request() {
465 clad::gradient()"
466 << _funcName << R"(, "params");
467#ifdef DO_HESSIAN
468 clad::hessian()"
469 << _funcName << ", " << paramsStr << R"();
470#endif
471}
472#pragma clad OFF
473)";
474
476
477 auto writeVector = [&](std::string const &name, std::span<const double> vec) {
478 std::stringstream decl;
479 decl << "std::vector<double> " << name << " = {";
480 for (std::size_t i = 0; i < vec.size(); ++i) {
481 if (i % 10 == 0)
482 decl << "\n ";
483 decl << vec[i];
484 if (i < vec.size() - 1)
485 decl << ", ";
486 }
487 decl << "\n};\n";
488
489 std::string declStr = decl.str();
490
491 replaceAll(declStr, "inf", "std::numeric_limits<double>::infinity()");
492 replaceAll(declStr, "nan", "NAN");
493
494 outFile << declStr;
495 };
496
497 outFile << "// clang-format off\n" << std::endl;
498 writeVector("parametersVec", _varBuffer);
499 outFile << std::endl;
500 writeVector("observablesVec", _observables);
501 outFile << std::endl;
502 writeVector("auxConstantsVec", _xlArr);
503 outFile << std::endl;
504 outFile << "// clang-format on\n" << std::endl;
505
506 outFile << R"(
507// To run as a ROOT macro
508void )" << filename
509 << R"(()
510{
511 const std::size_t n = parametersVec.size();
512
513 std::vector<double> gradientVec(n);
514
515 auto func = [&](std::span<double> params) {
516 return )"
517 << _funcName << R"((params.data(), observablesVec.data(), auxConstantsVec.data());
518 };
519 auto grad = [&](std::span<double> params, std::span<double> out) {
520 return )"
521 << _funcName << R"(_grad_0(parametersVec.data(), observablesVec.data(), auxConstantsVec.data(),
522 out.data());
523 };
524
525 grad(parametersVec, gradientVec);
526
527 auto numDiff = [&](int i) {
528 const double eps = 1e-6;
529 std::vector<double> p{parametersVec};
530 p[i] = parametersVec[i] - eps;
531 double funcValDown = func(p);
532 p[i] = parametersVec[i] + eps;
533 double funcValUp = func(p);
534 return (funcValUp - funcValDown) / (2 * eps);
535 };
536
537 for (std::size_t i = 0; i < parametersVec.size(); ++i) {
538 std::cout << i << ":" << std::endl;
539 std::cout << " numr : " << numDiff(i) << std::endl;
540 std::cout << " clad : " << gradientVec[i] << std::endl;
541 }
542
543#ifdef DO_HESSIAN
544 std::cout << "\n";
545
546 auto hess = [&](std::span<double> params, std::span<double> out) {
547 return )"
548 << _funcName << R"(_hessian_0(params.data(), observablesVec.data(), auxConstantsVec.data(), out.data());
549 };
550
551 std::vector<double> hessianVec(n * n);
552 hess(parametersVec, hessianVec);
553
554 // ---------- Numerical Hessian ----------
555 // Uses central differences:
556 // diag: (f(x+ei)-2f(x)+f(x-ei))/eps^2
557 // offdiag: (f(++ ) - f(+-) - f(-+) + f(--)) / (4 eps^2)
558 auto numHess = [&](std::size_t i, std::size_t j) {
559 const double eps = 1e-5; // often needs to be a bit larger than grad eps
560 std::vector<double> p(parametersVec.begin(), parametersVec.end());
561
562 if (i == j) {
563 const double f0 = func(p);
564
565 p[i] = parametersVec[i] + eps;
566 const double fUp = func(p);
567
568 p[i] = parametersVec[i] - eps;
569 const double fDown = func(p);
570
571 return (fUp - 2.0 * f0 + fDown) / (eps * eps);
572 } else {
573 // f(x_i + eps, x_j + eps)
574 p[i] = parametersVec[i] + eps;
575 p[j] = parametersVec[j] + eps;
576 const double fPP = func(p);
577
578 // f(x_i + eps, x_j - eps)
579 p[i] = parametersVec[i] + eps;
580 p[j] = parametersVec[j] - eps;
581 const double fPM = func(p);
582
583 // f(x_i - eps, x_j + eps)
584 p[i] = parametersVec[i] - eps;
585 p[j] = parametersVec[j] + eps;
586 const double fMP = func(p);
587
588 // f(x_i - eps, x_j - eps)
589 p[i] = parametersVec[i] - eps;
590 p[j] = parametersVec[j] - eps;
591 const double fMM = func(p);
592
593 return (fPP - fPM - fMP + fMM) / (4.0 * eps * eps);
594 }
595 };
596
597 // Compute full numerical Hessian
598 std::vector<double> numHessianVec(n * n);
599 for (std::size_t i = 0; i < n; ++i) {
600 for (std::size_t j = 0; j < n; ++j) {
601 numHessianVec[i + n * j] = numHess(i, j); // keep same layout as your print
602 }
603 }
604
605 // ---------- Compare & print ----------
606 std::cout << "Hessian comparison (clad vs numeric vs diff):\n\n";
607
608 for (std::size_t i = 0; i < n; ++i) {
609 for (std::size_t j = 0; j < n; ++j) {
610 const std::size_t idx = i + n * j; // same indexing you used
611 const double cladH = hessianVec[idx];
612 const double numH = numHessianVec[idx];
613 const double diff = cladH - numH;
614
615 std::cout << "[" << i << "," << j << "] "
616 << "clad=" << cladH << " num=" << numH << " diff=" << diff << "\n";
617 }
618 }
619
620 std::cout << "\nRaw Clad Hessian matrix:\n";
621 for (std::size_t i = 0; i < n; ++i) {
622 for (std::size_t j = 0; j < n; ++j) {
623 std::cout << hessianVec[i + n * j] << " ";
624 }
625 std::cout << "\n";
626 }
627
628 std::cout << "\nRaw Numerical Hessian matrix:\n";
629 for (std::size_t i = 0; i < n; ++i) {
630 for (std::size_t j = 0; j < n; ++j) {
631 std::cout << numHessianVec[i + n * j] << " ";
632 }
633 std::cout << "\n";
634 }
635#endif
636}
637)";
638}
639
640double RooEvaluatorWrapper::evaluate() const
641{
643 return _funcWrapper->evaluate();
644
645 if (!_evaluator)
646 return 0.0;
647
648 _evaluator->setOffsetMode(hideOffset() ? RooFit::EvalContext::OffsetMode::WithoutOffset
649 : RooFit::EvalContext::OffsetMode::WithOffset);
650
651 return _evaluator->run()[0];
652}
653
654bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
655{
656 // To make things easier for RooFit, we only support resetting with
657 // datasets that have the same structure, e.g. the same columns and global
658 // observables. This is anyway the usecase: resetting same-structured data
659 // when iterating over toys.
660 constexpr auto errMsg = "Error in RooAbsReal::setData(): only resetting with same-structured data is supported.";
661
662 _data = &data;
663 bool isInitializing = _paramSet.empty();
664 const std::size_t oldSize = _dataSpans.size();
665
666 std::stack<std::vector<double>>{}.swap(_vectorBuffers);
667 const bool isChi2 = _topNode->getAttribute("Chi2EvaluationActive");
668 bool skipZeroWeights = !isChi2 && (!_pdf || !_pdf->getAttribute("BinnedLikelihoodActive"));
669 auto simPdf = dynamic_cast<RooSimultaneous const *>(_pdf);
670 _dataSpans = RooFit::BatchModeDataHelpers::getDataSpans(*_data, _rangeName, simPdf, skipZeroWeights,
671 _takeGlobalObservablesFromData, _vectorBuffers);
672 if (!isInitializing && _dataSpans.size() != oldSize) {
673 coutE(DataHandling) << errMsg << std::endl;
674 throw std::runtime_error(errMsg);
675 }
676 for (auto const &item : _dataSpans) {
677 const char *name = item.first->GetName();
678 _evaluator->setInput(name, item.second, false);
679 if (_paramSet.find(name)) {
680 coutE(DataHandling) << errMsg << std::endl;
681 throw std::runtime_error(errMsg);
682 }
683 }
684 if (_funcWrapper) {
685 _funcWrapper->loadData(*_data, simPdf, _rangeName, skipZeroWeights);
686 }
687 return true;
688}
689
690void RooEvaluatorWrapper::createFuncWrapper()
691{
692 // Get the parameters.
694 this->getParameters(_data ? _data->get() : nullptr, paramSet, /*sripDisconnectedParams=*/false);
695
696 const bool isChi2 = _topNode->getAttribute("Chi2EvaluationActive");
697 const bool skipZeroWeights = !isChi2 && (!_pdf || !_pdf->getAttribute("BinnedLikelihoodActive"));
698 _funcWrapper = std::make_unique<RooFuncWrapper>(*_topNode, _data, dynamic_cast<RooSimultaneous const *>(_pdf),
699 paramSet, _rangeName, skipZeroWeights);
700}
701
702void RooEvaluatorWrapper::generateGradient()
703{
704 if (!_funcWrapper)
706 if (!_funcWrapper->hasGradient())
707 _funcWrapper->createGradient();
708}
709
710void RooEvaluatorWrapper::generateHessian()
711{
712 if (!_funcWrapper)
714 if (!_funcWrapper->hasHessian())
715 _funcWrapper->createHessian();
716}
717
718void RooEvaluatorWrapper::setUseGeneratedFunctionCode(bool flag)
719{
723}
724
725void RooEvaluatorWrapper::gradient(double *out) const
726{
727 _funcWrapper->gradient(out);
728}
729
730void RooEvaluatorWrapper::hessian(double *out) const
731{
732 _funcWrapper->hessian(out);
733}
734
735bool RooEvaluatorWrapper::hasGradient() const
736{
737 return _funcWrapper && _funcWrapper->hasGradient();
738}
739
740bool RooEvaluatorWrapper::hasHessian() const
741{
742 return _funcWrapper && _funcWrapper->hasHessian();
743}
744
745void RooEvaluatorWrapper::writeDebugMacro(std::string const &filename) const
746{
747 if (_funcWrapper)
748 return _funcWrapper->writeDebugMacro(filename);
749}
750
751std::unique_ptr<ChangeOperModeRAII> RooEvaluatorWrapper::setOperModes(RooAbsArg::OperMode opMode)
752{
753 return _evaluator->setOperModes(opMode);
754}
755
756} // namespace RooFit::Experimental
757
758/// \endcond
#define oocoutE(o, a)
#define oocoutI(o, a)
#define coutE(a)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
static unsigned int total
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:145
#define gInterpreter
const_iterator begin() const
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:76
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:56
Abstract interface for all probability density functions.
Definition RooAbsPdf.h:32
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:63
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
A class to maintain the context for squashing of RooFit models into code.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
std::string buildFunction(RooAbsArg const &arg, std::unordered_set< RooFit::Detail::DataKey > const &dependsOnData={})
Assemble and return the final code with the return expression and global statements.
std::vector< std::string > const & collectedFunctions()
std::vector< double > const & xlArr()
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
const Int_t n
Definition legend1.C:16
void replaceAll(std::string &inOut, std::string_view what, std::string_view with)
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition CodegenImpl.h:72
void getSortedComputationGraph(RooAbsArg const &func, RooArgSet &out)
void evaluate(typename Architecture_t::Tensor_t &A, EActivationFunction f)
Apply the given activation function to each value in the given tensor A.
Definition Functions.h:98