Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MetropolisHastings.cxx
Go to the documentation of this file.
1// @(#)root/roostats:$Id$
2// Authors: Kevin Belasco 17/06/2009
3// Authors: Kyle Cranmer 17/06/2009
4/*************************************************************************
5 * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/** \class RooStats::MetropolisHastings
13 \ingroup Roostats
14
15This class uses the Metropolis-Hastings algorithm to construct a Markov Chain
16of data points using Monte Carlo. In the main algorithm, new points in the
17parameter space are proposed and then visited based on their relative
18likelihoods. This class can use any implementation of the ProposalFunction,
19including non-symmetric proposal functions, to propose parameter points and
20still maintain detailed balance when constructing the chain.
21
22
23
24The "Likelihood" function that is sampled when deciding what steps to take in
25the chain has been given a very generic implementation. The user can create
26any RooAbsReal based on the parameters and pass it to a MetropolisHastings
27object with the method SetFunction(RooAbsReal&). Be sure to tell
28MetropolisHastings whether your RooAbsReal is on a (+/-) regular or log scale,
29so that it knows what logic to use when sampling your RooAbsReal. For example,
30a common use is to sample from a -log(Likelihood) distribution (NLL), for which
31the appropriate configuration calls are SetType(MetropolisHastings::kLog);
32SetSign(MetropolisHastings::kNegative);
33If you're using a traditional likelihood function:
34SetType(MetropolisHastings::kRegular); SetSign(MetropolisHastings::kPositive);
35You must set these type and sign flags or MetropolisHastings will not construct
36a MarkovChain.
37
38Also note that in ConstructChain(), the values of the variables are randomized
39uniformly over their intervals before construction of the MarkovChain begins.
40
41*/
42
44
49
50#include "Rtypes.h"
51#include "RooRealVar.h"
52#include "RooGlobalFunc.h"
53#include "RooDataSet.h"
54#include "RooArgSet.h"
55#include "RooArgList.h"
56#include "RooMsgService.h"
57#include "RooRandom.h"
58#include "TMath.h"
59
61
62using namespace RooFit;
63using namespace RooStats;
64using std::endl;
65
66////////////////////////////////////////////////////////////////////////////////
67
69 ProposalFunction &proposalFunction, Int_t numIters)
70 : fFunction(&function), fNumIters(numIters)
71{
72 SetParameters(paramsOfInterest);
73 SetProposalFunction(proposalFunction);
74}
75
76////////////////////////////////////////////////////////////////////////////////
77
79{
80 if (fParameters.empty() || !fPropFunc || !fFunction) {
81 coutE(Eval) << "Critical members uninitialized: parameters, proposal " <<
82 " function, or (log) likelihood function" << endl;
83 return nullptr;
84 }
85 if (fSign == kSignUnset || fType == kTypeUnset) {
86 coutE(Eval) << "Please set type and sign of your function using "
87 << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
88 endl;
89 return nullptr;
90 }
91
93
95 RooArgSet xPrime;
96 x.addClone(fParameters);
98 xPrime.addClone(fParameters);
99 RandomizeCollection(xPrime);
100
101 MarkovChain* chain = new MarkovChain();
102 // only the POI will be added to the chain
104
105 Int_t weight = 0;
106 double xL = 0.0;
107 double xPrimeL = 0.0;
108 double a = 0.0;
109
110 // ibucur: i think the user should have the possibility to display all the message
111 // levels should they want to; maybe a setPrintLevel would be appropriate
112 // (maybe for the other classes that use this approach as well)?
115
116 // We will need to check if log-likelihood evaluation left an error status.
117 // Now using faster eval error logging with CountErrors.
118 if (fType == kLog) {
120 //N.B: need to clear the count in case of previous errors !
121 // the clear needs also to be done after calling setEvalErrorLoggingMode
123 }
124
125 bool hadEvalError = true;
126
127 Int_t i = 0;
128 // get a good starting point for x
129 // for fType == kLog, this means that fFunction->getVal() did not cause
130 // an eval error
131 // for fType == kRegular this means fFunction->getVal() != 0
132 //
133 // kbelasco: i < 1000 is sort of arbitrary, but way higher than the number of
134 // steps we should have to take for any reasonable (log) likelihood function
135 while (i < 1000 && hadEvalError) {
138 xL = fFunction->getVal();
139
140 if (fType == kLog) {
141 if (RooAbsReal::numEvalErrors() > 0) {
143 hadEvalError = true;
144 } else
145 hadEvalError = false;
146 } else if (fType == kRegular) {
147 if (xL == 0.0) {
148 hadEvalError = true;
149 } else {
150 hadEvalError = false;
151 }
152 } else {
153 // for now the only 2 types are kLog and kRegular (won't get here)
154 hadEvalError = false;
155 }
156 ++i;
157 }
158
159 if(hadEvalError) {
160 coutE(Eval) << "Problem finding a good starting point in " <<
161 "MetropolisHastings::ConstructChain() " << endl;
162 }
163
164
165 ooccoutP((TObject *)nullptr, Generation) << "Metropolis-Hastings progress: ";
166
167 // do main loop
168 for (i = 0; i < fNumIters; i++) {
169 // reset error handling flag
170 hadEvalError = false;
171
172 // print a dot every 1% of the chain construction
173 if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)nullptr, Generation) << ".";
174
175 fPropFunc->Propose(xPrime, x);
176
178 xPrimeL = fFunction->getVal();
179
180 // check if log-likelihood for xprime had an error status
181 if (fFunction->numEvalErrors() > 0 && fType == kLog) {
182 xPrimeL = RooNumber::infinity();
184 hadEvalError = true;
185 }
186
187 // why evaluate the last point again, can't we cache it?
188 // kbelasco: commenting out lines below to add/test caching support
189 //RooStats::SetParameters(&x, &fParameters);
190 //xL = fFunction->getVal();
191
192 if (fType == kLog) {
193 if (fSign == kPositive) {
194 a = xL - xPrimeL;
195 } else {
196 a = xPrimeL - xL;
197 }
198 }
199 else
200 a = xPrimeL / xL;
201 //a = xL / xPrimeL;
202
203 if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
204 double xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
205 double xPD = fPropFunc->GetProposalDensity(x, xPrime);
206 if (fType == kRegular) {
207 a *= xPD / xPrimePD;
208 } else {
209 a += std::log(xPrimePD) - TMath::Log(xPD);
210 }
211 }
212
213 if (!hadEvalError && ShouldTakeStep(a)) {
214 // go to the proposed point xPrime
215
216 // add the current point with the current weight
217 if (weight != 0.0)
218 chain->Add(x, CalcNLL(xL), (double)weight);
219
220 // reset the weight and go to xPrime
221 weight = 1;
222 RooStats::SetParameters(&xPrime, &x);
223 xL = xPrimeL;
224 } else {
225 // stay at the current point
226 weight++;
227 }
228 }
229
230 // make sure to add the last point
231 if (weight != 0.0)
232 chain->Add(x, CalcNLL(xL), (double)weight);
233 ooccoutP((TObject *)nullptr, Generation) << endl;
234
236
237 Int_t numAccepted = chain->Size();
238 coutI(Eval) << "Proposal acceptance rate: " <<
239 numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
240 coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
241
242 //TFile chainDataFile("chainData.root", "recreate");
243 //chain->GetDataSet()->Write();
244 //chainDataFile.Close();
245
246 return chain;
247}
248
249////////////////////////////////////////////////////////////////////////////////
250
252{
253 if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
254 // The proposed point has a higher likelihood than the
255 // current point, so we should go there
256 return true;
257 }
258 else {
259 // generate numbers on a log distribution to decide
260 // whether to go to xPrime or stay at x
261 //double rand = fGen.Uniform(1.0);
262 double rand = RooRandom::uniform();
263 if (fType == kLog) {
264 rand = std::log(rand);
265 // kbelasco: should this be changed to just (-rand > a) for logical
266 // consistency with below test when fType == kRegular?
267 if (-1.0 * rand >= a) {
268 // we chose to go to the new proposed point
269 // even though it has a lower likelihood than the current one
270 return true;
271 }
272 } else {
273 // fType must be kRegular
274 // kbelasco: ensure that we never visit a point where PDF == 0
275 //if (rand <= a)
276 if (rand < a) {
277 // we chose to go to the new proposed point
278 // even though it has a lower likelihood than the current one
279 return true;
280 }
281 }
282 return false;
283 }
284}
285
286////////////////////////////////////////////////////////////////////////////////
287
289{
290 if (fType == kLog) {
291 if (fSign == kNegative) {
292 return xL;
293 } else {
294 return -xL;
295 }
296 } else {
297 if (fSign == kPositive) {
298 return -1.0 * std::log(xL);
299 } else {
300 return -1.0 * std::log(-xL);
301 }
302 }
303}
#define a(i)
Definition RSha256.hxx:99
#define coutI(a)
#define coutE(a)
#define ooccoutP(o, a)
#define ClassImp(name)
Definition Rtypes.h:382
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
virtual RooAbsArg * addClone(const RooAbsArg &var, bool silent=false)
Add a clone of the specified argument to list.
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
double getVal(const RooArgSet *normalisationSet=nullptr) const
Evaluate object.
Definition RooAbsReal.h:103
static Int_t numEvalErrors()
Return the number of logged evaluation errors since the last clearing.
static void setEvalErrorLoggingMode(ErrorLoggingMode m)
Set evaluation error logging mode.
static void clearEvalErrorLog()
Clear the stack of evaluation error messages.
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
static RooMsgService & instance()
Return reference to singleton instance.
void setGlobalKillBelow(RooFit::MsgLevel level)
RooFit::MsgLevel globalKillBelow() const
static constexpr double infinity()
Return internal infinity representation.
Definition RooNumber.h:25
static double uniform(TRandom *generator=randomGenerator())
Return a number uniformly distributed from (0,1)
Definition RooRandom.cxx:78
Stores the steps in a Markov Chain of points.
Definition MarkovChain.h:33
virtual void Add(RooArgSet &entry, double nllValue, double weight=1.0)
safely add an entry to the chain
virtual void SetParameters(RooArgSet &parameters)
set which of your parameters this chain should store
virtual Int_t Size() const
get the number of steps in the chain
Definition MarkovChain.h:52
This class uses the Metropolis-Hastings algorithm to construct a Markov Chain of data points using Mo...
RooArgSet fParameters
RooRealVars that define all parameter space.
virtual void SetProposalFunction(ProposalFunction &proposalFunction)
set the proposal function for suggesting new points for the MCMC
Int_t fNumIters
number of iterations to run metropolis algorithm
MetropolisHastings()=default
default constructor
RooAbsReal * fFunction
function that will generate likelihood values
virtual void SetParameters(const RooArgSet &set)
specify all the parameters of interest in the interval
virtual MarkovChain * ConstructChain()
main purpose of MetropolisHastings - run Metropolis-Hastings algorithm to generate Markov Chain of po...
enum FunctionType fType
whether the likelihood is on a regular, log, (or other) scale
enum FunctionSign fSign
whether the likelihood is negative (like NLL) or positive
RooArgSet fChainParams
RooRealVars that are stored in the chain.
virtual bool ShouldTakeStep(double d)
virtual double CalcNLL(double xL)
ProposalFunction * fPropFunc
Proposal function for MCMC integration.
ProposalFunction is an interface for all proposal functions that would be used with a Markov Chain Mo...
virtual void Propose(RooArgSet &xPrime, RooArgSet &x)=0
Populate xPrime with the new proposed point, possibly based on the current point x.
virtual double GetProposalDensity(RooArgSet &x1, RooArgSet &x2)=0
Return the probability of proposing the point x1 given the starting point x2.
virtual bool IsSymmetric(RooArgSet &x1, RooArgSet &x2)=0
Determine whether or not the proposal density is symmetric for points x1 and x2 - that is,...
Mother of all ROOT objects.
Definition TObject.h:41
Double_t x[n]
Definition legend1.C:17
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
MsgLevel
Verbosity level for RooMsgService::StreamConfig in RooMsgService.
Namespace for the RooStats classes.
Definition Asimov.h:19
void SetParameters(const RooArgSet *desiredVals, RooArgSet *paramsToChange)
void RandomizeCollection(RooAbsCollection &set, bool randomizeConstants=true)
assuming all values in set are RooRealVars, randomize their values
Double_t Log(Double_t x)
Returns the natural logarithm of x.
Definition TMath.h:756