Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooFuncWrapper.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#include <RooFuncWrapper.h>
14
15#include <RooAbsData.h>
18#include <RooFit/Evaluator.h>
19#include <RooGlobalFunc.h>
20#include <RooHelpers.h>
21#include <RooMsgService.h>
22#include <RooRealVar.h>
23#include <RooSimultaneous.h>
24#include "RooEvaluatorWrapper.h"
25
26#include <TROOT.h>
27#include <TSystem.h>
28
29namespace RooFit {
30
31namespace Experimental {
32
33RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal &obj, const RooAbsData *data,
34 RooSimultaneous const *simPdf, bool useEvaluator)
35 : RooAbsReal{name, title}, _params{"!params", "List of parameters", this}
36{
37 if (useEvaluator) {
38 _absReal = std::make_unique<RooEvaluatorWrapper>(obj, const_cast<RooAbsData *>(data), false, "", simPdf, false);
39 }
40
41 std::string func;
42
43 // Get the parameters.
44 RooArgSet paramSet;
45 obj.getParameters(data ? data->get() : nullptr, paramSet);
46 RooArgSet floatingParamSet;
47 for (RooAbsArg *param : paramSet) {
48 if (!param->isConstant()) {
49 floatingParamSet.add(*param);
50 }
51 }
52
53 // Load the parameters and observables.
54 loadParamsAndData(&obj, floatingParamSet, data, simPdf);
55
56 func = buildCode(obj);
57
58 // Declare the function and create its derivative.
60 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
61}
62
64 : RooAbsReal(other, name),
65 _params("!params", this, other._params),
66 _funcName(other._funcName),
67 _func(other._func),
68 _grad(other._grad),
69 _hasGradient(other._hasGradient),
70 _gradientVarBuffer(other._gradientVarBuffer),
71 _observables(other._observables)
72{
73}
74
75void RooFuncWrapper::loadParamsAndData(RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data,
76 RooSimultaneous const *simPdf)
77{
78 // Extract observables
79 std::stack<std::vector<double>> vectorBuffers; // for data loading
80 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
81
82 if (data) {
83 spans = RooFit::Detail::BatchModeDataHelpers::getDataSpans(*data, "", simPdf, true, false, vectorBuffers);
84 }
85
86 std::size_t idx = 0;
87 for (auto const &item : spans) {
88 std::size_t n = item.second.size();
89 _obsInfos.emplace(item.first, ObsInfo{idx, n});
90 _observables.reserve(_observables.size() + n);
91 for (std::size_t i = 0; i < n; ++i) {
92 _observables.push_back(item.second[i]);
93 }
94 idx += n;
95 }
96
97 // Extract parameters
98 for (auto *param : paramSet) {
99 if (!dynamic_cast<RooAbsReal *>(param)) {
100 std::stringstream errorMsg;
101 errorMsg << "In creation of function " << GetName()
102 << " wrapper: input param expected to be of type RooAbsReal.";
103 coutE(InputArguments) << errorMsg.str() << std::endl;
104 throw std::runtime_error(errorMsg.str().c_str());
105 }
106 if (spans.find(param) == spans.end()) {
107 _params.add(*param);
108 }
109 }
110 _gradientVarBuffer.resize(_params.size());
111
112 if (head) {
113 _nodeOutputSizes = RooFit::Detail::BatchModeDataHelpers::determineOutputSizes(
114 *head, [&spans](RooFit::Detail::DataKey key) -> int {
115 auto found = spans.find(key);
116 return found != spans.end() ? found->second.size() : -1;
117 });
118 }
119}
120
121std::string RooFuncWrapper::declareFunction(std::string const &funcBody)
122{
123 static int iFuncWrapper = 0;
124 auto funcName = "roo_func_wrapper_" + std::to_string(iFuncWrapper++);
125
126 gInterpreter->Declare("#pragma cling optimize(2)");
127
128 // Declare the function
129 std::stringstream bodyWithSigStrm;
130 bodyWithSigStrm << "double " << funcName << "(double* params, double const* obs, double const* xlArr) {\n"
131 << funcBody << "\n}";
132 bool comp = gInterpreter->Declare(bodyWithSigStrm.str().c_str());
133 if (!comp) {
134 std::stringstream errorMsg;
135 errorMsg << "Function " << funcName << " could not be compiled. See above for details.";
136 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
137 throw std::runtime_error(errorMsg.str().c_str());
138 }
139 return funcName;
140}
141
142void RooFuncWrapper::createGradient()
143{
144 std::string gradName = _funcName + "_grad_0";
145 std::string requestName = _funcName + "_req";
146 std::string wrapperName = _funcName + "_derivativeWrapper";
147
148 // Calculate gradient
149 gInterpreter->ProcessLine("#include <Math/CladDerivator.h>");
150 // disable clang-format for making the following code unreadable.
151 // clang-format off
152 std::stringstream requestFuncStrm;
153 requestFuncStrm << "#pragma clad ON\n"
154 "void " << requestName << "() {\n"
155 " clad::gradient(" << _funcName << ", \"params\");\n"
156 "}\n"
157 "#pragma clad OFF";
158 // clang-format on
159 auto comp = gInterpreter->Declare(requestFuncStrm.str().c_str());
160 if (!comp) {
161 std::stringstream errorMsg;
162 errorMsg << "Function " << GetName() << " could not be differentiated. See above for details.";
163 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
164 throw std::runtime_error(errorMsg.str().c_str());
165 }
166
167 // Build a wrapper over the derivative to hide clad specific types such as 'array_ref'.
168 // disable clang-format for making the following code unreadable.
169 // clang-format off
170 std::stringstream dWrapperStrm;
171 dWrapperStrm << "void " << wrapperName << "(double* params, double const* obs, double const* xlArr, double* out) {\n"
172 " clad::array_ref<double> cladOut(out, " << _params.size() << ");\n"
173 " " << gradName << "(params, obs, xlArr, cladOut);\n"
174 "}";
175 // clang-format on
176 gInterpreter->Declare(dWrapperStrm.str().c_str());
177 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine((wrapperName + ";").c_str()));
178 _hasGradient = true;
179}
180
181void RooFuncWrapper::gradient(double *out) const
182{
183 updateGradientVarBuffer();
184 std::fill(out, out + _params.size(), 0.0);
185
186 _grad(_gradientVarBuffer.data(), _observables.data(), _xlArr.data(), out);
187}
188
189void RooFuncWrapper::updateGradientVarBuffer() const
190{
191 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(),
192 [](RooAbsArg *obj) { return static_cast<RooAbsReal *>(obj)->getVal(); });
193}
194
195double RooFuncWrapper::evaluate() const
196{
197 if (_absReal)
198 return _absReal->getVal();
199 updateGradientVarBuffer();
200
201 return _func(_gradientVarBuffer.data(), _observables.data(), _xlArr.data());
202}
203
204void RooFuncWrapper::gradient(const double *x, double *g) const
205{
206 std::fill(g, g + _params.size(), 0.0);
207
208 _grad(const_cast<double *>(x), _observables.data(), _xlArr.data(), g);
209}
210
211std::string RooFuncWrapper::buildCode(RooAbsReal const &head)
212{
213 RooFit::Detail::CodeSquashContext ctx(_nodeOutputSizes, _xlArr);
214
215 // First update the result variable of params in the compute graph to in[<position>].
216 int idx = 0;
217 for (RooAbsArg *param : _params) {
218 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
219 idx++;
220 }
221
222 for (auto const &item : _obsInfos) {
223 const char *name = item.first->GetName();
224 // If the observable is scalar, set name to the start idx. else, store
225 // the start idx and later set the the name to obs[start_idx + curr_idx],
226 // here curr_idx is defined by a loop producing parent node.
227 if (item.second.size == 1) {
228 ctx.addResult(name, "obs[" + std::to_string(item.second.idx) + "]");
229 } else {
230 ctx.addResult(name, "obs");
231 ctx.addVecObs(name, item.second.idx);
232 }
233 }
234
235 return ctx.assembleCode(ctx.getResult(head));
236}
237
238/// @brief Prints the squashed code body to console.
239void RooFuncWrapper::dumpCode()
240{
241 gInterpreter->ProcessLine(_funcName.c_str());
242}
243
244/// @brief Prints the derivative code body to console.
245void RooFuncWrapper::dumpGradient()
246{
247 gInterpreter->ProcessLine((_funcName + "_grad_0").c_str());
248}
249
250} // namespace Experimental
251
252} // namespace RooFit
#define g(i)
Definition RSha256.hxx:105
RooAbsReal * _func
Pointer to original input function.
#define oocoutE(o, a)
#define coutE(a)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:77
RooFit::OwningPtr< RooArgSet > getParameters(const RooAbsData *data, bool stripDisconnected=true) const
Create a list of leaf nodes in the arg tree starting with ourself as top node that don't match any of...
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
A class to maintain the context for squashing of RooFit models into code.
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
double(*)(double *, double const *, double const *) Func
std::unique_ptr< RooAbsReal > _absReal
std::string buildCode(RooAbsReal const &head)
void loadParamsAndData(RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
std::map< RooFit::Detail::DataKey, ObsInfo > _obsInfos
void(*)(double *, double const *, double const *, double *) Grad
static std::string declareFunction(std::string const &funcBody)
RooFuncWrapper(const char *name, const char *title, RooAbsReal &obj, const RooAbsData *data=nullptr, RooSimultaneous const *simPdf=nullptr, bool useEvaluator=false)
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
@ InputArguments