#ifndef ROOSTATS_ToyMCSampler
#define ROOSTATS_ToyMCSampler
#ifndef ROOT_Rtypes
#include "Rtypes.h"
#endif
#include <vector>
#include <sstream>
#include "RooStats/TestStatSampler.h"
#include "RooStats/SamplingDistribution.h"
#include "RooStats/TestStatistic.h"
#include "RooStats/RooStatsUtils.h"
#ifndef __CINT__
#include "RooGlobalFunc.h"
#endif
#include "RooWorkspace.h"
#include "RooMsgService.h"
#include "RooAbsPdf.h"
#include "TRandom.h"
#include "RooDataSet.h"
namespace RooStats {
class ToyMCSampler : public TestStatSampler {
public:
ToyMCSampler(TestStatistic &ts) {
fTestStat = &ts;
fWS = new RooWorkspace();
fOwnsWorkspace = true;
fDataName = "";
fPdfName = "";
fPOI = 0;
fNuisParams=0;
fObservables=0;
fExtended = kTRUE;
fRand = new TRandom();
fCounter=0;
fVarName = fTestStat->GetVarName();
fLastDataSet = 0;
}
virtual ~ToyMCSampler() {
if(fOwnsWorkspace) delete fWS;
if(fRand) delete fRand;
if(fLastDataSet) delete fLastDataSet;
}
virtual SamplingDistribution* AppendSamplingDistribution(RooArgSet& allParameters,
SamplingDistribution* last,
Int_t additionalMC) {
Int_t tmp = fNtoys;
fNtoys = additionalMC;
SamplingDistribution* newSamples = GetSamplingDistribution(allParameters);
fNtoys = tmp;
if(last){
last->Add(newSamples);
delete newSamples;
return last;
}
return newSamples;
}
virtual SamplingDistribution* GetSamplingDistribution(RooArgSet& allParameters) {
std::vector<Double_t> testStatVec;
RooMsgService::instance().setGlobalKillBelow(RooFit::ERROR) ;
for(Int_t i=0; i<fNtoys; ++i){
RooDataSet* toydata = (RooDataSet*)GenerateToyData(allParameters);
testStatVec.push_back( fTestStat->Evaluate(*toydata, allParameters) );
if(fLastDataSet) delete fLastDataSet;
fLastDataSet = toydata;
}
return new SamplingDistribution( "temp",
"Sampling Distribution of Test Statistic", testStatVec, fVarName );
}
virtual RooAbsData* GenerateToyData(RooArgSet& allParameters) const {
RooAbsPdf* pdf = fWS->pdf(fPdfName);
if(!fObservables){
cout << "Observables not specified in ToyMCSampler, will try to determine. "
<< "Will ignore all constant parameters, parameters of interest, and nuisance parameters." << endl;
RooArgSet* observables = pdf->getVariables();
RemoveConstantParameters(observables);
if(fPOI) observables->remove(*fPOI, kFALSE, kTRUE);
if(fNuisParams) observables->remove(*fNuisParams, kFALSE, kTRUE);
cout << "will use the following as observables when generating data" << endl;
observables->Print();
fObservables=observables;
}
Int_t nEvents = fNevents;
if(fExtended) {
if( pdf->expectedEvents(*fObservables) > 0){
nEvents = fRand->Poisson(pdf->expectedEvents(*fObservables));
} else{
nEvents = fRand->Poisson(fNevents);
}
}
RooArgSet* parameters = pdf->getParameters(fObservables);
RooStats::SetParameters(&allParameters, parameters);
RooFit::MsgLevel level = RooMsgService::instance().globalKillBelow();
RooMsgService::instance().setGlobalKillBelow(RooFit::ERROR) ;
RooAbsData* data = (RooAbsData*)pdf->generate(*fObservables, nEvents);
RooMsgService::instance().setGlobalKillBelow(level) ;
delete parameters;
return data;
}
string MakeName(RooArgSet& ){
std::stringstream str;
str<<"SamplingDist_"<< fCounter;
fCounter++;
static char buf[1024] ;
strcpy(buf,str.str().c_str()) ;
return buf ;
}
virtual Double_t EvaluateTestStatistic(RooAbsData& data, RooArgSet& allParameters) {
return fTestStat->Evaluate(data, allParameters);
}
virtual const RooAbsArg* GetTestStatistic() const {
return fTestStat->GetTestStatistic();}
virtual Double_t ConfidenceLevel() const {return 1.-fSize;}
virtual void Initialize(RooAbsArg& ,
RooArgSet& ,
RooArgSet& ) {}
virtual void SetNToys(const Int_t ntoy) {
fNtoys = ntoy;
}
virtual void SetNEventsPerToy(const Int_t nevents) {
fNevents = nevents;
}
virtual void SetExtended(const Bool_t isExtended) {
fExtended = isExtended;
}
virtual void SetData(RooAbsData& data) {
if(&data){
fWS->import(data);
fDataName = data.GetName();
fWS->Print();
}
}
virtual void SetPdf(RooAbsPdf& pdf) {
if(&pdf){
fWS->import(pdf);
fPdfName = pdf.GetName();
}
}
virtual void SetData(const char* name) {fDataName = name;}
virtual void SetPdf(const char* name) {fPdfName = name;}
virtual void SetParameters(RooArgSet& set) {fPOI = &set;}
virtual void SetNuisanceParameters(RooArgSet& set) {fNuisParams = &set;}
virtual void SetObservables(RooArgSet& set) {fObservables = &set;}
virtual void SetTestSize(Double_t size) {fSize = size;}
virtual void SetConfidenceLevel(Double_t cl) {fSize = 1.-cl;}
virtual void SetTestStatistic(RooAbsArg&) const {}
private:
Double_t fSize;
RooWorkspace* fWS;
Bool_t fOwnsWorkspace;
const char* fPdfName;
const char* fDataName;
RooArgSet* fPOI;
RooArgSet* fNuisParams;
mutable RooArgSet* fObservables;
TestStatistic* fTestStat;
Int_t fNtoys;
Int_t fNevents;
Bool_t fExtended;
TRandom* fRand;
TString fVarName;
Int_t fCounter;
RooDataSet* fLastDataSet;
protected:
ClassDef(ToyMCSampler,1)
};
}
#endif