Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
MetropolisHastings.cxx
Go to the documentation of this file.
1// @(#)root/roostats:$Id$
2// Authors: Kevin Belasco 17/06/2009
3// Authors: Kyle Cranmer 17/06/2009
4/*************************************************************************
5 * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/** \class RooStats::MetropolisHastings
13 \ingroup Roostats
14
15This class uses the Metropolis-Hastings algorithm to construct a Markov Chain
16of data points using Monte Carlo. In the main algorithm, new points in the
17parameter space are proposed and then visited based on their relative
18likelihoods. This class can use any implementation of the ProposalFunction,
19including non-symmetric proposal functions, to propose parameter points and
20still maintain detailed balance when constructing the chain.
21
22
23
24The "Likelihood" function that is sampled when deciding what steps to take in
25the chain has been given a very generic implementation. The user can create
26any RooAbsReal based on the parameters and pass it to a MetropolisHastings
27object with the method SetFunction(RooAbsReal&). Be sure to tell
28MetropolisHastings whether your RooAbsReal is on a (+/-) regular or log scale,
29so that it knows what logic to use when sampling your RooAbsReal. For example,
30a common use is to sample from a -log(Likelihood) distribution (NLL), for which
31the appropriate configuration calls are SetType(MetropolisHastings::kLog);
32SetSign(MetropolisHastings::kNegative);
33If you're using a traditional likelihood function:
34SetType(MetropolisHastings::kRegular); SetSign(MetropolisHastings::kPositive);
35You must set these type and sign flags or MetropolisHastings will not construct
36a MarkovChain.
37
38Also note that in ConstructChain(), the values of the variables are randomized
39uniformly over their intervals before construction of the MarkovChain begins.
40
41*/
42
44
49
50#include "Rtypes.h"
51#include "RooRealVar.h"
52#include "RooGlobalFunc.h"
53#include "RooDataSet.h"
54#include "RooArgSet.h"
55#include "RooArgList.h"
56#include "RooMsgService.h"
57#include "RooRandom.h"
58#include "TMath.h"
59
60
61using namespace RooFit;
62using namespace RooStats;
63using std::endl;
64
65////////////////////////////////////////////////////////////////////////////////
66
74
75////////////////////////////////////////////////////////////////////////////////
76
78{
79 if (fParameters.empty() || !fPropFunc || !fFunction) {
80 coutE(Eval) << "Critical members uninitialized: parameters, proposal " <<
81 " function, or (log) likelihood function" << std::endl;
82 return nullptr;
83 }
84 if (fSign == kSignUnset || fType == kTypeUnset) {
85 coutE(Eval) << "Please set type and sign of your function using "
86 << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
87 std::endl;
88 return nullptr;
89 }
90
92
95 x.addClone(fParameters);
97 xPrime.addClone(fParameters);
99
101 // only the POI will be added to the chain
102 chain->SetParameters(fChainParams);
103
104 Int_t weight = 0;
105 double xL = 0.0;
106 double xPrimeL = 0.0;
107 double a = 0.0;
108
109 // ibucur: i think the user should have the possibility to display all the message
110 // levels should they want to; maybe a setPrintLevel would be appropriate
111 // (maybe for the other classes that use this approach as well)?
113 RooMsgService::instance().setGlobalKillBelow(RooFit::PROGRESS);
114
115 // We will need to check if log-likelihood evaluation left an error status.
116 // Now using faster eval error logging with CountErrors.
117 if (fType == kLog) {
119 //N.B: need to clear the count in case of previous errors !
120 // the clear needs also to be done after calling setEvalErrorLoggingMode
122 }
123
124 bool hadEvalError = true;
125
126 Int_t i = 0;
127 // get a good starting point for x
128 // for fType == kLog, this means that fFunction->getVal() did not cause
129 // an eval error
130 // for fType == kRegular this means fFunction->getVal() != 0
131 //
132 // kbelasco: i < 1000 is sort of arbitrary, but way higher than the number of
133 // steps we should have to take for any reasonable (log) likelihood function
134 while (i < 1000 && hadEvalError) {
137 xL = fFunction->getVal();
138
139 if (fType == kLog) {
140 if (RooAbsReal::numEvalErrors() > 0) {
142 hadEvalError = true;
143 } else
144 hadEvalError = false;
145 } else if (fType == kRegular) {
146 if (xL == 0.0) {
147 hadEvalError = true;
148 } else {
149 hadEvalError = false;
150 }
151 } else {
152 // for now the only 2 types are kLog and kRegular (won't get here)
153 hadEvalError = false;
154 }
155 ++i;
156 }
157
158 if(hadEvalError) {
159 coutE(Eval) << "Problem finding a good starting point in " <<
160 "MetropolisHastings::ConstructChain() " << std::endl;
161 }
162
163
164 ooccoutP((TObject *)nullptr, Generation) << "Metropolis-Hastings progress: ";
165
166 // do main loop
167 for (i = 0; i < fNumIters; i++) {
168 // reset error handling flag
169 hadEvalError = false;
170
171 // print a dot every 1% of the chain construction
172 if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)nullptr, Generation) << ".";
173
175
178
179 // check if log-likelihood for xprime had an error status
180 if (fFunction->numEvalErrors() > 0 && fType == kLog) {
183 hadEvalError = true;
184 }
185
186 // why evaluate the last point again, can't we cache it?
187 // kbelasco: commenting out lines below to add/test caching support
188 //RooStats::SetParameters(&x, &fParameters);
189 //xL = fFunction->getVal();
190
191 if (fType == kLog) {
192 if (fSign == kPositive) {
193 a = xL - xPrimeL;
194 } else {
195 a = xPrimeL - xL;
196 }
197 }
198 else
199 a = xPrimeL / xL;
200 //a = xL / xPrimeL;
201
205 if (fType == kRegular) {
206 a *= xPD / xPrimePD;
207 } else {
208 a += std::log(xPrimePD) - TMath::Log(xPD);
209 }
210 }
211
212 if (!hadEvalError && ShouldTakeStep(a)) {
213 // go to the proposed point xPrime
214
215 // add the current point with the current weight
216 if (weight != 0.0)
217 chain->Add(x, CalcNLL(xL), (double)weight);
218
219 // reset the weight and go to xPrime
220 weight = 1;
222 xL = xPrimeL;
223 } else {
224 // stay at the current point
225 weight++;
226 }
227 }
228
229 // make sure to add the last point
230 if (weight != 0.0)
231 chain->Add(x, CalcNLL(xL), (double)weight);
232 ooccoutP((TObject *)nullptr, Generation) << std::endl;
233
234 RooMsgService::instance().setGlobalKillBelow(oldMsgLevel);
235
236 Int_t numAccepted = chain->Size();
237 coutI(Eval) << "Proposal acceptance rate: " <<
238 numAccepted/(Float_t)fNumIters * 100 << "%" << std::endl;
239 coutI(Eval) << "Number of steps in chain: " << numAccepted << std::endl;
240
241 //TFile chainDataFile("chainData.root", "recreate");
242 //chain->GetDataSet()->Write();
243 //chainDataFile.Close();
244
245 return chain;
246}
247
248////////////////////////////////////////////////////////////////////////////////
249
251{
252 if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
253 // The proposed point has a higher likelihood than the
254 // current point, so we should go there
255 return true;
256 }
257 else {
258 // generate numbers on a log distribution to decide
259 // whether to go to xPrime or stay at x
260 //double rand = fGen.Uniform(1.0);
261 double rand = RooRandom::uniform();
262 if (fType == kLog) {
263 rand = std::log(rand);
264 // kbelasco: should this be changed to just (-rand > a) for logical
265 // consistency with below test when fType == kRegular?
266 if (-1.0 * rand >= a) {
267 // we chose to go to the new proposed point
268 // even though it has a lower likelihood than the current one
269 return true;
270 }
271 } else {
272 // fType must be kRegular
273 // kbelasco: ensure that we never visit a point where PDF == 0
274 //if (rand <= a)
275 if (rand < a) {
276 // we chose to go to the new proposed point
277 // even though it has a lower likelihood than the current one
278 return true;
279 }
280 }
281 return false;
282 }
283}
284
285////////////////////////////////////////////////////////////////////////////////
286
288{
289 if (fType == kLog) {
290 if (fSign == kNegative) {
291 return xL;
292 } else {
293 return -xL;
294 }
295 } else {
296 if (fSign == kPositive) {
297 return -1.0 * std::log(xL);
298 } else {
299 return -1.0 * std::log(-xL);
300 }
301 }
302}
#define a(i)
Definition RSha256.hxx:99
#define coutI(a)
#define coutE(a)
#define ooccoutP(o, a)
float Float_t
Definition RtypesCore.h:57
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
double getVal(const RooArgSet *normalisationSet=nullptr) const
Evaluate object.
Definition RooAbsReal.h:103
static Int_t numEvalErrors()
Return the number of logged evaluation errors since the last clearing.
static void setEvalErrorLoggingMode(ErrorLoggingMode m)
Set evaluation error logging mode.
static void clearEvalErrorLog()
Clear the stack of evaluation error messages.
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
static RooMsgService & instance()
Return reference to singleton instance.
static constexpr double infinity()
Return internal infinity representation.
Definition RooNumber.h:25
static double uniform(TRandom *generator=randomGenerator())
Return a number uniformly distributed from (0,1)
Definition RooRandom.cxx:77
Stores the steps in a Markov Chain of points.
Definition MarkovChain.h:26
RooArgSet fParameters
RooRealVars that define all parameter space.
virtual void SetProposalFunction(ProposalFunction &proposalFunction)
set the proposal function for suggesting new points for the MCMC
Int_t fNumIters
number of iterations to run metropolis algorithm
MetropolisHastings()=default
default constructor
RooAbsReal * fFunction
function that will generate likelihood values
virtual void SetParameters(const RooArgSet &set)
specify all the parameters of interest in the interval
virtual MarkovChain * ConstructChain()
main purpose of MetropolisHastings - run Metropolis-Hastings algorithm to generate Markov Chain of po...
enum FunctionType fType
whether the likelihood is on a regular, log, (or other) scale
enum FunctionSign fSign
whether the likelihood is negative (like NLL) or positive
RooArgSet fChainParams
RooRealVars that are stored in the chain.
virtual bool ShouldTakeStep(double d)
virtual double CalcNLL(double xL)
ProposalFunction * fPropFunc
Proposal function for MCMC integration.
ProposalFunction is an interface for all proposal functions that would be used with a Markov Chain Mo...
virtual void Propose(RooArgSet &xPrime, RooArgSet &x)=0
Populate xPrime with the new proposed point, possibly based on the current point x.
virtual double GetProposalDensity(RooArgSet &x1, RooArgSet &x2)=0
Return the probability of proposing the point x1 given the starting point x2.
virtual bool IsSymmetric(RooArgSet &x1, RooArgSet &x2)=0
Determine whether or not the proposal density is symmetric for points x1 and x2 - that is,...
Mother of all ROOT objects.
Definition TObject.h:41
Double_t x[n]
Definition legend1.C:17
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition CodegenImpl.h:64
MsgLevel
Verbosity level for RooMsgService::StreamConfig in RooMsgService.
Namespace for the RooStats classes.
Definition CodegenImpl.h:58
void SetParameters(const RooArgSet *desiredVals, RooArgSet *paramsToChange)
void RandomizeCollection(RooAbsCollection &set, bool randomizeConstants=true)
assuming all values in set are RooRealVars, randomize their values
Double_t Log(Double_t x)
Returns the natural logarithm of x.
Definition TMath.h:760