Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BatchModeDataHelpers.cxx
Go to the documentation of this file.
1/// \cond ROOFIT_INTERNAL
2
3/*
4 * Project: RooFit
5 * Authors:
6 * Jonas Rembser, CERN 2022
7 *
8 * Copyright (c) 2022, CERN
9 *
10 * Redistribution and use in source and binary forms,
11 * with or without modification, are permitted according to the terms
12 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
13 */
14
16
17#include <RooAbsData.h>
18#include <RooRealVar.h>
19#include <RooSimultaneous.h>
20
21#include "RooFitImplHelpers.h"
22#include "RooNLLVarNew.h"
23
24#include <ROOT/StringUtils.hxx>
25
26#include <numeric>
27
28namespace {
29
30// To avoid deleted move assignment.
31template <class T>
32void assignSpan(std::span<T> &to, std::span<T> const &from)
33{
34 to = from;
35}
36
37std::map<RooFit::Detail::DataKey, std::span<const double>>
38getSingleDataSpans(RooAbsData const &data, std::string_view rangeName, std::string const &prefix,
39 std::stack<std::vector<double>> &buffers, bool skipZeroWeights)
40{
41 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
42
43 auto &nameReg = RooNameReg::instance();
44
45 auto insert = [&](const char *key, std::span<const double> span) {
46 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
47 dataSpans[namePtr] = span;
48 };
49
50 auto retrieve = [&](const char *key) {
51 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
52 return dataSpans.at(namePtr);
53 };
54
55 std::size_t nEvents = static_cast<size_t>(data.numEntries());
56
57 auto weight = data.getWeightBatch(0, nEvents, /*sumW2=*/false);
58 auto weightSumW2 = data.getWeightBatch(0, nEvents, /*sumW2=*/true);
59
60 std::vector<bool> hasZeroWeight;
61 hasZeroWeight.resize(nEvents);
62 std::size_t nNonZeroWeight = 0;
63
64 // Add weights to the datamap. They should have the names expected by the
65 // RooNLLVarNew. We also add the sumW2 weights here under a different name,
66 // so we can apply the sumW2 correction by easily swapping the spans.
67 {
68 buffers.emplace();
69 auto &buffer = buffers.top();
70 buffers.emplace();
71 auto &bufferSumW2 = buffers.top();
72 if (weight.empty()) {
73 // If the dataset has no weight, we fill the data spans with a scalar
74 // unity weight so we don't need to check for the existence of weights
75 // later in the likelihood.
76 buffer.push_back(1.0);
77 bufferSumW2.push_back(1.0);
78 assignSpan(weight, {buffer.data(), 1});
79 assignSpan(weightSumW2, {bufferSumW2.data(), 1});
80 nNonZeroWeight = nEvents;
81 } else {
82 buffer.reserve(nEvents);
83 bufferSumW2.reserve(nEvents);
84 for (std::size_t i = 0; i < nEvents; ++i) {
85 if (!skipZeroWeights || weight[i] != 0) {
86 buffer.push_back(weight[i]);
87 bufferSumW2.push_back(weightSumW2[i]);
88 ++nNonZeroWeight;
89 } else {
90 hasZeroWeight[i] = true;
91 }
92 }
93 assignSpan(weight, {buffer.data(), nNonZeroWeight});
94 assignSpan(weightSumW2, {bufferSumW2.data(), nNonZeroWeight});
95 }
96 insert(RooNLLVarNew::weightVarName, weight);
97 insert(RooNLLVarNew::weightVarNameSumW2, weightSumW2);
98 }
99
100 // Get the real-valued batches and cast the also to double branches to put in
101 // the data map
102 for (auto const &item : data.getBatches(0, nEvents)) {
103
104 std::span<const double> span{item.second};
105
106 buffers.emplace();
107 auto &buffer = buffers.top();
108 buffer.reserve(nNonZeroWeight);
109
110 for (std::size_t i = 0; i < nEvents; ++i) {
111 if (!hasZeroWeight[i]) {
112 buffer.push_back(span[i]);
113 }
114 }
115 insert(item.first->GetName(), {buffer.data(), buffer.size()});
116 }
117
118 // Get the category batches and cast the also to double branches to put in
119 // the data map
120 for (auto const &item : data.getCategoryBatches(0, nEvents)) {
121
122 std::span<const RooAbsCategory::value_type> intSpan{item.second};
123
124 buffers.emplace();
125 auto &buffer = buffers.top();
126 buffer.reserve(nNonZeroWeight);
127
128 for (std::size_t i = 0; i < nEvents; ++i) {
129 if (!hasZeroWeight[i]) {
130 buffer.push_back(static_cast<double>(intSpan[i]));
131 }
132 }
133 insert(item.first->GetName(), {buffer.data(), buffer.size()});
134 }
135
136 nEvents = nNonZeroWeight;
137
138 // Now we have do do the range selection
139 if (!rangeName.empty()) {
140 // figure out which events are in the range
141 std::vector<bool> isInRange(nEvents, false);
142 for (auto const &range : ROOT::Split(rangeName, ",")) {
143 std::vector<bool> isInSubRange(nEvents, true);
144 for (auto *observable : dynamic_range_cast<RooAbsRealLValue *>(*data.get())) {
145 // If the observables is not real-valued, it will not be considered for the range selection
146 if (observable) {
147 observable->inRange({retrieve(observable->GetName()).data(), nEvents}, range, isInSubRange);
148 }
149 }
150 for (std::size_t i = 0; i < isInSubRange.size(); ++i) {
151 isInRange[i] = isInRange[i] || isInSubRange[i];
152 }
153 }
154
155 // reset the number of events
156 nEvents = std::accumulate(isInRange.begin(), isInRange.end(), 0);
157
158 // do the data reduction in the data map
159 for (auto const &item : dataSpans) {
160 auto const &allValues = item.second;
161 if (allValues.size() == 1) {
162 continue;
163 }
164 buffers.emplace(nEvents);
165 double *buffer = buffers.top().data();
166 std::size_t j = 0;
167 for (std::size_t i = 0; i < isInRange.size(); ++i) {
168 if (isInRange[i]) {
169 buffer[j] = allValues[i];
170 ++j;
171 }
172 }
173 assignSpan(dataSpans[item.first], {buffer, nEvents});
174 }
175 }
176
177 return dataSpans;
178}
179
180} // namespace
181
182////////////////////////////////////////////////////////////////////////////////
183/// Extract all content from a RooFit datasets as a map of spans.
184/// Spans with the weights and squared weights will be also stored in the map,
185/// keyed with the names `_weight` and the `_weight_sumW2`. If the dataset is
186/// unweighted, these weight spans will only contain the single value `1.0`.
187/// Entries with zero weight will be skipped.
188///
189/// \return A `std::map` with spans keyed to name pointers.
190/// \param[in] data The input dataset.
191/// \param[in] rangeName Select only entries from the data in a given range
192/// (empty string for no range).
193/// \param[in] simPdf A simultaneous pdf to use as a guide for splitting the
194/// dataset. The spans from each channel data will be prefixed with
195/// the channel name.
196/// \param[in] skipZeroWeights Skip entries with zero weight when filling the
197/// data spans. Be very careful with enabling it, because the user
198/// might not expect that the batch results are not aligned with the
199/// original dataset anymore!
200/// \param[in] takeGlobalObservablesFromData Take also the global observables
201/// stored in the dataset.
202/// \param[in] buffers Pass here an empty stack of `double` vectors, which will
203/// be used as memory for the data if the memory in the dataset
204/// object can't be used directly (e.g. because you used the range
205/// selection or the splitting by categories).
206std::map<RooFit::Detail::DataKey, std::span<const double>>
207RooFit::Detail::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string const &rangeName,
208 RooSimultaneous const *simPdf, bool skipZeroWeights,
209 bool takeGlobalObservablesFromData,
210 std::stack<std::vector<double>> &buffers)
211{
212 std::vector<std::pair<std::string, RooAbsData const *>> datasets;
213 std::vector<bool> isBinnedL;
214 bool splitRange = false;
215 std::vector<std::unique_ptr<RooAbsData>> splitDataSets;
216
217 if (simPdf) {
218 std::unique_ptr<TList> splits{data.split(*simPdf, true)};
219 for (auto *d : static_range_cast<RooAbsData *>(*splits)) {
220 RooAbsPdf *simComponent = simPdf->getPdf(d->GetName());
221 // If there is no PDF for that component, we also don't need to fill the data
222 if (!simComponent) {
223 continue;
224 }
225 datasets.emplace_back(std::string("_") + d->GetName() + "_", d);
226 isBinnedL.emplace_back(simComponent->getAttribute("BinnedLikelihoodActive"));
227 // The dataset need to be kept alive because the datamap points to their content
228 splitDataSets.emplace_back(d);
229 }
230 splitRange = simPdf->getAttribute("SplitRange");
231 } else {
232 datasets.emplace_back("", &data);
233 isBinnedL.emplace_back(false);
234 }
235
236 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
237
238 for (std::size_t iData = 0; iData < datasets.size(); ++iData) {
239 auto const &toAdd = datasets[iData];
240 auto spans = getSingleDataSpans(
241 *toAdd.second, RooHelpers::getRangeNameForSimComponent(rangeName, splitRange, toAdd.second->GetName()),
242 toAdd.first, buffers, skipZeroWeights && !isBinnedL[iData]);
243 for (auto const &item : spans) {
244 dataSpans.insert(item);
245 }
246 }
247
248 if (takeGlobalObservablesFromData && data.getGlobalObservables()) {
249 buffers.emplace();
250 auto &buffer = buffers.top();
251 buffer.reserve(data.getGlobalObservables()->size());
252 for (auto *arg : static_range_cast<RooRealVar const *>(*data.getGlobalObservables())) {
253 buffer.push_back(arg->getVal());
254 assignSpan(dataSpans[arg], {&buffer.back(), 1});
255 }
256 }
257
258 return dataSpans;
259}
260
261////////////////////////////////////////////////////////////////////////////////
262/// Figure out the output size for each node in the computation graph that
263/// leads up to the top node, given some vector data as an input. The input
264/// data spans are in general not of the same size, for example in the case of
265/// a simultaneous fit.
266///
267/// \return A `std::map` with output sizes for each node in the computation graph.
268/// \param[in] topNode The top node of the computation graph.
269/// \param[in] inputSizeFunc A function to get the input sizes.
270std::map<RooFit::Detail::DataKey, std::size_t> RooFit::Detail::BatchModeDataHelpers::determineOutputSizes(
271 RooAbsArg const &topNode, std::function<int(RooFit::Detail::DataKey)> const &inputSizeFunc)
272{
273 std::map<RooFit::Detail::DataKey, std::size_t> output;
274
275 RooArgSet serverSet;
276 RooHelpers::getSortedComputationGraph(topNode, serverSet);
277
278 for (RooAbsArg *arg : serverSet) {
279 int inputSize = inputSizeFunc(arg);
280 // The size == -1 encodes that the input doesn't come from an array
281 // input.
282 if (inputSize != -1) {
283 output[arg] = inputSize;
284 }
285 }
286
287 for (RooAbsArg *arg : serverSet) {
288 std::size_t size = 1;
289 if (output.find(arg) != output.end()) {
290 continue;
291 }
292 if (!arg->isReducerNode()) {
293 for (RooAbsArg *server : arg->servers()) {
294 if (server->isValueServer(*arg)) {
295 std::size_t inputSize = output.at(server);
296 if (inputSize != 1) {
297 // If the input if from an external array, the output will
298 // adopt its size and we can stop the checking of other
299 // servers.
300 size = inputSize;
301 break;
302 }
303 }
304 }
305 }
306 output[arg] = size;
307 }
308
309 return output;
310}
311
312/// \endcond
#define d(i)
Definition RSha256.hxx:102
ROOT::RRangeCast< T, true, Range_t > dynamic_range_cast(Range_t &&coll)
ROOT::RRangeCast< T, false, Range_t > static_range_cast(Range_t &&coll)
static void retrieve(const gsl_integration_workspace *workspace, double *a, double *b, double *r, double *e)
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:77
bool getAttribute(const Text_t *name) const
Check if a named attribute is set. By default, all attributes are unset.
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract interface for all probability density functions.
Definition RooAbsPdf.h:40
Abstract base class for objects that represent a real value that may appear on the left hand side of ...
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
static RooNameReg & instance()
Return reference to singleton instance.
Variable that can be changed from the outside.
Definition RooRealVar.h:37
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
RooAbsPdf * getPdf(RooStringView catName) const
Return the p.d.f associated with the given index category name.
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
tbb::task_arena is an alias of tbb::interface7::task_arena, which doesn't allow to forward declare tb...
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.
std::string getRangeNameForSimComponent(std::string const &rangeName, bool splitRange, std::string const &catName)
void getSortedComputationGraph(RooAbsArg const &func, RooArgSet &out)
static void output()