Logo ROOT  
Reference Guide
HistFactoryModelUtils.cxx
Go to the documentation of this file.
1/**
2 * \ingroup HistFactory
3 */
4
5// A set of utils for navegating HistFactory models
6#include <stdexcept>
7#include <typeinfo>
8
10#include "RooAbsPdf.h"
11#include "RooArgSet.h"
12#include "RooArgList.h"
13#include "RooSimultaneous.h"
14#include "RooCategory.h"
15#include "RooRealVar.h"
16#include "RooProdPdf.h"
17#include "TH1.h"
18
20
21namespace RooStats{
22namespace HistFactory{
23
24
25 std::string channelNameFromPdf( RooAbsPdf* channelPdf ) {
26 std::string channelPdfName = channelPdf->GetName();
27 std::string ChannelName = channelPdfName.substr(6, channelPdfName.size() );
28 return ChannelName;
29 }
30
32
33 bool verbose=false;
34
35 if(verbose) std::cout << "Getting the RooRealSumPdf for the channel: "
36 << sim_channel->GetName() << std::endl;
37
38 std::string channelPdfName = sim_channel->GetName();
39 std::string ChannelName = channelPdfName.substr(6, channelPdfName.size() );
40
41 // Now, get the RooRealSumPdf
42 // ie the channel WITHOUT constraints
43 std::string realSumPdfName = ChannelName + "_model";
44
45 RooAbsPdf* sum_pdf = nullptr;
46 bool FoundSumPdf=false;
47 for (auto *sum_pdf_arg : *sim_channel->getComponents()) {
48 std::string NodeClassName = sum_pdf_arg->ClassName();
49 if( NodeClassName == std::string("RooRealSumPdf") ) {
50 FoundSumPdf=true;
51 sum_pdf = static_cast<RooAbsPdf*>(sum_pdf_arg);
52 break;
53 }
54 }
55 if( ! FoundSumPdf ) {
56 if(verbose) {
57 std::cout << "Failed to find RooRealSumPdf for channel: " << sim_channel->GetName() << std::endl;
58 sim_channel->getComponents()->Print("V");
59 }
60 sum_pdf=nullptr;
61 //throw std::runtime_error("Failed to find RooRealSumPdf for channel");
62 }
63 else {
64 if(verbose) std::cout << "Found RooRealSumPdf: " << sum_pdf->GetName() << std::endl;
65 }
66
67 return sum_pdf;
68
69 }
70
71
72 void FactorizeHistFactoryPdf(const RooArgSet &observables, RooAbsPdf &pdf, RooArgList &obsTerms, RooArgList &constraints) {
73 // utility function to factorize constraint terms from a pdf
74 // (from G. Petrucciani)
75 const std::type_info & id = typeid(pdf);
76 if (id == typeid(RooProdPdf)) {
77 RooProdPdf *prod = dynamic_cast<RooProdPdf *>(&pdf);
78 RooArgList list(prod->pdfList());
79 for (int i = 0, n = list.getSize(); i < n; ++i) {
80 RooAbsPdf *pdfi = (RooAbsPdf *) list.at(i);
81 FactorizeHistFactoryPdf(observables, *pdfi, obsTerms, constraints);
82 }
83 } else if (id == typeid(RooSimultaneous)) { //|| id == typeid(RooSimultaneousOpt)) {
84 RooSimultaneous *sim = dynamic_cast<RooSimultaneous *>(&pdf);
85 std::unique_ptr<RooAbsCategoryLValue> cat{static_cast<RooAbsCategoryLValue *>(sim->indexCat().Clone())};
86 for (int ic = 0, nc = cat->numBins((const char *)0); ic < nc; ++ic) {
87 cat->setBin(ic);
88 FactorizeHistFactoryPdf(observables, *sim->getPdf(cat->getCurrentLabel()), obsTerms, constraints);
89 }
90 } else if (pdf.dependsOn(observables)) {
91 if (!obsTerms.contains(pdf)) obsTerms.add(pdf);
92 } else {
93 if (!constraints.contains(pdf)) constraints.add(pdf);
94 }
95 }
96
97
98 bool getStatUncertaintyFromChannel( RooAbsPdf* channel, ParamHistFunc*& paramfunc, RooArgList* gammaList ) {
99
100 bool verbose=false;
101
102 // Find the servers of this channel
103 bool FoundParamHistFunc=false;
104 for( auto *paramfunc_arg : *channel->getComponents() ) {
105 std::string NodeName = paramfunc_arg->GetName();
106 std::string NodeClassName = paramfunc_arg->ClassName();
107 if( NodeClassName != std::string("ParamHistFunc") ) continue;
108 if( NodeName.find("mc_stat_") != std::string::npos ) {
109 FoundParamHistFunc=true;
110 paramfunc = static_cast<ParamHistFunc*>(paramfunc_arg);
111 break;
112 }
113 }
114 if( ! FoundParamHistFunc || !paramfunc ) {
115 if(verbose) std::cout << "Failed to find ParamHistFunc for channel: " << channel->GetName() << std::endl;
116 return false;
117 }
118
119 // Now, get the set of gamma's
120 gammaList = (RooArgList*) &( paramfunc->paramList());
121 if(verbose) gammaList->Print("V");
122
123 return true;
124
125 }
126
127
128 void getDataValuesForObservables( std::map< std::string, std::vector<double> >& ChannelBinDataMap,
129 RooAbsData* data, RooAbsPdf* pdf ) {
130
131 bool verbose=false;
132
133 RooSimultaneous* simPdf = (RooSimultaneous*) pdf;
134
135 // get category label
136 RooCategory* cat = nullptr;
137 for (auto* temp : *data->get()) {
138 if( strcmp(temp->ClassName(),"RooCategory")==0){
139 cat = static_cast<RooCategory*>(temp);
140 break;
141 }
142 }
143 if(verbose) {
144 if(!cat) std::cout <<"didn't find category"<< std::endl;
145 else std::cout <<"found category"<< std::endl;
146 }
147
148 if (!cat) {
149 std::cerr <<"Category not found"<< std::endl;
150 return;
151 }
152
153 // split dataset
154 std::unique_ptr<TList> dataByCategory{data->split(*cat)};
155 if(verbose) dataByCategory->Print();
156 // note :
157 // RooAbsData* dataForChan = (RooAbsData*) dataByCategory->FindObject("");
158
159 // loop over channels
160 RooCategory* channelCat = (RooCategory*) (&simPdf->indexCat());
161 for (const auto& nameIdx : *channelCat) {
162
163 // Get pdf associated with state from simpdf
164 RooAbsPdf* pdftmp = simPdf->getPdf(nameIdx.first.c_str());
165
166 std::string ChannelName = pdftmp->GetName(); //tt->GetName();
167 if(verbose) std::cout << "Getting data for channel: " << ChannelName << std::endl;
168 ChannelBinDataMap[ ChannelName ] = std::vector<double>();
169
170 RooAbsData* dataForChan = (RooAbsData*) dataByCategory->FindObject(nameIdx.first.c_str());
171 if(verbose) dataForChan->Print();
172
173 // Generate observables defined by the pdf associated with this state
174 RooArgSet* obstmp = pdftmp->getObservables(*dataForChan->get()) ;
175 RooRealVar* obs = ((RooRealVar*)obstmp->first());
176 if(verbose) obs->Print();
177
178 //double expected = pdftmp->expectedEvents(*obstmp);
179
180 // set value to desired value (this is just an example)
181 // double obsVal = obs->getVal();
182 // set obs to desired value of observable
183 // obs->setVal( obsVal );
184 //double fracAtObsValue = pdftmp->getVal(*obstmp);
185
186 // get num events expected in bin for obsVal
187 // double nu = expected * fracAtObsValue;
188
189 // multidimensional way to get n
190 // credit goes to P. Hamilton
191 for (int i = 0; i < dataForChan->numEntries(); i++) {
192 const RooArgSet *tmpargs = dataForChan->get(i);
193 if (verbose)
194 tmpargs->Print();
195 const double n = dataForChan->weight();
196 if (verbose)
197 std::cout << "n" << i << " = " << n << std::endl;
198 ChannelBinDataMap[ChannelName].push_back(n);
199 }
200
201 } // End Loop Over Categories
202
203 dataByCategory->Delete();
204 }
205
206
208 RooAbsReal*& pois_nom, RooRealVar*& tau ) {
209 // Given a set of constraint terms,
210 // find the poisson constraint for the
211 // given gamma and return the mean
212 // as well as the 'tau' parameter
213
214 bool verbose=false;
215
216 // To get the constraint term, loop over all constraint terms
217 // and look for the gamma_stat name as well as '_constraint'
218
219 bool FoundConstraintTerm=false;
220 RooAbsPdf* constraintTerm=nullptr;
221 for (auto *term_constr : *constraints) {
222 std::string TermName = term_constr->GetName();
223 if( term_constr->dependsOn( *gamma_stat) ) {
224 if( TermName.find("_constraint")!=std::string::npos ) {
225 FoundConstraintTerm=true;
226 constraintTerm = static_cast<RooAbsPdf*>(term_constr);
227 break;
228 }
229 }
230 }
231 if( FoundConstraintTerm==false ) {
232 std::cout << "Error: Couldn't find constraint term for parameter: " << gamma_stat->GetName()
233 << " among constraints: " << constraints->GetName() << std::endl;
234 constraints->Print("V");
235 throw std::runtime_error("Failed to find Gamma ConstraintTerm");
236 return -1;
237 }
238
239 /*
240 RooAbsPdf* constraintTerm = (RooAbsPdf*) constraints->find( constraintTermName.c_str() );
241 if( constraintTerm == nullptr ) {
242 std::cout << "Error: Couldn't find constraint term: " << constraintTermName
243 << " for parameter: " << gamma_stat->GetName()
244 << std::endl;
245 throw std::runtime_error("Failed to find Gamma ConstraintTerm");
246 return -1;
247 }
248 */
249
250 // Find the "data" of the poisson term
251 // This is the nominal value
252 bool FoundNomMean=false;
253 for (RooAbsArg * term_pois : constraintTerm->servers()) {
254 std::string serverName = term_pois->GetName();
255 //std::cout << "Checking Server: " << serverName << std::endl;
256 if( serverName.find("nom_")!=std::string::npos ) {
257 FoundNomMean = true;
258 pois_nom = (RooRealVar*) term_pois;
259 }
260 }
261 if( !FoundNomMean || !pois_nom ) {
262 std::cout << "Error: Did not find Nominal Pois Mean parameter in gamma constraint term PoissonMean: "
263 << constraintTerm->GetName() << std::endl;
264 throw std::runtime_error("Failed to find Nom Pois Mean");
265 }
266 else {
267 if(verbose) std::cout << "Found Poisson 'data' term: " << pois_nom->GetName() << std::endl;
268 }
269
270 // Taking the constraint term (a Poisson), find
271 // the "mean" which is the product: gamma*tau
272 // Then, from that mean, find tau
273 RooAbsArg * pois_mean_arg = nullptr;
274 for (RooAbsArg * arg : constraintTerm->servers()) {
275 if( arg->dependsOn( *gamma_stat ) ) {
276 pois_mean_arg = arg;
277 break;
278 }
279 }
280 if( !pois_mean_arg ) {
281 std::cout << "Error: Did not find PoissonMean parameter in gamma constraint term: "
282 << constraintTerm->GetName() << std::endl;
283 throw std::runtime_error("Failed to find PoissonMean");
284 return -1;
285 }
286 else {
287 if(verbose) std::cout << "Found Poisson 'mean' term: " << pois_mean_arg->GetName() << std::endl;
288 }
289
290 bool FoundTau=false;
291 for(RooAbsArg* term_in_product : pois_mean_arg->servers()) {
292 std::string serverName = term_in_product->GetName();
293 //std::cout << "Checking Server: " << serverName << std::endl;
294 if( serverName.find("_tau")!=std::string::npos ) {
295 FoundTau = true;
296 tau = (RooRealVar*) term_in_product;
297 }
298 }
299 if( !FoundTau || !tau ) {
300 std::cout << "Error: Did not find Tau parameter in gamma constraint term PoissonMean: "
301 << pois_mean_arg->GetName() << std::endl;
302 throw std::runtime_error("Failed to find Tau");
303 }
304 else {
305 if(verbose) std::cout << "Found Poisson 'tau' term: " << tau->GetName() << std::endl;
306 }
307
308 return 0;
309
310 }
311
312
313
314} // close RooStats namespace
315} // close HistFactory namespace
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
A class which maps the current values of a RooRealVar (or a set of RooRealVars) to one of a number of...
Definition: ParamHistFunc.h:24
const RooArgList & paramList() const
Definition: ParamHistFunc.h:34
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition: RooAbsArg.h:71
void Print(Option_t *options=nullptr) const override
Print the object to the defaultPrintStream().
Definition: RooAbsArg.h:321
bool dependsOn(const RooAbsCollection &serverList, const RooAbsArg *ignoreArg=nullptr, bool valueOnly=false) const
Test whether we depend on (ie, are served by) any object in the specified collection.
Definition: RooAbsArg.cxx:805
RooArgSet * getObservables(const RooArgSet &set, bool valueOnly=true) const
Given a set of possible observables, return the observables that this PDF depends on.
Definition: RooAbsArg.h:293
const RefCountList_t & servers() const
List of all servers of this object.
Definition: RooAbsArg.h:198
TObject * Clone(const char *newname=nullptr) const override
Make a clone of an object using the Streamer facility.
Definition: RooAbsArg.h:83
RooArgSet * getComponents() const
Create a RooArgSet with all components (branch nodes) of the expression tree headed by this object.
Definition: RooAbsArg.cxx:754
RooAbsCategoryLValue is the common abstract base class for objects that represent a discrete value th...
void setBin(Int_t ibin, const char *rangeName=nullptr) override
Set category to i-th fit bin, which is the i-th registered state.
bool contains(const RooAbsArg &var) const
Check if collection contains an argument with the same name as var.
Int_t getSize() const
Return the number of elements in the collection.
const char * GetName() const override
Returns name of object.
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
RooAbsArg * first() const
void Print(Option_t *options=nullptr) const override
This method must be overridden when a class wants to print itself.
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition: RooAbsData.h:62
virtual double weight() const =0
virtual const RooArgSet * get() const
Definition: RooAbsData.h:106
void Print(Option_t *options=nullptr) const override
Print TNamed name and title.
Definition: RooAbsData.h:239
virtual Int_t numEntries() const
Return number of entries in dataset, i.e., count unweighted entries.
Definition: RooAbsData.cxx:374
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition: RooAbsReal.h:62
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgList.h:22
RooAbsArg * at(Int_t idx) const
Return object at given index, or nullptr if index is out of range.
Definition: RooArgList.h:110
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgSet.h:56
RooCategory is an object to represent discrete states.
Definition: RooCategory.h:28
RooProdPdf is an efficient implementation of a product of PDFs of the form.
Definition: RooProdPdf.h:33
const RooArgList & pdfList() const
Definition: RooProdPdf.h:68
RooRealVar represents a variable that can be changed from the outside.
Definition: RooRealVar.h:40
RooSimultaneous facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
const RooAbsCategoryLValue & indexCat() const
RooAbsPdf * getPdf(const char *catName) const
Return the p.d.f associated with the given index category name.
const char * GetName() const override
Returns name of object.
Definition: TNamed.h:47
virtual const char * ClassName() const
Returns name of class to which the object belongs.
Definition: TObject.cxx:207
virtual TObject * FindObject(const char *name) const
Must be redefined in derived classes.
Definition: TObject.cxx:404
bool getStatUncertaintyFromChannel(RooAbsPdf *channel, ParamHistFunc *&paramfunc, RooArgList *gammaList)
void FactorizeHistFactoryPdf(const RooArgSet &, RooAbsPdf &, RooArgList &, RooArgList &)
void getDataValuesForObservables(std::map< std::string, std::vector< double > > &ChannelBinDataMap, RooAbsData *data, RooAbsPdf *simPdf)
RooAbsPdf * getSumPdfFromChannel(RooAbsPdf *channel)
int getStatUncertaintyConstraintTerm(RooArgList *constraints, RooRealVar *gamma_stat, RooAbsReal *&pois_mean, RooRealVar *&tau)
std::string channelNameFromPdf(RooAbsPdf *channelPdf)
const Int_t n
Definition: legend1.C:16
@ HistFactory
Definition: RooGlobalFunc.h:63
Namespace for the RooStats classes.
Definition: Asimov.h:19