Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
BatchModeDataHelpers.cxx
Go to the documentation of this file.
1/// \cond ROOFIT_INTERNAL
2
3/*
4 * Project: RooFit
5 * Authors:
6 * Jonas Rembser, CERN 2022
7 *
8 * Copyright (c) 2022, CERN
9 *
10 * Redistribution and use in source and binary forms,
11 * with or without modification, are permitted according to the terms
12 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
13 */
14
16
17#include <RooAbsData.h>
18#include <RooRealVar.h>
19#include <RooSimultaneous.h>
20
21#include "RooFitImplHelpers.h"
23
24#include <ROOT/StringUtils.hxx>
25
26#include <numeric>
27
28namespace {
29
30// To avoid deleted move assignment.
31template <class T>
32void assignSpan(std::span<T> &to, std::span<T> const &from)
33{
34 to = from;
35}
36
37std::map<RooFit::Detail::DataKey, std::span<const double>>
38getSingleDataSpans(RooAbsData const &data, std::string_view rangeName, std::string const &prefix,
39 std::stack<std::vector<double>> &buffers, bool skipZeroWeights)
40{
41 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
42
44
45 auto insert = [&](const char *key, std::span<const double> span) {
46 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
47 dataSpans[namePtr] = span;
48 };
49
50 auto retrieve = [&](const char *key) {
51 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
52 return dataSpans.at(namePtr);
53 };
54
55 std::size_t nEvents = static_cast<size_t>(data.numEntries());
56
57 auto weight = data.getWeightBatch(0, nEvents, /*sumW2=*/false);
58 auto weightSumW2 = data.getWeightBatch(0, nEvents, /*sumW2=*/true);
59
60 std::vector<bool> hasZeroWeight;
61 hasZeroWeight.resize(nEvents);
62 std::size_t nNonZeroWeight = 0;
63
64 // Add weights to the datamap. They should have the names expected by the
65 // RooNLLVarNew. We also add the sumW2 weights here under a different name,
66 // so we can apply the sumW2 correction by easily swapping the spans.
67 {
68 buffers.emplace();
69 auto &buffer = buffers.top();
70 buffers.emplace();
71 auto &bufferSumW2 = buffers.top();
72 buffer.reserve(nEvents);
73 bufferSumW2.reserve(nEvents);
74 for (std::size_t i = 0; i < nEvents; ++i) {
75 if (weight.empty()) {
76 // No weights in the dataset imply a constant weight of one
77 buffer.push_back(1.0);
78 bufferSumW2.push_back(1.0);
80 } else if (!skipZeroWeights || weight[i] != 0) {
81 buffer.push_back(weight[i]);
82 bufferSumW2.push_back(weightSumW2[i]);
84 } else {
85 hasZeroWeight[i] = true;
86 }
87 }
88 assignSpan(weight, {buffer.data(), nNonZeroWeight});
90 insert(RooFit::Detail::RooNLLVarNew::weightVarName, weight);
91 insert(RooFit::Detail::RooNLLVarNew::weightVarNameSumW2, weightSumW2);
92 }
93
94 // Get the real-valued batches and cast the also to double branches to put in
95 // the data map
96 for (auto const &item : data.getBatches(0, nEvents)) {
97
98 std::span<const double> span{item.second};
99
100 buffers.emplace();
101 auto &buffer = buffers.top();
102 buffer.reserve(nNonZeroWeight);
103
104 for (std::size_t i = 0; i < nEvents; ++i) {
105 if (!hasZeroWeight[i]) {
106 buffer.push_back(span[i]);
107 }
108 }
109 insert(item.first->GetName(), {buffer.data(), buffer.size()});
110 }
111
112 // Get the category batches and cast the also to double branches to put in
113 // the data map
114 for (auto const &item : data.getCategoryBatches(0, nEvents)) {
115
116 std::span<const RooAbsCategory::value_type> intSpan{item.second};
117
118 buffers.emplace();
119 auto &buffer = buffers.top();
120 buffer.reserve(nNonZeroWeight);
121
122 for (std::size_t i = 0; i < nEvents; ++i) {
123 if (!hasZeroWeight[i]) {
124 buffer.push_back(static_cast<double>(intSpan[i]));
125 }
126 }
127 insert(item.first->GetName(), {buffer.data(), buffer.size()});
128 }
129
130 nEvents = nNonZeroWeight;
131
132 // Now we have do do the range selection
133 if (!rangeName.empty()) {
134 // figure out which events are in the range
135 std::vector<bool> isInRange(nEvents, false);
136 for (auto const &range : ROOT::Split(rangeName, ",")) {
137 std::vector<bool> isInSubRange(nEvents, true);
138 for (auto *observable : dynamic_range_cast<RooAbsRealLValue *>(*data.get())) {
139 // If the observables is not real-valued, it will not be considered for the range selection
140 if (observable) {
141 observable->inRange({retrieve(observable->GetName()).data(), nEvents}, range, isInSubRange);
142 }
143 }
144 for (std::size_t i = 0; i < isInSubRange.size(); ++i) {
145 isInRange[i] = isInRange[i] || isInSubRange[i];
146 }
147 }
148
149 // reset the number of events
150 nEvents = std::accumulate(isInRange.begin(), isInRange.end(), 0);
151
152 // do the data reduction in the data map
153 for (auto const &item : dataSpans) {
154 auto const &allValues = item.second;
155 if (allValues.size() == 1) {
156 continue;
157 }
158 buffers.emplace(nEvents);
159 double *buffer = buffers.top().data();
160 std::size_t j = 0;
161 for (std::size_t i = 0; i < isInRange.size(); ++i) {
162 if (isInRange[i]) {
163 buffer[j] = allValues[i];
164 ++j;
165 }
166 }
167 assignSpan(dataSpans[item.first], {buffer, nEvents});
168 }
169 }
170
171 return dataSpans;
172}
173
174} // namespace
175
176////////////////////////////////////////////////////////////////////////////////
177/// Extract all content from a RooFit datasets as a map of spans.
178/// Spans with the weights and squared weights will be also stored in the map,
179/// keyed with the names `_weight` and the `_weight_sumW2`. If the dataset is
180/// unweighted, these weight spans will only contain the single value `1.0`.
181/// Entries with zero weight will be skipped.
182///
183/// \return A `std::map` with spans keyed to name pointers.
184/// \param[in] data The input dataset.
185/// \param[in] rangeName Select only entries from the data in a given range
186/// (empty string for no range).
187/// \param[in] simPdf A simultaneous pdf to use as a guide for splitting the
188/// dataset. The spans from each channel data will be prefixed with
189/// the channel name.
190/// \param[in] skipZeroWeights Skip entries with zero weight when filling the
191/// data spans. Be very careful with enabling it, because the user
192/// might not expect that the batch results are not aligned with the
193/// original dataset anymore!
194/// \param[in] takeGlobalObservablesFromData Take also the global observables
195/// stored in the dataset.
196/// \param[in] buffers Pass here an empty stack of `double` vectors, which will
197/// be used as memory for the data if the memory in the dataset
198/// object can't be used directly (e.g. because you used the range
199/// selection or the splitting by categories).
200std::map<RooFit::Detail::DataKey, std::span<const double>>
201RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string const &rangeName,
203 bool takeGlobalObservablesFromData, std::stack<std::vector<double>> &buffers)
204{
205 std::vector<std::pair<std::string, RooAbsData const *>> datasets;
206 std::vector<bool> isBinnedL;
207 bool splitRange = false;
208 std::vector<std::unique_ptr<RooAbsData>> splitDataSets;
209
210 if (simPdf) {
211 std::unique_ptr<TList> splits{data.split(*simPdf, true)};
212 for (auto *d : static_range_cast<RooAbsData *>(*splits)) {
213 RooAbsPdf *simComponent = simPdf->getPdf(d->GetName());
214 // If there is no PDF for that component, we also don't need to fill the data
215 if (!simComponent) {
216 continue;
217 }
218 datasets.emplace_back(std::string("_") + d->GetName() + "_", d);
219 isBinnedL.emplace_back(simComponent->getAttribute("BinnedLikelihoodActive"));
220 // The dataset need to be kept alive because the datamap points to their content
221 splitDataSets.emplace_back(d);
222 }
223 splitRange = simPdf->getAttribute("SplitRange");
224 } else {
225 datasets.emplace_back("", &data);
226 isBinnedL.emplace_back(false);
227 }
228
229 std::map<RooFit::Detail::DataKey, std::span<const double>> dataSpans; // output variable
230
231 for (std::size_t iData = 0; iData < datasets.size(); ++iData) {
232 auto const &toAdd = datasets[iData];
235 toAdd.first, buffers, skipZeroWeights && !isBinnedL[iData]);
236 for (auto const &item : spans) {
237 dataSpans.insert(item);
238 }
239 }
240
241 if (takeGlobalObservablesFromData && data.getGlobalObservables()) {
242 buffers.emplace();
243 auto &buffer = buffers.top();
244 buffer.reserve(data.getGlobalObservables()->size());
245 for (auto *arg : static_range_cast<RooRealVar const *>(*data.getGlobalObservables())) {
246 buffer.push_back(arg->getVal());
247 assignSpan(dataSpans[arg], {&buffer.back(), 1});
248 }
249 }
250
251 return dataSpans;
252}
253
254////////////////////////////////////////////////////////////////////////////////
255/// Figure out the output size for each node in the computation graph that
256/// leads up to the top node, given some vector data as an input. The input
257/// data spans are in general not of the same size, for example in the case of
258/// a simultaneous fit.
259///
260/// \return A `std::map` with output sizes for each node in the computation graph.
261/// \param[in] topNode The top node of the computation graph.
262/// \param[in] inputSizeFunc A function to get the input sizes.
263std::map<RooFit::Detail::DataKey, std::size_t>
264RooFit::BatchModeDataHelpers::determineOutputSizes(RooAbsArg const &topNode,
265 std::function<int(RooFit::Detail::DataKey)> const &inputSizeFunc)
266{
267 std::map<RooFit::Detail::DataKey, std::size_t> output;
268
271
272 for (RooAbsArg *arg : serverSet) {
273 int inputSize = inputSizeFunc(arg);
274 // The size == -1 encodes that the input doesn't come from an array
275 // input.
276 if (inputSize != -1) {
277 output[arg] = inputSize;
278 }
279 }
280
281 for (RooAbsArg *arg : serverSet) {
282 std::size_t size = 1;
283 if (output.find(arg) != output.end()) {
284 continue;
285 }
286 if (!arg->isReducerNode()) {
287 for (RooAbsArg *server : arg->servers()) {
288 if (server->isValueServer(*arg)) {
289 std::size_t inputSize = output.at(server);
290 if (inputSize != 1) {
291 // If the input if from an external array, the output will
292 // adopt its size and we can stop the checking of other
293 // servers.
294 size = inputSize;
295 break;
296 }
297 }
298 }
299 }
300 output[arg] = size;
301 }
302
303 return output;
304}
305
306/// \endcond
#define d(i)
Definition RSha256.hxx:102
ROOT::RRangeCast< T, true, Range_t > dynamic_range_cast(Range_t &&coll)
ROOT::RRangeCast< T, false, Range_t > static_range_cast(Range_t &&coll)
static void retrieve(const gsl_integration_workspace *workspace, double *a, double *b, double *r, double *e)
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:77
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract interface for all probability density functions.
Definition RooAbsPdf.h:40
Abstract base class for objects that represent a real value that may appear on the left hand side of ...
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
static RooNameReg & instance()
Return reference to singleton instance.
Variable that can be changed from the outside.
Definition RooRealVar.h:37
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
tbb::task_arena is an alias of tbb::interface7::task_arena, which doesn't allow to forward declare tb...
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.
std::string getRangeNameForSimComponent(std::string const &rangeName, bool splitRange, std::string const &catName)
void getSortedComputationGraph(RooAbsArg const &func, RooArgSet &out)
static void output()