Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BatchModeHelpers.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2021
5 *
6 * Copyright (c) 2021, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14
15#include <RooAbsData.h>
16#include <RooAbsPdf.h>
17#include <RooAddition.h>
18#include <RooBatchCompute.h>
19#include <RooBinSamplingPdf.h>
20#include <RooConstraintSum.h>
21#include <RooDataSet.h>
22#include <RooFitDriver.h>
23#include <RooNLLVarNew.h>
24#include <RooRealVar.h>
25#include <RooSimultaneous.h>
26#include <RooFitDriver.h>
27
28#include <ROOT/StringUtils.hxx>
29
30#include <string>
31
34
35namespace {
36
37std::unique_ptr<RooAbsArg> createSimultaneousNLL(RooSimultaneous const &simPdf, RooArgSet &observables, bool isExtended,
38 std::string const &rangeName, bool doOffset)
39{
40 // Prepare the NLLTerms for each component
41 RooArgList nllTerms;
42 RooArgSet newObservables;
43 for (auto const &catItem : simPdf.indexCat()) {
44 auto const &catName = catItem.first;
45 if (RooAbsPdf *pdf = simPdf.getPdf(catName.c_str())) {
46 auto name = std::string("nll_") + pdf->GetName();
47 auto nll = std::make_unique<RooNLLVarNew>(name.c_str(), name.c_str(), *pdf, observables, isExtended, rangeName,
48 doOffset);
49 // Rename the observables and weights
50 newObservables.add(nll->prefixObservableAndWeightNames(std::string("_") + catName + "_"));
51 nllTerms.addOwned(std::move(nll));
52 }
53 }
54
55 observables.clear();
56 observables.add(newObservables);
57
58 // Time to sum the NLLs
59 auto nll = std::make_unique<RooAddition>("mynll", "mynll", nllTerms);
60 nll->addOwnedComponents(std::move(nllTerms));
61 return nll;
62}
63
64class RooAbsRealWrapper final : public RooAbsReal {
65public:
66 RooAbsRealWrapper(std::unique_ptr<RooFitDriver> driver, std::string const &rangeName,
67 RooAbsCategory const *indexCatForSplitting, bool takeGlobalObservablesFromData)
68 : RooAbsReal{"RooFitDriverWrapper", "RooFitDriverWrapper"}, _driver{std::move(driver)},
69 _topNode("topNode", "top node", this, _driver->topNode()), _rangeName{rangeName},
70 _indexCatForSplitting{indexCatForSplitting}, _takeGlobalObservablesFromData{takeGlobalObservablesFromData}
71 {
72 }
73
74 RooAbsRealWrapper(const RooAbsRealWrapper &other, const char *name = nullptr)
75 : RooAbsReal{other, name}, _driver{other._driver},
76 _topNode("topNode", this, other._topNode), _data{other._data}, _parameters{other._parameters},
77 _rangeName{other._rangeName}, _indexCatForSplitting{other._indexCatForSplitting},
78 _takeGlobalObservablesFromData{other._takeGlobalObservablesFromData}
79 {
80 }
81
82 TObject *clone(const char *newname) const override { return new RooAbsRealWrapper(*this, newname); }
83
84 double defaultErrorLevel() const override { return _driver->topNode().defaultErrorLevel(); }
85
86 bool getParameters(const RooArgSet *observables, RooArgSet &outputSet, bool /*stripDisconnected*/) const override
87 {
88 outputSet.add(_parameters);
89 if (observables) {
90 outputSet.remove(*observables);
91 }
92 // If we take the global observables as data, we have to return these as
93 // parameters instead of the parameters in the model. Otherwise, the
94 // constant parameters in the fit result that are global observables will
95 // not have the right values.
96 if (_takeGlobalObservablesFromData && _data->getGlobalObservables()) {
97 outputSet.replace(*_data->getGlobalObservables());
98 }
99 return false;
100 }
101
102 bool setData(RooAbsData &data, bool /*cloneData*/) override
103 {
104 _data = &data;
105 _driver->topNode().getParameters(_data->get(), _parameters, true);
106 _driver->setData(*_data, _rangeName, _indexCatForSplitting, /*skipZeroWeights=*/true,
107 _takeGlobalObservablesFromData);
108 return true;
109 }
110
111 double getValV(const RooArgSet *) const override { return evaluate(); }
112
113 void applyWeightSquared(bool flag) override
114 {
115 const_cast<RooAbsReal &>(_driver->topNode()).applyWeightSquared(flag);
116 }
117
118protected:
119 double evaluate() const override { return _driver ? _driver->getVal() : 0.0; }
120
121private:
122 std::shared_ptr<RooFitDriver> _driver;
123 RooRealProxy _topNode;
124 RooAbsData *_data = nullptr;
125 RooArgSet _parameters;
126 std::string _rangeName;
127 RooAbsCategory const *_indexCatForSplitting = nullptr;
128 const bool _takeGlobalObservablesFromData;
129};
130
131} // namespace
132
133std::unique_ptr<RooAbsReal>
134RooFit::BatchModeHelpers::createNLL(RooAbsPdf &pdf, RooAbsData &data, std::unique_ptr<RooAbsReal> &&constraints,
135 std::string const &rangeName, std::string const &addCoefRangeName,
136 RooArgSet const &projDeps, bool isExtended, double integrateOverBinsPrecision,
137 RooFit::BatchModeOption batchMode, bool doOffset,
138 bool takeGlobalObservablesFromData)
139{
140 // Clone PDF and reattach to original parameters
141 std::unique_ptr<RooAbsPdf> pdfClone{static_cast<RooAbsPdf *>(pdf.cloneTree())};
142 {
143 RooArgSet origParams;
144 pdf.getParameters(data.get(), origParams);
145 pdfClone->recursiveRedirectServers(origParams);
146 }
147
148 RooArgSet observables;
149 RooArgSet obsClones;
150 pdfClone->getObservables(data.get(), obsClones);
151 pdf.getObservables(data.get(), observables);
152 observables.remove(projDeps, true, true);
153 obsClones.remove(projDeps, true, true);
154
155 oocxcoutI(&pdf, Fitting) << "RooAbsPdf::fitTo(" << pdf.GetName()
156 << ") fixing normalization set for coefficient determination to observables in data"
157 << "\n";
158 pdfClone->fixAddCoefNormalization(obsClones, false);
159 pdf.fixAddCoefNormalization(observables, false);
160 if (!addCoefRangeName.empty()) {
161 oocxcoutI(&pdf, Fitting) << "RooAbsPdf::fitTo(" << pdf.GetName()
162 << ") fixing interpretation of coefficients of any component to range "
163 << addCoefRangeName << "\n";
164 pdfClone->fixAddCoefRange(addCoefRangeName.c_str(), false);
165 pdf.fixAddCoefRange(addCoefRangeName.c_str(), false);
166 }
167
168 // Deal with the IntegrateBins argument
169 RooArgList binSamplingPdfs;
170 std::unique_ptr<RooAbsPdf> wrappedPdf = RooBinSamplingPdf::create(*pdfClone, data, integrateOverBinsPrecision);
171 RooAbsPdf &finalPdf = wrappedPdf ? *wrappedPdf : *pdfClone;
172 if (wrappedPdf) {
173 binSamplingPdfs.addOwned(std::move(wrappedPdf));
174 }
175 // Done dealing with the IntegrateBins option
176
177 RooArgList nllTerms;
178
179 RooAbsCategory const *indexCatForSplitting = nullptr;
180 if (auto simPdf = dynamic_cast<RooSimultaneous *>(&finalPdf)) {
181 indexCatForSplitting = &simPdf->indexCat();
182 simPdf->wrapPdfsInBinSamplingPdfs(data, integrateOverBinsPrecision);
183 // Warning! This mutates "obsClones"
184 nllTerms.addOwned(createSimultaneousNLL(*simPdf, obsClones, isExtended, rangeName, doOffset));
185 } else {
186 nllTerms.addOwned(std::make_unique<RooNLLVarNew>("RooNLLVarNew", "RooNLLVarNew", finalPdf, obsClones, isExtended,
187 rangeName, doOffset));
188 }
189 if (constraints) {
190 nllTerms.addOwned(std::move(constraints));
191 }
192
193 std::string nllName = std::string("nll_") + pdfClone->GetName() + "_" + data.GetName();
194 auto nll = std::make_unique<RooAddition>(nllName.c_str(), nllName.c_str(), nllTerms);
195 nll->addOwnedComponents(std::move(binSamplingPdfs));
196 nll->addOwnedComponents(std::move(nllTerms));
197
198 auto driver = std::make_unique<RooFitDriver>(*nll, obsClones, batchMode);
199
200 // Set the fitrange attribute so that RooPlot can automatically plot the fitting range by default
201 if (!rangeName.empty()) {
202
203 std::string fitrangeValue;
204 auto subranges = ROOT::Split(rangeName, ",");
205 for (auto const &subrange : subranges) {
206 if (subrange.empty())
207 continue;
208 std::string fitrangeValueSubrange = std::string("fit_") + nll->GetName();
209 if (subranges.size() > 1) {
210 fitrangeValueSubrange += "_" + subrange;
211 }
212 fitrangeValue += fitrangeValueSubrange + ",";
213 for (auto *observable : static_range_cast<RooRealVar *>(obsClones)) {
214 observable->setRange(fitrangeValueSubrange.c_str(), observable->getMin(subrange.c_str()),
215 observable->getMax(subrange.c_str()));
216 }
217 }
218 pdf.setStringAttribute("fitrange", fitrangeValue.substr(0, fitrangeValue.size() - 1).c_str());
219 }
220
221 auto driverWrapper = std::make_unique<RooAbsRealWrapper>(std::move(driver), rangeName, indexCatForSplitting,
222 takeGlobalObservablesFromData);
223 driverWrapper->setData(data, false);
224 driverWrapper->addOwnedComponents(std::move(nll));
225 driverWrapper->addOwnedComponents(std::move(pdfClone));
226
227 return driverWrapper;
228}
229
231{
232 // We have to exit early if the message stream is not active. Otherwise it's
233 // possible that this funciton skips logging because it thinks it has
234 // already logged, but actually it didn't.
235 if (!RooMsgService::instance().isActive(static_cast<RooAbsArg *>(nullptr), RooFit::Fitting, RooFit::INFO)) {
236 return;
237 }
238
239 // Don't repeat logging architecture info if the batchMode option didn't change
240 {
241 // Second element of pair tracks whether this function has already been called
242 static std::pair<RooFit::BatchModeOption, bool> lastBatchMode;
243 if (lastBatchMode.second && lastBatchMode.first == batchMode)
244 return;
245 lastBatchMode = {batchMode, true};
246 }
247
248 auto log = [](std::string_view message) {
249 oocxcoutI(static_cast<RooAbsArg *>(nullptr), Fitting) << message << std::endl;
250 };
251
253 throw std::runtime_error(std::string("In: ") + __func__ + "(), " + __FILE__ + ":" + __LINE__ +
254 ": Cuda implementation of the computing library is not available\n");
255 }
257 log("using generic CPU library compiled with no vectorizations");
258 } else {
259 log(std::string("using CPU computation library compiled with -m") +
260 RooBatchCompute::dispatchCPU->architectureName());
261 }
262 if (batchMode == RooFit::BatchModeOption::Cuda) {
263 log("using CUDA computation library");
264 }
265}
#define oocxcoutI(o, a)
char name[80]
Definition TGX11.cxx:110
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition RooAbsArg.h:69
virtual RooAbsArg * cloneTree(const char *newname=0) const
Clone tree expression of objects.
RooArgSet * getObservables(const RooArgSet &set, Bool_t valueOnly=kTRUE) const
Given a set of possible observables, return the observables that this PDF depends on.
Definition RooAbsArg.h:309
void setStringAttribute(const Text_t *key, const Text_t *value)
Associate string 'value' to this object under key 'key'.
RooArgSet * getParameters(const RooAbsData *data, bool stripDisconnected=true) const
Create a list of leaf nodes in the arg tree starting with ourself as top node that don't match any of...
RooAbsCategory is the base class for objects that represent a discrete value with a finite number of ...
virtual Bool_t replace(const RooAbsArg &var1, const RooAbsArg &var2)
Replace var1 with var2 and return kTRUE for success.
virtual Bool_t add(const RooAbsArg &var, Bool_t silent=kFALSE)
Add the specified argument to list.
virtual Bool_t addOwned(RooAbsArg &var, Bool_t silent=kFALSE)
Add an argument and transfer the ownership to the collection.
RooAbsArg * first() const
void clear()
Clear contents. If the collection is owning, it will also delete the contents.
virtual Bool_t remove(const RooAbsArg &var, Bool_t silent=kFALSE, Bool_t matchByNameOnly=kFALSE)
Remove the specified argument from our list.
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:82
virtual const RooArgSet * get() const
Definition RooAbsData.h:128
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition RooAbsReal.h:64
virtual void fixAddCoefRange(const char *rangeName=0, Bool_t force=kTRUE)
Fix the interpretation of the coefficient of any RooAddPdf component in the expression tree headed by...
virtual void fixAddCoefNormalization(const RooArgSet &addNormSet=RooArgSet(), Bool_t force=kTRUE)
Fix the interpretation of the coefficient of any RooAddPdf component in the expression tree headed by...
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:35
static std::unique_ptr< RooAbsPdf > create(RooAbsPdf &pdf, RooAbsData const &data, double precision)
Creates a wrapping RooBinSamplingPdf if appropriate.
static RooMsgService & instance()
Return reference to singleton instance.
RooSimultaneous facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
void wrapPdfsInBinSamplingPdfs(RooAbsData const &data, double precision)
Wraps the components of this RooSimultaneous in RooBinSamplingPdfs.
const RooAbsCategoryLValue & indexCat() const
RooAbsPdf * getPdf(const char *catName) const
Return the p.d.f associated with the given index category name.
virtual const char * GetName() const
Returns name of object.
Definition TNamed.h:47
Mother of all ROOT objects.
Definition TObject.h:41
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.
R__EXTERN RooBatchComputeInterface * dispatchCUDA
R__EXTERN RooBatchComputeInterface * dispatchCPU
This dispatch pointer points to an implementation of the compute library, provided one has been loade...
std::unique_ptr< RooAbsReal > createNLL(RooAbsPdf &pdf, RooAbsData &data, std::unique_ptr< RooAbsReal > &&constraints, std::string const &rangeName, std::string const &addCoefRangeName, RooArgSet const &projDeps, bool isExtended, double integrateOverBinsPrecision, RooFit::BatchModeOption batchMode, bool doOffset, bool takeGlobalObservablesFromData)
void logArchitectureInfo(RooFit::BatchModeOption batchMode)
BatchModeOption
For setting the batch mode flag with the BatchMode() command argument to RooAbsPdf::fitTo();.
void evaluate(typename Architecture_t::Tensor_t &A, EActivationFunction f)
Apply the given activation function to each value in the given tensor A.
Definition Functions.h:98