Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
JSONFactories_HistFactory.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Carsten D. Burgard, DESY/ATLAS, Dec 2021
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
15
19#include <RooConstVar.h>
20#include <RooCategory.h>
21#include <RooRealVar.h>
22#include <RooDataHist.h>
23#include <RooHistFunc.h>
24#include <RooRealSumPdf.h>
25#include <RooBinWidthFunction.h>
26#include <RooProdPdf.h>
27#include <RooPoisson.h>
28#include <RooGaussian.h>
29#include <RooProduct.h>
30#include <RooWorkspace.h>
31
32#include <TH1.h>
33
34#include <stack>
35
36#include "static_execute.h"
37
39
40namespace {
41inline void collectNames(const JSONNode &n, std::vector<std::string> &names)
42{
43 for (const auto &c : n.children()) {
44 names.push_back(RooJSONFactoryWSTool::name(c));
45 }
46}
47
48inline void stackError(const JSONNode &n, std::vector<double> &sumW, std::vector<double> &sumW2)
49{
50 if (!n.is_map())
51 return;
52 if (!n.has_child("counts"))
53 throw "no counts given";
54 if (!n["counts"].is_seq())
55 throw "counts are not in list form";
56 if (!n.has_child("errors"))
57 throw "no errors given";
58 if (!n["errors"].is_seq())
59 throw "errors are not in list form";
60 if (n["counts"].num_children() != n["errors"].num_children()) {
61 throw "inconsistent bin numbers";
62 }
63 const size_t nbins = n["counts"].num_children();
64 for (size_t ibin = 0; ibin < nbins; ++ibin) {
65 double w = n["counts"][ibin].val_float();
66 double e = n["errors"][ibin].val_float();
67 if (ibin < sumW.size())
68 sumW[ibin] += w;
69 else
70 sumW.push_back(w);
71 if (ibin < sumW2.size())
72 sumW2[ibin] += e * e;
73 else
74 sumW2.push_back(e * e);
75 }
76}
77
78std::vector<std::string> getVarnames(const RooHistFunc *hf)
79{
80 const RooDataHist &dh = hf->dataHist();
81 RooArgList vars(*dh.get());
82 return RooJSONFactoryWSTool::names(&vars);
83}
84
85std::unique_ptr<TH1> histFunc2TH1(const RooHistFunc *hf)
86{
87 if (!hf)
88 RooJSONFactoryWSTool::error("null pointer passed to histFunc2TH1");
89 const RooDataHist &dh = hf->dataHist();
90 RooArgSet *vars = hf->getVariables();
91 auto varnames = RooJSONFactoryWSTool::names(vars);
92 std::unique_ptr<TH1> hist{hf->createHistogram(RooJSONFactoryWSTool::concat(vars).c_str())};
93 hist->SetDirectory(nullptr);
94 auto volumes = dh.binVolumes(0, dh.numEntries());
95 for (size_t i = 0; i < volumes.size(); ++i) {
96 hist->SetBinContent(i + 1, hist->GetBinContent(i + 1) / volumes[i]);
97 hist->SetBinError(i + 1, sqrt(hist->GetBinContent(i + 1)));
98 }
99 return hist;
100}
101
102template <class T>
103T *findClient(RooAbsArg *gamma)
104{
105 for (const auto &client : gamma->clients()) {
106 if (client->InheritsFrom(T::Class())) {
107 return static_cast<T *>(client);
108 } else {
109 T *c = findClient<T>(client);
110 if (c)
111 return c;
112 }
113 }
114 return nullptr;
115}
116
117RooRealVar *getNP(RooJSONFactoryWSTool *tool, const char *parname)
118{
119 RooRealVar *par = tool->workspace()->var(parname);
120 if (!tool->workspace()->var(parname)) {
121 par = (RooRealVar *)tool->workspace()->factory(TString::Format("%s[0.,-5,5]", parname).Data());
122 }
123 if (par) {
124 par->setAttribute("np");
125 }
126 TString globname = TString::Format("nom_%s", parname);
127 RooRealVar *nom = tool->workspace()->var(globname.Data());
128 if (!nom) {
129 nom = (RooRealVar *)tool->workspace()->factory((globname + "[0.]").Data());
130 }
131 if (nom) {
132 nom->setAttribute("glob");
133 nom->setRange(-5, 5);
134 nom->setConstant(true);
135 }
136 TString constrname = TString::Format("sigma_%s", parname);
137 RooRealVar *sigma = tool->workspace()->var(constrname.Data());
138 if (!sigma) {
139 sigma = (RooRealVar *)tool->workspace()->factory((constrname + "[1.]").Data());
140 }
141 if (sigma) {
142 sigma->setRange(sigma->getVal(), sigma->getVal());
143 sigma->setConstant(true);
144 }
145 if (!par)
146 RooJSONFactoryWSTool::error(TString::Format("unable to find nuisance parameter '%s'", parname));
147 return par;
148}
149RooAbsPdf *getConstraint(RooJSONFactoryWSTool *tool, const std::string &sysname)
150{
151 RooAbsPdf *pdf = tool->workspace()->pdf((sysname + "_constraint").c_str());
152 if (!pdf) {
153 pdf = (RooAbsPdf *)(tool->workspace()->factory(
154 TString::Format("RooGaussian::%s_constraint(alpha_%s,nom_alpha_%s,sigma_alpha_%s)", sysname.c_str(),
155 sysname.c_str(), sysname.c_str(), sysname.c_str())
156 .Data()));
157 }
158 if (!pdf) {
159 RooJSONFactoryWSTool::error(TString::Format("unable to find constraint term '%s'", sysname.c_str()));
160 }
161 return pdf;
162}
163
164class RooHistogramFactory : public RooJSONFactoryWSTool::Importer {
165public:
166 bool importFunction(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
167 {
168 std::string name(RooJSONFactoryWSTool::name(p));
169 std::string prefix = RooJSONFactoryWSTool::genPrefix(p, true);
170 if (prefix.size() > 0)
171 name = prefix + name;
172 if (!p.has_child("data")) {
173 RooJSONFactoryWSTool::error("function '" + name + "' is of histogram type, but does not define a 'data' key");
174 }
175 try {
176 std::stack<std::unique_ptr<RooAbsArg>> ownedArgsStack;
177 RooArgSet shapeElems;
178 RooArgSet normElems;
179 RooArgSet varlist;
180 tool->getObservables(p["data"], prefix, varlist);
181
182 auto getBinnedData = [&tool, &p, &varlist](std::string const &binnedDataName) -> RooDataHist & {
183 auto *dh = dynamic_cast<RooDataHist *>(tool->workspace()->embeddedData(binnedDataName.c_str()));
184 if (!dh) {
185 auto dhForImport = tool->readBinnedData(p["data"], binnedDataName, varlist);
186 tool->workspace()->import(*dhForImport, RooFit::Silence(true), RooFit::Embedded());
187 dh = static_cast<RooDataHist *>(tool->workspace()->embeddedData(dhForImport->GetName()));
188 }
189 return *dh;
190 };
191
192 RooDataHist &dh = getBinnedData(name);
193 auto hf = std::make_unique<RooHistFunc>(("hist_" + name).c_str(), RooJSONFactoryWSTool::name(p).c_str(),
194 *(dh.get()), dh);
195 ownedArgsStack.push(std::make_unique<RooBinWidthFunction>(
196 TString::Format("%s_binWidth", (!prefix.empty() ? prefix : name).c_str()).Data(),
197 TString::Format("%s_binWidth", (!prefix.empty() ? prefix : name).c_str()).Data(), *hf, true));
198 shapeElems.add(*ownedArgsStack.top());
199
200 if (p.has_child("statError") && p["statError"].val_bool()) {
201 RooAbsArg *phf = tool->getScopeObject("mcstat");
202 if (phf) {
203 shapeElems.add(*phf);
204 } else {
205 RooJSONFactoryWSTool::error("function '" + name +
206 "' has 'statError' active, but no element called 'mcstat' in scope!");
207 }
208 }
209
210 if (p.has_child("normFactors")) {
211 for (const auto &nf : p["normFactors"].children()) {
212 std::string nfname(RooJSONFactoryWSTool::name(nf));
213 RooAbsReal *r = tool->workspace()->var(nfname.c_str());
214 if (r) {
215 normElems.add(*r);
216 } else {
217 normElems.add(
218 *(RooRealVar *)tool->workspace()->factory(TString::Format("%s[1.]", nfname.c_str()).Data()));
219 }
220 }
221 }
222
223 if (p.has_child("overallSystematics")) {
224 RooArgList nps;
225 std::vector<double> low;
226 std::vector<double> high;
227 for (const auto &sys : p["overallSystematics"].children()) {
228 std::string sysname(RooJSONFactoryWSTool::name(sys));
229 std::string parname(sys.has_child("parameter") ? RooJSONFactoryWSTool::name(sys["parameter"])
230 : "alpha_" + sysname);
231 RooRealVar *par = ::getNP(tool, parname.c_str());
232 if (par) {
233 nps.add(*par);
234 low.push_back(sys["low"].val_float());
235 high.push_back(sys["high"].val_float());
236 } else {
237 RooJSONFactoryWSTool::error("overall systematic '" + sysname + "' doesn't have a valid parameter!");
238 }
239 }
240 auto v = std::make_unique<RooStats::HistFactory::FlexibleInterpVar>(
241 ("overallSys_" + name).c_str(), ("overallSys_" + name).c_str(), nps, 1., low, high);
242 v->setAllInterpCodes(4); // default HistFactory interpCode
243 normElems.add(*v);
244 ownedArgsStack.push(std::move(v));
245 }
246
247 if (p.has_child("histogramSystematics")) {
248 RooArgList nps;
249 RooArgList low;
250 RooArgList high;
251 for (const auto &sys : p["histogramSystematics"].children()) {
252 std::string sysname(RooJSONFactoryWSTool::name(sys));
253 std::string parname(sys.has_child("parameter") ? RooJSONFactoryWSTool::name(sys["parameter"])
254 : "alpha_" + sysname);
255 RooAbsReal *par = ::getNP(tool, parname.c_str());
256 nps.add(*par);
257 RooDataHist &dh_low = getBinnedData(sysname + "Low_" + name);
258 ownedArgsStack.push(std::make_unique<RooHistFunc>(
259 (sysname + "Low_" + name).c_str(), RooJSONFactoryWSTool::name(p).c_str(), *(dh_low.get()), dh_low));
260 low.add(*ownedArgsStack.top());
261 RooDataHist &dh_high = getBinnedData(sysname + "High_" + name);
262 ownedArgsStack.push(std::make_unique<RooHistFunc>((sysname + "High_" + name).c_str(),
264 *(dh_high.get()), dh_high));
265 high.add(*ownedArgsStack.top());
266 }
267 auto v = std::make_unique<PiecewiseInterpolation>(("histoSys_" + name).c_str(),
268 ("histoSys_" + name).c_str(), *hf, low, high, nps, false);
269 v->setAllInterpCodes(4); // default interpCode for HistFactory
270 shapeElems.add(*v);
271 ownedArgsStack.push(std::move(v));
272 } else {
273 shapeElems.add(*hf);
274 ownedArgsStack.push(std::move(hf));
275 }
276
277 if (p.has_child("shapeSystematics")) {
278 for (const auto &sys : p["shapeSystematics"].children()) {
279 std::string sysname(RooJSONFactoryWSTool::name(sys));
280 std::string funcName = prefix + sysname + "_ShapeSys";
281 RooAbsArg *phf = tool->getScopeObject(funcName);
282 if (!phf) {
283 RooJSONFactoryWSTool::error("PHF '" + funcName +
284 "' should have been created but cannot be found in scope.");
285 }
286 shapeElems.add(*phf);
287 }
288 }
289
290 RooProduct shape(name.c_str(), (name + "_shape").c_str(), shapeElems);
291 tool->workspace()->import(shape, RooFit::RecycleConflictNodes(true), RooFit::Silence(true));
292 if (normElems.size() > 0) {
293 RooProduct norm((name + "_norm").c_str(), (name + "_norm").c_str(), normElems);
295 } else {
296 tool->workspace()->factory(("RooConstVar::" + name + "_norm(1.)").c_str());
297 }
298 } catch (const std::runtime_error &e) {
299 RooJSONFactoryWSTool::error("function '" + name +
300 "' is of histogram type, but 'data' is not a valid definition. " + e.what() + ".");
301 }
302 return true;
303 }
304};
305
306class RooRealSumPdfFactory : public RooJSONFactoryWSTool::Importer {
307public:
308 std::unique_ptr<ParamHistFunc> createPHF(const std::string &sysname, const std::string &phfname,
309 const std::vector<double> &vals, RooWorkspace &w, RooArgList &constraints,
310 const RooArgSet &observables, const std::string &constraintType,
311 RooArgList &gammas, double gamma_min, double gamma_max) const
312 {
313 RooArgList ownedComponents;
314
315 std::string funcParams = "gamma_" + sysname;
316 gammas.add(ParamHistFunc::createParamSet(w, funcParams.c_str(), observables, gamma_min, gamma_max));
317 auto phf = std::make_unique<ParamHistFunc>(phfname.c_str(), phfname.c_str(), observables, gammas);
318 for (auto &g : gammas) {
319 g->setAttribute("np");
320 }
321
322 if (constraintType == "Gauss") {
323 for (size_t i = 0; i < vals.size(); ++i) {
324 TString nomname = TString::Format("nom_%s", gammas[i].GetName());
325 TString poisname = TString::Format("%s_constraint", gammas[i].GetName());
326 TString sname = TString::Format("%s_sigma", gammas[i].GetName());
327 auto nom = std::make_unique<RooRealVar>(nomname.Data(), nomname.Data(), 1);
328 nom->setAttribute("glob");
329 nom->setConstant(true);
330 nom->setRange(0, std::max(10., gamma_max));
331 auto sigma = std::make_unique<RooConstVar>(sname.Data(), sname.Data(), vals[i]);
332 auto g = static_cast<RooRealVar *>(gammas.at(i));
333 auto gaus = std::make_unique<RooGaussian>(poisname.Data(), poisname.Data(), *nom, *g, *sigma);
334 gaus->addOwnedComponents(std::move(nom), std::move(sigma));
335 constraints.add(*gaus, true);
336 ownedComponents.addOwned(std::move(gaus), true);
337 }
338 } else if (constraintType == "Poisson") {
339 for (size_t i = 0; i < vals.size(); ++i) {
340 double tau_float = vals[i];
341 TString tname = TString::Format("%s_tau", gammas[i].GetName());
342 TString nomname = TString::Format("nom_%s", gammas[i].GetName());
343 TString prodname = TString::Format("%s_poisMean", gammas[i].GetName());
344 TString poisname = TString::Format("%s_constraint", gammas[i].GetName());
345 auto tau = std::make_unique<RooConstVar>(tname.Data(), tname.Data(), tau_float);
346 auto nom = std::make_unique<RooRealVar>(nomname.Data(), nomname.Data(), tau_float);
347 nom->setAttribute("glob");
348 nom->setConstant(true);
349 nom->setMin(0);
350 RooArgSet elems{gammas[i], *tau};
351 auto prod = std::make_unique<RooProduct>(prodname.Data(), prodname.Data(), elems);
352 auto pois = std::make_unique<RooPoisson>(poisname.Data(), poisname.Data(), *nom, *prod);
353 pois->addOwnedComponents(std::move(tau), std::move(nom), std::move(prod));
354 pois->setNoRounding(true);
355 constraints.add(*pois, true);
356 ownedComponents.addOwned(std::move(pois), true);
357 }
358 } else {
359 RooJSONFactoryWSTool::error("unknown constraint type " + constraintType);
360 }
361 for (auto &g : gammas) {
362 for (auto client : g->clients()) {
363 if (client->InheritsFrom(RooAbsPdf::Class()) && !constraints.find(*client)) {
364 constraints.add(*client);
365 }
366 }
367 }
368 phf->recursiveRedirectServers(observables);
369 // Transfer ownership of gammas and owned constraints to the ParamHistFunc
370 phf->addOwnedComponents(std::move(ownedComponents));
371
372 return phf;
373 }
374
375 std::unique_ptr<ParamHistFunc> createPHFMCStat(const std::string &name, const std::vector<double> &sumW,
376 const std::vector<double> &sumW2, RooWorkspace &w,
377 RooArgList &constraints, const RooArgSet &observables,
378 double statErrorThreshold, const std::string &statErrorType) const
379 {
380 if (sumW.size() == 0)
381 return nullptr;
382
383 RooArgList gammas;
384 std::string phfname = std::string("mc_stat_") + name;
385 std::string sysname = std::string("stat_") + name;
386 std::vector<double> vals(sumW.size());
387 std::vector<double> errs(sumW.size());
388
389 for (size_t i = 0; i < sumW.size(); ++i) {
390 errs[i] = sqrt(sumW2[i]) / sumW[i];
391 if (statErrorType == "Gauss") {
392 vals[i] = std::max(errs[i], 0.); // avoid negative sigma. This NP will be set constant anyway later
393 } else if (statErrorType == "Poisson") {
394 vals[i] = sumW[i] * sumW[i] / sumW2[i];
395 }
396 }
397
398 auto phf = createPHF(sysname, phfname, vals, w, constraints, observables, statErrorType, gammas, 0, 10);
399
400 // set constant NPs which are below the MC stat threshold, and remove them from the np list
401 for (size_t i = 0; i < sumW.size(); ++i) {
402 auto g = static_cast<RooRealVar *>(gammas.at(i));
403 g->setError(errs[i]);
404 if (errs[i] < statErrorThreshold) {
405 g->setConstant(true); // all negative errs are set constant
406 }
407 }
408
409 return phf;
410 }
411
412 std::unique_ptr<ParamHistFunc> createPHFShapeSys(const JSONNode &p, const std::string &phfname, RooWorkspace &w,
413 RooArgList &constraints, const RooArgSet &observables) const
414 {
415 std::string sysname(RooJSONFactoryWSTool::name(p));
416 std::vector<double> vals;
417 for (const auto &v : p["vals"].children()) {
418 vals.push_back(v.val_float());
419 }
420 RooArgList gammas;
421 return createPHF(sysname, phfname, vals, w, constraints, observables, p["constraint"].val(), gammas, 0, 1000);
422 }
423
424 bool importPdf(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
425 {
426 std::string name(RooJSONFactoryWSTool::name(p));
427 RooArgList funcs;
428 RooArgList coefs;
429 RooArgList constraints;
430 if (!p.has_child("samples")) {
431 RooJSONFactoryWSTool::error("no samples in '" + name + "', skipping.");
432 }
433 std::vector<std::string> usesStatError;
434 double statErrorThreshold = 0;
435 std::string statErrorType = "Poisson";
436 if (p.has_child("statError")) {
437 auto &staterr = p["statError"];
438 if (staterr.has_child("relThreshold"))
439 statErrorThreshold = staterr["relThreshold"].val_float();
440 if (staterr.has_child("constraint"))
441 statErrorType = staterr["constraint"].val();
442 }
443 std::vector<double> sumW;
444 std::vector<double> sumW2;
445 std::vector<double> dummy;
446 std::vector<std::string> sysnames;
447 std::vector<std::string> funcnames;
448 std::vector<std::string> coefnames;
449 RooArgSet observables;
450 if (p.has_child("observables")) {
451 tool->getObservables(p, name, observables);
452 tool->setScopeObservables(observables);
453 }
454 for (const auto &comp : p["samples"].children()) {
455 std::string fname(RooJSONFactoryWSTool::name(comp));
456 auto &def = comp.is_container() ? comp : p["functions"][fname.c_str()];
457 std::string fprefix = RooJSONFactoryWSTool::genPrefix(def, true);
458 if (def["type"].val() == "hist-sample") {
459 try {
460 if (observables.empty()) {
461 tool->getObservables(comp["data"], fprefix, observables);
462 }
463 if (def.has_child("overallSystematics"))
464 ::collectNames(def["overallSystematics"], sysnames);
465 if (def.has_child("histogramSystematics"))
466 ::collectNames(def["histogramSystematics"], sysnames);
467 if (def.has_child("shapeSystematics")) { // ShapeSys are special case. Create PHFs here if needed
468 std::vector<std::string> shapeSysNames;
469 ::collectNames(def["shapeSystematics"], shapeSysNames);
470 for (auto &sysname : shapeSysNames) {
471 std::string phfname = name + "_" + sysname + "_ShapeSys";
472 auto phf = tool->getScopeObject(phfname);
473 if (!phf) {
474 auto newphf = createPHFShapeSys(def["shapeSystematics"][sysname], phfname, *(tool->workspace()),
475 constraints, observables);
477 tool->setScopeObject(phfname, tool->workspace()->function(phfname.c_str()));
478 }
479 }
480 }
481 } catch (const char *s) {
482 RooJSONFactoryWSTool::error("function '" + name + "' unable to collect observables from function " +
483 fname + ". " + s);
484 }
485 try {
486 if (comp["statError"].val_bool()) {
487 ::stackError(def["data"], sumW, sumW2);
488 }
489 } catch (const char *s) {
490 RooJSONFactoryWSTool::error("function '" + name + "' unable to sum statError from function " + fname +
491 ". " + s);
492 }
493 }
494 funcnames.push_back(fprefix + fname);
495 coefnames.push_back(fprefix + fname + "_norm");
496 }
497
498 auto phf = createPHFMCStat(name, sumW, sumW2, *(tool->workspace()), constraints, observables, statErrorThreshold,
499 statErrorType);
500 if (phf) {
502 tool->setScopeObject("mcstat", tool->workspace()->function(phf->GetName()));
503 }
504
505 tool->importFunctions(p["samples"]);
506 for (const auto &fname : funcnames) {
507 RooAbsReal *func = tool->request<RooAbsReal>(fname.c_str(), name);
508 funcs.add(*func);
509 }
510 for (const auto &coefname : coefnames) {
511 RooAbsReal *coef = tool->request<RooAbsReal>(coefname.c_str(), name);
512 coefs.add(*coef);
513 }
514 for (auto sysname : sysnames) {
515 RooAbsPdf *pdf = ::getConstraint(tool, sysname.c_str());
516 constraints.add(*pdf);
517 }
518 if (constraints.empty()) {
519 RooRealSumPdf sum(name.c_str(), name.c_str(), funcs, coefs, true);
520 sum.setAttribute("BinnedLikelihood");
522 } else {
523 RooRealSumPdf sum((name + "_model").c_str(), name.c_str(), funcs, coefs, true);
524 sum.setAttribute("BinnedLikelihood");
526 RooArgList lhelems;
527 lhelems.add(sum);
528 RooProdPdf prod(name.c_str(), name.c_str(), RooArgSet(constraints), RooFit::Conditional(lhelems, observables));
530 }
531
532 tool->clearScope();
533
534 return true;
535 }
536};
537
538} // namespace
539
540namespace {
541class FlexibleInterpVarStreamer : public RooJSONFactoryWSTool::Exporter {
542public:
543 std::string const &key() const override
544 {
545 static const std::string keystring = "interpolation0d";
546 return keystring;
547 }
548 bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override
549 {
551 static_cast<const RooStats::HistFactory::FlexibleInterpVar *>(func);
552 elem["type"] << key();
553 auto &vars = elem["vars"];
554 vars.set_seq();
555 for (const auto &v : fip->variables()) {
556 vars.append_child() << v->GetName();
557 }
558 elem["nom"] << fip->nominal();
559 elem["high"] << fip->high();
560 elem["low"] << fip->low();
561 return true;
562 }
563};
564
565class PiecewiseInterpolationStreamer : public RooJSONFactoryWSTool::Exporter {
566public:
567 std::string const &key() const override
568 {
569 static const std::string keystring = "interpolation";
570 return keystring;
571 }
572 bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override
573 {
574 const PiecewiseInterpolation *pip = static_cast<const PiecewiseInterpolation *>(func);
575 elem["type"] << key();
576 elem["interpolationCodes"] << pip->interpolationCodes();
577 auto &vars = elem["vars"];
578 vars.set_seq();
579 for (const auto &v : pip->paramList()) {
580 vars.append_child() << v->GetName();
581 }
582
583 auto &nom = elem["nom"];
584 nom << pip->nominalHist()->GetName();
585
586 auto &high = elem["high"];
587 high.set_seq();
588 for (const auto &v : pip->highList()) {
589 high.append_child() << v->GetName();
590 }
591
592 auto &low = elem["low"];
593 low.set_seq();
594 for (const auto &v : pip->lowList()) {
595 low.append_child() << v->GetName();
596 }
597 return true;
598 }
599};
600} // namespace
601
602namespace {
603class PiecewiseInterpolationFactory : public RooJSONFactoryWSTool::Importer {
604public:
605 bool importFunction(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
606 {
607 std::string name(RooJSONFactoryWSTool::name(p));
608 if (!p.has_child("vars")) {
609 RooJSONFactoryWSTool::error("no vars of '" + name + "'");
610 }
611 if (!p.has_child("high")) {
612 RooJSONFactoryWSTool::error("no high variations of '" + name + "'");
613 }
614 if (!p.has_child("low")) {
615 RooJSONFactoryWSTool::error("no low variations of '" + name + "'");
616 }
617 if (!p.has_child("nom")) {
618 RooJSONFactoryWSTool::error("no nominal variation of '" + name + "'");
619 }
620
621 std::string nomname(p["nom"].val());
622 RooAbsReal *nominal = tool->request<RooAbsReal>(nomname, name);
623
624 RooArgList vars;
625 for (const auto &d : p["vars"].children()) {
626 std::string objname(RooJSONFactoryWSTool::name(d));
627 RooRealVar *obj = tool->request<RooRealVar>(objname, name);
628 vars.add(*obj);
629 }
630
631 RooArgList high;
632 for (const auto &d : p["high"].children()) {
633 std::string objname(RooJSONFactoryWSTool::name(d));
634 RooAbsReal *obj = tool->request<RooAbsReal>(objname, name);
635 high.add(*obj);
636 }
637
638 RooArgList low;
639 for (const auto &d : p["low"].children()) {
640 std::string objname(RooJSONFactoryWSTool::name(d));
641 RooAbsReal *obj = tool->request<RooAbsReal>(objname, name);
642 low.add(*obj);
643 }
644
645 PiecewiseInterpolation pip(name.c_str(), name.c_str(), *nominal, low, high, vars);
646
647 if (p.has_child("interpolationCodes")) {
648 for (size_t i = 0; i < vars.size(); ++i) {
649 pip.setInterpCode(*static_cast<RooAbsReal *>(vars.at(i)), p["interpolationCodes"][i].val_int(), true);
650 }
651 }
652
654 return true;
655 }
656};
657} // namespace
658
659namespace {
660class HistFactoryStreamer : public RooJSONFactoryWSTool::Exporter {
661public:
662 bool autoExportDependants() const override { return false; }
663 void collectElements(RooArgSet &elems, RooProduct *prod) const
664 {
665 for (const auto &e : prod->components()) {
666 if (e->InheritsFrom(RooProduct::Class())) {
667 collectElements(elems, (RooProduct *)e);
668 } else {
669 elems.add(*e);
670 }
671 }
672 }
673 bool tryExport(const RooProdPdf *prodpdf, JSONNode &elem) const
674 {
675 std::string chname(prodpdf->GetName());
676 if (chname.find("model_") == 0) {
677 chname = chname.substr(6);
678 }
679 RooRealSumPdf *sumpdf = nullptr;
680 for (const auto &v : prodpdf->pdfList()) {
681 if (v->InheritsFrom(RooRealSumPdf::Class())) {
682 sumpdf = static_cast<RooRealSumPdf *>(v);
683 }
684 }
685 if (!sumpdf)
686 return false;
687 for (const auto &sample : sumpdf->funcList()) {
688 if (!sample->InheritsFrom(RooProduct::Class()) && !sample->InheritsFrom(RooRealSumPdf::Class()))
689 return false;
690 }
691
692 bool has_poisson_constraints = false;
693 bool has_gauss_constraints = false;
694 std::map<int, double> tot_yield;
695 std::map<int, double> tot_yield2;
696 std::map<int, double> rel_errors;
697 std::map<std::string, std::unique_ptr<TH1>> bb_histograms;
698 std::map<std::string, std::unique_ptr<TH1>> nonbb_histograms;
699 std::vector<std::string> varnames;
700
701 for (size_t sampleidx = 0; sampleidx < sumpdf->funcList().size(); ++sampleidx) {
702 const auto func = sumpdf->funcList().at(sampleidx);
703 const auto coef = sumpdf->coefList().at(sampleidx);
704 std::string samplename = func->GetName();
705 if (samplename.find("L_x_") == 0)
706 samplename = samplename.substr(4);
707 auto end = samplename.find("_" + chname);
708 if (end < samplename.size())
709 samplename = samplename.substr(0, end);
710
711 RooArgSet elems;
712 if (func->InheritsFrom(RooProduct::Class())) {
713 collectElements(elems, (RooProduct *)func);
714 } else {
715 elems.add(*func);
716 }
717 if (coef->InheritsFrom(RooProduct::Class())) {
718 collectElements(elems, (RooProduct *)coef);
719 } else {
720 elems.add(*coef);
721 }
722 std::unique_ptr<TH1> hist;
723 std::vector<ParamHistFunc *> phfs;
724 PiecewiseInterpolation *pip = nullptr;
725 std::vector<const RooAbsArg *> norms;
726
728 for (const auto &e : elems) {
729 if (e->InheritsFrom(RooConstVar::Class())) {
730 if (((RooConstVar *)e)->getVal() == 1.)
731 continue;
732 norms.push_back(e);
733 } else if (e->InheritsFrom(RooRealVar::Class())) {
734 norms.push_back(e);
735 } else if (e->InheritsFrom(RooHistFunc::Class())) {
736 const RooHistFunc *hf = static_cast<const RooHistFunc *>(e);
737 if (varnames.size() == 0) {
738 varnames = getVarnames(hf);
739 }
740 if (!hist) {
741 hist = histFunc2TH1(hf);
742 }
743 } else if (e->InheritsFrom(RooStats::HistFactory::FlexibleInterpVar::Class())) {
744 fip = static_cast<RooStats::HistFactory::FlexibleInterpVar *>(e);
745 } else if (e->InheritsFrom(PiecewiseInterpolation::Class())) {
746 pip = static_cast<PiecewiseInterpolation *>(e);
747 } else if (e->InheritsFrom(ParamHistFunc::Class())) {
748 phfs.push_back((ParamHistFunc *)e);
749 }
750 }
751 if (pip) {
752 if (!hist && pip->nominalHist()->InheritsFrom(RooHistFunc::Class())) {
753 hist = histFunc2TH1(static_cast<const RooHistFunc *>(pip->nominalHist()));
754 }
755 if (varnames.empty() && pip->nominalHist()->InheritsFrom(RooHistFunc::Class())) {
756 varnames = getVarnames(dynamic_cast<const RooHistFunc *>(pip->nominalHist()));
757 }
758 }
759 if (!hist) {
760 return false;
761 }
762
763 elem["name"] << chname;
764 elem["type"] << key();
765
766 auto &samples = elem["samples"];
767 samples.set_map();
768 auto &s = samples[samplename];
769 s.set_map();
770 s["type"] << "hist-sample";
771
772 for (const auto &norm : norms) {
773 auto &nfs = s["normFactors"];
774 nfs.set_seq();
775 nfs.append_child() << norm->GetName();
776 }
777
778 if (pip) {
779 auto &systs = s["histogramSystematics"];
780 systs.set_map();
781 for (size_t i = 0; i < pip->paramList().size(); ++i) {
782 std::string sysname(pip->paramList().at(i)->GetName());
783 if (sysname.find("alpha_") == 0) {
784 sysname = sysname.substr(6);
785 }
786 auto &sys = systs[sysname];
787 sys.set_map();
788 auto &dataLow = sys["dataLow"];
789 if (pip->lowList().at(i)->InheritsFrom(RooHistFunc::Class())) {
790 auto histLow = histFunc2TH1(static_cast<RooHistFunc *>(pip->lowList().at(i)));
791 RooJSONFactoryWSTool::exportHistogram(*histLow, dataLow, varnames, 0, false, false);
792 }
793 auto &dataHigh = sys["dataHigh"];
794 if (pip->highList().at(i)->InheritsFrom(RooHistFunc::Class())) {
795 auto histHigh = histFunc2TH1(static_cast<RooHistFunc *>(pip->highList().at(i)));
796 RooJSONFactoryWSTool::exportHistogram(*histHigh, dataHigh, varnames, 0, false, false);
797 }
798 }
799 }
800
801 if (fip) {
802 auto &systs = s["overallSystematics"];
803 systs.set_map();
804 for (size_t i = 0; i < fip->variables().size(); ++i) {
805 std::string sysname(fip->variables().at(i)->GetName());
806 if (sysname.find("alpha_") == 0) {
807 sysname = sysname.substr(6);
808 }
809 auto &sys = systs[sysname];
810 sys.set_map();
811 sys["low"] << fip->low()[i];
812 sys["high"] << fip->high()[i];
813 }
814 }
815 bool has_mc_stat = false;
816 for (auto phf : phfs) {
817 if (TString(phf->GetName()).BeginsWith("mc_stat_")) { // MC stat uncertainty
818 has_mc_stat = true;
819 s["statError"] << 1;
820 int idx = 0;
821 for (const auto &g : phf->paramList()) {
822 ++idx;
823 RooPoisson *constraint_p = findClient<RooPoisson>(g);
824 RooGaussian *constraint_g = findClient<RooGaussian>(g);
825 if (tot_yield.find(idx) == tot_yield.end()) {
826 tot_yield[idx] = 0;
827 tot_yield2[idx] = 0;
828 }
829 tot_yield[idx] += hist->GetBinContent(idx);
830 tot_yield2[idx] += (hist->GetBinContent(idx) * hist->GetBinContent(idx));
831 if (constraint_p) {
832 double erel = 1. / std::sqrt(constraint_p->getX().getVal());
833 rel_errors[idx] = erel;
834 has_poisson_constraints = true;
835 } else if (constraint_g) {
836 double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
837 rel_errors[idx] = erel;
838 has_gauss_constraints = true;
839 }
840 }
841 bb_histograms[samplename] = std::move(hist);
842 } else { // other ShapeSys
843 auto &shapesysts = s["shapeSystematics"];
844 shapesysts.set_map();
845 // Getting the name of the syst is tricky.
846 TString sysName(phf->GetName());
847 sysName.Remove(sysName.Index("_ShapeSys"));
848 sysName.Remove(0, chname.size() + 1);
849 auto &sys = shapesysts[sysName.Data()];
850 sys.set_map();
851 auto &cstrts = sys["vals"];
852 cstrts.set_seq();
853 bool is_poisson = false;
854 for (const auto &g : phf->paramList()) {
855 RooPoisson *constraint_p = findClient<RooPoisson>(g);
856 RooGaussian *constraint_g = findClient<RooGaussian>(g);
857 if (constraint_p) {
858 is_poisson = true;
859 cstrts.append_child() << constraint_p->getX().getVal();
860 } else if (constraint_g) {
861 is_poisson = false;
862 cstrts.append_child() << constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
863 }
864 }
865 if (is_poisson) {
866 sys["constraint"] << "Poisson";
867 } else {
868 sys["constraint"] << "Gauss";
869 }
870 }
871 }
872 if (!has_mc_stat) {
873 nonbb_histograms[samplename] = std::move(hist);
874 s["statError"] << 0;
875 }
876 auto &ns = s["namespaces"];
877 ns.set_seq();
878 ns.append_child() << chname;
879 }
880
881 auto &samples = elem["samples"];
882 for (const auto &hist : nonbb_histograms) {
883 auto &s = samples[hist.first];
884 auto &data = s["data"];
885 RooJSONFactoryWSTool::writeObservables(*hist.second, elem, varnames);
886 RooJSONFactoryWSTool::exportHistogram(*hist.second, data, varnames, 0, false, false);
887 }
888 for (const auto &hist : bb_histograms) {
889 auto &s = samples[hist.first];
890 for (auto bin : rel_errors) {
891 // reverse engineering the correct partial error
892 // the (arbitrary) convention used here is that all samples should have the same relative error
893 const int i = bin.first;
894 const double relerr_tot = bin.second;
895 const double count = hist.second->GetBinContent(i);
896 hist.second->SetBinError(i, relerr_tot * tot_yield[i] / sqrt(tot_yield2[i]) * count);
897 }
898 auto &data = s["data"];
899 RooJSONFactoryWSTool::writeObservables(*hist.second, elem, varnames);
900 RooJSONFactoryWSTool::exportHistogram(*hist.second, data, varnames, 0, false, true);
901 }
902 auto &statError = elem["statError"];
903 statError.set_map();
904 if (has_poisson_constraints) {
905 statError["constraint"] << "Poisson";
906 } else if (has_gauss_constraints) {
907 statError["constraint"] << "Gauss";
908 }
909 return true;
910 }
911
912 std::string const &key() const override
913 {
914 static const std::string keystring = "histfactory";
915 return keystring;
916 }
917 bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *p, JSONNode &elem) const override
918 {
919 const RooProdPdf *prodpdf = static_cast<const RooProdPdf *>(p);
920 if (tryExport(prodpdf, elem)) {
921 return true;
922 }
923 return false;
924 }
925};
926
928
929 using Tool = RooJSONFactoryWSTool;
930
931 Tool::registerImporter<RooRealSumPdfFactory>("histfactory", true);
932 Tool::registerImporter<RooHistogramFactory>("hist-sample", true);
933 Tool::registerImporter<PiecewiseInterpolationFactory>("interpolation", true);
934 Tool::registerExporter<FlexibleInterpVarStreamer>(RooStats::HistFactory::FlexibleInterpVar::Class(), true);
935 Tool::registerExporter<PiecewiseInterpolationStreamer>(PiecewiseInterpolation::Class(), true);
936 Tool::registerExporter<HistFactoryStreamer>(RooProdPdf::Class(), true);
937
938)
939
940} // namespace
ROOT::R::TRInterface & r
Definition Object.C:4
#define d(i)
Definition RSha256.hxx:102
#define c(i)
Definition RSha256.hxx:101
#define g(i)
Definition RSha256.hxx:105
#define e(i)
Definition RSha256.hxx:103
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
char name[80]
Definition TGX11.cxx:110
static char * Format(const char *format, va_list ap)
Format a string in a circular formatting buffer (using a printf style format descriptor).
Definition TString.cxx:2400
A class which maps the current values of a RooRealVar (or a set of RooRealVars) to one of a number of...
static RooArgList createParamSet(RooWorkspace &w, const std::string &, const RooArgList &Vars)
Create the list of RooRealVar parameters which represent the height of the histogram bins.
The PiecewiseInterpolation is a class that can morph distributions into each other,...
const RooArgList & highList() const
const RooAbsReal * nominalHist() const
Return pointer to the nominal hist function.
void setInterpCode(RooAbsReal &param, int code, bool silent=false)
const RooArgList & lowList() const
const RooArgList & paramList() const
const std::vector< int > & interpolationCodes() const
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition RooAbsArg.h:69
bool addOwnedComponents(const RooAbsCollection &comps)
Take ownership of the contents of 'comps'.
void setAttribute(const Text_t *name, Bool_t value=kTRUE)
Set (default) or clear a named boolean attribute of this object.
RooArgSet * getVariables(Bool_t stripDisconnected=kTRUE) const
Return RooArgSet with all variables (tree leaf nodes of expresssion tree)
Bool_t recursiveRedirectServers(const RooAbsCollection &newServerList, Bool_t mustReplaceAll=kFALSE, Bool_t nameChange=kFALSE, Bool_t recurseInNewSet=kTRUE)
Recursively replace all servers with the new servers in newSet.
virtual Bool_t add(const RooAbsArg &var, Bool_t silent=kFALSE)
Add the specified argument to list.
virtual Bool_t addOwned(RooAbsArg &var, Bool_t silent=kFALSE)
Add an argument and transfer the ownership to the collection.
Storage_t::size_type size() const
const char * GetName() const
Returns name of object.
RooAbsArg * find(const char *name) const
Find object with given name in list.
void setConstant(Bool_t value=kTRUE)
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition RooAbsReal.h:64
TH1 * createHistogram(const char *varNameList, Int_t xbins=0, Int_t ybins=0, Int_t zbins=0) const
Create and fill a ROOT histogram TH1, TH2 or TH3 with the values of this function for the variables w...
Double_t getVal(const RooArgSet *normalisationSet=nullptr) const
Evaluate object.
Definition RooAbsReal.h:94
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooAbsArg * at(Int_t idx) const
Return object at given index, or nullptr if index is out of range.
Definition RooArgList.h:110
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:35
RooConstVar represent a constant real-valued object.
Definition RooConstVar.h:26
The RooDataHist is a container class to hold N-dimensional binned data.
Definition RooDataHist.h:45
RooSpan< const double > binVolumes(std::size_t first, std::size_t len) const
Retrieve all bin volumes. Bins are indexed according to getIndex().
Definition RooDataHist.h:98
Int_t numEntries() const override
Return the number of bins.
const RooArgSet * get() const override
Get bin centre of current bin.
Definition RooDataHist.h:84
virtual bool has_child(std::string const &) const =0
virtual float val_float() const
Plain Gaussian p.d.f.
Definition RooGaussian.h:24
RooAbsReal const & getMean() const
Get the mean parameter.
Definition RooGaussian.h:45
RooAbsReal const & getSigma() const
Get the sigma parameter.
Definition RooGaussian.h:48
RooHistFunc implements a real-valued function sampled from a multidimensional histogram.
Definition RooHistFunc.h:30
RooDataHist & dataHist()
Return RooDataHist that is represented.
Definition RooHistFunc.h:40
virtual std::string const & key() const =0
virtual bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *, RooFit::Experimental::JSONNode &) const
virtual bool autoExportDependants() const
virtual bool importFunction(RooJSONFactoryWSTool *, const RooFit::Experimental::JSONNode &) const
virtual bool importPdf(RooJSONFactoryWSTool *, const RooFit::Experimental::JSONNode &) const
When using RooFit, statistical models can be conveniently handled and stored as a RooWorkspace.
static void exportHistogram(const TH1 &h, RooFit::Experimental::JSONNode &n, const std::vector< std::string > &obsnames, const TH1 *errH=0, bool writeObservables=true, bool writeErrors=true)
T * request(const std::string &objname, const std::string &requestAuthor)
RooAbsArg * getScopeObject(const std::string &name)
void setScopeObject(const std::string &key, RooAbsArg *obj)
static std::string genPrefix(const RooFit::Experimental::JSONNode &p, bool trailing_underscore)
static std::string concat(const T *items, const std::string &sep=",")
static std::string name(const RooFit::Experimental::JSONNode &n)
static std::vector< std::string > names(const T *items)
void importFunctions(const RooFit::Experimental::JSONNode &n)
std::unique_ptr< RooDataHist > readBinnedData(const RooFit::Experimental::JSONNode &n, const std::string &namecomp, RooArgList observables)
static void writeObservables(const TH1 &h, RooFit::Experimental::JSONNode &n, const std::vector< std::string > &varnames)
void getObservables(const RooFit::Experimental::JSONNode &n, const std::string &obsnamecomp, RooArgSet &out)
void setScopeObservables(const RooArgList &args)
static void error(const char *s)
Poisson pdf.
Definition RooPoisson.h:19
RooAbsReal const & getX() const
Get the x variable.
Definition RooPoisson.h:39
RooProdPdf is an efficient implementation of a product of PDFs of the form.
Definition RooProdPdf.h:33
A RooProduct represents the product of a given set of RooAbsReal objects.
Definition RooProduct.h:29
The class RooRealSumPdf implements a PDF constructed from a sum of functions:
const RooArgList & funcList() const
const RooArgList & coefList() const
RooRealVar represents a variable that can be changed from the outside.
Definition RooRealVar.h:39
void setMin(const char *name, Double_t value)
Set minimum of name range to given value.
void setError(Double_t value)
Definition RooRealVar.h:64
void setRange(const char *name, Double_t min, Double_t max)
Set a fit or plotting range.
const std::vector< double > & high() const
const std::vector< double > & low() const
const RooListProxy & variables() const
Const getters.
The RooWorkspace is a persistable container for RooFit projects.
RooAbsData * embeddedData(const char *name) const
Retrieve dataset (binned or unbinned) with given name. A null pointer is returned if not found.
Bool_t import(const RooAbsArg &arg, const RooCmdArg &arg1=RooCmdArg(), const RooCmdArg &arg2=RooCmdArg(), const RooCmdArg &arg3=RooCmdArg(), const RooCmdArg &arg4=RooCmdArg(), const RooCmdArg &arg5=RooCmdArg(), const RooCmdArg &arg6=RooCmdArg(), const RooCmdArg &arg7=RooCmdArg(), const RooCmdArg &arg8=RooCmdArg(), const RooCmdArg &arg9=RooCmdArg())
Import a RooAbsArg object, e.g.
RooAbsReal * function(const char *name) const
Retrieve function (RooAbsReal) with given name. Note that all RooAbsPdfs are also RooAbsReals....
RooRealVar * var(const char *name) const
Retrieve real-valued variable (RooRealVar) with given name. A null pointer is returned if not found.
RooFactoryWSTool & factory()
Return instance to factory tool.
RooAbsPdf * pdf(const char *name) const
Retrieve p.d.f (RooAbsPdf) with given name. A null pointer is returned if not found.
virtual void SetDirectory(TDirectory *dir)
By default, when a histogram is created, it is added to the list of histogram objects in the current ...
Definition TH1.cxx:8767
virtual const char * GetName() const
Returns name of object.
Definition TNamed.h:47
virtual Bool_t InheritsFrom(const char *classname) const
Returns kTRUE if object inherits from class "classname".
Definition TObject.cxx:515
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition TString.h:615
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2336
RooCmdArg RecycleConflictNodes(Bool_t flag=kTRUE)
RooCmdArg Embedded(Bool_t flag=kTRUE)
RooCmdArg Silence(Bool_t flag=kTRUE)
RooCmdArg Conditional(const RooArgSet &pdfSet, const RooArgSet &depSet, Bool_t depsAreCond=kFALSE)
const Double_t sigma
const Int_t n
Definition legend1.C:16
double gamma(double x)
double T(double x)
VecExpr< UnaryOp< Sqrt< T >, VecExpr< A, T, D >, T >, T, D > sqrt(const VecExpr< A, T, D > &rhs)
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
#define STATIC_EXECUTE(MY_CODE)
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2345