Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BatchModeDataHelpers.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14#include <RooAbsData.h>
15#include <RooHelpers.h>
16#include "RooNLLVarNew.h"
17#include <RooRealVar.h>
18#include <RooSimultaneous.h>
19
20#include <ROOT/StringUtils.hxx>
21
22#include <numeric>
23
24namespace {
25
26// To avoid deleted move assignment.
27template <class T>
28void assignSpan(std::span<T> &to, std::span<T> const &from)
29{
30 to = from;
31}
32
33std::map<RooFit::Detail::DataKey, std::span<const double>>
34getSingleDataSpans(RooAbsData const &data, std::string_view rangeName, std::string const &prefix,
35 std::stack<std::vector<double>> &buffers, bool skipZeroWeights)
36{
37 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
38
39 auto &nameReg = RooNameReg::instance();
40
41 auto insert = [&](const char *key, std::span<const double> span) {
42 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
43 dataSpans[namePtr] = span;
44 };
45
46 auto retrieve = [&](const char *key) {
47 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
48 return dataSpans.at(namePtr);
49 };
50
51 std::size_t nEvents = static_cast<size_t>(data.numEntries());
52
53 // We also want to support empty datasets: in this case the
54 // RooFitDriver::Dataset is not filled with anything.
55 if (nEvents == 0) {
56 return dataSpans;
57 }
58
59 auto weight = data.getWeightBatch(0, nEvents, /*sumW2=*/false);
60 auto weightSumW2 = data.getWeightBatch(0, nEvents, /*sumW2=*/true);
61
62 std::vector<bool> hasZeroWeight;
63 hasZeroWeight.resize(nEvents);
64 std::size_t nNonZeroWeight = 0;
65
66 // Add weights to the datamap. They should have the names expected by the
67 // RooNLLVarNew. We also add the sumW2 weights here under a different name,
68 // so we can apply the sumW2 correction by easily swapping the spans.
69 {
70 buffers.emplace();
71 auto &buffer = buffers.top();
72 buffers.emplace();
73 auto &bufferSumW2 = buffers.top();
74 if (weight.empty()) {
75 // If the dataset has no weight, we fill the data spans with a scalar
76 // unity weight so we don't need to check for the existence of weights
77 // later in the likelihood.
78 buffer.push_back(1.0);
79 bufferSumW2.push_back(1.0);
80 assignSpan(weight, {buffer.data(), 1});
81 assignSpan(weightSumW2, {bufferSumW2.data(), 1});
82 nNonZeroWeight = nEvents;
83 } else {
84 buffer.reserve(nEvents);
85 bufferSumW2.reserve(nEvents);
86 for (std::size_t i = 0; i < nEvents; ++i) {
87 if (!skipZeroWeights || weight[i] != 0) {
88 buffer.push_back(weight[i]);
89 bufferSumW2.push_back(weightSumW2[i]);
90 ++nNonZeroWeight;
91 } else {
92 hasZeroWeight[i] = true;
93 }
94 }
95 assignSpan(weight, {buffer.data(), nNonZeroWeight});
96 assignSpan(weightSumW2, {bufferSumW2.data(), nNonZeroWeight});
97 }
98 insert(RooNLLVarNew::weightVarName, weight);
99 insert(RooNLLVarNew::weightVarNameSumW2, weightSumW2);
100 }
101
102 // Get the real-valued batches and cast the also to double branches to put in
103 // the data map
104 for (auto const &item : data.getBatches(0, nEvents)) {
105
106 std::span<const double> span{item.second};
107
108 buffers.emplace();
109 auto &buffer = buffers.top();
110 buffer.reserve(nNonZeroWeight);
111
112 for (std::size_t i = 0; i < nEvents; ++i) {
113 if (!hasZeroWeight[i]) {
114 buffer.push_back(span[i]);
115 }
116 }
117 insert(item.first->GetName(), {buffer.data(), buffer.size()});
118 }
119
120 // Get the category batches and cast the also to double branches to put in
121 // the data map
122 for (auto const &item : data.getCategoryBatches(0, nEvents)) {
123
124 std::span<const RooAbsCategory::value_type> intSpan{item.second};
125
126 buffers.emplace();
127 auto &buffer = buffers.top();
128 buffer.reserve(nNonZeroWeight);
129
130 for (std::size_t i = 0; i < nEvents; ++i) {
131 if (!hasZeroWeight[i]) {
132 buffer.push_back(static_cast<double>(intSpan[i]));
133 }
134 }
135 insert(item.first->GetName(), {buffer.data(), buffer.size()});
136 }
137
138 nEvents = nNonZeroWeight;
139
140 // Now we have do do the range selection
141 if (!rangeName.empty()) {
142 // figure out which events are in the range
143 std::vector<bool> isInRange(nEvents, false);
144 for (auto const &range : ROOT::Split(rangeName, ",")) {
145 std::vector<bool> isInSubRange(nEvents, true);
146 for (auto *observable : dynamic_range_cast<RooAbsRealLValue *>(*data.get())) {
147 // If the observables is not real-valued, it will not be considered for the range selection
148 if (observable) {
149 observable->inRange({retrieve(observable->GetName()).data(), nEvents}, range, isInSubRange);
150 }
151 }
152 for (std::size_t i = 0; i < isInSubRange.size(); ++i) {
153 isInRange[i] = isInRange[i] || isInSubRange[i];
154 }
155 }
156
157 // reset the number of events
158 nEvents = std::accumulate(isInRange.begin(), isInRange.end(), 0);
159
160 // do the data reduction in the data map
161 for (auto const &item : dataSpans) {
162 auto const &allValues = item.second;
163 if (allValues.size() == 1) {
164 continue;
165 }
166 buffers.emplace(nEvents);
167 double *buffer = buffers.top().data();
168 std::size_t j = 0;
169 for (std::size_t i = 0; i < isInRange.size(); ++i) {
170 if (isInRange[i]) {
171 buffer[j] = allValues[i];
172 ++j;
173 }
174 }
175 assignSpan(dataSpans[item.first], {buffer, nEvents});
176 }
177 }
178
179 return dataSpans;
180}
181
182} // namespace
183
184////////////////////////////////////////////////////////////////////////////////
185/// Extract all content from a RooFit datasets as a map of spans.
186/// Spans with the weights and squared weights will be also stored in the map,
187/// keyed with the names `_weight` and the `_weight_sumW2`. If the dataset is
188/// unweighted, these weight spans will only contain the single value `1.0`.
189/// Entries with zero weight will be skipped.
190///
191/// \return A `std::map` with spans keyed to name pointers.
192/// \param[in] data The input dataset.
193/// \param[in] rangeName Select only entries from the data in a given range
194/// (empty string for no range).
195/// \param[in] simPdf A simultaneous pdf to use as a guide for splitting the
196/// dataset. The spans from each channel data will be prefixed with
197/// the channel name.
198/// \param[in] skipZeroWeights Skip entries with zero weight when filling the
199/// data spans. Be very careful with enabling it, because the user
200/// might not expect that the batch results are not aligned with the
201/// original dataset anymore!
202/// \param[in] takeGlobalObservablesFromData Take also the global observables
203/// stored in the dataset.
204/// \param[in] buffers Pass here an empty stack of `double` vectors, which will
205/// be used as memory for the data if the memory in the dataset
206/// object can't be used directly (e.g. because you used the range
207/// selection or the splitting by categories).
208std::map<RooFit::Detail::DataKey, std::span<const double>>
209RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string const &rangeName,
210 RooSimultaneous const *simPdf, bool skipZeroWeights,
211 bool takeGlobalObservablesFromData, std::stack<std::vector<double>> &buffers)
212{
213 std::vector<std::pair<std::string, RooAbsData const *>> datasets;
214 std::vector<bool> isBinnedL;
215 bool splitRange = false;
216 std::vector<std::unique_ptr<RooAbsData>> splitDataSets;
217
218 if (simPdf) {
219 std::unique_ptr<TList> splits{data.split(*simPdf, true)};
220 for (auto *d : static_range_cast<RooAbsData *>(*splits)) {
221 RooAbsPdf *simComponent = simPdf->getPdf(d->GetName());
222 // If there is no PDF for that component, we also don't need to fill the data
223 if (!simComponent) {
224 continue;
225 }
226 datasets.emplace_back(std::string("_") + d->GetName() + "_", d);
227 isBinnedL.emplace_back(simComponent->getAttribute("BinnedLikelihoodActive"));
228 // The dataset need to be kept alive because the datamap points to their content
229 splitDataSets.emplace_back(d);
230 }
231 splitRange = simPdf->getAttribute("SplitRange");
232 } else {
233 datasets.emplace_back("", &data);
234 isBinnedL.emplace_back(false);
235 }
236
237 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
238
239 for (std::size_t iData = 0; iData < datasets.size(); ++iData) {
240 auto const &toAdd = datasets[iData];
241 auto spans = getSingleDataSpans(
242 *toAdd.second, RooHelpers::getRangeNameForSimComponent(rangeName, splitRange, toAdd.second->GetName()),
243 toAdd.first, buffers, skipZeroWeights && !isBinnedL[iData]);
244 for (auto const &item : spans) {
245 dataSpans.insert(item);
246 }
247 }
248
249 if (takeGlobalObservablesFromData && data.getGlobalObservables()) {
250 buffers.emplace();
251 auto &buffer = buffers.top();
252 buffer.reserve(data.getGlobalObservables()->size());
253 for (auto *arg : static_range_cast<RooRealVar const *>(*data.getGlobalObservables())) {
254 buffer.push_back(arg->getVal());
255 assignSpan(dataSpans[arg], {&buffer.back(), 1});
256 }
257 }
258
259 return dataSpans;
260}
261
262////////////////////////////////////////////////////////////////////////////////
263/// Figure out the output size for each node in the computation graph that
264/// leads up to the top node, given some vector data as an input. The input
265/// data spans are in general not of the same size, for example in the case of
266/// a simultaneous fit.
267///
268/// \return A `std::map` with output sizes for each node in the computation graph.
269/// \param[in] topNode The top node of the computation graph.
270/// \param[in] inputSizeFunc A function to get the input sizes.
271std::map<RooFit::Detail::DataKey, std::size_t> RooFit::BatchModeDataHelpers::determineOutputSizes(
272 RooAbsArg const &topNode, std::function<std::size_t(RooFit::Detail::DataKey)> const &inputSizeFunc)
273{
274 std::map<RooFit::Detail::DataKey, std::size_t> output;
275
276 RooArgSet serverSet;
277 RooHelpers::getSortedComputationGraph(topNode, serverSet);
278
279 for (RooAbsArg *arg : serverSet) {
280 std::size_t inputSize = inputSizeFunc(arg);
281 if (inputSize > 0) {
282 output[arg] = inputSize;
283 }
284 }
285
286 for (RooAbsArg *arg : serverSet) {
287 std::size_t size = 1;
288 if (output.find(arg) != output.end()) {
289 continue;
290 }
291 if (!arg->isReducerNode()) {
292 for (RooAbsArg *server : arg->servers()) {
293 if (server->isValueServer(*arg)) {
294 size = std::max(output.at(server), size);
295 }
296 }
297 }
298 output[arg] = size;
299 }
300
301 return output;
302}
#define d(i)
Definition RSha256.hxx:102
ROOT::RRangeCast< T, true, Range_t > dynamic_range_cast(Range_t &&coll)
static void retrieve(const gsl_integration_workspace *workspace, double *a, double *b, double *r, double *e)
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition RooAbsArg.h:80
bool getAttribute(const Text_t *name) const
Check if a named attribute is set. By default, all attributes are unset.
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
RooAbsRealLValue is the common abstract base class for objects that represent a real value that may a...
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
static constexpr const char * weightVarName
static constexpr const char * weightVarNameSumW2
static RooNameReg & instance()
Return reference to singleton instance.
RooSimultaneous facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
RooAbsPdf * getPdf(RooStringView catName) const
Return the p.d.f associated with the given index category name.
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
This file contains a specialised ROOT message handler to test for diagnostic in unit tests.
std::map< RooFit::Detail::DataKey, std::span< const double > > getDataSpans(RooAbsData const &data, std::string const &rangeName, RooSimultaneous const *simPdf, bool skipZeroWeights, bool takeGlobalObservablesFromData, std::stack< std::vector< double > > &buffers)
Extract all content from a RooFit datasets as a map of spans.
std::map< RooFit::Detail::DataKey, std::size_t > determineOutputSizes(RooAbsArg const &topNode, std::function< std::size_t(RooFit::Detail::DataKey)> const &inputSizeFunc)
Figure out the output size for each node in the computation graph that leads up to the top node,...
std::string getRangeNameForSimComponent(std::string const &rangeName, bool splitRange, std::string const &catName)
void getSortedComputationGraph(RooAbsArg const &func, RooArgSet &out)
Get the topologically-sorted list of all nodes in the computation graph.
#define Split(a, ahi, alo)
Definition triangle.c:4776
static void output()