Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MetropolisHastings.cxx
Go to the documentation of this file.
1// @(#)root/roostats:$Id$
2// Authors: Kevin Belasco 17/06/2009
3// Authors: Kyle Cranmer 17/06/2009
4/*************************************************************************
5 * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/** \class RooStats::MetropolisHastings
13 \ingroup Roostats
14
15This class uses the Metropolis-Hastings algorithm to construct a Markov Chain
16of data points using Monte Carlo. In the main algorithm, new points in the
17parameter space are proposed and then visited based on their relative
18likelihoods. This class can use any implementation of the ProposalFunction,
19including non-symmetric proposal functions, to propose parameter points and
20still maintain detailed balance when constructing the chain.
21
22
23
24The "Likelihood" function that is sampled when deciding what steps to take in
25the chain has been given a very generic implementation. The user can create
26any RooAbsReal based on the parameters and pass it to a MetropolisHastings
27object with the method SetFunction(RooAbsReal&). Be sure to tell
28MetropolisHastings whether your RooAbsReal is on a (+/-) regular or log scale,
29so that it knows what logic to use when sampling your RooAbsReal. For example,
30a common use is to sample from a -log(Likelihood) distribution (NLL), for which
31the appropriate configuration calls are SetType(MetropolisHastings::kLog);
32SetSign(MetropolisHastings::kNegative);
33If you're using a traditional likelihood function:
34SetType(MetropolisHastings::kRegular); SetSign(MetropolisHastings::kPositive);
35You must set these type and sign flags or MetropolisHastings will not construct
36a MarkovChain.
37
38Also note that in ConstructChain(), the values of the variables are randomized
39uniformly over their intervals before construction of the MarkovChain begins.
40
41*/
42
44
49
50#include "Rtypes.h"
51#include "RooRealVar.h"
52#include "RooGlobalFunc.h"
53#include "RooDataSet.h"
54#include "RooArgSet.h"
55#include "RooArgList.h"
56#include "RooMsgService.h"
57#include "RooRandom.h"
58#include "TMath.h"
59
61
62using namespace RooFit;
63using namespace RooStats;
64using namespace std;
65
66////////////////////////////////////////////////////////////////////////////////
67
69{
70 // default constructor
71 fFunction = nullptr;
72 fPropFunc = nullptr;
73 fNumIters = 0;
77}
78
79////////////////////////////////////////////////////////////////////////////////
80
82 ProposalFunction& proposalFunction, Int_t numIters)
83{
84 fFunction = &function;
85 SetParameters(paramsOfInterest);
86 SetProposalFunction(proposalFunction);
87 fNumIters = numIters;
91}
92
93////////////////////////////////////////////////////////////////////////////////
94
96{
97 if (fParameters.empty() || !fPropFunc || !fFunction) {
98 coutE(Eval) << "Critical members uninitialized: parameters, proposal " <<
99 " function, or (log) likelihood function" << endl;
100 return nullptr;
101 }
102 if (fSign == kSignUnset || fType == kTypeUnset) {
103 coutE(Eval) << "Please set type and sign of your function using "
104 << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
105 endl;
106 return nullptr;
107 }
108
110
111 RooArgSet x;
112 RooArgSet xPrime;
113 x.addClone(fParameters);
115 xPrime.addClone(fParameters);
116 RandomizeCollection(xPrime);
117
118 MarkovChain* chain = new MarkovChain();
119 // only the POI will be added to the chain
121
122 Int_t weight = 0;
123 double xL = 0.0, xPrimeL = 0.0, a = 0.0;
124
125 // ibucur: i think the user should have the possibility to display all the message
126 // levels should they want to; maybe a setPrintLevel would be appropriate
127 // (maybe for the other classes that use this approach as well)?
130
131 // We will need to check if log-likelihood evaluation left an error status.
132 // Now using faster eval error logging with CountErrors.
133 if (fType == kLog) {
135 //N.B: need to clear the count in case of previous errors !
136 // the clear needs also to be done after calling setEvalErrorLoggingMode
138 }
139
140 bool hadEvalError = true;
141
142 Int_t i = 0;
143 // get a good starting point for x
144 // for fType == kLog, this means that fFunction->getVal() did not cause
145 // an eval error
146 // for fType == kRegular this means fFunction->getVal() != 0
147 //
148 // kbelasco: i < 1000 is sort of arbitrary, but way higher than the number of
149 // steps we should have to take for any reasonable (log) likelihood function
150 while (i < 1000 && hadEvalError) {
153 xL = fFunction->getVal();
154
155 if (fType == kLog) {
156 if (RooAbsReal::numEvalErrors() > 0) {
158 hadEvalError = true;
159 } else
160 hadEvalError = false;
161 } else if (fType == kRegular) {
162 if (xL == 0.0)
163 hadEvalError = true;
164 else
165 hadEvalError = false;
166 } else
167 // for now the only 2 types are kLog and kRegular (won't get here)
168 hadEvalError = false;
169 ++i;
170 }
171
172 if(hadEvalError) {
173 coutE(Eval) << "Problem finding a good starting point in " <<
174 "MetropolisHastings::ConstructChain() " << endl;
175 }
176
177
178 ooccoutP((TObject *)nullptr, Generation) << "Metropolis-Hastings progress: ";
179
180 // do main loop
181 for (i = 0; i < fNumIters; i++) {
182 // reset error handling flag
183 hadEvalError = false;
184
185 // print a dot every 1% of the chain construction
186 if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)nullptr, Generation) << ".";
187
188 fPropFunc->Propose(xPrime, x);
189
191 xPrimeL = fFunction->getVal();
192
193 // check if log-likelihood for xprime had an error status
194 if (fFunction->numEvalErrors() > 0 && fType == kLog) {
195 xPrimeL = RooNumber::infinity();
197 hadEvalError = true;
198 }
199
200 // why evaluate the last point again, can't we cache it?
201 // kbelasco: commenting out lines below to add/test caching support
202 //RooStats::SetParameters(&x, &fParameters);
203 //xL = fFunction->getVal();
204
205 if (fType == kLog) {
206 if (fSign == kPositive)
207 a = xL - xPrimeL;
208 else
209 a = xPrimeL - xL;
210 }
211 else
212 a = xPrimeL / xL;
213 //a = xL / xPrimeL;
214
215 if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
216 double xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
217 double xPD = fPropFunc->GetProposalDensity(x, xPrime);
218 if (fType == kRegular)
219 a *= xPD / xPrimePD;
220 else
221 a += TMath::Log(xPrimePD) - TMath::Log(xPD);
222 }
223
224 if (!hadEvalError && ShouldTakeStep(a)) {
225 // go to the proposed point xPrime
226
227 // add the current point with the current weight
228 if (weight != 0.0)
229 chain->Add(x, CalcNLL(xL), (double)weight);
230
231 // reset the weight and go to xPrime
232 weight = 1;
233 RooStats::SetParameters(&xPrime, &x);
234 xL = xPrimeL;
235 } else {
236 // stay at the current point
237 weight++;
238 }
239 }
240
241 // make sure to add the last point
242 if (weight != 0.0)
243 chain->Add(x, CalcNLL(xL), (double)weight);
244 ooccoutP((TObject *)nullptr, Generation) << endl;
245
247
248 Int_t numAccepted = chain->Size();
249 coutI(Eval) << "Proposal acceptance rate: " <<
250 numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
251 coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
252
253 //TFile chainDataFile("chainData.root", "recreate");
254 //chain->GetDataSet()->Write();
255 //chainDataFile.Close();
256
257 return chain;
258}
259
260////////////////////////////////////////////////////////////////////////////////
261
263{
264 if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
265 // The proposed point has a higher likelihood than the
266 // current point, so we should go there
267 return true;
268 }
269 else {
270 // generate numbers on a log distribution to decide
271 // whether to go to xPrime or stay at x
272 //double rand = fGen.Uniform(1.0);
273 double rand = RooRandom::uniform();
274 if (fType == kLog) {
275 rand = TMath::Log(rand);
276 // kbelasco: should this be changed to just (-rand > a) for logical
277 // consistency with below test when fType == kRegular?
278 if (-1.0 * rand >= a)
279 // we chose to go to the new proposed point
280 // even though it has a lower likelihood than the current one
281 return true;
282 } else {
283 // fType must be kRegular
284 // kbelasco: ensure that we never visit a point where PDF == 0
285 //if (rand <= a)
286 if (rand < a)
287 // we chose to go to the new proposed point
288 // even though it has a lower likelihood than the current one
289 return true;
290 }
291 return false;
292 }
293}
294
295////////////////////////////////////////////////////////////////////////////////
296
298{
299 if (fType == kLog) {
300 if (fSign == kNegative)
301 return xL;
302 else
303 return -xL;
304 } else {
305 if (fSign == kPositive)
306 return -1.0 * TMath::Log(xL);
307 else
308 return -1.0 * TMath::Log(-xL);
309 }
310}
#define a(i)
Definition RSha256.hxx:99
#define coutI(a)
#define coutE(a)
#define ooccoutP(o, a)
#define ClassImp(name)
Definition Rtypes.h:377
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
virtual RooAbsArg * addClone(const RooAbsArg &var, bool silent=false)
Add a clone of the specified argument to list.
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
double getVal(const RooArgSet *normalisationSet=nullptr) const
Evaluate object.
Definition RooAbsReal.h:103
static Int_t numEvalErrors()
Return the number of logged evaluation errors since the last clearing.
static void setEvalErrorLoggingMode(ErrorLoggingMode m)
Set evaluation error logging mode.
static void clearEvalErrorLog()
Clear the stack of evaluation error messages.
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
static RooMsgService & instance()
Return reference to singleton instance.
void setGlobalKillBelow(RooFit::MsgLevel level)
RooFit::MsgLevel globalKillBelow() const
static constexpr double infinity()
Return internal infinity representation.
Definition RooNumber.h:25
static double uniform(TRandom *generator=randomGenerator())
Return a number uniformly distributed from (0,1)
Definition RooRandom.cxx:81
Stores the steps in a Markov Chain of points.
Definition MarkovChain.h:30
virtual void Add(RooArgSet &entry, double nllValue, double weight=1.0)
safely add an entry to the chain
virtual void SetParameters(RooArgSet &parameters)
set which of your parameters this chain should store
virtual Int_t Size() const
get the number of steps in the chain
Definition MarkovChain.h:49
This class uses the Metropolis-Hastings algorithm to construct a Markov Chain of data points using Mo...
RooArgSet fParameters
RooRealVars that define all parameter space.
virtual void SetProposalFunction(ProposalFunction &proposalFunction)
set the proposal function for suggesting new points for the MCMC
Int_t fNumIters
number of iterations to run metropolis algorithm
RooAbsReal * fFunction
function that will generate likelihood values
virtual void SetParameters(const RooArgSet &set)
specify all the parameters of interest in the interval
virtual MarkovChain * ConstructChain()
main purpose of MetropolisHastings - run Metropolis-Hastings algorithm to generate Markov Chain of po...
enum FunctionType fType
whether the likelihood is on a regular, log, (or other) scale
enum FunctionSign fSign
whether the likelihood is negative (like NLL) or positive
RooArgSet fChainParams
RooRealVars that are stored in the chain.
virtual bool ShouldTakeStep(double d)
virtual double CalcNLL(double xL)
MetropolisHastings()
default constructor
Int_t fNumBurnInSteps
number of iterations to discard as burn-in, starting from the first
ProposalFunction * fPropFunc
Proposal function for MCMC integration.
ProposalFunction is an interface for all proposal functions that would be used with a Markov Chain Mo...
virtual void Propose(RooArgSet &xPrime, RooArgSet &x)=0
Populate xPrime with the new proposed point, possibly based on the current point x.
virtual double GetProposalDensity(RooArgSet &x1, RooArgSet &x2)=0
Return the probability of proposing the point x1 given the starting point x2.
virtual bool IsSymmetric(RooArgSet &x1, RooArgSet &x2)=0
Determine whether or not the proposal density is symmetric for points x1 and x2 - that is,...
Mother of all ROOT objects.
Definition TObject.h:41
Double_t x[n]
Definition legend1.C:17
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
MsgLevel
Verbosity level for RooMsgService::StreamConfig in RooMsgService.
Namespace for the RooStats classes.
Definition Asimov.h:19
void SetParameters(const RooArgSet *desiredVals, RooArgSet *paramsToChange)
void RandomizeCollection(RooAbsCollection &set, bool randomizeConstants=true)
assuming all values in set are RooRealVars, randomize their values
Double_t Log(Double_t x)
Returns the natural logarithm of x.
Definition TMath.h:756