Logo ROOT  
Reference Guide
RooNLLVarNew.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2021
5 * Emmanouil Michalainas, CERN 2021
6 *
7 * Copyright (c) 2021, CERN
8 *
9 * Redistribution and use in source and binary forms,
10 * with or without modification, are permitted according to the terms
11 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
12 */
13
14/**
15\file RooNLLVarNew.cxx
16\class RooNLLVarNew
17\ingroup Roofitcore
18
19This is a simple class designed to produce the nll values needed by the fitter.
20In contrast to the `RooNLLVar` class, any logic except the bare minimum has been
21transfered away to other classes, like the `RooFitDriver`. This class also calls
22functions from `RooBatchCompute` library to provide faster computation times.
23**/
24
25#include <RooNLLVarNew.h>
26
27#include <RooAddition.h>
28#include <RooFormulaVar.h>
29#include <RooNaNPacker.h>
30#include <RooRealSumPdf.h>
31#include <RooProdPdf.h>
32#include <RooRealVar.h>
34
35#include <ROOT/StringUtils.hxx>
36
37#include <TClass.h>
38#include <TMath.h>
39#include <Math/Util.h>
40#include <TMath.h>
41
42#include <numeric>
43#include <stdexcept>
44#include <vector>
45
46using namespace ROOT::Experimental;
47
48// Declare constexpr static members to make them available if odr-used in C++14.
49constexpr const char *RooNLLVarNew::weightVarName;
50constexpr const char *RooNLLVarNew::weightVarNameSumW2;
51
52namespace {
53
54template <class Input>
55double kahanSum(Input const &input)
56{
57 return ROOT::Math::KahanSum<double, 4u>::Accumulate(input.begin(), input.end()).Sum();
58}
59
60RooArgSet getObservablesInPdf(RooAbsPdf const &pdf, RooArgSet const &observables)
61{
62 RooArgSet observablesInPdf;
63 pdf.getObservables(&observables, observablesInPdf);
64 return observablesInPdf;
65}
66
67} // namespace
68
69/** Construct a RooNLLVarNew
70\param name the name
71\param title the title
72\param pdf The pdf for which the nll is computed for
73\param observables The observabes of the pdf
74\param isExtended Set to true if this is an extended fit
75**/
76RooNLLVarNew::RooNLLVarNew(const char *name, const char *title, RooAbsPdf &pdf, RooArgSet const &observables,
77 bool isExtended, bool doOffset)
78 : RooAbsReal(name, title), _pdf{"pdf", "pdf", this, pdf}, _observables{getObservablesInPdf(pdf, observables)},
79 _isExtended{isExtended}, _doOffset{doOffset},
80 _weightVar{"weightVar", "weightVar", this, *new RooRealVar(weightVarName, weightVarName, 1.0), true, false, true},
81 _weightSquaredVar{weightVarNameSumW2,
82 weightVarNameSumW2,
83 this,
84 *new RooRealVar("weightSquardVar", "weightSquaredVar", 1.0),
85 true,
86 false,
87 true}
88{
89 RooAbsPdf *actualPdf = &pdf;
90
91 if (pdf.getAttribute("BinnedLikelihood") && pdf.IsA()->InheritsFrom(RooRealSumPdf::Class())) {
92 // Simplest case: top-level of component is a RooRealSumPdf
93 _binnedL = true;
94 } else if (pdf.IsA()->InheritsFrom(RooProdPdf::Class())) {
95 // Default case: top-level pdf is a product of RooRealSumPdf and other pdfs
96 for (RooAbsArg *component : static_cast<RooProdPdf &>(pdf).pdfList()) {
97 if (component->getAttribute("BinnedLikelihood") && component->IsA()->InheritsFrom(RooRealSumPdf::Class())) {
98 actualPdf = static_cast<RooAbsPdf *>(component);
99 _binnedL = true;
100 }
101 }
102 }
103
104 if (actualPdf != &pdf) {
105 _pdf.setArg(*actualPdf);
106 }
107
108 if (_binnedL) {
109 if (_observables.size() != 1) {
110 throw std::runtime_error("BinnedPdf optimization only works with a 1D pdf.");
111 } else {
112 auto *var = static_cast<RooRealVar *>(_observables.first());
113 std::list<double> *boundaries = actualPdf->binBoundaries(*var, var->getMin(), var->getMax());
114 std::list<double>::iterator biter = boundaries->begin();
115 _binw.resize(boundaries->size() - 1);
116 double lastBound = (*biter);
117 ++biter;
118 int ibin = 0;
119 while (biter != boundaries->end()) {
120 _binw[ibin] = (*biter) - lastBound;
121 lastBound = (*biter);
122 ibin++;
123 ++biter;
124 }
125 }
126 }
127
129}
130
132 : RooAbsReal(other, name), _pdf{"pdf", this, other._pdf}, _observables{other._observables},
133 _isExtended{other._isExtended}, _weightSquared{other._weightSquared}, _binnedL{other._binnedL},
134 _prefix{other._prefix}, _weightVar{"weightVar", this, other._weightVar}, _weightSquaredVar{"weightSquaredVar",
135 this,
136 other._weightSquaredVar}
137{
138}
139
140/** Compute multiple negative logs of propabilities
141
142\param output An array of doubles where the computation results will be stored
143\param nOut not used
144\note nEvents is the number of events to be processed (the dataMap size)
145\param dataMap A map containing spans with the input data for the computation
146**/
147void RooNLLVarNew::computeBatch(cudaStream_t * /*stream*/, double *output, size_t /*nOut*/,
148 RooFit::Detail::DataMap const &dataMap) const
149{
150 std::size_t nEvents = dataMap.at(_pdf).size();
151
152 auto weights = dataMap.at(_weightVar);
153 auto weightsSumW2 = dataMap.at(_weightSquaredVar);
154 auto weightSpan = _weightSquared ? weightsSumW2 : weights;
155
156 if (_binnedL) {
158 ROOT::Math::KahanSum<double> sumWeightKahanSum{0.0};
159 auto preds = dataMap.at(&*_pdf);
160
161 for (std::size_t i = 0; i < nEvents; ++i) {
162
163 double eventWeight = weightSpan[i];
164
165 // Calculate log(Poisson(N|mu) for this bin
166 double N = eventWeight;
167 double mu = preds[i] * _binw[i];
168
169 if (mu <= 0 && N > 0) {
170
171 // Catch error condition: data present where zero events are predicted
172 logEvalError(Form("Observed %f events in bin %lu with zero event yield", N, (unsigned long)i));
173
174 } else if (std::abs(mu) < 1e-10 && std::abs(N) < 1e-10) {
175
176 // Special handling of this case since log(Poisson(0,0)=0 but can't be calculated with usual log-formula
177 // since log(mu)=0. No update of result is required since term=0.
178
179 } else {
180
181 result += -1 * (-mu + N * log(mu) - TMath::LnGamma(N + 1));
182 sumWeightKahanSum += eventWeight;
183 }
184 }
185
186 result += sumWeightKahanSum.Sum();
187
188 // Check if value offset flag is set.
189 if (_doOffset) {
190
191 // If no offset is stored enable this feature now
192 if (_offset == 0 && result != 0) {
193 _offset = result;
194 }
195
196 // Subtract offset
197 result -= _offset;
198 }
199
200 output[0] = result.Sum();
201
202 return;
203 }
204
205 auto probas = dataMap.at(_pdf);
206
207 _logProbasBuffer.resize(nEvents);
208 (*_pdf).getLogProbabilities(probas, _logProbasBuffer.data());
209
210 if (_isExtended && _sumWeight == 0.0) {
211 _sumWeight = weights.size() == 1 ? weights[0] * nEvents : kahanSum(weights);
212 }
213 if (_isExtended && _weightSquared && _sumWeight2 == 0.0) {
214 _sumWeight2 = weights.size() == 1 ? weightsSumW2[0] * nEvents : kahanSum(weightsSumW2);
215 }
216
218 RooNaNPacker packedNaN(0.f);
219
220 for (std::size_t i = 0; i < nEvents; ++i) {
221
222 double eventWeight = weightSpan.size() > 1 ? weightSpan[i] : weightSpan[0];
223 if (0. == eventWeight * eventWeight)
224 continue;
225
226 const double term = -eventWeight * _logProbasBuffer[i];
227
228 kahanProb.Add(term);
229 packedNaN.accumulate(term);
230 }
231
232 if (packedNaN.getPayload() != 0.) {
233 // Some events with evaluation errors. Return "badness" of errors.
234 kahanProb = packedNaN.getNaNWithPayload();
235 }
236
237 if (_isExtended) {
238 assert(_sumWeight != 0.0);
239 double expected = _pdf->expectedEvents(&_observables);
240 kahanProb += _pdf->extendedTerm(_sumWeight, expected, _weightSquared ? _sumWeight2 : 0.0);
241 }
242
243 // Check if value offset flag is set.
244 if (_doOffset) {
245
246 // If no offset is stored enable this feature now
247 if (_offset == 0 && kahanProb != 0) {
248 _offset = kahanProb;
249 }
250
251 // Subtract offset
252 kahanProb -= _offset;
253 }
254
255 output[0] = kahanProb.Sum();
256}
257
259{
260 return _value;
261}
262
263void RooNLLVarNew::getParametersHook(const RooArgSet * /*nset*/, RooArgSet *params, bool /*stripDisconnected*/) const
264{
265 // strip away the observables and weights
266 params->remove(_observables, true, true);
267 params->remove(RooArgList{*_weightVar, *_weightSquaredVar}, true, true);
268}
269
270////////////////////////////////////////////////////////////////////////////////
271/// Replaces all observables and the weight variable of this NLL with clones
272/// that only differ by a prefix added to the names. Used for simultaneous fits.
273/// \return A RooArgSet with the new observable args.
274/// \param[in] prefix The prefix to add to the observables and weight names.
276{
277 _prefix = prefix;
278
279 RooArgSet obsSet{_observables};
280 RooArgSet obsClones;
281 obsSet.snapshot(obsClones);
282 for (auto *arg : static_range_cast<RooRealVar *>(obsClones)) {
283 arg->setAttribute((std::string("ORIGNAME:") + arg->GetName()).c_str());
284 arg->SetName((prefix + arg->GetName()).c_str());
285 arg->setConstant();
286 }
287 recursiveRedirectServers(obsClones, false, true);
288
289 RooArgSet newObservables{obsClones};
290
292 _observables.add(obsClones);
293
294 addOwnedComponents(std::move(obsClones));
295
297
298 return newObservables;
299}
300
302{
305}
306
307////////////////////////////////////////////////////////////////////////////////
308/// Toggles the weight square correction.
310{
311 _weightSquared = flag;
312}
313
314std::unique_ptr<RooArgSet>
315RooNLLVarNew::fillNormSetForServer(RooArgSet const & /*normSet*/, RooAbsArg const & /*server*/) const
316{
317 if (_binnedL) {
318 return std::make_unique<RooArgSet>();
319 }
320 return nullptr;
321}
#define e(i)
Definition: RSha256.hxx:103
#define N
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
char name[80]
Definition: TGX11.cxx:110
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition: TString.cxx:2452
RooTemplateProxy< RooAbsReal > _weightVar
Definition: RooNLLVarNew.h:69
void computeBatch(cudaStream_t *, double *output, size_t nOut, RooFit::Detail::DataMap const &) const override
Compute multiple negative logs of propabilities.
void applyWeightSquared(bool flag) override
Toggles the weight square correction.
void getParametersHook(const RooArgSet *nset, RooArgSet *list, bool stripDisconnected) const override
ROOT::Math::KahanSum< double > _offset
! Offset as KahanSum to avoid loss of precision
Definition: RooNLLVarNew.h:73
RooTemplateProxy< RooAbsPdf > _pdf
Definition: RooNLLVarNew.h:60
std::vector< double > _logProbasBuffer
!
Definition: RooNLLVarNew.h:72
RooTemplateProxy< RooAbsReal > _weightSquaredVar
Definition: RooNLLVarNew.h:70
std::vector< double > _binw
!
Definition: RooNLLVarNew.h:71
static constexpr const char * weightVarName
Definition: RooNLLVarNew.h:32
std::unique_ptr< RooArgSet > fillNormSetForServer(RooArgSet const &normSet, RooAbsArg const &server) const override
Fills a RooArgSet to be used as the normalization set for a server, given a normalization set for thi...
RooArgSet prefixObservableAndWeightNames(std::string const &prefix)
Replaces all observables and the weight variable of this NLL with clones that only differ by a prefix...
static constexpr const char * weightVarNameSumW2
Definition: RooNLLVarNew.h:33
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
T Sum() const
Definition: Util.h:240
static KahanSum< T, N > Accumulate(Iterator begin, Iterator end, T initialValue=T{})
Iterate over a range and return an instance of a KahanSum.
Definition: Util.h:211
void Add(T x)
Single-element accumulation. Will not vectorise.
Definition: Util.h:165
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition: RooAbsArg.h:71
bool recursiveRedirectServers(const RooAbsCollection &newServerList, bool mustReplaceAll=false, bool nameChange=false, bool recurseInNewSet=true)
Recursively replace all servers with the new servers in newSet.
Definition: RooAbsArg.cxx:1165
RooArgSet * getObservables(const RooArgSet &set, bool valueOnly=true) const
Given a set of possible observables, return the observables that this PDF depends on.
Definition: RooAbsArg.h:293
void SetName(const char *name) override
Set the name of the TNamed.
Definition: RooAbsArg.cxx:2314
bool addOwnedComponents(const RooAbsCollection &comps)
Take ownership of the contents of 'comps'.
Definition: RooAbsArg.cxx:2185
bool getAttribute(const Text_t *name) const
Check if a named attribute is set. By default, all attributes are unset.
Definition: RooAbsArg.cxx:269
virtual bool remove(const RooAbsArg &var, bool silent=false, bool matchByNameOnly=false)
Remove the specified argument from our list.
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
Storage_t::size_type size() const
RooAbsArg * first() const
void clear()
Clear contents. If the collection is owning, it will also delete the contents.
virtual double expectedEvents(const RooArgSet *nset) const
Return expected number of events to be used in calculation of extended likelihood.
Definition: RooAbsPdf.cxx:3101
TClass * IsA() const override
Definition: RooAbsPdf.h:391
double extendedTerm(double sumEntries, double expected, double sumEntriesW2=0.0) const
Definition: RooAbsPdf.cxx:785
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition: RooAbsReal.h:62
virtual std::list< double > * binBoundaries(RooAbsRealLValue &obs, double xlo, double xhi) const
Retrieve bin boundaries if this distribution is binned in obs.
double _value
Cache for current value of object.
Definition: RooAbsReal.h:480
void logEvalError(const char *message, const char *serverValueString=nullptr) const
Log evaluation error message.
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgSet.h:56
RooArgSet * snapshot(bool deepCopy=true) const
Use RooAbsCollection::snapshot(), but return as RooArgSet.
Definition: RooArgSet.h:179
auto & at(RooAbsArg const *arg, RooAbsArg const *=nullptr)
Definition: DataMap.h:88
RooProdPdf is an efficient implementation of a product of PDFs of the form.
Definition: RooProdPdf.h:33
static TClass * Class()
static TClass * Class()
RooRealVar represents a variable that can be changed from the outside.
Definition: RooRealVar.h:40
bool setArg(T &newRef)
Change object held in proxy into newRef.
Bool_t InheritsFrom(const char *cl) const override
Return kTRUE if this class inherits from a class with name "classname".
Definition: TClass.cxx:4863
RVec< PromoteType< T > > abs(const RVec< T > &v)
Definition: RVec.hxx:1756
RVec< PromoteType< T > > log(const RVec< T > &v)
Definition: RVec.hxx:1765
void probas(TString dataset, TString fin="TMVA.root", Bool_t useTMVAStyle=kTRUE)
Double_t LnGamma(Double_t z)
Computation of ln[gamma(z)] for all z.
Definition: TMath.cxx:509
Little struct that can pack a float into the unused bits of the mantissa of a NaN double.
Definition: RooNaNPacker.h:28
float getPayload() const
Retrieve packed float.
Definition: RooNaNPacker.h:85
double getNaNWithPayload() const
Retrieve a NaN with the current float payload packed into the mantissa.
Definition: RooNaNPacker.h:90
void accumulate(double val)
Accumulate a packed float from another NaN into this.
Definition: RooNaNPacker.h:57
static void output()