Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooNLLVarNew.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2021
5 * Emmanouil Michalainas, CERN 2021
6 *
7 * Copyright (c) 2021, CERN
8 *
9 * Redistribution and use in source and binary forms,
10 * with or without modification, are permitted according to the terms
11 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
12 */
13
14/**
15\file RooNLLVarNew.cxx
16\class RooNLLVarNew
17\ingroup Roofitcore
18
19This is a simple class designed to produce the nll values needed by the fitter.
20This class calls functions from `RooBatchCompute` library to provide faster
21computation times.
22**/
23
24#include "RooNLLVarNew.h"
25
26#include <RooHistPdf.h>
27#include <RooBatchCompute.h>
28#include <RooDataHist.h>
29#include <RooNaNPacker.h>
30#include <RooConstVar.h>
31#include <RooRealVar.h>
32#include <RooSetProxy.h>
34
35#include <ROOT/StringUtils.hxx>
36
37#include <TClass.h>
38#include <TMath.h>
39#include <Math/Util.h>
40
41#include <numeric>
42#include <stdexcept>
43#include <vector>
44
45// Declare constexpr static members to make them available if odr-used in C++14.
46constexpr const char *RooNLLVarNew::weightVarName;
47constexpr const char *RooNLLVarNew::weightVarNameSumW2;
48
49namespace {
50
51RooArgSet getObs(RooAbsArg const &arg, RooArgSet const &observables)
52{
54 arg.getObservables(&observables, out);
55 return out;
56}
57
58// Use RooConstVar for dummies such that they don't get included in getParameters().
59RooConstVar *dummyVar(const char *name)
60{
61 return new RooConstVar(name, name, 1.0);
62}
63
64// Helper class to represent a template pdf based on the fit dataset.
65class RooOffsetPdf : public RooAbsPdf {
66public:
67 RooOffsetPdf(const char *name, const char *title, RooArgSet const &observables, RooAbsReal &weightVar)
68 : RooAbsPdf(name, title),
69 _observables("!observables", "List of observables", this),
70 _weightVar{"!weightVar", "weightVar", this, weightVar, true, false}
71 {
72 for (RooAbsArg *obs : observables) {
73 _observables.add(*obs);
74 }
75 }
76 RooOffsetPdf(const RooOffsetPdf &other, const char *name = nullptr)
77 : RooAbsPdf(other, name),
78 _observables("!servers", this, other._observables),
79 _weightVar{"!weightVar", this, other._weightVar}
80 {
81 }
82 TObject *clone(const char *newname) const override { return new RooOffsetPdf(*this, newname); }
83
84 void computeBatch(double *output, size_t nEvents, RooFit::Detail::DataMap const &dataMap) const override
85 {
86 std::span<const double> weights = dataMap.at(_weightVar);
87
88 // Create the template histogram from the data. This operation is very
89 // expensive, but since the offset only depends on the observables it
90 // only has to be done once.
91
92 RooDataHist dataHist{"data", "data", _observables};
93 // Loop over events to fill the histogram
94 for (std::size_t i = 0; i < nEvents; ++i) {
95 for (auto *var : static_range_cast<RooRealVar *>(_observables)) {
96 var->setVal(dataMap.at(var)[i]);
97 }
98 dataHist.add(_observables, weights[weights.size() == 1 ? 0 : i]);
99 }
100
101 // Lookup bin weights via RooHistPdf
102 RooHistPdf pdf{"offsetPdf", "offsetPdf", _observables, dataHist};
103 for (std::size_t i = 0; i < nEvents; ++i) {
104 for (auto *var : static_range_cast<RooRealVar *>(_observables)) {
105 var->setVal(dataMap.at(var)[i]);
106 }
107 output[i] = pdf.getVal(_observables);
108 }
109 }
110
111private:
112 double evaluate() const override { return 0.0; } // should never be called
113
114 RooSetProxy _observables;
116};
117
118} // namespace
119
120/** Construct a RooNLLVarNew
121\param name the name
122\param title the title
123\param pdf The pdf for which the nll is computed for
124\param observables The observabes of the pdf
125\param isExtended Set to true if this is an extended fit
126**/
127RooNLLVarNew::RooNLLVarNew(const char *name, const char *title, RooAbsPdf &pdf, RooArgSet const &observables,
128 bool isExtended, RooFit::OffsetMode offsetMode)
129 : RooAbsReal(name, title),
130 _pdf{"pdf", "pdf", this, pdf},
131 _weightVar{"weightVar", "weightVar", this, *dummyVar(weightVarName), true, false, true},
132 _weightSquaredVar{weightVarNameSumW2, weightVarNameSumW2, this, *dummyVar("weightSquardVar"), true, false, true},
133 _binnedL{pdf.getAttribute("BinnedLikelihoodActive")}
134{
135 RooArgSet obs{getObs(pdf, observables)};
136
137 // In the "BinnedLikelihoodActiveYields" mode, the pdf values can directly
138 // be interpreted as yields and don't need to be multiplied by the bin
139 // widths. That's why we don't need to even fill them in this case.
140 if (_binnedL && !pdf.getAttribute("BinnedLikelihoodActiveYields")) {
142 }
143
144 if (isExtended && !_binnedL) {
145 std::unique_ptr<RooAbsReal> expectedEvents = pdf.createExpectedEventsFunc(&obs);
146 if (expectedEvents) {
148 std::make_unique<RooTemplateProxy<RooAbsReal>>("expectedEvents", "expectedEvents", this, *expectedEvents);
149 addOwnedComponents(std::move(expectedEvents));
150 }
151 }
152
156
157 if (_doBinOffset) {
158 auto offsetPdf = std::make_unique<RooOffsetPdf>("_offset_pdf", "_offset_pdf", observables, *_weightVar);
159 _offsetPdf = std::make_unique<RooTemplateProxy<RooAbsPdf>>("offsetPdf", "offsetPdf", this, *offsetPdf);
160 addOwnedComponents(std::move(offsetPdf));
161 }
162}
163
165 : RooAbsReal(other, name),
166 _pdf{"pdf", this, other._pdf},
167 _weightVar{"weightVar", this, other._weightVar},
168 _weightSquaredVar{"weightSquaredVar", this, other._weightSquaredVar},
169 _weightSquared{other._weightSquared},
170 _binnedL{other._binnedL},
171 _doOffset{other._doOffset},
172 _simCount{other._simCount},
173 _prefix{other._prefix},
174 _binw{other._binw}
175{
176 if (other._expectedEvents) {
177 _expectedEvents = std::make_unique<RooTemplateProxy<RooAbsReal>>("expectedEvents", this, *other._expectedEvents);
178 }
179}
180
182{
183 // Check if the bin widths were already filled
184 if (!_binw.empty()) {
185 return;
186 }
187
188 if (observables.size() != 1) {
189 throw std::runtime_error("BinnedPdf optimization only works with a 1D pdf.");
190 } else {
191 auto *var = static_cast<RooRealVar *>(observables.first());
192 std::list<double> *boundaries = pdf.binBoundaries(*var, var->getMin(), var->getMax());
193 std::list<double>::iterator biter = boundaries->begin();
194 _binw.resize(boundaries->size() - 1);
195 double lastBound = (*biter);
196 ++biter;
197 int ibin = 0;
198 while (biter != boundaries->end()) {
199 _binw[ibin] = (*biter) - lastBound;
200 lastBound = (*biter);
201 ibin++;
202 ++biter;
203 }
204 }
205}
206
207double RooNLLVarNew::computeBatchBinnedL(std::span<const double> preds, std::span<const double> weights) const
208{
210 ROOT::Math::KahanSum<double> sumWeightKahanSum{0.0};
211
212 const bool predsAreYields = _binw.empty();
213
214 for (std::size_t i = 0; i < preds.size(); ++i) {
215
216 double eventWeight = weights[i];
217
218 // Calculate log(Poisson(N|mu) for this bin
219 double N = eventWeight;
220 double mu = preds[i];
221 if (!predsAreYields) {
222 mu *= _binw[i];
223 }
224
225 if (mu <= 0 && N > 0) {
226
227 // Catch error condition: data present where zero events are predicted
228 logEvalError(Form("Observed %f events in bin %lu with zero event yield", N, (unsigned long)i));
229
230 } else if (std::abs(mu) < 1e-10 && std::abs(N) < 1e-10) {
231
232 // Special handling of this case since log(Poisson(0,0)=0 but can't be calculated with usual log-formula
233 // since log(mu)=0. No update of result is required since term=0.
234
235 } else {
236
237 result += -1 * (-mu + N * log(mu) - TMath::LnGamma(N + 1));
238 sumWeightKahanSum += eventWeight;
239 }
240 }
241
242 return finalizeResult(result, sumWeightKahanSum.Sum());
243}
244
245/** Compute multiple negative logs of probabilities.
246
247\param output An array of doubles where the computation results will be stored
248\param nOut not used
249\note nEvents is the number of events to be processed (the dataMap size)
250\param dataMap A map containing spans with the input data for the computation
251**/
252void RooNLLVarNew::computeBatch(double *output, size_t /*nOut*/, RooFit::Detail::DataMap const &dataMap) const
253{
254 std::span<const double> weights = dataMap.at(_weightVar);
255 std::span<const double> weightsSumW2 = dataMap.at(_weightSquaredVar);
256
257 if (_binnedL) {
258 output[0] = computeBatchBinnedL(dataMap.at(&*_pdf), _weightSquared ? weightsSumW2 : weights);
259 return;
260 }
261
262 auto config = dataMap.config(this);
263
264 auto probas = dataMap.at(_pdf);
265
266 _sumWeight = weights.size() == 1 ? weights[0] * probas.size()
267 : RooBatchCompute::reduceSum(config, weights.data(), weights.size());
268 if (_expectedEvents && _weightSquared && _sumWeight2 == 0.0) {
269 _sumWeight2 = weights.size() == 1 ? weightsSumW2[0] * probas.size()
270 : RooBatchCompute::reduceSum(config, weightsSumW2.data(), weightsSumW2.size());
271 }
272
273 auto nllOut = RooBatchCompute::reduceNLL(config, probas, _weightSquared ? weightsSumW2 : weights,
274 _doBinOffset ? dataMap.at(*_offsetPdf) : std::span<const double>{});
275
276 if (nllOut.nLargeValues > 0) {
277 oocoutW(&*_pdf, Eval) << "RooAbsPdf::getLogVal(" << _pdf->GetName()
278 << ") WARNING: top-level pdf has unexpectedly large values" << std::endl;
279 }
280 for (std::size_t i = 0; i < nllOut.nNonPositiveValues; ++i) {
281 _pdf->logEvalError("getLogVal() top-level p.d.f not greater than zero");
282 }
283 for (std::size_t i = 0; i < nllOut.nNaNValues; ++i) {
284 _pdf->logEvalError("getLogVal() top-level p.d.f evaluates to NaN");
285 }
286
287 if (_expectedEvents) {
288 std::span<const double> expected = dataMap.at(*_expectedEvents);
289 nllOut.nllSum += _pdf->extendedTerm(_sumWeight, expected[0], _weightSquared ? _sumWeight2 : 0.0, _doBinOffset);
290 }
291
292 output[0] = finalizeResult(nllOut.nllSum, _sumWeight);
293}
294
295void RooNLLVarNew::getParametersHook(const RooArgSet * /*nset*/, RooArgSet *params, bool /*stripDisconnected*/) const
296{
297 // strip away the special variables
298 params->remove(RooArgList{*_weightVar, *_weightSquaredVar}, true, true);
299}
300
301////////////////////////////////////////////////////////////////////////////////
302/// Sets the prefix for the special variables of this NLL, like weights or bin
303/// volumes.
304/// \param[in] prefix The prefix to add to the observables and weight names.
305void RooNLLVarNew::setPrefix(std::string const &prefix)
306{
307 _prefix = prefix;
308
310}
311
313{
316 if (_offsetPdf) {
317 _offsetPdf->SetName((_prefix + "_offset_pdf").c_str());
318 }
319}
320
321////////////////////////////////////////////////////////////////////////////////
322/// Toggles the weight square correction.
324{
325 _weightSquared = flag;
326}
327
329{
330 _doOffset = flag;
332}
333
335{
336 // If part of simultaneous PDF normalize probability over
337 // number of simultaneous PDFs: -sum(log(p/n)) = -sum(log(p)) + N*log(n)
338 if (_simCount > 1) {
339 result += weightSum * std::log(static_cast<double>(_simCount));
340 }
341
342 // Check if value offset flag is set.
343 if (_doOffset) {
344
345 // If no offset is stored enable this feature now
346 if (_offset.Sum() == 0 && _offset.Carry() == 0 && (result.Sum() != 0 || result.Carry() != 0)) {
347 _offset = result;
348 }
349
350 // Subtract offset
351 if (!RooAbsReal::hideOffset()) {
352 result -= _offset;
353 }
354 }
355 return result.Sum();
356}
357
359{
360 std::string weightSumName = ctx.makeValidVarName(GetName()) + "WeightSum";
361 std::string resName = ctx.makeValidVarName(GetName()) + "Result";
362 ctx.addResult(this, resName);
363 ctx.addToGlobalScope("double " + weightSumName + " = 0.0;\n");
364 ctx.addToGlobalScope("double " + resName + " = 0.0;\n");
365
366 const bool needWeightSum = _expectedEvents || _simCount > 1;
367
368 if (needWeightSum) {
369 auto scope = ctx.beginLoop(this);
370 ctx.addToCodeBody(weightSumName + " += " + ctx.getResult(*_weightVar) + ";\n");
371 }
372 if (_simCount > 1) {
373 std::string simCountStr = std::to_string(static_cast<double>(_simCount));
374 ctx.addToCodeBody(resName + " += " + weightSumName + " * std::log(" + simCountStr + ");\n");
375 }
376
377 // Begin loop scope for the observables and weight variable. If the weight
378 // is a scalar, the context will ignore it for the loop scope. The closing
379 // brackets of the loop is written at the end of the scopes lifetime.
380 {
381 auto scope = ctx.beginLoop(this);
382 std::string const &weight = ctx.getResult(_weightVar.arg());
383 std::string const &pdfName = ctx.getResult(_pdf.arg());
384
385 if (_binnedL) {
386 // Since we only support uniform binning, bin width is the same for all.
387 if (!_pdf->getAttribute("BinnedLikelihoodActiveYields")) {
388 std::stringstream errorMsg;
389 errorMsg << "RooNLLVarNew::translate(): binned likelihood optimization is only supported when raw pdf "
390 "values can be interpreted as yields."
391 << " This is not the case for HistFactory models written with ROOT versions before 6.26.00";
392 coutE(InputArguments) << errorMsg.str() << std::endl;
393 throw std::runtime_error(errorMsg.str());
394 }
395 std::string muName = pdfName;
396 ctx.addToCodeBody(this, resName + " += -1 * (-" + muName + " + " + weight + " * std::log(" + muName +
397 ") - TMath::LnGamma(" + weight + "+ 1));\n");
398 } else {
399 ctx.addToCodeBody(this, resName + " -= " + weight + " * std::log(" + pdfName + ");\n");
400 }
401 }
402 if (_expectedEvents) {
403 std::string expected = ctx.getResult(**_expectedEvents);
404 ctx.addToCodeBody(resName + " += " + expected + " - " + weightSumName + " * std::log(" + expected + ");\n");
405 }
406}
#define e(i)
Definition RSha256.hxx:103
ROOT::RRangeCast< T, false, Range_t > static_range_cast(Range_t &&coll)
TObject * clone(const char *newname) const override
#define oocoutW(o, a)
#define coutE(a)
#define N
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
char name[80]
Definition TGX11.cxx:110
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2468
The Kahan summation is a compensated summation algorithm, which significantly reduces numerical error...
Definition Util.h:122
T Sum() const
Definition Util.h:240
T Carry() const
Definition Util.h:250
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition RooAbsArg.h:80
RooFit::OwningPtr< RooArgSet > getObservables(const RooArgSet &set, bool valueOnly=true) const
Given a set of possible observables, return the observables that this PDF depends on.
void SetName(const char *name) override
Set the name of the TNamed.
bool addOwnedComponents(const RooAbsCollection &comps)
Take ownership of the contents of 'comps'.
bool getAttribute(const Text_t *name) const
Check if a named attribute is set. By default, all attributes are unset.
virtual bool remove(const RooAbsArg &var, bool silent=false, bool matchByNameOnly=false)
Remove the specified argument from our list.
Storage_t::size_type size() const
RooAbsArg * first() const
virtual std::unique_ptr< RooAbsReal > createExpectedEventsFunc(const RooArgSet *nset) const
Returns an object that represents the expected number of events for a given normalization set,...
double extendedTerm(double sumEntries, double expected, double sumEntriesW2=0.0, bool doOffset=false) const
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition RooAbsReal.h:59
virtual std::list< double > * binBoundaries(RooAbsRealLValue &obs, double xlo, double xhi) const
Retrieve bin boundaries if this distribution is binned in obs.
double getVal(const RooArgSet *normalisationSet=nullptr) const
Evaluate object.
Definition RooAbsReal.h:103
static bool hideOffset()
void logEvalError(const char *message, const char *serverValueString=nullptr) const
Log evaluation error message.
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
RooConstVar represent a constant real-valued object.
Definition RooConstVar.h:23
The RooDataHist is a container class to hold N-dimensional binned data.
Definition RooDataHist.h:39
void add(const RooArgSet &row, double wgt=1.0) override
Add wgt to the bin content enclosed by the coordinates passed in row.
Definition RooDataHist.h:66
A class to maintain the context for squashing of RooFit models into code.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
std::string makeValidVarName(TString in) const
Transform a string into a valid C++ variable name by replacing forbidden.
void addToCodeBody(RooAbsArg const *klass, std::string const &in)
Adds the input string to the squashed code body.
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
void addToGlobalScope(std::string const &str)
Adds the given string to the string block that will be emitted at the top of the squashed function.
std::unique_ptr< LoopScope > beginLoop(RooAbsArg const *in)
Create a RAII scope for iterating over vector observables.
RooBatchCompute::Config config(RooAbsArg const *arg) const
Definition DataMap.cxx:40
std::span< const double > at(RooAbsArg const *arg, RooAbsArg const *caller=nullptr)
Definition DataMap.cxx:22
RooHistPdf implements a propability density function sampled from a multidimensional histogram.
Definition RooHistPdf.h:30
This is a simple class designed to produce the nll values needed by the fitter.
ROOT::Math::KahanSum< double > _offset
! Offset as KahanSum to avoid loss of precision
void enableOffsetting(bool) override
std::unique_ptr< RooTemplateProxy< RooAbsReal > > _expectedEvents
RooNLLVarNew(const char *name, const char *title, RooAbsPdf &pdf, RooArgSet const &observables, bool isExtended, RooFit::OffsetMode offsetMode)
Construct a RooNLLVarNew.
bool _weightSquared
void translate(RooFit::Detail::CodeSquashContext &ctx) const override
This function defines a translation for each RooAbsReal based object that can be used to express the ...
void applyWeightSquared(bool flag) override
Toggles the weight square correction.
void getParametersHook(const RooArgSet *nset, RooArgSet *list, bool stripDisconnected) const override
void enableBinOffsetting(bool on=true)
std::unique_ptr< RooTemplateProxy< RooAbsPdf > > _offsetPdf
void computeBatch(double *output, size_t nOut, RooFit::Detail::DataMap const &) const override
Compute multiple negative logs of probabilities.
void setPrefix(std::string const &prefix)
Sets the prefix for the special variables of this NLL, like weights or bin volumes.
double _sumWeight
void fillBinWidthsFromPdfBoundaries(RooAbsReal const &pdf, RooArgSet const &observables)
void resetWeightVarNames()
std::string _prefix
std::vector< double > _binw
RooTemplateProxy< RooAbsPdf > _pdf
double computeBatchBinnedL(std::span< const double > preds, std::span< const double > weights) const
double _sumWeight2
RooTemplateProxy< RooAbsReal > _weightVar
static constexpr const char * weightVarName
double finalizeResult(ROOT::Math::KahanSum< double > result, double weightSum) const
static constexpr const char * weightVarNameSumW2
RooTemplateProxy< RooAbsReal > _weightSquaredVar
RooRealVar represents a variable that can be changed from the outside.
Definition RooRealVar.h:37
const T & arg() const
Return reference to object held in proxy.
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
Mother of all ROOT objects.
Definition TObject.h:41
double reduceSum(Config cfg, InputArr input, size_t n)
ReduceNLLOutput reduceNLL(Config cfg, std::span< const double > probas, std::span< const double > weights, std::span< const double > offsetProbas)
OffsetMode
For setting the offset mode with the Offset() command argument to RooAbsPdf::fitTo()
void evaluate(typename Architecture_t::Tensor_t &A, EActivationFunction f)
Apply the given activation function to each value in the given tensor A.
Definition Functions.h:98
Double_t LnGamma(Double_t z)
Computation of ln[gamma(z)] for all z.
Definition TMath.cxx:509
static void output()