Logo ROOT   6.12/07
Reference Guide
MetropolisHastings.cxx
Go to the documentation of this file.
1 // @(#)root/roostats:$Id$
2 // Authors: Kevin Belasco 17/06/2009
3 // Authors: Kyle Cranmer 17/06/2009
4 /*************************************************************************
5  * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /** \class RooStats::MetropolisHastings
13  \ingroup Roostats
14 
15 This class uses the Metropolis-Hastings algorithm to construct a Markov Chain
16 of data points using Monte Carlo. In the main algorithm, new points in the
17 parameter space are proposed and then visited based on their relative
18 likelihoods. This class can use any implementation of the ProposalFunction,
19 including non-symmetric proposal functions, to propose parameter points and
20 still maintain detailed balance when constructing the chain.
21 
22 
23 
24 The "Likelihood" function that is sampled when deciding what steps to take in
25 the chain has been given a very generic implementation. The user can create
26 any RooAbsReal based on the parameters and pass it to a MetropolisHastings
27 object with the method SetFunction(RooAbsReal&). Be sure to tell
28 MetropolisHastings whether your RooAbsReal is on a (+/-) regular or log scale,
29 so that it knows what logic to use when sampling your RooAbsReal. For example,
30 a common use is to sample from a -log(Likelihood) distribution (NLL), for which
31 the appropriate configuration calls are SetType(MetropolisHastings::kLog);
32 SetSign(MetropolisHastings::kNegative);
33 If you're using a traditional likelihood function:
34 SetType(MetropolisHastings::kRegular); SetSign(MetropolisHastings::kPositive);
35 You must set these type and sign flags or MetropolisHastings will not construct
36 a MarkovChain.
37 
38 Also note that in ConstructChain(), the values of the variables are randomized
39 uniformly over their intervals before construction of the MarkovChain begins.
40 
41 */
42 
44 
45 #include "RooStats/MarkovChain.h"
46 #include "RooStats/MCMCInterval.h"
47 #include "RooStats/RooStatsUtils.h"
49 
50 #include "Rtypes.h"
51 #include "RooRealVar.h"
52 #include "RooNLLVar.h"
53 #include "RooGlobalFunc.h"
54 #include "RooDataSet.h"
55 #include "RooArgSet.h"
56 #include "RooArgList.h"
57 #include "RooMsgService.h"
58 #include "RooRandom.h"
59 #include "TH1.h"
60 #include "TMath.h"
61 #include "TFile.h"
62 
64 
65 using namespace RooFit;
66 using namespace RooStats;
67 using namespace std;
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 
71 MetropolisHastings::MetropolisHastings()
72 {
73  // default constructor
74  fFunction = NULL;
75  fPropFunc = NULL;
76  fNumIters = 0;
77  fNumBurnInSteps = 0;
78  fSign = kSignUnset;
79  fType = kTypeUnset;
80 }
81 
82 ////////////////////////////////////////////////////////////////////////////////
83 
84 MetropolisHastings::MetropolisHastings(RooAbsReal& function, const RooArgSet& paramsOfInterest,
85  ProposalFunction& proposalFunction, Int_t numIters)
86 {
87  fFunction = &function;
88  SetParameters(paramsOfInterest);
89  SetProposalFunction(proposalFunction);
90  fNumIters = numIters;
91  fNumBurnInSteps = 0;
92  fSign = kSignUnset;
93  fType = kTypeUnset;
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 
98 MarkovChain* MetropolisHastings::ConstructChain()
99 {
100  if (fParameters.getSize() == 0 || !fPropFunc || !fFunction) {
101  coutE(Eval) << "Critical members unintialized: parameters, proposal " <<
102  " function, or (log) likelihood function" << endl;
103  return NULL;
104  }
105  if (fSign == kSignUnset || fType == kTypeUnset) {
106  coutE(Eval) << "Please set type and sign of your function using "
107  << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
108  endl;
109  return NULL;
110  }
111 
112  if (fChainParams.getSize() == 0) fChainParams.add(fParameters);
113 
114  RooArgSet x;
115  RooArgSet xPrime;
116  x.addClone(fParameters);
118  xPrime.addClone(fParameters);
119  RandomizeCollection(xPrime);
120 
121  MarkovChain* chain = new MarkovChain();
122  // only the POI will be added to the chain
123  chain->SetParameters(fChainParams);
124 
125  Int_t weight = 0;
126  Double_t xL = 0.0, xPrimeL = 0.0, a = 0.0;
127 
128  // ibucur: i think the user should have the possibility to display all the message
129  // levels should they want to; maybe a setPrintLevel would be appropriate
130  // (maybe for the other classes that use this approach as well)?
133 
134  // We will need to check if log-likelihood evaluation left an error status.
135  // Now using faster eval error logging with CountErrors.
136  if (fType == kLog) {
138  //N.B: need to clear the count in case of previous errors !
139  // the clear needs also to be done after calling setEvalErrorLoggingMode
141  }
142 
143  bool hadEvalError = true;
144 
145  Int_t i = 0;
146  // get a good starting point for x
147  // for fType == kLog, this means that fFunction->getVal() did not cause
148  // an eval error
149  // for fType == kRegular this means fFunction->getVal() != 0
150  //
151  // kbelasco: i < 1000 is sort of arbitrary, but way higher than the number of
152  // steps we should have to take for any reasonable (log) likelihood function
153  while (i < 1000 && hadEvalError) {
155  RooStats::SetParameters(&x, &fParameters);
156  xL = fFunction->getVal();
157 
158  if (fType == kLog) {
159  if (RooAbsReal::numEvalErrors() > 0) {
161  hadEvalError = true;
162  } else
163  hadEvalError = false;
164  } else if (fType == kRegular) {
165  if (xL == 0.0)
166  hadEvalError = true;
167  else
168  hadEvalError = false;
169  } else
170  // for now the only 2 types are kLog and kRegular (won't get here)
171  hadEvalError = false;
172  }
173 
174  if(hadEvalError) {
175  coutE(Eval) << "Problem finding a good starting point in " <<
176  "MetropolisHastings::ConstructChain() " << endl;
177  }
178 
179 
180  ooccoutP((TObject *)0, Generation) << "Metropolis-Hastings progress: ";
181 
182  // do main loop
183  for (i = 0; i < fNumIters; i++) {
184  // reset error handling flag
185  hadEvalError = false;
186 
187  // print a dot every 1% of the chain construction
188  if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)0, Generation) << ".";
189 
190  fPropFunc->Propose(xPrime, x);
191 
192  RooStats::SetParameters(&xPrime, &fParameters);
193  xPrimeL = fFunction->getVal();
194 
195  // check if log-likelihood for xprime had an error status
196  if (fFunction->numEvalErrors() > 0 && fType == kLog) {
197  xPrimeL = RooNumber::infinity();
198  fFunction->clearEvalErrorLog();
199  hadEvalError = true;
200  }
201 
202  // why evaluate the last point again, can't we cache it?
203  // kbelasco: commenting out lines below to add/test caching support
204  //RooStats::SetParameters(&x, &fParameters);
205  //xL = fFunction->getVal();
206 
207  if (fType == kLog) {
208  if (fSign == kPositive)
209  a = xL - xPrimeL;
210  else
211  a = xPrimeL - xL;
212  }
213  else
214  a = xPrimeL / xL;
215  //a = xL / xPrimeL;
216 
217  if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
218  Double_t xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
219  Double_t xPD = fPropFunc->GetProposalDensity(x, xPrime);
220  if (fType == kRegular)
221  a *= xPD / xPrimePD;
222  else
223  a += TMath::Log(xPrimePD) - TMath::Log(xPD);
224  }
225 
226  if (!hadEvalError && ShouldTakeStep(a)) {
227  // go to the proposed point xPrime
228 
229  // add the current point with the current weight
230  if (weight != 0.0)
231  chain->Add(x, CalcNLL(xL), (Double_t)weight);
232 
233  // reset the weight and go to xPrime
234  weight = 1;
235  RooStats::SetParameters(&xPrime, &x);
236  xL = xPrimeL;
237  } else {
238  // stay at the current point
239  weight++;
240  }
241  }
242 
243  // make sure to add the last point
244  if (weight != 0.0)
245  chain->Add(x, CalcNLL(xL), (Double_t)weight);
246  ooccoutP((TObject *)0, Generation) << endl;
247 
249 
250  Int_t numAccepted = chain->Size();
251  coutI(Eval) << "Proposal acceptance rate: " <<
252  numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
253  coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
254 
255  //TFile chainDataFile("chainData.root", "recreate");
256  //chain->GetDataSet()->Write();
257  //chainDataFile.Close();
258 
259  return chain;
260 }
261 
262 ////////////////////////////////////////////////////////////////////////////////
263 
264 Bool_t MetropolisHastings::ShouldTakeStep(Double_t a)
265 {
266  if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
267  // The proposed point has a higher likelihood than the
268  // current point, so we should go there
269  return kTRUE;
270  }
271  else {
272  // generate numbers on a log distribution to decide
273  // whether to go to xPrime or stay at x
274  //Double_t rand = fGen.Uniform(1.0);
275  Double_t rand = RooRandom::uniform();
276  if (fType == kLog) {
277  rand = TMath::Log(rand);
278  // kbelasco: should this be changed to just (-rand > a) for logical
279  // consistency with below test when fType == kRegular?
280  if (-1.0 * rand >= a)
281  // we chose to go to the new proposed point
282  // even though it has a lower likelihood than the current one
283  return kTRUE;
284  } else {
285  // fType must be kRegular
286  // kbelasco: ensure that we never visit a point where PDF == 0
287  //if (rand <= a)
288  if (rand < a)
289  // we chose to go to the new proposed point
290  // even though it has a lower likelihood than the current one
291  return kTRUE;
292  }
293  return kFALSE;
294  }
295 }
296 
297 ////////////////////////////////////////////////////////////////////////////////
298 
299 Double_t MetropolisHastings::CalcNLL(Double_t xL)
300 {
301  if (fType == kLog) {
302  if (fSign == kNegative)
303  return xL;
304  else
305  return -xL;
306  } else {
307  if (fSign == kPositive)
308  return -1.0 * TMath::Log(xL);
309  else
310  return -1.0 * TMath::Log(-xL);
311  }
312 }
ProposalFunction is an interface for all proposal functions that would be used with a Markov Chain Mo...
#define coutE(a)
Definition: RooMsgService.h:34
Double_t Log(Double_t x)
Definition: TMath.h:648
float Float_t
Definition: RtypesCore.h:53
#define coutI(a)
Definition: RooMsgService.h:31
void SetParameters(const RooArgSet *desiredVals, RooArgSet *paramsToChange)
Definition: RooStatsUtils.h:58
RooFit::MsgLevel globalKillBelow() const
static void clearEvalErrorLog()
Clear the stack of evaluation error messages.
virtual Int_t Size() const
get the number of steps in the chain
Definition: MarkovChain.h:49
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
static RooMsgService & instance()
Return reference to singleton instance.
#define ooccoutP(o, a)
Definition: RooMsgService.h:52
STL namespace.
This class uses the Metropolis-Hastings algorithm to construct a Markov Chain of data points using Mo...
static void setEvalErrorLoggingMode(ErrorLoggingMode m)
Set evaluation error logging mode.
void SetParameters(TFitEditor::FuncParams_t &pars, TF1 *func)
Restore the parameters from pars into the function.
Definition: TFitEditor.cxx:287
Double_t x[n]
Definition: legend1.C:17
virtual void Add(RooArgSet &entry, Double_t nllValue, Double_t weight=1.0)
safely add an entry to the chain
virtual void addClone(const RooAbsCollection &col, Bool_t silent=kFALSE)
Add a collection of arguments to this collection by calling addOwned() for each element in the source...
Definition: RooArgSet.h:94
static Int_t numEvalErrors()
Return the number of logged evaluation errors since the last clearing.
auto * a
Definition: textangle.C:12
static Double_t infinity()
Return internal infinity representation.
Definition: RooNumber.cxx:49
void setGlobalKillBelow(RooFit::MsgLevel level)
void RandomizeCollection(RooAbsCollection &set, Bool_t randomizeConstants=kTRUE)
Definition: RooStatsUtils.h:99
const Bool_t kFALSE
Definition: RtypesCore.h:88
PyObject * fType
Stores the steps in a Markov Chain of points.
Definition: MarkovChain.h:30
Namespace for the RooStats classes.
Definition: Asimov.h:20
#define ClassImp(name)
Definition: Rtypes.h:359
static Double_t uniform(TRandom *generator=randomGenerator())
Return a number uniformly distributed from (0,1)
Definition: RooRandom.cxx:84
double Double_t
Definition: RtypesCore.h:55
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition: RooAbsReal.h:53
virtual void SetParameters(RooArgSet &parameters)
set which of your parameters this chain should store
Definition: MarkovChain.cxx:77
Mother of all ROOT objects.
Definition: TObject.h:37
const Bool_t kTRUE
Definition: RtypesCore.h:87