Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BatchModeDataHelpers.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14#include <RooAbsData.h>
15#include <RooDataHist.h>
16#include "RooNLLVarNew.h"
17#include <RooRealVar.h>
18#include <RooSimultaneous.h>
19
20#include <ROOT/StringUtils.hxx>
21
22#include <numeric>
23
24namespace {
25
26std::map<RooFit::Detail::DataKey, RooSpan<const double>>
27getSingleDataSpans(RooAbsData const &data, std::string_view rangeName, std::string const &prefix,
28 std::stack<std::vector<double>> &buffers, bool skipZeroWeights)
29{
30 std::map<RooFit::Detail::DataKey, RooSpan<const double>> dataSpans; // output variable
31
32 auto &nameReg = RooNameReg::instance();
33
34 auto insert = [&](const char *key, RooSpan<const double> span) {
35 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
36 dataSpans[namePtr] = span;
37 };
38
39 auto retrieve = [&](const char *key) {
40 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
41 return dataSpans.at(namePtr);
42 };
43
44 std::size_t nEvents = static_cast<size_t>(data.numEntries());
45
46 // We also want to support empty datasets: in this case the
47 // RooFitDriver::Dataset is not filled with anything.
48 if (nEvents == 0) {
49 return dataSpans;
50 }
51
52 auto weight = data.getWeightBatch(0, nEvents, /*sumW2=*/false);
53 auto weightSumW2 = data.getWeightBatch(0, nEvents, /*sumW2=*/true);
54
55 std::vector<bool> hasZeroWeight;
56 hasZeroWeight.resize(nEvents);
57 std::size_t nNonZeroWeight = 0;
58
59 // Add weights to the datamap. They should have the names expected by the
60 // RooNLLVarNew. We also add the sumW2 weights here under a different name,
61 // so we can apply the sumW2 correction by easily swapping the spans.
62 {
63 buffers.emplace();
64 auto &buffer = buffers.top();
65 buffers.emplace();
66 auto &bufferSumW2 = buffers.top();
67 if (weight.empty()) {
68 // If the dataset has no weight, we fill the data spans with a scalar
69 // unity weight so we don't need to check for the existance of weights
70 // later in the likelihood.
71 buffer.push_back(1.0);
72 bufferSumW2.push_back(1.0);
73 weight = RooSpan<const double>(buffer.data(), 1);
74 weightSumW2 = RooSpan<const double>(bufferSumW2.data(), 1);
75 nNonZeroWeight = nEvents;
76 } else {
77 buffer.reserve(nEvents);
78 bufferSumW2.reserve(nEvents);
79 for (std::size_t i = 0; i < nEvents; ++i) {
80 if (!skipZeroWeights || weight[i] != 0) {
81 buffer.push_back(weight[i]);
82 bufferSumW2.push_back(weightSumW2[i]);
83 ++nNonZeroWeight;
84 } else {
85 hasZeroWeight[i] = true;
86 }
87 }
88 weight = RooSpan<const double>(buffer.data(), nNonZeroWeight);
89 weightSumW2 = RooSpan<const double>(bufferSumW2.data(), nNonZeroWeight);
90 }
91 using namespace ROOT::Experimental;
92 insert(RooNLLVarNew::weightVarName, weight);
93 insert(RooNLLVarNew::weightVarNameSumW2, weightSumW2);
94 }
95
96 // Add also bin volume information if we are dealing with a RooDataHist
97 if (auto dataHist = dynamic_cast<RooDataHist const *>(&data)) {
98 buffers.emplace();
99 auto &buffer = buffers.top();
100 buffer.reserve(nNonZeroWeight);
101
102 for (std::size_t i = 0; i < nEvents; ++i) {
103 if (!hasZeroWeight[i]) {
104 buffer.push_back(dataHist->binVolume(i));
105 }
106 }
107
108 insert("_bin_volume", {buffer.data(), buffer.size()});
109 }
110
111 // Get the real-valued batches and cast the also to double branches to put in
112 // the data map
113 for (auto const &item : data.getBatches(0, nEvents)) {
114
115 RooSpan<const double> span{item.second};
116
117 buffers.emplace();
118 auto &buffer = buffers.top();
119 buffer.reserve(nNonZeroWeight);
120
121 for (std::size_t i = 0; i < nEvents; ++i) {
122 if (!hasZeroWeight[i]) {
123 buffer.push_back(span[i]);
124 }
125 }
126 insert(item.first->GetName(), {buffer.data(), buffer.size()});
127 }
128
129 // Get the category batches and cast the also to double branches to put in
130 // the data map
131 for (auto const &item : data.getCategoryBatches(0, nEvents)) {
132
134
135 buffers.emplace();
136 auto &buffer = buffers.top();
137 buffer.reserve(nNonZeroWeight);
138
139 for (std::size_t i = 0; i < nEvents; ++i) {
140 if (!hasZeroWeight[i]) {
141 buffer.push_back(static_cast<double>(intSpan[i]));
142 }
143 }
144 insert(item.first->GetName(), {buffer.data(), buffer.size()});
145 }
146
147 nEvents = nNonZeroWeight;
148
149 // Now we have do do the range selection
150 if (!rangeName.empty()) {
151 // figure out which events are in the range
152 std::vector<bool> isInRange(nEvents, false);
153 for (auto const &range : ROOT::Split(rangeName, ",")) {
154 std::vector<bool> isInSubRange(nEvents, true);
155 for (auto *observable : dynamic_range_cast<RooAbsRealLValue *>(*data.get())) {
156 // If the observables is not real-valued, it will not be considered for the range selection
157 if (observable) {
158 observable->inRange({retrieve(observable->GetName()).data(), nEvents}, range, isInSubRange);
159 }
160 }
161 for (std::size_t i = 0; i < isInSubRange.size(); ++i) {
162 isInRange[i] = isInRange[i] || isInSubRange[i];
163 }
164 }
165
166 // reset the number of events
167 nEvents = std::accumulate(isInRange.begin(), isInRange.end(), 0);
168
169 // do the data reduction in the data map
170 for (auto const &item : dataSpans) {
171 auto const &allValues = item.second;
172 if (allValues.size() == 1) {
173 continue;
174 }
175 buffers.emplace(nEvents);
176 double *buffer = buffers.top().data();
177 std::size_t j = 0;
178 for (std::size_t i = 0; i < isInRange.size(); ++i) {
179 if (isInRange[i]) {
180 buffer[j] = allValues[i];
181 ++j;
182 }
183 }
184 dataSpans[item.first] = RooSpan<const double>{buffer, nEvents};
185 }
186 }
187
188 return dataSpans;
189}
190
191} // namespace
192
193////////////////////////////////////////////////////////////////////////////////
194/// Extract all content from a RooFit datasets as a map of spans.
195/// Spans with the weights and squared weights will be also stored in the map,
196/// keyed with the names `_weight` and the `_weight_sumW2`. If the dataset is
197/// unweighted, these weight spans will only contain the single value `1.0`.
198/// Entries with zero weight will be skipped. If the input dataset is a
199/// RooDataHist, the output map will also contain an item for the key
200/// `_bin_volume` with the bin volumes.
201///
202/// \return A `std::map` with spans keyed to name pointers.
203/// \param[in] data The input dataset.
204/// \param[in] rangeName Select only entries from the data in a given range
205/// (empty string for no range).
206/// \param[in] simPdf A simultaneous pdf to use as a guide for splitting the
207/// dataset. The spans from each channel data will be prefixed with
208/// the channel name.
209/// \param[in] skipZeroWeights Skip entries with zero weight when filling the
210/// data spans. Be very careful with enabling it, because the user
211/// might not expect that the batch results are not aligned with the
212/// original dataset anymore!
213/// \param[in] takeGlobalObservablesFromData Take also the global observables
214/// stored in the dataset.
215/// \param[in] buffers Pass here an empty stack of `double` vectors, which will
216/// be used as memory for the data if the memory in the dataset
217/// object can't be used directly (e.g. because you used the range
218/// selection or the splitting by categories).
219std::map<RooFit::Detail::DataKey, RooSpan<const double>>
220RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string const &rangeName,
221 RooSimultaneous const *simPdf, bool skipZeroWeights,
222 bool takeGlobalObservablesFromData, std::stack<std::vector<double>> &buffers)
223{
224 std::vector<std::pair<std::string, RooAbsData const *>> datas;
225 std::vector<bool> isBinnedL;
226 bool splitRange = false;
227 std::vector<std::unique_ptr<RooAbsData>> splittedDataSets;
228
229 if (simPdf) {
230 std::unique_ptr<TList> splits{data.split(*simPdf, true)};
231 for (auto *d : static_range_cast<RooAbsData *>(*splits)) {
232 RooAbsPdf *simComponent = simPdf->getPdf(d->GetName());
233 // If there is no PDF for that component, we also don't need to fill the data
234 if (!simComponent) {
235 continue;
236 }
237 datas.emplace_back(std::string("_") + d->GetName() + "_", d);
238 isBinnedL.emplace_back(simComponent->getAttribute("BinnedLikelihoodActive"));
239 // The dataset need to be kept alive because the datamap points to their content
240 splittedDataSets.emplace_back(d);
241 }
242 splitRange = simPdf->getAttribute("SplitRange");
243 } else {
244 datas.emplace_back("", &data);
245 isBinnedL.emplace_back(false);
246 }
247
248 std::map<RooFit::Detail::DataKey, RooSpan<const double>> dataSpans; // output variable
249
250 for (std::size_t iData = 0; iData < datas.size(); ++iData) {
251 auto const &toAdd = datas[iData];
252 auto spans = getSingleDataSpans(
253 *toAdd.second, RooHelpers::getRangeNameForSimComponent(rangeName, splitRange, toAdd.second->GetName()),
254 toAdd.first, buffers, skipZeroWeights && !isBinnedL[iData]);
255 for (auto const &item : spans) {
256 dataSpans.insert(item);
257 }
258 }
259
260 if (takeGlobalObservablesFromData && data.getGlobalObservables()) {
261 buffers.emplace();
262 auto &buffer = buffers.top();
263 buffer.reserve(data.getGlobalObservables()->size());
264 for (auto *arg : static_range_cast<RooRealVar const *>(*data.getGlobalObservables())) {
265 buffer.push_back(arg->getVal());
266 dataSpans[arg] = RooSpan<const double>{&buffer.back(), 1};
267 }
268 }
269
270 return dataSpans;
271}
#define d(i)
Definition RSha256.hxx:102
ROOT::RRangeCast< T, true, Range_t > dynamic_range_cast(Range_t &&coll)
static void retrieve(const gsl_integration_workspace *workspace, double *a, double *b, double *r, double *e)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
bool getAttribute(const Text_t *name) const
Check if a named attribute is set. By default, all attributes are unset.
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:59
RooAbsRealLValue is the common abstract base class for objects that represent a real value that may a...
The RooDataHist is a container class to hold N-dimensional binned data.
Definition RooDataHist.h:39
static RooNameReg & instance()
Return reference to singleton instance.
RooSimultaneous facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
RooAbsPdf * getPdf(RooStringView catName) const
Return the p.d.f associated with the given index category name.
A simple container to hold a batch of data values.
Definition RooSpan.h:34
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
This file contains a specialised ROOT message handler to test for diagnostic in unit tests.
std::map< RooFit::Detail::DataKey, RooSpan< const double > > getDataSpans(RooAbsData const &data, std::string const &rangeName, RooSimultaneous const *simPdf, bool skipZeroWeights, bool takeGlobalObservablesFromData, std::stack< std::vector< double > > &buffers)
Extract all content from a RooFit datasets as a map of spans.
std::string getRangeNameForSimComponent(std::string const &rangeName, bool splitRange, std::string const &catName)
#define Split(a, ahi, alo)
Definition triangle.c:4776