Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BatchModeDataHelpers.cxx
Go to the documentation of this file.
1/// \cond ROOFIT_INTERNAL
2
3/*
4 * Project: RooFit
5 * Authors:
6 * Jonas Rembser, CERN 2022
7 *
8 * Copyright (c) 2022, CERN
9 *
10 * Redistribution and use in source and binary forms,
11 * with or without modification, are permitted according to the terms
12 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
13 */
14
16
17#include <RooAbsData.h>
18#include <RooDataHist.h>
19#include <RooRealVar.h>
20#include <RooSimultaneous.h>
21
22#include "RooFitImplHelpers.h"
24
25#include <ROOT/StringUtils.hxx>
26
27#include <numeric>
28
29namespace {
30
31// To avoid deleted move assignment.
32template <class T>
33void assignSpan(std::span<T> &to, std::span<T> const &from)
34{
35 to = from;
36}
37
38std::map<RooFit::Detail::DataKey, std::span<const double>>
39getSingleDataSpans(RooAbsData const &data, std::string_view rangeName, std::string const &prefix,
40 std::stack<std::vector<double>> &buffers, bool skipZeroWeights)
41{
42 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
43
45
46 auto insert = [&](const char *key, std::span<const double> span) {
47 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
48 dataSpans[namePtr] = span;
49 };
50
51 auto retrieve = [&](const char *key) {
52 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
53 return dataSpans.at(namePtr);
54 };
55
56 std::size_t nEvents = static_cast<size_t>(data.numEntries());
57
58 auto weight = data.getWeightBatch(0, nEvents, /*sumW2=*/false);
59 auto weightSumW2 = data.getWeightBatch(0, nEvents, /*sumW2=*/true);
60
61 std::vector<bool> hasZeroWeight;
62 hasZeroWeight.resize(nEvents);
63 std::size_t nNonZeroWeight = 0;
64
65 // Add weights to the datamap. They should have the names expected by the
66 // RooNLLVarNew. We also add the sumW2 weights here under a different name,
67 // so we can apply the sumW2 correction by easily swapping the spans.
68 {
69 buffers.emplace();
70 auto &buffer = buffers.top();
71 buffers.emplace();
72 auto &bufferSumW2 = buffers.top();
73 buffer.reserve(nEvents);
74 bufferSumW2.reserve(nEvents);
75 for (std::size_t i = 0; i < nEvents; ++i) {
76 if (weight.empty()) {
77 // No weights in the dataset imply a constant weight of one
78 buffer.push_back(1.0);
79 bufferSumW2.push_back(1.0);
81 } else if (!skipZeroWeights || weight[i] != 0) {
82 buffer.push_back(weight[i]);
83 bufferSumW2.push_back(weightSumW2[i]);
85 } else {
86 hasZeroWeight[i] = true;
87 }
88 }
89 assignSpan(weight, {buffer.data(), nNonZeroWeight});
91 insert(RooFit::Detail::RooNLLVarNew::weightVarName, weight);
92 insert(RooFit::Detail::RooNLLVarNew::weightVarNameSumW2, weightSumW2);
93 }
94
95 // For RooDataHist datasets, also publish per-bin volumes and asymmetric
96 // Poisson errors. These are consumed by the chi2 evaluation path in
97 // RooNLLVarNew.
98 if (auto const *dataHist = dynamic_cast<RooDataHist const *>(&data)) {
99 auto binVolumes = dataHist->binVolumes(0, static_cast<std::size_t>(data.numEntries()));
100 buffers.emplace();
101 auto &bufferBinVol = buffers.top();
103 buffers.emplace();
104 auto &bufferErrLo = buffers.top();
106 buffers.emplace();
107 auto &bufferErrHi = buffers.top();
109
110 auto *dataHistMutable = const_cast<RooDataHist *>(dataHist);
111 for (std::size_t i = 0; i < binVolumes.size(); ++i) {
112 if (hasZeroWeight[i]) {
113 continue;
114 }
115 bufferBinVol.push_back(binVolumes[i]);
116 dataHistMutable->get(static_cast<int>(i));
117 double lo = 0.0;
118 double hi = 0.0;
119 dataHistMutable->weightError(lo, hi, RooAbsData::Poisson);
120 bufferErrLo.push_back(lo);
121 bufferErrHi.push_back(hi);
122 }
123 insert(RooFit::Detail::RooNLLVarNew::binVolumeVarName, {bufferBinVol.data(), bufferBinVol.size()});
124 insert(RooFit::Detail::RooNLLVarNew::weightErrorLoVarName, {bufferErrLo.data(), bufferErrLo.size()});
125 insert(RooFit::Detail::RooNLLVarNew::weightErrorHiVarName, {bufferErrHi.data(), bufferErrHi.size()});
126 }
127
128 // Get the real-valued batches and cast the also to double branches to put in
129 // the data map
130 for (auto const &item : data.getBatches(0, nEvents)) {
131
132 std::span<const double> span{item.second};
133
134 buffers.emplace();
135 auto &buffer = buffers.top();
136 buffer.reserve(nNonZeroWeight);
137
138 for (std::size_t i = 0; i < nEvents; ++i) {
139 if (!hasZeroWeight[i]) {
140 buffer.push_back(span[i]);
141 }
142 }
143 insert(item.first->GetName(), {buffer.data(), buffer.size()});
144 }
145
146 // Get the category batches and cast the also to double branches to put in
147 // the data map
148 for (auto const &item : data.getCategoryBatches(0, nEvents)) {
149
150 std::span<const RooAbsCategory::value_type> intSpan{item.second};
151
152 buffers.emplace();
153 auto &buffer = buffers.top();
154 buffer.reserve(nNonZeroWeight);
155
156 for (std::size_t i = 0; i < nEvents; ++i) {
157 if (!hasZeroWeight[i]) {
158 buffer.push_back(static_cast<double>(intSpan[i]));
159 }
160 }
161 insert(item.first->GetName(), {buffer.data(), buffer.size()});
162 }
163
164 nEvents = nNonZeroWeight;
165
166 // Now we have do do the range selection
167 if (!rangeName.empty()) {
168 // figure out which events are in the range
169 std::vector<bool> isInRange(nEvents, false);
170 for (auto const &range : ROOT::Split(rangeName, ",")) {
171 std::vector<bool> isInSubRange(nEvents, true);
172 for (auto *observable : dynamic_range_cast<RooAbsRealLValue *>(*data.get())) {
173 // If the observables is not real-valued, it will not be considered for the range selection
174 if (observable) {
175 observable->inRange({retrieve(observable->GetName()).data(), nEvents}, range, isInSubRange);
176 }
177 }
178 for (std::size_t i = 0; i < isInSubRange.size(); ++i) {
179 isInRange[i] = isInRange[i] || isInSubRange[i];
180 }
181 }
182
183 // reset the number of events
184 nEvents = std::accumulate(isInRange.begin(), isInRange.end(), 0);
185
186 // do the data reduction in the data map
187 for (auto const &item : dataSpans) {
188 auto const &allValues = item.second;
189 if (allValues.size() == 1) {
190 continue;
191 }
192 buffers.emplace(nEvents);
193 double *buffer = buffers.top().data();
194 std::size_t j = 0;
195 for (std::size_t i = 0; i < isInRange.size(); ++i) {
196 if (isInRange[i]) {
197 buffer[j] = allValues[i];
198 ++j;
199 }
200 }
201 assignSpan(dataSpans[item.first], {buffer, nEvents});
202 }
203 }
204
205 return dataSpans;
206}
207
208} // namespace
209
210////////////////////////////////////////////////////////////////////////////////
211/// Extract all content from a RooFit datasets as a map of spans.
212/// Spans with the weights and squared weights will be also stored in the map,
213/// keyed with the names `_weight` and the `_weight_sumW2`. If the dataset is
214/// unweighted, these weight spans will only contain the single value `1.0`.
215/// Entries with zero weight will be skipped.
216///
217/// \return A `std::map` with spans keyed to name pointers.
218/// \param[in] data The input dataset.
219/// \param[in] rangeName Select only entries from the data in a given range
220/// (empty string for no range).
221/// \param[in] simPdf A simultaneous pdf to use as a guide for splitting the
222/// dataset. The spans from each channel data will be prefixed with
223/// the channel name.
224/// \param[in] skipZeroWeights Skip entries with zero weight when filling the
225/// data spans. Be very careful with enabling it, because the user
226/// might not expect that the batch results are not aligned with the
227/// original dataset anymore!
228/// \param[in] takeGlobalObservablesFromData Take also the global observables
229/// stored in the dataset.
230/// \param[in] buffers Pass here an empty stack of `double` vectors, which will
231/// be used as memory for the data if the memory in the dataset
232/// object can't be used directly (e.g. because you used the range
233/// selection or the splitting by categories).
234std::map<RooFit::Detail::DataKey, std::span<const double>>
235RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string const &rangeName,
237 bool takeGlobalObservablesFromData, std::stack<std::vector<double>> &buffers)
238{
239 std::vector<std::pair<std::string, RooAbsData const *>> datasets;
240 std::vector<bool> isBinnedL;
241 bool splitRange = false;
242
243 // The split datasets need to be kept alive because the datamap points to their content
244 std::vector<std::unique_ptr<RooAbsData>> splitDataSets;
245
246 if (simPdf) {
247 splitDataSets = data.split(*simPdf, true);
248 for (auto const &d : splitDataSets) {
249 RooAbsPdf *simComponent = simPdf->getPdf(d->GetName());
250 // If there is no PDF for that component, we also don't need to fill the data
251 if (!simComponent) {
252 continue;
253 }
254 datasets.emplace_back(std::string("_") + d->GetName() + "_", d.get());
255 isBinnedL.emplace_back(simComponent->getAttribute("BinnedLikelihoodActive"));
256 }
257 splitRange = simPdf->getAttribute("SplitRange");
258 } else {
259 datasets.emplace_back("", &data);
260 isBinnedL.emplace_back(false);
261 }
262
263 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
264
265 for (std::size_t iData = 0; iData < datasets.size(); ++iData) {
266 auto const &toAdd = datasets[iData];
269 toAdd.first, buffers, skipZeroWeights && !isBinnedL[iData]);
270 for (auto const &item : spans) {
271 dataSpans.insert(item);
272 }
273 }
274
275 if (takeGlobalObservablesFromData && data.getGlobalObservables()) {
276 buffers.emplace();
277 auto &buffer = buffers.top();
278 buffer.reserve(data.getGlobalObservables()->size());
279 for (auto *arg : static_range_cast<RooRealVar const *>(*data.getGlobalObservables())) {
280 buffer.push_back(arg->getVal());
281 assignSpan(dataSpans[arg], {&buffer.back(), 1});
282 }
283 }
284
285 return dataSpans;
286}
287
288////////////////////////////////////////////////////////////////////////////////
289/// Figure out the output size for each node in the computation graph that
290/// leads up to the top node, given some vector data as an input. The input
291/// data spans are in general not of the same size, for example in the case of
292/// a simultaneous fit.
293///
294/// \return A `std::map` with output sizes for each node in the computation graph.
295/// \param[in] topNode The top node of the computation graph.
296/// \param[in] inputSizeFunc A function to get the input sizes.
297std::map<RooFit::Detail::DataKey, std::size_t>
298RooFit::BatchModeDataHelpers::determineOutputSizes(RooAbsArg const &topNode,
299 std::function<int(RooFit::Detail::DataKey)> const &inputSizeFunc)
300{
301 std::map<RooFit::Detail::DataKey, std::size_t> output;
302
305
306 for (RooAbsArg *arg : serverSet) {
307 int inputSize = inputSizeFunc(arg);
308 // The size == -1 encodes that the input doesn't come from an array
309 // input.
310 if (inputSize != -1) {
311 output[arg] = inputSize;
312 }
313 }
314
315 for (RooAbsArg *arg : serverSet) {
316 std::size_t size = 1;
317 if (output.find(arg) != output.end()) {
318 continue;
319 }
320 if (!arg->isReducerNode()) {
321 for (RooAbsArg *server : arg->servers()) {
322 if (server->isValueServer(*arg)) {
323 std::size_t inputSize = output.at(server);
324 if (inputSize != 1) {
325 // If the input if from an external array, the output will
326 // adopt its size and we can stop the checking of other
327 // servers.
328 size = inputSize;
329 break;
330 }
331 }
332 }
333 }
334 output[arg] = size;
335 }
336
337 return output;
338}
339
340/// \endcond
#define d(i)
Definition RSha256.hxx:102
ROOT::RRangeCast< T, true, Range_t > dynamic_range_cast(Range_t &&coll)
ROOT::RRangeCast< T, false, Range_t > static_range_cast(Range_t &&coll)
static void retrieve(const gsl_integration_workspace *workspace, double *a, double *b, double *r, double *e)
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
#define hi
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:76
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:56
Abstract interface for all probability density functions.
Definition RooAbsPdf.h:32
Abstract base class for objects that represent a real value that may appear on the left hand side of ...
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
Container class to hold N-dimensional binned data.
Definition RooDataHist.h:40
static RooNameReg & instance()
Return reference to singleton instance.
Variable that can be changed from the outside.
Definition RooRealVar.h:37
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
CoordSystem::Scalar get(DisplacementVector2D< CoordSystem, Tag > const &p)
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.
std::string getRangeNameForSimComponent(std::string const &rangeName, bool splitRange, std::string const &catName)
void getSortedComputationGraph(RooAbsArg const &func, RooArgSet &out)
static void output()