Logo ROOT  
Reference Guide
NormalizationHelpers.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14
15#include <RooAbsCachedPdf.h>
16#include <RooAbsPdf.h>
17#include <RooAbsReal.h>
18#include <RooAddition.h>
19#include <RooConstraintSum.h>
20#include <RooProdPdf.h>
21
22#include "RooNormalizedPdf.h"
23
24#include <unordered_set>
25
26namespace {
27
29using ServerLists = std::map<DataKey, std::vector<DataKey>>;
30
31class GraphChecker {
32public:
33 GraphChecker(RooAbsArg const &topNode)
34 {
35
36 // To track the RooProdPdfs to figure out which ones are responsible for constraints.
37 std::vector<RooAbsArg *> prodPdfs;
38
39 // Get the list of servers for each node by data key.
40 {
41 RooArgList nodes;
42 topNode.treeNodeServerList(&nodes, nullptr, true, true, false, true);
43 RooArgSet nodesSet{nodes};
44 for (RooAbsArg *node : nodesSet) {
45 if (dynamic_cast<RooProdPdf *>(node)) {
46 prodPdfs.push_back(node);
47 }
48 _serverLists[node];
49 bool isConstraintSum = dynamic_cast<RooConstraintSum const *>(node);
50 for (RooAbsArg *server : node->servers()) {
51 _serverLists[node].push_back(server);
52 if (isConstraintSum)
53 _constraints.insert(server);
54 }
55 }
56 }
57 for (auto &item : _serverLists) {
58 auto &l = item.second;
59 std::sort(l.begin(), l.end());
60 l.erase(std::unique(l.begin(), l.end()), l.end());
61 }
62
63 // Loop over the RooProdPdfs to figure out which ones are responsible for constraints.
64 for (auto *prodPdf : static_range_cast<RooProdPdf *>(prodPdfs)) {
65 std::size_t actualPdfIdx = 0;
66 std::size_t nNonConstraint = 0;
67 for (std::size_t i = 0; i < prodPdf->pdfList().size(); ++i) {
68 RooAbsArg &pdf = prodPdf->pdfList()[i];
69
70 // Heuristic for HistFactory models to find also the constraints
71 // that were not extracted for the RooConstraint sum, e.g. because
72 // they were constant. TODO: fix RooProdPdf such that is also
73 // extracts constraints for which the parameters is set constant.
74 bool isProbablyConstraint = std::string(pdf.GetName()).find("onstrain") != std::string::npos;
75
76 if (_constraints.find(&pdf) == _constraints.end() && !isProbablyConstraint) {
77 actualPdfIdx = i;
78 ++nNonConstraint;
79 }
80 }
81 if (nNonConstraint != prodPdf->pdfList().size()) {
82 if (nNonConstraint != 1) {
83 throw std::runtime_error("A RooProdPdf that multiplies a pdf with constraints should contain only one "
84 "pdf that is not a constraint!");
85 }
86 _prodPdfsWithConstraints[prodPdf] = actualPdfIdx;
87 }
88 }
89 }
90
91 bool dependsOn(DataKey arg, DataKey testArg)
92 {
93
94 std::pair<DataKey, DataKey> p{arg, testArg};
95
96 auto found = _results.find(p);
97 if (found != _results.end())
98 return found->second;
99
100 if (arg == testArg)
101 return true;
102
103 auto const &serverList = _serverLists.at(arg);
104
105 // Next test direct dependence
106 auto foundServer = std::find(serverList.begin(), serverList.end(), testArg);
107 if (foundServer != serverList.end()) {
108 _results.emplace(p, true);
109 return true;
110 }
111
112 // If not, recurse
113 for (auto const &server : serverList) {
114 bool t = dependsOn(server, testArg);
115 _results.emplace(std::pair<DataKey, DataKey>{server, testArg}, t);
116 if (t) {
117 return true;
118 }
119 }
120
121 _results.emplace(p, false);
122 return false;
123 }
124
125 bool isConstraint(DataKey key) const
126 {
127 auto found = _constraints.find(key);
128 return found != _constraints.end();
129 }
130
131 std::unordered_map<RooAbsArg *, std::size_t> const &prodPdfsWithConstraints() const
132 {
133 return _prodPdfsWithConstraints;
134 }
135
136private:
137 std::unordered_set<DataKey> _constraints;
138 std::unordered_map<RooAbsArg *, std::size_t> _prodPdfsWithConstraints;
139 ServerLists _serverLists;
140 std::map<std::pair<DataKey, DataKey>, bool> _results;
141};
142
143void treeNodeServerListAndNormSets(const RooAbsArg &arg, RooAbsCollection &list, RooArgSet const &normSet,
144 std::unordered_map<DataKey, RooArgSet *> &normSets, GraphChecker const &checker)
145{
146 if (normSets.find(&arg) != normSets.end())
147 return;
148
149 list.add(arg, true);
150
151 // normalization sets only need to be added for pdfs
152 if (dynamic_cast<RooAbsPdf const *>(&arg)) {
153 normSets.insert({&arg, new RooArgSet{normSet}});
154 }
155
156 // Recurse if current node is derived
157 if (arg.isDerived() && !arg.isFundamental()) {
158 for (const auto server : arg.servers()) {
159
160 if (!server->isValueServer(arg)) {
161 continue;
162 }
163
164 // If this is a server that is also serving a RooConstraintSum, it
165 // should be skipped because it is not evaluated by this client (e.g.
166 // a RooProdPdf). It was only part of the servers to be extracted for
167 // the constraint sum.
168 if (!dynamic_cast<RooConstraintSum const *>(&arg) && checker.isConstraint(server)) {
169 continue;
170 }
171
172 auto differentSet = arg.fillNormSetForServer(normSet, *server);
173 if (differentSet)
174 differentSet->sort();
175
176 auto &serverNormSet = differentSet ? *differentSet : normSet;
177
178 // Make sure that the server is not already part of the computation
179 // graph with a different normalization set.
180 auto found = normSets.find(server);
181 if (found != normSets.end()) {
182 if (found->second->size() != serverNormSet.size() || !serverNormSet.hasSameLayout(*found->second)) {
183 std::stringstream ss;
184 ss << server->ClassName() << "::" << server->GetName()
185 << " is requested to be evaluated with two different normalization sets in the same model!";
186 ss << " This is not supported yet. The conflicting norm sets are:\n RooArgSet";
187 serverNormSet.printValue(ss);
188 ss << " requested by " << arg.ClassName() << "::" << arg.GetName() << "\n RooArgSet";
189 found->second->printValue(ss);
190 ss << " first requested by other client";
191 auto errMsg = ss.str();
192 oocoutE(server, Minimization) << errMsg << std::endl;
193 throw std::runtime_error(errMsg);
194 }
195 continue;
196 }
197
198 treeNodeServerListAndNormSets(*server, list, serverNormSet, normSets, checker);
199 }
200 }
201}
202
203std::vector<std::unique_ptr<RooAbsArg>> unfoldIntegrals(RooAbsArg const &topNode, RooArgSet const &normSet,
204 std::unordered_map<DataKey, RooArgSet *> &normSets,
205 RooArgSet &replacedArgs)
206{
207 std::vector<std::unique_ptr<RooAbsArg>> newNodes;
208
209 // No normalization set: we don't need to create any integrals
210 if (normSet.empty())
211 return newNodes;
212
213 GraphChecker checker{topNode};
214
215 RooArgSet nodes;
216 // The norm sets are sorted to compare them for equality more easliy
217 RooArgSet normSetSorted{normSet};
218 normSetSorted.sort();
219 treeNodeServerListAndNormSets(topNode, nodes, normSetSorted, normSets, checker);
220
221 // Clean normsets of the variables that the arg does not depend on
222 // std::unordered_map<std::pair<RooAbsArg const*,RooAbsArg const*>,bool> dependsResults;
223 for (auto &item : normSets) {
224 if (!item.second || item.second->empty())
225 continue;
226 auto actualNormSet = new RooArgSet{};
227 for (auto *narg : *item.second) {
228 if (checker.dependsOn(item.first, narg))
229 actualNormSet->add(*narg);
230 }
231 delete item.second;
232 item.second = actualNormSet;
233 }
234
235 // Function to `oldArg` with `newArg` in the computation graph.
236 auto replaceArg = [&](RooAbsArg &newArg, RooAbsArg const &oldArg) {
237 const std::string attrib = std::string("ORIGNAME:") + oldArg.GetName();
238
239 newArg.setAttribute(attrib.c_str());
240 newArg.setStringAttribute("_replaced_arg", oldArg.GetName());
241
242 RooArgList newServerList{newArg};
243
244 RooArgList originalClients;
245 for (auto *client : oldArg.clients()) {
246 originalClients.add(*client);
247 }
248 for (auto *client : originalClients) {
249 if (!nodes.containsInstance(*client))
250 continue;
251 if (dynamic_cast<RooAbsCachedPdf *>(client))
252 continue;
253 client->redirectServers(newServerList, false, true);
254 }
255 replacedArgs.add(oldArg);
256
257 newArg.setAttribute(attrib.c_str(), false);
258 };
259
260 // Replaces the RooProdPdfs that were used to wrap constraints with the actual pdf.
261 for (RooAbsArg *node : nodes) {
262 if (auto prodPdf = dynamic_cast<RooProdPdf *>(node)) {
263 auto found = checker.prodPdfsWithConstraints().find(prodPdf);
264 if (found != checker.prodPdfsWithConstraints().end()) {
265 replaceArg(prodPdf->pdfList()[found->second], *prodPdf);
266 }
267 }
268 }
269
270 // Replace all pdfs that need to be normalized with a pdf wrapper that
271 // applies the right normalization.
272 for (RooAbsArg *node : nodes) {
273 if (auto pdf = dynamic_cast<RooAbsPdf *>(node)) {
274 RooArgSet const &currNormSet = *normSets.at(pdf);
275
276 if (currNormSet.empty())
277 continue;
278
279 // The call to getVal() sets up cached states for this normalization
280 // set, which is important in case this pdf is also used by clients
281 // using the getVal() interface (without this, test 28 in stressRooFit
282 // is failing for example).
283 pdf->getVal(currNormSet);
284
285 if (pdf->selfNormalized() && !dynamic_cast<RooAbsCachedPdf *>(pdf))
286 continue;
287
288 auto normalizedPdf = std::make_unique<RooNormalizedPdf>(*pdf, currNormSet);
289
290 replaceArg(*normalizedPdf, *pdf);
291
292 newNodes.emplace_back(std::move(normalizedPdf));
293 }
294 }
295
296 return newNodes;
297}
298
299void foldIntegrals(RooAbsArg const &topNode, RooArgSet &replacedArgs)
300{
301 RooArgSet nodes;
302 topNode.treeNodeServerList(&nodes);
303
304 for (RooAbsArg *normalizedPdf : nodes) {
305
306 if (auto const &replacedArgName = normalizedPdf->getStringAttribute("_replaced_arg")) {
307
308 auto pdf = &replacedArgs[replacedArgName];
309
310 pdf->setAttribute((std::string("ORIGNAME:") + normalizedPdf->GetName()).c_str());
311
312 RooArgList newServerList{*pdf};
313 for (auto *client : normalizedPdf->clients()) {
314 client->redirectServers(newServerList, false, true);
315 }
316
317 pdf->setAttribute((std::string("ORIGNAME:") + normalizedPdf->GetName()).c_str(), false);
318
319 normalizedPdf->removeStringAttribute("_replaced_arg");
320 }
321 }
322}
323
324} // namespace
325
326/// \class NormalizationIntegralUnfolder
327/// \ingroup Roofitcore
328///
329/// A NormalizationIntegralUnfolder takes the top node of a computation graph
330/// and a normalization set for its constructor. The normalization integrals
331/// for the PDFs in that graph will be created, and placed into the computation
332/// graph itself, rewiring the existing RooAbsArgs. When the unfolder goes out
333/// of scope, all changes to the computation graph will be reverted.
334///
335/// It also performs some other optimizations of the computation graph that are
336/// reverted when the object goes out of scope:
337///
338/// 1. Replacing RooProdPdfs that were used to bring constraints into the
339/// likelihood with the actual pdf that is not a constraint.
340///
341/// Note that for evaluation, the original topNode should not be used anymore,
342/// because if it is a pdf there is now a new normalized pdf wrapping it,
343/// serving as the new top node. This normalized top node can be retreived by
344/// NormalizationIntegralUnfolder::arg().
345
347 : _topNodeWrapper{std::make_unique<RooAddition>("_dummy", "_dummy", RooArgList{topNode})}, _normSetWasEmpty{
348 normSet.empty()}
349{
350 auto ownedArgs = unfoldIntegrals(*_topNodeWrapper, normSet, _normSets, _replacedArgs);
351 for (std::unique_ptr<RooAbsArg> &arg : ownedArgs) {
352 _topNodeWrapper->addOwnedComponents(std::move(arg));
353 }
354 _arg = &static_cast<RooAddition &>(*_topNodeWrapper).list()[0];
355}
356
358{
359 // If there was no normalization set to compile the computation graph for,
360 // we also don't need to fold the integrals back in.
361 if (_normSetWasEmpty)
362 return;
363
364 foldIntegrals(*_topNodeWrapper, _replacedArgs);
365
366 for (auto &item : _normSets) {
367 delete item.second;
368 }
369}
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
#define oocoutE(o, a)
Definition: RooMsgService.h:52
winID h TVirtualViewer3D TVirtualGLPainter p
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition: RooAbsArg.h:71
void setStringAttribute(const Text_t *key, const Text_t *value)
Associate string 'value' to this object under key 'key'.
Definition: RooAbsArg.cxx:278
bool redirectServers(const RooAbsCollection &newServerList, bool mustReplaceAll=false, bool nameChange=false, bool isRecursionStep=false)
Replace all direct servers of this object with the new servers in newServerList.
Definition: RooAbsArg.cxx:999
const RefCountList_t & servers() const
List of all servers of this object.
Definition: RooAbsArg.h:198
virtual bool isDerived() const
Does value or shape of this arg depend on any other arg?
Definition: RooAbsArg.h:91
void setAttribute(const Text_t *name, bool value=true)
Set (default) or clear a named boolean attribute of this object.
Definition: RooAbsArg.cxx:246
virtual bool isFundamental() const
Is this object a fundamental type that can be added to a dataset? Fundamental-type subclasses overrid...
Definition: RooAbsArg.h:241
virtual std::unique_ptr< RooArgSet > fillNormSetForServer(RooArgSet const &normSet, RooAbsArg const &server) const
Fills a RooArgSet to be used as the normalization set for a server, given a normalization set for thi...
Definition: RooAbsArg.cxx:2460
void treeNodeServerList(RooAbsCollection *list, const RooAbsArg *arg=nullptr, bool doBranch=true, bool doLeaf=true, bool valueOnly=false, bool recurseNonDerived=false) const
Fill supplied list with nodes of the arg tree, following all server links, starting with ourself as t...
Definition: RooAbsArg.cxx:499
RooAbsCachedPdf is the abstract base class for p.d.f.s that need or want to cache their evaluate() ou...
RooAbsCollection is an abstract container object that can hold multiple RooAbsArg objects.
bool empty() const
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
void sort(bool reverse=false)
Sort collection using std::sort and name comparison.
RooAbsArg * find(const char *name) const
Find object with given name in list.
RooAddition calculates the sum of a set of RooAbsReal terms, or when constructed with two sets,...
Definition: RooAddition.h:27
const RooArgList & list() const
Definition: RooAddition.h:42
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgSet.h:56
bool containsInstance(const RooAbsArg &var) const override
Check if this exact instance is in this collection.
Definition: RooArgSet.h:170
RooConstraintSum calculates the sum of the -(log) likelihoods of a set of RooAbsPfs that represent co...
To use as a key type for RooFit data maps and containers.
std::unique_ptr< RooAbsArg > _topNodeWrapper
NormalizationIntegralUnfolder(RooAbsArg const &topNode, RooArgSet const &normSet)
std::unordered_map< RooFit::Detail::DataKey, RooArgSet * > _normSets
RooProdPdf is an efficient implementation of a product of PDFs of the form.
Definition: RooProdPdf.h:33
const char * GetName() const override
Returns name of object.
Definition: TNamed.h:47
virtual const char * ClassName() const
Returns name of class to which the object belongs.
Definition: TObject.cxx:130
@ Minimization
Definition: RooGlobalFunc.h:63
TLine l
Definition: textangle.C:4