Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MetropolisHastings.cxx
Go to the documentation of this file.
1// @(#)root/roostats:$Id$
2// Authors: Kevin Belasco 17/06/2009
3// Authors: Kyle Cranmer 17/06/2009
4/*************************************************************************
5 * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/** \class RooStats::MetropolisHastings
13 \ingroup Roostats
14
15This class uses the Metropolis-Hastings algorithm to construct a Markov Chain
16of data points using Monte Carlo. In the main algorithm, new points in the
17parameter space are proposed and then visited based on their relative
18likelihoods. This class can use any implementation of the ProposalFunction,
19including non-symmetric proposal functions, to propose parameter points and
20still maintain detailed balance when constructing the chain.
21
22
23
24The "Likelihood" function that is sampled when deciding what steps to take in
25the chain has been given a very generic implementation. The user can create
26any RooAbsReal based on the parameters and pass it to a MetropolisHastings
27object with the method SetFunction(RooAbsReal&). Be sure to tell
28MetropolisHastings whether your RooAbsReal is on a (+/-) regular or log scale,
29so that it knows what logic to use when sampling your RooAbsReal. For example,
30a common use is to sample from a -log(Likelihood) distribution (NLL), for which
31the appropriate configuration calls are SetType(MetropolisHastings::kLog);
32SetSign(MetropolisHastings::kNegative);
33If you're using a traditional likelihood function:
34SetType(MetropolisHastings::kRegular); SetSign(MetropolisHastings::kPositive);
35You must set these type and sign flags or MetropolisHastings will not construct
36a MarkovChain.
37
38Also note that in ConstructChain(), the values of the variables are randomized
39uniformly over their intervals before construction of the MarkovChain begins.
40
41*/
42
44
49
50#include "Rtypes.h"
51#include "RooRealVar.h"
52#include "RooNLLVar.h"
53#include "RooGlobalFunc.h"
54#include "RooDataSet.h"
55#include "RooArgSet.h"
56#include "RooArgList.h"
57#include "RooMsgService.h"
58#include "RooRandom.h"
59#include "TMath.h"
60#include "TFile.h"
61
63
64using namespace RooFit;
65using namespace RooStats;
66using namespace std;
67
68////////////////////////////////////////////////////////////////////////////////
69
71{
72 // default constructor
73 fFunction = NULL;
74 fPropFunc = NULL;
75 fNumIters = 0;
79}
80
81////////////////////////////////////////////////////////////////////////////////
82
84 ProposalFunction& proposalFunction, Int_t numIters)
85{
86 fFunction = &function;
87 SetParameters(paramsOfInterest);
88 SetProposalFunction(proposalFunction);
89 fNumIters = numIters;
93}
94
95////////////////////////////////////////////////////////////////////////////////
96
98{
99 if (fParameters.getSize() == 0 || !fPropFunc || !fFunction) {
100 coutE(Eval) << "Critical members unintialized: parameters, proposal " <<
101 " function, or (log) likelihood function" << endl;
102 return NULL;
103 }
104 if (fSign == kSignUnset || fType == kTypeUnset) {
105 coutE(Eval) << "Please set type and sign of your function using "
106 << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
107 endl;
108 return NULL;
109 }
110
112
113 RooArgSet x;
114 RooArgSet xPrime;
117 xPrime.addClone(fParameters);
118 RandomizeCollection(xPrime);
119
120 MarkovChain* chain = new MarkovChain();
121 // only the POI will be added to the chain
123
124 Int_t weight = 0;
125 Double_t xL = 0.0, xPrimeL = 0.0, a = 0.0;
126
127 // ibucur: i think the user should have the possibility to display all the message
128 // levels should they want to; maybe a setPrintLevel would be appropriate
129 // (maybe for the other classes that use this approach as well)?
132
133 // We will need to check if log-likelihood evaluation left an error status.
134 // Now using faster eval error logging with CountErrors.
135 if (fType == kLog) {
137 //N.B: need to clear the count in case of previous errors !
138 // the clear needs also to be done after calling setEvalErrorLoggingMode
140 }
141
142 bool hadEvalError = true;
143
144 Int_t i = 0;
145 // get a good starting point for x
146 // for fType == kLog, this means that fFunction->getVal() did not cause
147 // an eval error
148 // for fType == kRegular this means fFunction->getVal() != 0
149 //
150 // kbelasco: i < 1000 is sort of arbitrary, but way higher than the number of
151 // steps we should have to take for any reasonable (log) likelihood function
152 while (i < 1000 && hadEvalError) {
155 xL = fFunction->getVal();
156
157 if (fType == kLog) {
158 if (RooAbsReal::numEvalErrors() > 0) {
160 hadEvalError = true;
161 } else
162 hadEvalError = false;
163 } else if (fType == kRegular) {
164 if (xL == 0.0)
165 hadEvalError = true;
166 else
167 hadEvalError = false;
168 } else
169 // for now the only 2 types are kLog and kRegular (won't get here)
170 hadEvalError = false;
171 ++i;
172 }
173
174 if(hadEvalError) {
175 coutE(Eval) << "Problem finding a good starting point in " <<
176 "MetropolisHastings::ConstructChain() " << endl;
177 }
178
179
180 ooccoutP((TObject *)0, Generation) << "Metropolis-Hastings progress: ";
181
182 // do main loop
183 for (i = 0; i < fNumIters; i++) {
184 // reset error handling flag
185 hadEvalError = false;
186
187 // print a dot every 1% of the chain construction
188 if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)0, Generation) << ".";
189
190 fPropFunc->Propose(xPrime, x);
191
193 xPrimeL = fFunction->getVal();
194
195 // check if log-likelihood for xprime had an error status
196 if (fFunction->numEvalErrors() > 0 && fType == kLog) {
197 xPrimeL = RooNumber::infinity();
199 hadEvalError = true;
200 }
201
202 // why evaluate the last point again, can't we cache it?
203 // kbelasco: commenting out lines below to add/test caching support
204 //RooStats::SetParameters(&x, &fParameters);
205 //xL = fFunction->getVal();
206
207 if (fType == kLog) {
208 if (fSign == kPositive)
209 a = xL - xPrimeL;
210 else
211 a = xPrimeL - xL;
212 }
213 else
214 a = xPrimeL / xL;
215 //a = xL / xPrimeL;
216
217 if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
218 Double_t xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
219 Double_t xPD = fPropFunc->GetProposalDensity(x, xPrime);
220 if (fType == kRegular)
221 a *= xPD / xPrimePD;
222 else
223 a += TMath::Log(xPrimePD) - TMath::Log(xPD);
224 }
225
226 if (!hadEvalError && ShouldTakeStep(a)) {
227 // go to the proposed point xPrime
228
229 // add the current point with the current weight
230 if (weight != 0.0)
231 chain->Add(x, CalcNLL(xL), (Double_t)weight);
232
233 // reset the weight and go to xPrime
234 weight = 1;
235 RooStats::SetParameters(&xPrime, &x);
236 xL = xPrimeL;
237 } else {
238 // stay at the current point
239 weight++;
240 }
241 }
242
243 // make sure to add the last point
244 if (weight != 0.0)
245 chain->Add(x, CalcNLL(xL), (Double_t)weight);
246 ooccoutP((TObject *)0, Generation) << endl;
247
249
250 Int_t numAccepted = chain->Size();
251 coutI(Eval) << "Proposal acceptance rate: " <<
252 numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
253 coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
254
255 //TFile chainDataFile("chainData.root", "recreate");
256 //chain->GetDataSet()->Write();
257 //chainDataFile.Close();
258
259 return chain;
260}
261
262////////////////////////////////////////////////////////////////////////////////
263
265{
266 if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
267 // The proposed point has a higher likelihood than the
268 // current point, so we should go there
269 return kTRUE;
270 }
271 else {
272 // generate numbers on a log distribution to decide
273 // whether to go to xPrime or stay at x
274 //Double_t rand = fGen.Uniform(1.0);
276 if (fType == kLog) {
277 rand = TMath::Log(rand);
278 // kbelasco: should this be changed to just (-rand > a) for logical
279 // consistency with below test when fType == kRegular?
280 if (-1.0 * rand >= a)
281 // we chose to go to the new proposed point
282 // even though it has a lower likelihood than the current one
283 return kTRUE;
284 } else {
285 // fType must be kRegular
286 // kbelasco: ensure that we never visit a point where PDF == 0
287 //if (rand <= a)
288 if (rand < a)
289 // we chose to go to the new proposed point
290 // even though it has a lower likelihood than the current one
291 return kTRUE;
292 }
293 return kFALSE;
294 }
295}
296
297////////////////////////////////////////////////////////////////////////////////
298
300{
301 if (fType == kLog) {
302 if (fSign == kNegative)
303 return xL;
304 else
305 return -xL;
306 } else {
307 if (fSign == kPositive)
308 return -1.0 * TMath::Log(xL);
309 else
310 return -1.0 * TMath::Log(-xL);
311 }
312}
#define a(i)
Definition RSha256.hxx:99
#define coutI(a)
#define coutE(a)
#define ooccoutP(o, a)
const Bool_t kFALSE
Definition RtypesCore.h:92
const Bool_t kTRUE
Definition RtypesCore.h:91
#define ClassImp(name)
Definition Rtypes.h:364
Int_t getSize() const
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition RooAbsReal.h:61
Double_t getVal(const RooArgSet *normalisationSet=nullptr) const
Evaluate object.
Definition RooAbsReal.h:91
static Int_t numEvalErrors()
Return the number of logged evaluation errors since the last clearing.
static void setEvalErrorLoggingMode(ErrorLoggingMode m)
Set evaluation error logging mode.
static void clearEvalErrorLog()
Clear the stack of evaluation error messages.
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:29
Bool_t add(const RooAbsArg &var, Bool_t silent=kFALSE) override
Add element to non-owning set.
RooAbsArg * addClone(const RooAbsArg &var, Bool_t silent=kFALSE) override
Add clone of specified element to an owning set.
static RooMsgService & instance()
Return reference to singleton instance.
void setGlobalKillBelow(RooFit::MsgLevel level)
RooFit::MsgLevel globalKillBelow() const
static Double_t infinity()
Return internal infinity representation.
Definition RooNumber.cxx:49
static Double_t uniform(TRandom *generator=randomGenerator())
Return a number uniformly distributed from (0,1)
Definition RooRandom.cxx:83
Stores the steps in a Markov Chain of points.
Definition MarkovChain.h:30
virtual void Add(RooArgSet &entry, Double_t nllValue, Double_t weight=1.0)
safely add an entry to the chain
virtual void SetParameters(RooArgSet &parameters)
set which of your parameters this chain should store
virtual Int_t Size() const
get the number of steps in the chain
Definition MarkovChain.h:49
This class uses the Metropolis-Hastings algorithm to construct a Markov Chain of data points using Mo...
virtual void SetProposalFunction(ProposalFunction &proposalFunction)
virtual Bool_t ShouldTakeStep(Double_t d)
virtual void SetParameters(const RooArgSet &set)
virtual MarkovChain * ConstructChain()
virtual Double_t CalcNLL(Double_t xL)
ProposalFunction is an interface for all proposal functions that would be used with a Markov Chain Mo...
virtual void Propose(RooArgSet &xPrime, RooArgSet &x)=0
Populate xPrime with the new proposed point, possibly based on the current point x.
virtual Double_t GetProposalDensity(RooArgSet &x1, RooArgSet &x2)=0
Return the probability of proposing the point x1 given the starting point x2.
virtual Bool_t IsSymmetric(RooArgSet &x1, RooArgSet &x2)=0
Determine whether or not the proposal density is symmetric for points x1 and x2 - that is,...
Mother of all ROOT objects.
Definition TObject.h:37
Double_t x[n]
Definition legend1.C:17
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
MsgLevel
Verbosity level for RooMsgService::StreamConfig in RooMsgService.
Namespace for the RooStats classes.
Definition Asimov.h:19
void SetParameters(const RooArgSet *desiredVals, RooArgSet *paramsToChange)
void RandomizeCollection(RooAbsCollection &set, Bool_t randomizeConstants=kTRUE)
Double_t Log(Double_t x)
Definition TMath.h:760