Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ConstantTermsOptimizer.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * PB, Patrick Bos, Netherlands eScience Center, p.bos@esciencecenter.nl
5 *
6 * Copyright (c) 2021, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14
15#include <RooMsgService.h>
16#include <RooVectorDataStore.h> // complete type for dynamic cast
17#include <RooAbsReal.h>
18#include <RooArgSet.h>
19#include <RooAbsData.h>
20
21namespace RooFit {
22namespace TestStatistics {
23
24/** \class ConstantTermsOptimizer
25 *
26 * \brief Analyzes a function given a dataset/observables for constant terms and caches those in the dataset
27 *
28 * This optimizer should be used on a consistent combination of function (usually a pdf) and a dataset with observables.
29 * It then analyzes the function to find parts that can be precalculated because they are constant given the set of
30 * observables. These are cached inside the dataset and used in subsequent evaluations of the function on that dataset.
31 * The typical use case for this is inside likelihood minimization where many calls of the same pdf/dataset combination
32 * are made. \p norm_set must provide the normalization set of the function, which would typically be the set of
33 * observables in the dataset; this is used to make sure all object caches are created before analysis by evaluating the
34 * function on this set at the beginning of enableConstantTermsOptimization.
35 */
36
38{
39 // TODO: the RooAbsOptTestStatistics::requiredExtraObservables() call this code was copied
40 // from was overloaded for RooXYChi2Var only; implement different options when necessary
41 return RooArgSet();
42}
43
45 RooAbsData *dataset, bool applyTrackingOpt)
46{
47 // Trigger create of all object caches now in nodes that have deferred object creation
48 // so that cache contents can be processed immediately
49 function->getVal(norm_set);
50
51 // Apply tracking optimization here. Default strategy is to track components
52 // of RooAddPdfs and RooRealSumPdfs. If these components are a RooProdPdf
53 // or a RooProduct respectively, track the components of these products instead
54 // of the product term
55 RooArgSet trackNodes;
56
57 // Add safety check here - applyTrackingOpt will only be applied if present
58 // dataset is constructed in terms of a RooVectorDataStore
59 if (applyTrackingOpt) {
60 if (!dynamic_cast<RooVectorDataStore *>(dataset->store())) {
61 oocoutW(nullptr, Optimization)
62 << "enableConstantTermsOptimization(function: " << function->GetName()
63 << ", dataset: " << dataset->GetName()
64 << ") WARNING Cache-and-track optimization (Optimize level 2) is only available for datasets"
65 << " implemented in terms of RooVectorDataStore - ignoring this option for current dataset" << std::endl;
66 applyTrackingOpt = false;
67 }
68 }
69
70 if (applyTrackingOpt) {
71 RooArgSet branches;
72 function->branchNodeServerList(&branches);
73 for (const auto arg : branches) {
74 arg->setCacheAndTrackHints(trackNodes);
75 }
76 // Do not set CacheAndTrack on constant expressions
77 std::unique_ptr<RooArgSet> constNodes{trackNodes.selectByAttrib("Constant", true)};
78 trackNodes.remove(*constNodes);
79
80 // Set CacheAndTrack flag on all remaining nodes
81 trackNodes.setAttribAll("CacheAndTrack", true);
82 }
83
84 // Find all nodes that depend exclusively on constant parameters
85 RooArgSet cached_nodes;
86
87 function->findConstantNodes(*dataset->get(), cached_nodes);
88
89 // Cache constant nodes with dataset - also cache entries corresponding to zero-weights in data when using
90 // BinnedLikelihood
91 // NOTE: we pass nullptr as cache-owner here, because we don't use the cacheOwner() anywhere in TestStatistics
92 // TODO: make sure this (nullptr) is always correct
93 dataset->cacheArgs(nullptr, cached_nodes, norm_set, !function->getAttribute("BinnedLikelihood"));
94
95 // Put all cached nodes in AClean value caching mode so that their evaluate() is never called
96 for (const auto cacheArg : cached_nodes) {
97 cacheArg->setOperMode(RooAbsArg::AClean);
98 }
99
100 std::unique_ptr<RooArgSet> constNodes{cached_nodes.selectByAttrib("ConstantExpressionCached", true)};
101 RooArgSet actualTrackNodes(cached_nodes);
102 actualTrackNodes.remove(*constNodes);
103 if (!constNodes->empty()) {
104 if (constNodes->size() < 20) {
105 oocoutI(nullptr, Minimization)
106 << " The following expressions have been identified as constant and will be precalculated and cached: "
107 << *constNodes << std::endl;
108 } else {
109 oocoutI(nullptr, Minimization)
110 << " A total of " << constNodes->size()
111 << " expressions have been identified as constant and will be precalculated and cached." << std::endl;
112 }
113 }
114 if (!actualTrackNodes.empty()) {
115 if (actualTrackNodes.size() < 20) {
116 oocoutI(nullptr, Minimization) << " The following expressions will be evaluated in cache-and-track mode: "
117 << actualTrackNodes << std::endl;
118 } else {
119 oocoutI(nullptr, Minimization) << " A total of " << constNodes->size()
120 << " expressions will be evaluated in cache-and-track-mode." << std::endl;
121 }
122 }
123
124 // Disable reading of observables that are no longer used
125 dataset->optimizeReadingWithCaching(*function, cached_nodes, requiredExtraObservables());
126}
127
129 RooAbsData *dataset, RooArgSet *observables)
130{
131 // Delete the cache
132 dataset->resetCache();
133
134 // Reactivate all tree branches
135 dataset->setArgStatus(*dataset->get(), true);
136
137 // Reset all nodes to ADirty
138 optimizeCaching(function, norm_set, dataset, observables);
139
140 // Disable propagation of dirty state flags for observables
141 dataset->setDirtyProp(false);
142
143 // _cachedNodes.removeAll();
144
145 // _optimized = false;
146}
147
149 RooArgSet *observables)
150{
151 // Trigger create of all object caches now in nodes that have deferred object creation
152 // so that cache contents can be processed immediately
153 function->getVal(norm_set);
154
155 // Set value caching mode for all nodes that depend on any of the observables to ADirty
156 std::unique_ptr<RooArgSet> ownedObservables;
157 if (observables == nullptr) {
158 ownedObservables = std::unique_ptr<RooArgSet>{function->getObservables(dataset)};
159 observables = ownedObservables.get();
160 }
161 function->optimizeCacheMode(*observables);
162
163 // Disable propagation of dirty state flags for observables
164 dataset->setDirtyProp(false);
165
166 // Disable reading of observables that are not used
167 dataset->optimizeReadingWithCaching(*function, RooArgSet(), requiredExtraObservables());
168}
169
170} // namespace TestStatistics
171} // namespace RooFit
#define oocoutW(o, a)
#define oocoutI(o, a)
virtual bool remove(const RooAbsArg &var, bool silent=false, bool matchByNameOnly=false)
Remove the specified argument from our list.
Storage_t const & get() const
Const access to the underlying stl container.
void setAttribAll(const Text_t *name, bool value=true)
Set given attribute in each element of the collection by calling each elements setAttribute() functio...
Storage_t::size_type size() const
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
RooArgSet * selectByAttrib(const char *name, bool value) const
Use RooAbsCollection::selectByAttrib(), but return as RooArgSet.
Definition RooArgSet.h:144
Uses std::vector to store data columns.
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition CodegenImpl.h:64
static void disableConstantTermsOptimization(RooAbsReal *function, RooArgSet *norm_set, RooAbsData *dataset, RooArgSet *observables=nullptr)
static void optimizeCaching(RooAbsReal *function, RooArgSet *norm_set, RooAbsData *dataset, RooArgSet *observables=nullptr)
static void enableConstantTermsOptimization(RooAbsReal *function, RooArgSet *norm_set, RooAbsData *dataset, bool applyTrackingOpt)