41RooEvaluatorWrapper::RooEvaluatorWrapper(RooAbsReal &topNode, RooAbsData *data,
bool useGPU,
42 std::string
const &rangeName, RooAbsPdf
const *pdf,
43 bool takeGlobalObservablesFromData)
44 : RooAbsReal{
"RooEvaluatorWrapper",
"RooEvaluatorWrapper"},
45 _evaluator{std::make_unique<RooFit::Evaluator>(topNode, useGPU)},
46 _topNode(
"topNode",
"top node", this, topNode, false, false),
48 _paramSet(
"paramSet",
"Set of parameters", this),
49 _rangeName{rangeName},
51 _takeGlobalObservablesFromData{takeGlobalObservablesFromData}
54 setData(*data,
false);
56 _paramSet.add(_evaluator->getParameters());
57 for (
auto const &item : _dataSpans) {
58 _paramSet.remove(*_paramSet.find(item.first->GetName()));
62RooEvaluatorWrapper::RooEvaluatorWrapper(
const RooEvaluatorWrapper &other,
const char *
name)
64 _evaluator{other._evaluator},
65 _topNode(
"topNode", this, other._topNode),
67 _paramSet(
"paramSet",
"Set of parameters", this),
68 _rangeName{other._rangeName},
70 _takeGlobalObservablesFromData{other._takeGlobalObservablesFromData},
71 _dataSpans{other._dataSpans}
73 _paramSet.add(other._paramSet);
76RooEvaluatorWrapper::~RooEvaluatorWrapper() =
default;
78bool RooEvaluatorWrapper::getParameters(
const RooArgSet *observables,
RooArgSet &outputSet,
79 bool stripDisconnected)
const
81 outputSet.
add(_evaluator->getParameters());
83 outputSet.
remove(*observables,
false,
true);
86 for (
auto const &item : _dataSpans) {
87 if (_data->getGlobalObservables() && _data->getGlobalObservables()->find(item.first->GetName())) {
99 if (_takeGlobalObservablesFromData && _data->getGlobalObservables()) {
100 outputSet.
replace(*_data->getGlobalObservables());
107 if (stripDisconnected) {
109 _topNode->getParameters(observables, paramsStripped,
true);
112 if (!paramsStripped.
find(param->GetName())) {
113 toRemove.
add(*param);
116 outputSet.
remove(toRemove,
false,
true);
126class RooFuncWrapper {
128 RooFuncWrapper(RooAbsReal &obj,
const RooAbsData *data, RooSimultaneous
const *simPdf, RooArgSet
const ¶mSet,
129 std::string
const &rangeName,
bool skipZeroWeights);
131 bool hasGradient()
const {
return _hasGradient; }
132 bool hasHessian()
const {
return _hasHessian; }
133 void gradient(
double *out)
const
135 updateGradientVarBuffer();
136 std::fill(out, out + _params.size(), 0.0);
137 _grad(_varBuffer.data(), _observables.data(), _xlArr.data(), out);
139 void hessian(
double *out)
const
141 updateGradientVarBuffer();
142 std::fill(out, out + _params.size() * _params.size(), 0.0);
143 _hessian(_varBuffer.data(), _observables.data(), _xlArr.data(), out);
146 void createGradient();
147 void createHessian();
149 void writeDebugMacro(std::string
const &)
const;
151 std::vector<std::string>
const &collectedFunctions() {
return _collectedFunctions; }
155 updateGradientVarBuffer();
156 return _func(_varBuffer.data(), _observables.data(), _xlArr.data());
160 loadData(RooAbsData
const &data, RooSimultaneous
const *simPdf, std::string
const &rangeName,
bool skipZeroWeights);
163 void updateGradientVarBuffer()
const;
165 void buildFuncAndGradFunctors();
167 using Func = double (*)(
double *,
double const *,
double const *);
168 using Grad =
void (*)(
double *,
double const *,
double const *,
double *);
169 using Hessian =
void (*)(
double *,
double const *,
double const *,
double *);
172 std::string _funcName;
176 bool _hasGradient =
false;
177 bool _hasHessian =
false;
178 mutable std::vector<double> _varBuffer;
179 std::vector<double> _observables;
180 std::unordered_map<RooFit::Detail::DataKey, std::size_t> _obsInfos;
181 std::vector<double> _xlArr;
182 std::vector<std::string> _collectedFunctions;
187void replaceAll(std::string &str,
const std::string &from,
const std::string &to)
191 size_t start_pos = 0;
192 while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
193 str.replace(start_pos, from.length(), to);
194 start_pos += to.length();
203 std::unordered_set<RooFit::Detail::DataKey> dependsOnData;
205 dependsOnData.insert(arg);
209 if (arg->getAttribute(
"__obs__")) {
210 dependsOnData.insert(arg);
213 if (server->isValueServer(*arg)) {
214 if (dependsOnData.find(server) != dependsOnData.end() && !arg->isReducerNode()) {
215 dependsOnData.insert(arg);
222 return dependsOnData;
228 RooArgSet const ¶mSet, std::string
const &rangeName,
bool skipZeroWeights)
232 loadData(*data, simPdf, rangeName, skipZeroWeights);
236 for (
auto *param : paramSet) {
237 if (_obsInfos.find(param) == _obsInfos.end()) {
241 _varBuffer.resize(_params.size());
244 std::unordered_set<RooFit::Detail::DataKey> dependsOnData;
246 dependsOnData = getDependsOnData(obj, *data->get());
255 ctx.
addResult(param,
"params[" + std::to_string(idx) +
"]");
259 for (
auto const &item : _obsInfos) {
260 const char *obsName = item.first->GetName();
266 auto print = [](std::string
const &msg) {
oocoutI(
nullptr, Fitting) << msg << std::endl; };
271 gInterpreter->Declare(
"#include <RooFit/CodegenImpl.h>\n");
274 std::stringstream errorMsg;
275 std::string debugFileName =
"_codegen_" + _funcName +
".cxx";
276 errorMsg <<
"Function " << _funcName <<
" could not be compiled. See above for details. Full code dumped to file "
277 << debugFileName <<
" for debugging";
279 std::ofstream outFile;
280 outFile.open(debugFileName.c_str());
283 oocoutE(
nullptr, InputArguments) << errorMsg.str() << std::endl;
284 throw std::runtime_error(errorMsg.str().c_str());
287 _func =
reinterpret_cast<Func
>(
gInterpreter->ProcessLine((_funcName +
";").c_str()));
289 _xlArr = ctx.
xlArr();
294 bool skipZeroWeights)
297 std::stack<std::vector<double>> vectorBuffers;
299 RooFit::BatchModeDataHelpers::getDataSpans(data, rangeName, simPdf, skipZeroWeights,
false, vectorBuffers);
301 _observables.clear();
303 std::size_t
total = 0;
304 _observables.reserve(2 * spans.size());
306 for (
auto const &item : spans) {
307 _obsInfos.emplace(item.first, idx);
308 _observables.push_back(
total + 2 * spans.size());
309 _observables.push_back(item.second.size());
310 total += item.second.size();
314 for (
auto const &item : spans) {
315 std::size_t
n = item.second.size();
316 _observables.reserve(_observables.size() +
n);
317 for (std::size_t i = 0; i <
n; ++i) {
318 _observables.push_back(item.second[i]);
324void RooFuncWrapper::createGradient()
327 std::string gradName = _funcName +
"_grad_0";
328 std::string requestName = _funcName +
"_req";
331 gInterpreter->Declare(
"#include <Math/CladDerivator.h>\n");
334 std::stringstream requestFuncStrm;
335 requestFuncStrm <<
"#pragma clad ON\n"
336 "void " << requestName <<
"() {\n"
337 " clad::gradient(" << _funcName <<
", \"params\");\n"
341 auto print = [](std::string
const &msg) {
oocoutI(
nullptr, Fitting) << msg << std::endl; };
343 bool cladSuccess =
false;
346 cladSuccess = !
gInterpreter->Declare(requestFuncStrm.str().c_str());
349 std::stringstream errorMsg;
350 errorMsg <<
"Function could not be differentiated. See above for details.";
351 oocoutE(
nullptr, InputArguments) << errorMsg.str() << std::endl;
352 throw std::runtime_error(errorMsg.str().c_str());
358 std::stringstream ss;
360 ss <<
"static_cast<void (*)(double *, double const *, double const *, double *)>(" << gradName <<
");";
361 _grad =
reinterpret_cast<Grad
>(
gInterpreter->ProcessLine(ss.str().c_str()));
364 _hasGradient =
false;
365 std::stringstream errorMsg;
366 errorMsg <<
"Function could not be differentiated since ROOT was built without Clad support.";
367 oocoutE(
nullptr, InputArguments) << errorMsg.str() << std::endl;
368 throw std::runtime_error(errorMsg.str().c_str());
372void RooFuncWrapper::createHessian()
375 std::string hessianName = _funcName +
"_hessian_0";
376 std::string requestName = _funcName +
"_hessian_req";
379 gInterpreter->Declare(
"#include <Math/CladDerivator.h>\n");
382 std::stringstream requestFuncStrm;
383 std::string paramsStr =
384 _params.size() == 1 ?
"\"params[0]\"" : (
"\"params[0:" + std::to_string(_params.size() - 1) +
"]\"");
385 requestFuncStrm <<
"#pragma clad ON\n"
386 "void " << requestName <<
"() {\n"
387 " clad::hessian(" << _funcName <<
", " << paramsStr <<
");\n"
391 auto print = [](std::string
const &msg) {
oocoutI(
nullptr, Fitting) << msg << std::endl; };
393 bool cladSuccess =
false;
396 cladSuccess = !
gInterpreter->Declare(requestFuncStrm.str().c_str());
399 std::stringstream errorMsg;
400 errorMsg <<
"Function could not be differentiated. See above for details.";
401 oocoutE(
nullptr, InputArguments) << errorMsg.str() << std::endl;
402 throw std::runtime_error(errorMsg.str().c_str());
408 std::stringstream ss;
410 ss <<
"static_cast<void (*)(double *, double const *, double const *, double *)>(" << hessianName <<
");";
411 _hessian =
reinterpret_cast<Hessian
>(
gInterpreter->ProcessLine(ss.str().c_str()));
415 std::stringstream errorMsg;
416 errorMsg <<
"Function could not be differentiated since ROOT was built without Clad support.";
417 oocoutE(
nullptr, InputArguments) << errorMsg.str() << std::endl;
418 throw std::runtime_error(errorMsg.str().c_str());
422void RooFuncWrapper::updateGradientVarBuffer()
const
424 std::transform(_params.begin(), _params.end(), _varBuffer.begin(), [](
RooAbsArg *obj) {
425 return obj->isCategory() ? static_cast<RooAbsCategory *>(obj)->getCurrentIndex()
426 : static_cast<RooAbsReal *>(obj)->getVal();
431void RooFuncWrapper::writeDebugMacro(std::string
const &filename)
const
433 std::stringstream allCode;
434 std::set<std::string> seenFunctions;
437 for (std::string
const &
name : _collectedFunctions) {
438 if (seenFunctions.count(
name) > 0) {
441 seenFunctions.insert(
name);
442 std::unique_ptr<TInterpreterValue>
v =
gInterpreter->MakeInterpreterValue();
444 std::string s =
v->ToString();
445 for (
int i = 0; i < 2; ++i) {
446 s = s.erase(0, s.find(
"\n") + 1);
448 allCode << s << std::endl;
451 std::ofstream outFile;
452 std::string paramsStr =
453 _params.size() == 1 ?
"\"params[0]\"" : (
"\"params[0:" + std::to_string(_params.size() - 1) +
"]\"");
454 outFile.open(filename +
".C");
455 outFile << R
"(//auto-generated test macro
456#include <RooFit/Detail/MathFuncs.h>
457#include <Math/CladDerivator.h>
464void gradient_request() {
466 << _funcName << R"(, "params");
469 << _funcName << ", " << paramsStr << R
"();
475 updateGradientVarBuffer();
477 auto writeVector = [&](std::string
const &
name, std::span<const double>
vec) {
478 std::stringstream decl;
479 decl <<
"std::vector<double> " <<
name <<
" = {";
480 for (std::size_t i = 0; i <
vec.size(); ++i) {
484 if (i <
vec.size() - 1)
489 std::string declStr = decl.str();
491 replaceAll(declStr,
"inf",
"std::numeric_limits<double>::infinity()");
497 outFile <<
"// clang-format off\n" << std::endl;
498 writeVector(
"parametersVec", _varBuffer);
499 outFile << std::endl;
500 writeVector(
"observablesVec", _observables);
501 outFile << std::endl;
502 writeVector(
"auxConstantsVec", _xlArr);
503 outFile << std::endl;
504 outFile <<
"// clang-format on\n" << std::endl;
507// To run as a ROOT macro
511 const std::size_t n = parametersVec.size();
513 std::vector<double> gradientVec(n);
515 auto func = [&](std::span<double> params) {
517 << _funcName << R"((params.data(), observablesVec.data(), auxConstantsVec.data());
519 auto grad = [&](std::span<double> params, std::span<double> out) {
521 << _funcName << R"(_grad_0(parametersVec.data(), observablesVec.data(), auxConstantsVec.data(),
525 grad(parametersVec, gradientVec);
527 auto numDiff = [&](int i) {
528 const double eps = 1e-6;
529 std::vector<double> p{parametersVec};
530 p[i] = parametersVec[i] - eps;
531 double funcValDown = func(p);
532 p[i] = parametersVec[i] + eps;
533 double funcValUp = func(p);
534 return (funcValUp - funcValDown) / (2 * eps);
537 for (std::size_t i = 0; i < parametersVec.size(); ++i) {
538 std::cout << i << ":" << std::endl;
539 std::cout << " numr : " << numDiff(i) << std::endl;
540 std::cout << " clad : " << gradientVec[i] << std::endl;
546 auto hess = [&](std::span<double> params, std::span<double> out) {
548 << _funcName << R"(_hessian_0(params.data(), observablesVec.data(), auxConstantsVec.data(), out.data());
551 std::vector<double> hessianVec(n * n);
552 hess(parametersVec, hessianVec);
554 // ---------- Numerical Hessian ----------
555 // Uses central differences:
556 // diag: (f(x+ei)-2f(x)+f(x-ei))/eps^2
557 // offdiag: (f(++ ) - f(+-) - f(-+) + f(--)) / (4 eps^2)
558 auto numHess = [&](std::size_t i, std::size_t j) {
559 const double eps = 1e-5; // often needs to be a bit larger than grad eps
560 std::vector<double> p(parametersVec.begin(), parametersVec.end());
563 const double f0 = func(p);
565 p[i] = parametersVec[i] + eps;
566 const double fUp = func(p);
568 p[i] = parametersVec[i] - eps;
569 const double fDown = func(p);
571 return (fUp - 2.0 * f0 + fDown) / (eps * eps);
573 // f(x_i + eps, x_j + eps)
574 p[i] = parametersVec[i] + eps;
575 p[j] = parametersVec[j] + eps;
576 const double fPP = func(p);
578 // f(x_i + eps, x_j - eps)
579 p[i] = parametersVec[i] + eps;
580 p[j] = parametersVec[j] - eps;
581 const double fPM = func(p);
583 // f(x_i - eps, x_j + eps)
584 p[i] = parametersVec[i] - eps;
585 p[j] = parametersVec[j] + eps;
586 const double fMP = func(p);
588 // f(x_i - eps, x_j - eps)
589 p[i] = parametersVec[i] - eps;
590 p[j] = parametersVec[j] - eps;
591 const double fMM = func(p);
593 return (fPP - fPM - fMP + fMM) / (4.0 * eps * eps);
597 // Compute full numerical Hessian
598 std::vector<double> numHessianVec(n * n);
599 for (std::size_t i = 0; i < n; ++i) {
600 for (std::size_t j = 0; j < n; ++j) {
601 numHessianVec[i + n * j] = numHess(i, j); // keep same layout as your print
605 // ---------- Compare & print ----------
606 std::cout << "Hessian comparison (clad vs numeric vs diff):\n\n";
608 for (std::size_t i = 0; i < n; ++i) {
609 for (std::size_t j = 0; j < n; ++j) {
610 const std::size_t idx = i + n * j; // same indexing you used
611 const double cladH = hessianVec[idx];
612 const double numH = numHessianVec[idx];
613 const double diff = cladH - numH;
615 std::cout << "[" << i << "," << j << "] "
616 << "clad=" << cladH << " num=" << numH << " diff=" << diff << "\n";
620 std::cout << "\nRaw Clad Hessian matrix:\n";
621 for (std::size_t i = 0; i < n; ++i) {
622 for (std::size_t j = 0; j < n; ++j) {
623 std::cout << hessianVec[i + n * j] << " ";
628 std::cout << "\nRaw Numerical Hessian matrix:\n";
629 for (std::size_t i = 0; i < n; ++i) {
630 for (std::size_t j = 0; j < n; ++j) {
631 std::cout << numHessianVec[i + n * j] << " ";
640double RooEvaluatorWrapper::evaluate()
const
642 if (_useGeneratedFunctionCode)
643 return _funcWrapper->evaluate();
651 return _evaluator->run()[0];
654bool RooEvaluatorWrapper::setData(
RooAbsData &data,
bool )
660 constexpr auto errMsg =
"Error in RooAbsReal::setData(): only resetting with same-structured data is supported.";
663 bool isInitializing = _paramSet.empty();
664 const std::size_t oldSize = _dataSpans.size();
666 std::stack<std::vector<double>>{}.swap(_vectorBuffers);
667 const bool isChi2 = _topNode->getAttribute(
"Chi2EvaluationActive");
668 bool skipZeroWeights = !isChi2 && (!_pdf || !_pdf->getAttribute(
"BinnedLikelihoodActive"));
670 _dataSpans = RooFit::BatchModeDataHelpers::getDataSpans(*_data, _rangeName, simPdf, skipZeroWeights,
671 _takeGlobalObservablesFromData, _vectorBuffers);
672 if (!isInitializing && _dataSpans.size() != oldSize) {
673 coutE(DataHandling) << errMsg << std::endl;
674 throw std::runtime_error(errMsg);
676 for (
auto const &item : _dataSpans) {
677 const char *
name = item.first->GetName();
678 _evaluator->setInput(
name, item.second,
false);
679 if (_paramSet.find(
name)) {
680 coutE(DataHandling) << errMsg << std::endl;
681 throw std::runtime_error(errMsg);
685 _funcWrapper->loadData(*_data, simPdf, _rangeName, skipZeroWeights);
690void RooEvaluatorWrapper::createFuncWrapper()
694 this->getParameters(_data ? _data->get() :
nullptr, paramSet,
false);
696 const bool isChi2 = _topNode->getAttribute(
"Chi2EvaluationActive");
697 const bool skipZeroWeights = !isChi2 && (!_pdf || !_pdf->getAttribute(
"BinnedLikelihoodActive"));
698 _funcWrapper = std::make_unique<RooFuncWrapper>(*_topNode, _data,
dynamic_cast<RooSimultaneous const *
>(_pdf),
699 paramSet, _rangeName, skipZeroWeights);
702void RooEvaluatorWrapper::generateGradient()
706 if (!_funcWrapper->hasGradient())
707 _funcWrapper->createGradient();
710void RooEvaluatorWrapper::generateHessian()
714 if (!_funcWrapper->hasHessian())
715 _funcWrapper->createHessian();
718void RooEvaluatorWrapper::setUseGeneratedFunctionCode(
bool flag)
720 _useGeneratedFunctionCode = flag;
721 if (!_funcWrapper && _useGeneratedFunctionCode)
725void RooEvaluatorWrapper::gradient(
double *out)
const
727 _funcWrapper->gradient(out);
730void RooEvaluatorWrapper::hessian(
double *out)
const
732 _funcWrapper->hessian(out);
735bool RooEvaluatorWrapper::hasGradient()
const
737 return _funcWrapper && _funcWrapper->hasGradient();
740bool RooEvaluatorWrapper::hasHessian()
const
742 return _funcWrapper && _funcWrapper->hasHessian();
745void RooEvaluatorWrapper::writeDebugMacro(std::string
const &filename)
const
748 return _funcWrapper->writeDebugMacro(filename);
751std::unique_ptr<ChangeOperModeRAII> RooEvaluatorWrapper::setOperModes(
RooAbsArg::OperMode opMode)
753 return _evaluator->setOperModes(opMode);
double evaluate() const override
static unsigned int total
Common abstract base class for objects that represent a value and a "shape" in RooFit.
const RefCountList_t & servers() const
List of all servers of this object.
virtual bool remove(const RooAbsArg &var, bool silent=false, bool matchByNameOnly=false)
Remove the specified argument from our list.
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
virtual bool replace(const RooAbsArg &var1, const RooAbsArg &var2)
Replace var1 with var2 and return true for success.
RooAbsArg * find(const char *name) const
Find object with given name in list.
Abstract base class for binned and unbinned datasets.
Abstract base class for objects that represent a real value and implements functionality common to al...
RooArgSet is a container object that can hold multiple RooAbsArg objects.
A class to maintain the context for squashing of RooFit models into code.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
std::string buildFunction(RooAbsArg const &arg, std::unordered_set< RooFit::Detail::DataKey > const &dependsOnData={})
Assemble and return the final code with the return expression and global statements.
std::vector< std::string > const & collectedFunctions()
std::string const & collectedCode()
std::vector< double > const & xlArr()
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
void(off) SmallVectorTemplateBase< T
void replaceAll(std::string &inOut, std::string_view what, std::string_view with)
void getSortedComputationGraph(RooAbsArg const &func, RooArgSet &out)