Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MultivariateGaussianTest.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_roostats
3/// \notebook
4/// Comparison of MCMC and PLC in a multi-variate gaussian problem
5///
6/// This tutorial produces an N-dimensional multivariate Gaussian
7/// with a non-trivial covariance matrix. By default N=4 (called "dim").
8///
9/// A subset of these are considered parameters of interest.
10/// This problem is tractable analytically.
11///
12/// We use this mainly as a test of Markov Chain Monte Carlo
13/// and we compare the result to the profile likelihood ratio.
14///
15/// We use the proposal helper to create a customized
16/// proposal function for this problem.
17///
18/// For N=4 and 2 parameters of interest it takes about 10-20 seconds
19/// and the acceptance rate is 37%
20///
21/// \macro_image
22/// \macro_output
23/// \macro_code
24///
25/// \authors Kevin Belasco, Kyle Cranmer
26
27#include "RooGlobalFunc.h"
28#include <cstdlib>
29#include "TMatrixDSym.h"
30#include "RooMultiVarGaussian.h"
31#include "RooArgList.h"
32#include "RooRealVar.h"
33#include "TH2F.h"
34#include "TCanvas.h"
35#include "RooAbsReal.h"
36#include "RooFitResult.h"
37#include "TStopwatch.h"
51
52using namespace std;
53using namespace RooFit;
54using namespace RooStats;
55
56void MultivariateGaussianTest(Int_t dim = 4, Int_t nPOI = 2)
57{
58 // let's time this challenging example
59 TStopwatch t;
60 t.Start();
61
62 RooArgList xVec;
63 RooArgList muVec;
64 RooArgSet poi;
65
66 // make the observable and means
67 Int_t i, j;
69 RooRealVar *mu_x;
70 for (i = 0; i < dim; i++) {
71 char *name = Form("x%d", i);
72 x = new RooRealVar(name, name, 0, -3, 3);
73 xVec.add(*x);
74
75 char *mu_name = Form("mu_x%d", i);
76 mu_x = new RooRealVar(mu_name, mu_name, 0, -2, 2);
77 muVec.add(*mu_x);
78 }
79
80 // put them into the list of parameters of interest
81 for (i = 0; i < nPOI; i++) {
82 poi.add(*muVec.at(i));
83 }
84
85 // make a covariance matrix that is all 1's
86 TMatrixDSym cov(dim);
87 for (i = 0; i < dim; i++) {
88 for (j = 0; j < dim; j++) {
89 if (i == j)
90 cov(i, j) = 3.;
91 else
92 cov(i, j) = 1.0;
93 }
94 }
95
96 // now make the multivariate Gaussian
97 RooMultiVarGaussian mvg("mvg", "mvg", xVec, muVec, cov);
98
99 // --------------------
100 // make a toy dataset
101 std::unique_ptr<RooDataSet> data{mvg.generate(xVec, 100)};
102
103 // now create the model config for this problem
104 RooWorkspace *w = new RooWorkspace("MVG");
105 ModelConfig modelConfig(w);
106 modelConfig.SetPdf(mvg);
107 modelConfig.SetParametersOfInterest(poi);
108
109 // -------------------------------------------------------
110 // Setup calculators
111
112 // MCMC
113 // we want to setup an efficient proposal function
114 // using the covariance matrix from a fit to the data
115 std::unique_ptr<RooFitResult> fit{mvg.fitTo(*data, Save(true))};
117 ph.SetVariables((RooArgSet &)fit->floatParsFinal());
118 ph.SetCovMatrix(fit->covarianceMatrix());
120 ph.SetCacheSize(100);
122
123 // now create the calculator
124 MCMCCalculator mc(*data, modelConfig);
125 mc.SetConfidenceLevel(0.95);
126 mc.SetNumBurnInSteps(100);
127 mc.SetNumIters(10000);
128 mc.SetNumBins(50);
129 mc.SetProposalFunction(*pdfProp);
130
131 MCMCInterval *mcInt = mc.GetInterval();
132 RooArgList *poiList = mcInt->GetAxes();
133
134 // now setup the profile likelihood calculator
135 ProfileLikelihoodCalculator plc(*data, modelConfig);
136 plc.SetConfidenceLevel(0.95);
137 LikelihoodInterval *plInt = (LikelihoodInterval *)plc.GetInterval();
138
139 // make some plots
140 MCMCIntervalPlot mcPlot(*mcInt);
141
142 TCanvas *c1 = new TCanvas();
143 mcPlot.SetLineColor(kGreen);
144 mcPlot.SetLineWidth(2);
145 mcPlot.Draw();
146
147 LikelihoodIntervalPlot plPlot(plInt);
148 plPlot.Draw("same");
149
150 if (poiList->getSize() == 1) {
151 RooRealVar *p = (RooRealVar *)poiList->at(0);
152 Double_t ll = mcInt->LowerLimit(*p);
153 Double_t ul = mcInt->UpperLimit(*p);
154 cout << "MCMC interval: [" << ll << ", " << ul << "]" << endl;
155 }
156
157 if (poiList->getSize() == 2) {
158 RooRealVar *p0 = (RooRealVar *)poiList->at(0);
159 RooRealVar *p1 = (RooRealVar *)poiList->at(1);
160 TCanvas *scatter = new TCanvas();
161 Double_t ll = mcInt->LowerLimit(*p0);
162 Double_t ul = mcInt->UpperLimit(*p0);
163 cout << "MCMC interval on p0: [" << ll << ", " << ul << "]" << endl;
164 ll = mcInt->LowerLimit(*p1);
165 ul = mcInt->UpperLimit(*p1);
166 cout << "MCMC interval on p1: [" << ll << ", " << ul << "]" << endl;
167
168 // MCMC interval on p0: [-0.2, 0.6]
169 // MCMC interval on p1: [-0.2, 0.6]
170
171 mcPlot.DrawChainScatter(*p0, *p1);
172 scatter->Update();
173 }
174
175 t.Print();
176}
int Int_t
Definition RtypesCore.h:45
double Double_t
Definition RtypesCore.h:59
@ kGreen
Definition Rtypes.h:66
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
char name[80]
Definition TGX11.cxx:110
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2467
Int_t getSize() const
Return the number of elements in the collection.
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooAbsArg * at(Int_t idx) const
Return object at given index, or nullptr if index is out of range.
Definition RooArgList.h:110
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
Multivariate Gaussian p.d.f.
RooRealVar represents a variable that can be changed from the outside.
Definition RooRealVar.h:37
This class provides simple and straightforward utilities to plot a LikelihoodInterval object.
LikelihoodInterval is a concrete implementation of the RooStats::ConfInterval interface.
Bayesian Calculator estimating an interval or a credible region using the Markov-Chain Monte Carlo me...
This class provides simple and straightforward utilities to plot a MCMCInterval object.
MCMCInterval is a concrete implementation of the RooStats::ConfInterval interface.
virtual double UpperLimit(RooRealVar &param)
get the highest value of param that is within the confidence interval
virtual RooArgList * GetAxes()
return a list of RooRealVars representing the axes you own the returned RooArgList
virtual double LowerLimit(RooRealVar &param)
get the lowest value of param that is within the confidence interval
ModelConfig is a simple class that holds configuration information specifying how a model should be u...
Definition ModelConfig.h:35
The ProfileLikelihoodCalculator is a concrete implementation of CombinedCalculator (the interface cla...
ProposalFunction is an interface for all proposal functions that would be used with a Markov Chain Mo...
virtual void SetCovMatrix(const TMatrixDSym &covMatrix)
set the covariance matrix to use for a multi-variate Gaussian proposal
virtual ProposalFunction * GetProposalFunction()
Get the ProposalFunction that we've been designing.
virtual void SetVariables(RooArgList &vars)
virtual void SetCacheSize(Int_t size)
virtual void SetUpdateProposalParameters(bool updateParams)
Persistable container for RooFit projects.
The Canvas class.
Definition TCanvas.h:23
void Update() override
Update canvas pad buffers.
Definition TCanvas.cxx:2475
Stopwatch class.
Definition TStopwatch.h:28
void Start(Bool_t reset=kTRUE)
Start the stopwatch.
void Print(Option_t *option="") const override
Print the real and cpu time passed between the start and stop events.
RooCmdArg Save(bool flag=true)
return c1
Definition legend1.C:41
Double_t x[n]
Definition legend1.C:17
fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler)
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
Namespace for the RooStats classes.
Definition Asimov.h:19