Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ConstantTermsOptimizer.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * PB, Patrick Bos, Netherlands eScience Center, p.bos@esciencecenter.nl
5 *
6 * Copyright (c) 2021, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14
15#include <RooMsgService.h>
16#include <RooVectorDataStore.h> // complete type for dynamic cast
17#include <RooAbsReal.h>
18#include <RooArgSet.h>
19#include <RooAbsData.h>
20
21namespace RooFit {
22namespace TestStatistics {
23
24/** \class ConstantTermsOptimizer
25 *
26 * \brief Analyzes a function given a dataset/observables for constant terms and caches those in the dataset
27 *
28 * This optimizer should be used on a consistent combination of function (usually a pdf) and a dataset with observables.
29 * It then analyzes the function to find parts that can be precalculated because they are constant given the set of
30 * observables. These are cached inside the dataset and used in subsequent evaluations of the function on that dataset.
31 * The typical use case for this is inside likelihood minimization where many calls of the same pdf/dataset combination
32 * are made. \p norm_set must provide the normalization set of the function, which would typically be the set of
33 * observables in the dataset; this is used to make sure all object caches are created before analysis by evaluating the
34 * function on this set at the beginning of enableConstantTermsOptimization.
35 */
36
38{
39 // TODO: the RooAbsOptTestStatistics::requiredExtraObservables() call this code was copied
40 // from was overloaded for RooXYChi2Var only; implement different options when necessary
41 return RooArgSet();
42}
43
45 RooAbsData *dataset, bool applyTrackingOpt)
46{
47 // Trigger create of all object caches now in nodes that have deferred object creation
48 // so that cache contents can be processed immediately
49 function->getVal(norm_set);
50
51 // Apply tracking optimization here. Default strategy is to track components
52 // of RooAddPdfs and RooRealSumPdfs. If these components are a RooProdPdf
53 // or a RooProduct respectively, track the components of these products instead
54 // of the product term
55 RooArgSet trackNodes;
56
57 // Add safety check here - applyTrackingOpt will only be applied if present
58 // dataset is constructed in terms of a RooVectorDataStore
59 if (applyTrackingOpt) {
60 if (!dynamic_cast<RooVectorDataStore *>(dataset->store())) {
61 oocoutW(nullptr, Optimization)
62 << "enableConstantTermsOptimization(function: " << function->GetName()
63 << ", dataset: " << dataset->GetName()
64 << ") WARNING Cache-and-track optimization (Optimize level 2) is only available for datasets"
65 << " implemented in terms of RooVectorDataStore - ignoring this option for current dataset" << std::endl;
66 applyTrackingOpt = false;
67 }
68 }
69
70 if (applyTrackingOpt) {
71 RooArgSet branches;
72 function->branchNodeServerList(&branches);
73 for (const auto arg : branches) {
74 arg->setCacheAndTrackHints(trackNodes);
75 }
76 // Do not set CacheAndTrack on constant expressions
77 std::unique_ptr<RooArgSet> constNodes{static_cast<RooArgSet *>(trackNodes.selectByAttrib("Constant", true))};
78 trackNodes.remove(*constNodes);
79
80 // Set CacheAndTrack flag on all remaining nodes
81 trackNodes.setAttribAll("CacheAndTrack", true);
82 }
83
84 // Find all nodes that depend exclusively on constant parameters
85 RooArgSet cached_nodes;
86
87 function->findConstantNodes(*dataset->get(), cached_nodes);
88
89 // Cache constant nodes with dataset - also cache entries corresponding to zero-weights in data when using
90 // BinnedLikelihood
91 // NOTE: we pass nullptr as cache-owner here, because we don't use the cacheOwner() anywhere in TestStatistics
92 // TODO: make sure this (nullptr) is always correct
93 dataset->cacheArgs(nullptr, cached_nodes, norm_set, !function->getAttribute("BinnedLikelihood"));
94
95 // Put all cached nodes in AClean value caching mode so that their evaluate() is never called
96 for (const auto cacheArg : cached_nodes) {
97 cacheArg->setOperMode(RooAbsArg::AClean);
98 }
99
100 std::unique_ptr<RooArgSet> constNodes{
101 static_cast<RooArgSet *>(cached_nodes.selectByAttrib("ConstantExpressionCached", true))};
102 RooArgSet actualTrackNodes(cached_nodes);
103 actualTrackNodes.remove(*constNodes);
104 if (!constNodes->empty()) {
105 if (constNodes->size() < 20) {
106 oocoutI(nullptr, Minimization)
107 << " The following expressions have been identified as constant and will be precalculated and cached: "
108 << *constNodes << std::endl;
109 } else {
110 oocoutI(nullptr, Minimization)
111 << " A total of " << constNodes->size()
112 << " expressions have been identified as constant and will be precalculated and cached." << std::endl;
113 }
114 }
115 if (!actualTrackNodes.empty()) {
116 if (actualTrackNodes.size() < 20) {
117 oocoutI(nullptr, Minimization) << " The following expressions will be evaluated in cache-and-track mode: "
118 << actualTrackNodes << std::endl;
119 } else {
120 oocoutI(nullptr, Minimization) << " A total of " << constNodes->size()
121 << " expressions will be evaluated in cache-and-track-mode." << std::endl;
122 }
123 }
124
125 // Disable reading of observables that are no longer used
126 dataset->optimizeReadingWithCaching(*function, cached_nodes, requiredExtraObservables());
127}
128
130 RooAbsData *dataset, RooArgSet *observables)
131{
132 // Delete the cache
133 dataset->resetCache();
134
135 // Reactivate all tree branches
136 dataset->setArgStatus(*dataset->get(), true);
137
138 // Reset all nodes to ADirty
139 optimizeCaching(function, norm_set, dataset, observables);
140
141 // Disable propagation of dirty state flags for observables
142 dataset->setDirtyProp(false);
143
144 // _cachedNodes.removeAll();
145
146 // _optimized = false;
147}
148
150 RooArgSet *observables)
151{
152 // Trigger create of all object caches now in nodes that have deferred object creation
153 // so that cache contents can be processed immediately
154 function->getVal(norm_set);
155
156 // Set value caching mode for all nodes that depend on any of the observables to ADirty
157 std::unique_ptr<RooArgSet> ownedObservables;
158 if (observables == nullptr) {
159 ownedObservables = std::unique_ptr<RooArgSet>{function->getObservables(dataset)};
160 observables = ownedObservables.get();
161 }
162 function->optimizeCacheMode(*observables);
163
164 // Disable propagation of dirty state flags for observables
165 dataset->setDirtyProp(false);
166
167 // Disable reading of observables that are not used
168 dataset->optimizeReadingWithCaching(*function, RooArgSet(), requiredExtraObservables());
169}
170
171} // namespace TestStatistics
172} // namespace RooFit
#define oocoutW(o, a)
#define oocoutI(o, a)
RooAbsCollection * selectByAttrib(const char *name, bool value) const
Create a subset of the current collection, consisting only of those elements with the specified attri...
virtual bool remove(const RooAbsArg &var, bool silent=false, bool matchByNameOnly=false)
Remove the specified argument from our list.
Storage_t const & get() const
Const access to the underlying stl container.
void setAttribAll(const Text_t *name, bool value=true)
Set given attribute in each element of the collection by calling each elements setAttribute() functio...
Storage_t::size_type size() const
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
Uses std::vector to store data columns.
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition CodegenImpl.h:64
static void disableConstantTermsOptimization(RooAbsReal *function, RooArgSet *norm_set, RooAbsData *dataset, RooArgSet *observables=nullptr)
static void optimizeCaching(RooAbsReal *function, RooArgSet *norm_set, RooAbsData *dataset, RooArgSet *observables=nullptr)
static void enableConstantTermsOptimization(RooAbsReal *function, RooArgSet *norm_set, RooAbsData *dataset, bool applyTrackingOpt)