ROOT  6.06/09
Reference Guide
MetropolisHastings.cxx
Go to the documentation of this file.
1 // @(#)root/roostats:$Id$
2 // Authors: Kevin Belasco 17/06/2009
3 // Authors: Kyle Cranmer 17/06/2009
4 /*************************************************************************
5  * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 ////////////////////////////////////////////////////////////////////////////////
13 
14 
15 #ifndef RooStats_RooStatsUtils
16 #include "RooStats/RooStatsUtils.h"
17 #endif
18 #ifndef ROOT_Rtypes
19 #include "Rtypes.h"
20 #endif
21 #ifndef ROO_REAL_VAR
22 #include "RooRealVar.h"
23 #endif
24 #ifndef ROO_NLL_VAR
25 #include "RooNLLVar.h"
26 #endif
27 #ifndef ROO_GLOBAL_FUNC
28 #include "RooGlobalFunc.h"
29 #endif
30 #ifndef ROO_DATA_SET
31 #include "RooDataSet.h"
32 #endif
33 #ifndef ROO_ARG_SET
34 #include "RooArgSet.h"
35 #endif
36 #ifndef ROO_ARG_LIST
37 #include "RooArgList.h"
38 #endif
39 #ifndef ROO_MSG_SERVICE
40 #include "RooMsgService.h"
41 #endif
42 #ifndef ROO_RANDOM
43 #include "RooRandom.h"
44 #endif
45 #ifndef ROOT_TH1
46 #include "TH1.h"
47 #endif
48 #ifndef ROOT_TMath
49 #include "TMath.h"
50 #endif
51 #ifndef ROOT_TFile
52 #include "TFile.h"
53 #endif
54 #ifndef ROOSTATS_MetropolisHastings
56 #endif
57 #ifndef ROOSTATS_MarkovChain
58 #include "RooStats/MarkovChain.h"
59 #endif
60 #ifndef RooStats_MCMCInterval
61 #include "RooStats/MCMCInterval.h"
62 #endif
63 
65 
66 using namespace RooFit;
67 using namespace RooStats;
68 using namespace std;
69 
70 MetropolisHastings::MetropolisHastings()
71 {
72  // default constructor
73  fFunction = NULL;
74  fPropFunc = NULL;
75  fNumIters = 0;
76  fNumBurnInSteps = 0;
77  fSign = kSignUnset;
78  fType = kTypeUnset;
79 }
80 
81 MetropolisHastings::MetropolisHastings(RooAbsReal& function, const RooArgSet& paramsOfInterest,
82  ProposalFunction& proposalFunction, Int_t numIters)
83 {
84  fFunction = &function;
85  SetParameters(paramsOfInterest);
86  SetProposalFunction(proposalFunction);
87  fNumIters = numIters;
88  fNumBurnInSteps = 0;
89  fSign = kSignUnset;
90  fType = kTypeUnset;
91 }
92 
93 MarkovChain* MetropolisHastings::ConstructChain()
94 {
95  if (fParameters.getSize() == 0 || !fPropFunc || !fFunction) {
96  coutE(Eval) << "Critical members unintialized: parameters, proposal " <<
97  " function, or (log) likelihood function" << endl;
98  return NULL;
99  }
100  if (fSign == kSignUnset || fType == kTypeUnset) {
101  coutE(Eval) << "Please set type and sign of your function using "
102  << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
103  endl;
104  return NULL;
105  }
106 
107  if (fChainParams.getSize() == 0) fChainParams.add(fParameters);
108 
109  RooArgSet x;
110  RooArgSet xPrime;
111  x.addClone(fParameters);
113  xPrime.addClone(fParameters);
114  RandomizeCollection(xPrime);
115 
116  MarkovChain* chain = new MarkovChain();
117  // only the POI will be added to the chain
118  chain->SetParameters(fChainParams);
119 
120  Int_t weight = 0;
121  Double_t xL = 0.0, xPrimeL = 0.0, a = 0.0;
122 
123  // ibucur: i think the user should have the possiblity to display all the message
124  // levels should they want to; maybe a setPrintLevel would be appropriate
125  // (maybe for the other classes that use this approach as well)?
128 
129  // We will need to check if log-likelihood evaluation left an error status.
130  // Now using faster eval error logging with CountErrors.
131  if (fType == kLog) {
133  //N.B: need to clear the count in case of previous errors !
134  // the clear needs also to be done after calling setEvalErrorLoggingMode
136  }
137 
138  bool hadEvalError = true;
139 
140  Int_t i = 0;
141  // get a good starting point for x
142  // for fType == kLog, this means that fFunction->getVal() did not cause
143  // an eval error
144  // for fType == kRegular this means fFunction->getVal() != 0
145  //
146  // kbelasco: i < 1000 is sort of arbitary, but way higher than the number of
147  // steps we should have to take for any reasonable (log) likelihood function
148  while (i < 1000 && hadEvalError) {
150  RooStats::SetParameters(&x, &fParameters);
151  xL = fFunction->getVal();
152 
153  if (fType == kLog) {
154  if (RooAbsReal::numEvalErrors() > 0) {
156  hadEvalError = true;
157  } else
158  hadEvalError = false;
159  } else if (fType == kRegular) {
160  if (xL == 0.0)
161  hadEvalError = true;
162  else
163  hadEvalError = false;
164  } else
165  // for now the only 2 types are kLog and kRegular (won't get here)
166  hadEvalError = false;
167  }
168 
169  if(hadEvalError) {
170  coutE(Eval) << "Problem finding a good starting point in " <<
171  "MetropolisHastings::ConstructChain() " << endl;
172  }
173 
174 
175  ooccoutP((TObject *)0, Generation) << "Metropolis-Hastings progress: ";
176 
177  // do main loop
178  for (i = 0; i < fNumIters; i++) {
179  // reset error handling flag
180  hadEvalError = false;
181 
182  // print a dot every 1% of the chain construction
183  if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)0, Generation) << ".";
184 
185  fPropFunc->Propose(xPrime, x);
186 
187  RooStats::SetParameters(&xPrime, &fParameters);
188  xPrimeL = fFunction->getVal();
189 
190  // check if log-likelihood for xprime had an error status
191  if (fFunction->numEvalErrors() > 0 && fType == kLog) {
192  xPrimeL = RooNumber::infinity();
193  fFunction->clearEvalErrorLog();
194  hadEvalError = true;
195  }
196 
197  // why evaluate the last point again, can't we cache it?
198  // kbelasco: commenting out lines below to add/test caching support
199  //RooStats::SetParameters(&x, &fParameters);
200  //xL = fFunction->getVal();
201 
202  if (fType == kLog) {
203  if (fSign == kPositive)
204  a = xL - xPrimeL;
205  else
206  a = xPrimeL - xL;
207  }
208  else
209  a = xPrimeL / xL;
210  //a = xL / xPrimeL;
211 
212  if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
213  Double_t xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
214  Double_t xPD = fPropFunc->GetProposalDensity(x, xPrime);
215  if (fType == kRegular)
216  a *= xPD / xPrimePD;
217  else
218  a += TMath::Log(xPrimePD) - TMath::Log(xPD);
219  }
220 
221  if (!hadEvalError && ShouldTakeStep(a)) {
222  // go to the proposed point xPrime
223 
224  // add the current point with the current weight
225  if (weight != 0.0)
226  chain->Add(x, CalcNLL(xL), (Double_t)weight);
227 
228  // reset the weight and go to xPrime
229  weight = 1;
230  RooStats::SetParameters(&xPrime, &x);
231  xL = xPrimeL;
232  } else {
233  // stay at the current point
234  weight++;
235  }
236  }
237 
238  // make sure to add the last point
239  if (weight != 0.0)
240  chain->Add(x, CalcNLL(xL), (Double_t)weight);
241  ooccoutP((TObject *)0, Generation) << endl;
242 
244 
245  Int_t numAccepted = chain->Size();
246  coutI(Eval) << "Proposal acceptance rate: " <<
247  numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
248  coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
249 
250  //TFile chainDataFile("chainData.root", "recreate");
251  //chain->GetDataSet()->Write();
252  //chainDataFile.Close();
253 
254  return chain;
255 }
256 
257 Bool_t MetropolisHastings::ShouldTakeStep(Double_t a)
258 {
259  if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
260  // The proposed point has a higher likelihood than the
261  // current point, so we should go there
262  return kTRUE;
263  }
264  else {
265  // generate numbers on a log distribution to decide
266  // whether to go to xPrime or stay at x
267  //Double_t rand = fGen.Uniform(1.0);
268  Double_t rand = RooRandom::uniform();
269  if (fType == kLog) {
270  rand = TMath::Log(rand);
271  // kbelasco: should this be changed to just (-rand > a) for logical
272  // consistency with below test when fType == kRegular?
273  if (-1.0 * rand >= a)
274  // we chose to go to the new proposed point
275  // even though it has a lower likelihood than the current one
276  return kTRUE;
277  } else {
278  // fType must be kRegular
279  // kbelasco: ensure that we never visit a point where PDF == 0
280  //if (rand <= a)
281  if (rand < a)
282  // we chose to go to the new proposed point
283  // even though it has a lower likelihood than the current one
284  return kTRUE;
285  }
286  return kFALSE;
287  }
288 }
289 
290 Double_t MetropolisHastings::CalcNLL(Double_t xL)
291 {
292  if (fType == kLog) {
293  if (fSign == kNegative)
294  return xL;
295  else
296  return -xL;
297  } else {
298  if (fSign == kPositive)
299  return -1.0 * TMath::Log(xL);
300  else
301  return -1.0 * TMath::Log(-xL);
302  }
303 }
ProposalFunction is an interface for all proposal functions that would be used with a Markov Chain Mo...
#define coutE(a)
Definition: RooMsgService.h:35
Double_t Log(Double_t x)
Definition: TMath.h:526
float Float_t
Definition: RtypesCore.h:53
#define coutI(a)
Definition: RooMsgService.h:32
void SetParameters(const RooArgSet *desiredVals, RooArgSet *paramsToChange)
Definition: RooStatsUtils.h:69
ClassImp(RooStats::MetropolisHastings)
static void clearEvalErrorLog()
Clear the stack of evaluation error messages.
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
TArc * a
Definition: textangle.C:12
const Bool_t kFALSE
Definition: Rtypes.h:92
static RooMsgService & instance()
Return reference to singleton instance.
#define ooccoutP(o, a)
Definition: RooMsgService.h:53
STL namespace.
This class uses the Metropolis-Hastings algorithm to construct a Markov Chain of data points using Mo...
static void setEvalErrorLoggingMode(ErrorLoggingMode m)
Set evaluation error logging mode.
void SetParameters(TFitEditor::FuncParams_t &pars, TF1 *func)
Restore the parameters from pars into the function.
Definition: TFitEditor.cxx:287
Double_t x[n]
Definition: legend1.C:17
virtual void Add(RooArgSet &entry, Double_t nllValue, Double_t weight=1.0)
safely add an entry to the chain
RooFit::MsgLevel globalKillBelow() const
static Int_t numEvalErrors()
Return the number of logged evaluation errors since the last clearing.
static Double_t infinity()
Return internal infinity representation.
Definition: RooNumber.cxx:48
void setGlobalKillBelow(RooFit::MsgLevel level)
void RandomizeCollection(RooAbsCollection &set, Bool_t randomizeConstants=kTRUE)
PyObject * fType
Stores the steps in a Markov Chain of points.
Definition: MarkovChain.h:53
Namespace for the RooStats classes.
Definition: Asimov.h:20
static Double_t uniform(TRandom *generator=randomGenerator())
Return a number uniformly distributed from (0,1)
Definition: RooRandom.cxx:83
double Double_t
Definition: RtypesCore.h:55
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition: RooAbsReal.h:53
virtual void SetParameters(RooArgSet &parameters)
set which of your parameters this chain should store
Definition: MarkovChain.cxx:90
Mother of all ROOT objects.
Definition: TObject.h:58
#define NULL
Definition: Rtypes.h:82
virtual Int_t Size() const
get the number of steps in the chain
Definition: MarkovChain.h:72
virtual RooAbsArg * addClone(const RooAbsArg &var, Bool_t silent=kFALSE)
Add clone of specified element to an owning set.
Definition: RooArgSet.cxx:475
const Bool_t kTRUE
Definition: Rtypes.h:91