Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
HistFactoryModelUtils.cxx
Go to the documentation of this file.
1/**
2 * \ingroup HistFactory
3 */
4
5// A set of utils for navegating HistFactory models
6#include <stdexcept>
7#include <typeinfo>
8
10#include "RooAbsPdf.h"
11#include "RooArgSet.h"
12#include "RooArgList.h"
13#include "RooSimultaneous.h"
14#include "RooCategory.h"
15#include "RooRealVar.h"
16#include "RooProdPdf.h"
17#include "TH1.h"
18
20
21namespace RooStats{
22namespace HistFactory{
23
24
26 std::string channelPdfName = channelPdf->GetName();
27 std::string ChannelName = channelPdfName.substr(6, channelPdfName.size() );
28 return ChannelName;
29 }
30
32
33 bool verbose=false;
34
35 if(verbose) std::cout << "Getting the RooRealSumPdf for the channel: "
36 << sim_channel->GetName() << std::endl;
37
38 std::string channelPdfName = sim_channel->GetName();
39 std::string ChannelName = channelPdfName.substr(6, channelPdfName.size() );
40
41 // Now, get the RooRealSumPdf
42 // ie the channel WITHOUT constraints
43 std::string realSumPdfName = ChannelName + "_model";
44
45 RooAbsPdf* sum_pdf = nullptr;
46 bool FoundSumPdf=false;
47 std::unique_ptr<RooArgSet> components{sim_channel->getComponents()};
48 for (auto *sum_pdf_arg : *components) {
49 std::string NodeClassName = sum_pdf_arg->ClassName();
50 if( NodeClassName == std::string("RooRealSumPdf") ) {
51 FoundSumPdf=true;
52 sum_pdf = static_cast<RooAbsPdf*>(sum_pdf_arg);
53 break;
54 }
55 }
56 if( ! FoundSumPdf ) {
57 if(verbose) {
58 std::cout << "Failed to find RooRealSumPdf for channel: " << sim_channel->GetName() << std::endl;
59 sim_channel->getComponents()->Print("V");
60 }
61 sum_pdf=nullptr;
62 //throw std::runtime_error("Failed to find RooRealSumPdf for channel");
63 }
64 else {
65 if(verbose) std::cout << "Found RooRealSumPdf: " << sum_pdf->GetName() << std::endl;
66 }
67
68 return sum_pdf;
69
70 }
71
72
73 void FactorizeHistFactoryPdf(const RooArgSet &observables, RooAbsPdf &pdf, RooArgList &obsTerms, RooArgList &constraints) {
74 // utility function to factorize constraint terms from a pdf
75 // (from G. Petrucciani)
76 const std::type_info & id = typeid(pdf);
77 if (id == typeid(RooProdPdf)) {
78 RooProdPdf *prod = dynamic_cast<RooProdPdf *>(&pdf);
79 RooArgList list(prod->pdfList());
80 for (int i = 0, n = list.size(); i < n; ++i) {
81 RooAbsPdf *pdfi = static_cast<RooAbsPdf *>(list.at(i));
82 FactorizeHistFactoryPdf(observables, *pdfi, obsTerms, constraints);
83 }
84 } else if (id == typeid(RooSimultaneous)) { //|| id == typeid(RooSimultaneousOpt)) {
85 RooSimultaneous *sim = dynamic_cast<RooSimultaneous *>(&pdf);
86 std::unique_ptr<RooAbsCategoryLValue> cat{static_cast<RooAbsCategoryLValue *>(sim->indexCat().Clone())};
87 for (int ic = 0, nc = cat->numBins((const char *)nullptr); ic < nc; ++ic) {
88 cat->setBin(ic);
89 FactorizeHistFactoryPdf(observables, *sim->getPdf(cat->getCurrentLabel()), obsTerms, constraints);
90 }
91 } else if (pdf.dependsOn(observables)) {
92 if (!obsTerms.contains(pdf)) obsTerms.add(pdf);
93 } else {
94 if (!constraints.contains(pdf)) constraints.add(pdf);
95 }
96 }
97
98
100
101 bool verbose=false;
102
103 // Find the servers of this channel
104 bool FoundParamHistFunc=false;
105 std::unique_ptr<RooArgSet> components{channel->getComponents()};
106 for( auto *paramfunc_arg : *components) {
107 std::string NodeName = paramfunc_arg->GetName();
108 std::string NodeClassName = paramfunc_arg->ClassName();
109 if( NodeClassName != std::string("ParamHistFunc") ) continue;
110 if( NodeName.find("mc_stat_") != std::string::npos ) {
112 paramfunc = static_cast<ParamHistFunc*>(paramfunc_arg);
113 break;
114 }
115 }
116 if( ! FoundParamHistFunc || !paramfunc ) {
117 if(verbose) std::cout << "Failed to find ParamHistFunc for channel: " << channel->GetName() << std::endl;
118 return false;
119 }
120
121 // Now, get the set of gamma's
122 gammaList = const_cast<RooArgList*>(&( paramfunc->paramList()));
123 if(verbose) gammaList->Print("V");
124
125 return true;
126
127 }
128
129
130 void getDataValuesForObservables( std::map< std::string, std::vector<double> >& ChannelBinDataMap,
131 RooAbsData* data, RooAbsPdf* pdf ) {
132
133 bool verbose=false;
134
135 RooSimultaneous* simPdf = static_cast<RooSimultaneous*>(pdf);
136
137 // get category label
138 RooCategory* cat = nullptr;
139 for (auto* temp : *data->get()) {
140 if( strcmp(temp->ClassName(),"RooCategory")==0){
141 cat = static_cast<RooCategory*>(temp);
142 break;
143 }
144 }
145 if(verbose) {
146 if(!cat) std::cout <<"didn't find category"<< std::endl;
147 else std::cout <<"found category"<< std::endl;
148 }
149
150 if (!cat) {
151 std::cerr <<"Category not found"<< std::endl;
152 return;
153 }
154
155 // split dataset
156 std::unique_ptr<TList> dataByCategory{data->split(*cat)};
157 if(verbose) dataByCategory->Print();
158 // note :
159 // RooAbsData* dataForChan = (RooAbsData*) dataByCategory->FindObject("");
160
161 // loop over channels
162 auto channelCat = static_cast<RooCategory const*>(&simPdf->indexCat());
163 for (const auto& nameIdx : *channelCat) {
164
165 // Get pdf associated with state from simpdf
166 RooAbsPdf* pdftmp = simPdf->getPdf(nameIdx.first.c_str());
167
168 std::string ChannelName = pdftmp->GetName(); //tt->GetName();
169 if(verbose) std::cout << "Getting data for channel: " << ChannelName << std::endl;
170 ChannelBinDataMap[ ChannelName ] = std::vector<double>();
171
172 RooAbsData* dataForChan = static_cast<RooAbsData*>(dataByCategory->FindObject(nameIdx.first.c_str()));
173 if(verbose) dataForChan->Print();
174
175 // Generate observables defined by the pdf associated with this state
176 std::unique_ptr<RooArgSet> obstmp{pdftmp->getObservables(*dataForChan->get())};
177 RooRealVar* obs = (static_cast<RooRealVar*>(obstmp->first()));
178 if(verbose) obs->Print();
179
180 //double expected = pdftmp->expectedEvents(*obstmp);
181
182 // set value to desired value (this is just an example)
183 // double obsVal = obs->getVal();
184 // set obs to desired value of observable
185 // obs->setVal( obsVal );
186 //double fracAtObsValue = pdftmp->getVal(*obstmp);
187
188 // get num events expected in bin for obsVal
189 // double nu = expected * fracAtObsValue;
190
191 // multidimensional way to get n
192 // credit goes to P. Hamilton
193 for (int i = 0; i < dataForChan->numEntries(); i++) {
194 const RooArgSet *tmpargs = dataForChan->get(i);
195 if (verbose)
196 tmpargs->Print();
197 const double n = dataForChan->weight();
198 if (verbose)
199 std::cout << "n" << i << " = " << n << std::endl;
200 ChannelBinDataMap[ChannelName].push_back(n);
201 }
202
203 } // End Loop Over Categories
204
205 dataByCategory->Delete();
206 }
207
208
210 RooAbsReal*& pois_nom, RooRealVar*& tau ) {
211 // Given a set of constraint terms,
212 // find the poisson constraint for the
213 // given gamma and return the mean
214 // as well as the 'tau' parameter
215
216 bool verbose=false;
217
218 // To get the constraint term, loop over all constraint terms
219 // and look for the gamma_stat name as well as '_constraint'
220
221 bool FoundConstraintTerm=false;
222 RooAbsPdf* constraintTerm=nullptr;
223 for (auto *term_constr : *constraints) {
224 std::string TermName = term_constr->GetName();
225 if( term_constr->dependsOn( *gamma_stat) ) {
226 if( TermName.find("_constraint")!=std::string::npos ) {
228 constraintTerm = static_cast<RooAbsPdf*>(term_constr);
229 break;
230 }
231 }
232 }
233 if( FoundConstraintTerm==false ) {
234 std::cout << "Error: Couldn't find constraint term for parameter: " << gamma_stat->GetName()
235 << " among constraints: " << constraints->GetName() << std::endl;
236 constraints->Print("V");
237 throw std::runtime_error("Failed to find Gamma ConstraintTerm");
238 return -1;
239 }
240
241 /*
242 RooAbsPdf* constraintTerm = (RooAbsPdf*) constraints->find( constraintTermName.c_str() );
243 if( constraintTerm == nullptr ) {
244 std::cout << "Error: Couldn't find constraint term: " << constraintTermName
245 << " for parameter: " << gamma_stat->GetName()
246 << std::endl;
247 throw std::runtime_error("Failed to find Gamma ConstraintTerm");
248 return -1;
249 }
250 */
251
252 // Find the "data" of the poisson term
253 // This is the nominal value
254 bool FoundNomMean=false;
255 for (RooAbsArg * term_pois : constraintTerm->servers()) {
256 std::string serverName = term_pois->GetName();
257 //std::cout << "Checking Server: " << serverName << std::endl;
258 if( serverName.find("nom_")!=std::string::npos ) {
259 FoundNomMean = true;
260 pois_nom = static_cast<RooRealVar*>(term_pois);
261 }
262 }
263 if( !FoundNomMean || !pois_nom ) {
264 std::cout << "Error: Did not find Nominal Pois Mean parameter in gamma constraint term PoissonMean: "
265 << constraintTerm->GetName() << std::endl;
266 throw std::runtime_error("Failed to find Nom Pois Mean");
267 }
268 else {
269 if(verbose) std::cout << "Found Poisson 'data' term: " << pois_nom->GetName() << std::endl;
270 }
271
272 // Taking the constraint term (a Poisson), find
273 // the "mean" which is the product: gamma*tau
274 // Then, from that mean, find tau
275 RooAbsArg * pois_mean_arg = nullptr;
276 for (RooAbsArg * arg : constraintTerm->servers()) {
277 if( arg->dependsOn( *gamma_stat ) ) {
278 pois_mean_arg = arg;
279 break;
280 }
281 }
282 if( !pois_mean_arg ) {
283 std::cout << "Error: Did not find PoissonMean parameter in gamma constraint term: "
284 << constraintTerm->GetName() << std::endl;
285 throw std::runtime_error("Failed to find PoissonMean");
286 return -1;
287 }
288 else {
289 if(verbose) std::cout << "Found Poisson 'mean' term: " << pois_mean_arg->GetName() << std::endl;
290 }
291
292 bool FoundTau=false;
293 for(RooAbsArg* term_in_product : pois_mean_arg->servers()) {
294 std::string serverName = term_in_product->GetName();
295 //std::cout << "Checking Server: " << serverName << std::endl;
296 if( serverName.find("_tau")!=std::string::npos ) {
297 FoundTau = true;
298 tau = static_cast<RooRealVar*>(term_in_product);
299 }
300 }
301 if( !FoundTau || !tau ) {
302 std::cout << "Error: Did not find Tau parameter in gamma constraint term PoissonMean: "
303 << pois_mean_arg->GetName() << std::endl;
304 throw std::runtime_error("Failed to find Tau");
305 }
306 else {
307 if(verbose) std::cout << "Found Poisson 'tau' term: " << tau->GetName() << std::endl;
308 }
309
310 return 0;
311
312 }
313
314
315
316} // close RooStats namespace
317} // close HistFactory namespace
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
A class which maps the current values of a RooRealVar (or a set of RooRealVars) to one of a number of...
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:77
void Print(Option_t *options=nullptr) const override
Print the object to the defaultPrintStream().
Definition RooAbsArg.h:263
bool dependsOn(const RooAbsCollection &serverList, const RooAbsArg *ignoreArg=nullptr, bool valueOnly=false) const
Test whether we depend on (ie, are served by) any object in the specified collection.
RooFit::OwningPtr< RooArgSet > getComponents() const
Create a RooArgSet with all components (branch nodes) of the expression tree headed by this object.
const RefCountList_t & servers() const
List of all servers of this object.
Definition RooAbsArg.h:149
Abstract base class for objects that represent a discrete value that can be set from the outside,...
void setBin(Int_t ibin, const char *rangeName=nullptr) override
Set category to i-th fit bin, which is the i-th registered state.
bool contains(const RooAbsArg &var) const
Check if collection contains an argument with the same name as var.
const char * GetName() const override
Returns name of object.
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
void Print(Option_t *options=nullptr) const override
This method must be overridden when a class wants to print itself.
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract interface for all probability density functions.
Definition RooAbsPdf.h:40
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
Object to represent discrete states.
Definition RooCategory.h:28
Efficient implementation of a product of PDFs of the form.
Definition RooProdPdf.h:39
const RooArgList & pdfList() const
Definition RooProdPdf.h:73
Variable that can be changed from the outside.
Definition RooRealVar.h:37
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
bool getStatUncertaintyFromChannel(RooAbsPdf *channel, ParamHistFunc *&paramfunc, RooArgList *gammaList)
void FactorizeHistFactoryPdf(const RooArgSet &, RooAbsPdf &, RooArgList &, RooArgList &)
void getDataValuesForObservables(std::map< std::string, std::vector< double > > &ChannelBinDataMap, RooAbsData *data, RooAbsPdf *simPdf)
RooAbsPdf * getSumPdfFromChannel(RooAbsPdf *channel)
int getStatUncertaintyConstraintTerm(RooArgList *constraints, RooRealVar *gamma_stat, RooAbsReal *&pois_mean, RooRealVar *&tau)
std::string channelNameFromPdf(RooAbsPdf *channelPdf)
const Int_t n
Definition legend1.C:16
Namespace for the RooStats classes.
Definition CodegenImpl.h:58