Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BatchModeDataHelpers.cxx
Go to the documentation of this file.
1/// \cond ROOFIT_INTERNAL
2
3/*
4 * Project: RooFit
5 * Authors:
6 * Jonas Rembser, CERN 2022
7 *
8 * Copyright (c) 2022, CERN
9 *
10 * Redistribution and use in source and binary forms,
11 * with or without modification, are permitted according to the terms
12 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
13 */
14
16
17#include <RooAbsData.h>
18#include <RooRealVar.h>
19#include <RooSimultaneous.h>
20
21#include "RooFitImplHelpers.h"
22#include "RooNLLVarNew.h"
23
24#include <ROOT/StringUtils.hxx>
25
26#include <numeric>
27
28namespace {
29
30// To avoid deleted move assignment.
31template <class T>
32void assignSpan(std::span<T> &to, std::span<T> const &from)
33{
34 to = from;
35}
36
37std::map<RooFit::Detail::DataKey, std::span<const double>>
38getSingleDataSpans(RooAbsData const &data, std::string_view rangeName, std::string const &prefix,
39 std::stack<std::vector<double>> &buffers, bool skipZeroWeights)
40{
41 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
42
43 auto &nameReg = RooNameReg::instance();
44
45 auto insert = [&](const char *key, std::span<const double> span) {
46 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
47 dataSpans[namePtr] = span;
48 };
49
50 auto retrieve = [&](const char *key) {
51 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
52 return dataSpans.at(namePtr);
53 };
54
55 std::size_t nEvents = static_cast<size_t>(data.numEntries());
56
57 // We also want to support empty datasets: in this case the
58 // RooFitDriver::Dataset is not filled with anything.
59 if (nEvents == 0) {
60 return dataSpans;
61 }
62
63 auto weight = data.getWeightBatch(0, nEvents, /*sumW2=*/false);
64 auto weightSumW2 = data.getWeightBatch(0, nEvents, /*sumW2=*/true);
65
66 std::vector<bool> hasZeroWeight;
67 hasZeroWeight.resize(nEvents);
68 std::size_t nNonZeroWeight = 0;
69
70 // Add weights to the datamap. They should have the names expected by the
71 // RooNLLVarNew. We also add the sumW2 weights here under a different name,
72 // so we can apply the sumW2 correction by easily swapping the spans.
73 {
74 buffers.emplace();
75 auto &buffer = buffers.top();
76 buffers.emplace();
77 auto &bufferSumW2 = buffers.top();
78 if (weight.empty()) {
79 // If the dataset has no weight, we fill the data spans with a scalar
80 // unity weight so we don't need to check for the existence of weights
81 // later in the likelihood.
82 buffer.push_back(1.0);
83 bufferSumW2.push_back(1.0);
84 assignSpan(weight, {buffer.data(), 1});
85 assignSpan(weightSumW2, {bufferSumW2.data(), 1});
86 nNonZeroWeight = nEvents;
87 } else {
88 buffer.reserve(nEvents);
89 bufferSumW2.reserve(nEvents);
90 for (std::size_t i = 0; i < nEvents; ++i) {
91 if (!skipZeroWeights || weight[i] != 0) {
92 buffer.push_back(weight[i]);
93 bufferSumW2.push_back(weightSumW2[i]);
94 ++nNonZeroWeight;
95 } else {
96 hasZeroWeight[i] = true;
97 }
98 }
99 assignSpan(weight, {buffer.data(), nNonZeroWeight});
100 assignSpan(weightSumW2, {bufferSumW2.data(), nNonZeroWeight});
101 }
102 insert(RooNLLVarNew::weightVarName, weight);
103 insert(RooNLLVarNew::weightVarNameSumW2, weightSumW2);
104 }
105
106 // Get the real-valued batches and cast the also to double branches to put in
107 // the data map
108 for (auto const &item : data.getBatches(0, nEvents)) {
109
110 std::span<const double> span{item.second};
111
112 buffers.emplace();
113 auto &buffer = buffers.top();
114 buffer.reserve(nNonZeroWeight);
115
116 for (std::size_t i = 0; i < nEvents; ++i) {
117 if (!hasZeroWeight[i]) {
118 buffer.push_back(span[i]);
119 }
120 }
121 insert(item.first->GetName(), {buffer.data(), buffer.size()});
122 }
123
124 // Get the category batches and cast the also to double branches to put in
125 // the data map
126 for (auto const &item : data.getCategoryBatches(0, nEvents)) {
127
128 std::span<const RooAbsCategory::value_type> intSpan{item.second};
129
130 buffers.emplace();
131 auto &buffer = buffers.top();
132 buffer.reserve(nNonZeroWeight);
133
134 for (std::size_t i = 0; i < nEvents; ++i) {
135 if (!hasZeroWeight[i]) {
136 buffer.push_back(static_cast<double>(intSpan[i]));
137 }
138 }
139 insert(item.first->GetName(), {buffer.data(), buffer.size()});
140 }
141
142 nEvents = nNonZeroWeight;
143
144 // Now we have do do the range selection
145 if (!rangeName.empty()) {
146 // figure out which events are in the range
147 std::vector<bool> isInRange(nEvents, false);
148 for (auto const &range : ROOT::Split(rangeName, ",")) {
149 std::vector<bool> isInSubRange(nEvents, true);
150 for (auto *observable : dynamic_range_cast<RooAbsRealLValue *>(*data.get())) {
151 // If the observables is not real-valued, it will not be considered for the range selection
152 if (observable) {
153 observable->inRange({retrieve(observable->GetName()).data(), nEvents}, range, isInSubRange);
154 }
155 }
156 for (std::size_t i = 0; i < isInSubRange.size(); ++i) {
157 isInRange[i] = isInRange[i] || isInSubRange[i];
158 }
159 }
160
161 // reset the number of events
162 nEvents = std::accumulate(isInRange.begin(), isInRange.end(), 0);
163
164 // do the data reduction in the data map
165 for (auto const &item : dataSpans) {
166 auto const &allValues = item.second;
167 if (allValues.size() == 1) {
168 continue;
169 }
170 buffers.emplace(nEvents);
171 double *buffer = buffers.top().data();
172 std::size_t j = 0;
173 for (std::size_t i = 0; i < isInRange.size(); ++i) {
174 if (isInRange[i]) {
175 buffer[j] = allValues[i];
176 ++j;
177 }
178 }
179 assignSpan(dataSpans[item.first], {buffer, nEvents});
180 }
181 }
182
183 return dataSpans;
184}
185
186} // namespace
187
188////////////////////////////////////////////////////////////////////////////////
189/// Extract all content from a RooFit datasets as a map of spans.
190/// Spans with the weights and squared weights will be also stored in the map,
191/// keyed with the names `_weight` and the `_weight_sumW2`. If the dataset is
192/// unweighted, these weight spans will only contain the single value `1.0`.
193/// Entries with zero weight will be skipped.
194///
195/// \return A `std::map` with spans keyed to name pointers.
196/// \param[in] data The input dataset.
197/// \param[in] rangeName Select only entries from the data in a given range
198/// (empty string for no range).
199/// \param[in] simPdf A simultaneous pdf to use as a guide for splitting the
200/// dataset. The spans from each channel data will be prefixed with
201/// the channel name.
202/// \param[in] skipZeroWeights Skip entries with zero weight when filling the
203/// data spans. Be very careful with enabling it, because the user
204/// might not expect that the batch results are not aligned with the
205/// original dataset anymore!
206/// \param[in] takeGlobalObservablesFromData Take also the global observables
207/// stored in the dataset.
208/// \param[in] buffers Pass here an empty stack of `double` vectors, which will
209/// be used as memory for the data if the memory in the dataset
210/// object can't be used directly (e.g. because you used the range
211/// selection or the splitting by categories).
212std::map<RooFit::Detail::DataKey, std::span<const double>>
213RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string const &rangeName,
214 RooSimultaneous const *simPdf, bool skipZeroWeights,
215 bool takeGlobalObservablesFromData, std::stack<std::vector<double>> &buffers)
216{
217 std::vector<std::pair<std::string, RooAbsData const *>> datasets;
218 std::vector<bool> isBinnedL;
219 bool splitRange = false;
220 std::vector<std::unique_ptr<RooAbsData>> splitDataSets;
221
222 if (simPdf) {
223 std::unique_ptr<TList> splits{data.split(*simPdf, true)};
224 for (auto *d : static_range_cast<RooAbsData *>(*splits)) {
225 RooAbsPdf *simComponent = simPdf->getPdf(d->GetName());
226 // If there is no PDF for that component, we also don't need to fill the data
227 if (!simComponent) {
228 continue;
229 }
230 datasets.emplace_back(std::string("_") + d->GetName() + "_", d);
231 isBinnedL.emplace_back(simComponent->getAttribute("BinnedLikelihoodActive"));
232 // The dataset need to be kept alive because the datamap points to their content
233 splitDataSets.emplace_back(d);
234 }
235 splitRange = simPdf->getAttribute("SplitRange");
236 } else {
237 datasets.emplace_back("", &data);
238 isBinnedL.emplace_back(false);
239 }
240
241 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
242
243 for (std::size_t iData = 0; iData < datasets.size(); ++iData) {
244 auto const &toAdd = datasets[iData];
245 auto spans = getSingleDataSpans(
246 *toAdd.second, RooHelpers::getRangeNameForSimComponent(rangeName, splitRange, toAdd.second->GetName()),
247 toAdd.first, buffers, skipZeroWeights && !isBinnedL[iData]);
248 for (auto const &item : spans) {
249 dataSpans.insert(item);
250 }
251 }
252
253 if (takeGlobalObservablesFromData && data.getGlobalObservables()) {
254 buffers.emplace();
255 auto &buffer = buffers.top();
256 buffer.reserve(data.getGlobalObservables()->size());
257 for (auto *arg : static_range_cast<RooRealVar const *>(*data.getGlobalObservables())) {
258 buffer.push_back(arg->getVal());
259 assignSpan(dataSpans[arg], {&buffer.back(), 1});
260 }
261 }
262
263 return dataSpans;
264}
265
266////////////////////////////////////////////////////////////////////////////////
267/// Figure out the output size for each node in the computation graph that
268/// leads up to the top node, given some vector data as an input. The input
269/// data spans are in general not of the same size, for example in the case of
270/// a simultaneous fit.
271///
272/// \return A `std::map` with output sizes for each node in the computation graph.
273/// \param[in] topNode The top node of the computation graph.
274/// \param[in] inputSizeFunc A function to get the input sizes.
275std::map<RooFit::Detail::DataKey, std::size_t> RooFit::BatchModeDataHelpers::determineOutputSizes(
276 RooAbsArg const &topNode, std::function<std::size_t(RooFit::Detail::DataKey)> const &inputSizeFunc)
277{
278 std::map<RooFit::Detail::DataKey, std::size_t> output;
279
280 RooArgSet serverSet;
281 RooHelpers::getSortedComputationGraph(topNode, serverSet);
282
283 for (RooAbsArg *arg : serverSet) {
284 std::size_t inputSize = inputSizeFunc(arg);
285 if (inputSize > 0) {
286 output[arg] = inputSize;
287 }
288 }
289
290 for (RooAbsArg *arg : serverSet) {
291 std::size_t size = 1;
292 if (output.find(arg) != output.end()) {
293 continue;
294 }
295 if (!arg->isReducerNode()) {
296 for (RooAbsArg *server : arg->servers()) {
297 if (server->isValueServer(*arg)) {
298 size = std::max(output.at(server), size);
299 }
300 }
301 }
302 output[arg] = size;
303 }
304
305 return output;
306}
307
308/// \endcond
#define d(i)
Definition RSha256.hxx:102
ROOT::RRangeCast< T, true, Range_t > dynamic_range_cast(Range_t &&coll)
ROOT::RRangeCast< T, false, Range_t > static_range_cast(Range_t &&coll)
static void retrieve(const gsl_integration_workspace *workspace, double *a, double *b, double *r, double *e)
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:79
bool getAttribute(const Text_t *name) const
Check if a named attribute is set. By default, all attributes are unset.
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract interface for all probability density functions.
Definition RooAbsPdf.h:40
RooAbsRealLValue is the common abstract base class for objects that represent a real value that may a...
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
static constexpr const char * weightVarName
static constexpr const char * weightVarNameSumW2
static RooNameReg & instance()
Return reference to singleton instance.
RooRealVar represents a variable that can be changed from the outside.
Definition RooRealVar.h:37
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
RooAbsPdf * getPdf(RooStringView catName) const
Return the p.d.f associated with the given index category name.
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
This file contains a specialised ROOT message handler to test for diagnostic in unit tests.
std::string getRangeNameForSimComponent(std::string const &rangeName, bool splitRange, std::string const &catName)
void getSortedComputationGraph(RooAbsArg const &func, RooArgSet &out)
#define Split(a, ahi, aLo)
Definition triangle.c:4776
static void output()