//
// RooProduct a RooAbsReal implementation that represent the product
// of a given set of other RooAbsReal objects
//
// END_HTML
#include "RooFit.h"
#include "Riostream.h"
#include "Riostream.h"
#include <math.h>
#include <vector>
#include <utility>
#include <memory>
#include <algorithm>
#include "RooProduct.h"
#include "RooNameReg.h"
#include "RooAbsReal.h"
#include "RooAbsCategory.h"
#include "RooErrorHandler.h"
#include "RooMsgService.h"
#include "RooTrace.h"
using namespace std ;
ClassImp(RooProduct)
;
class RooProduct::ProdMap : public std::vector<std::pair<RooArgSet*,RooArgList*> > {} ;
namespace {
typedef RooProduct::ProdMap::iterator RPPMIter ;
std::pair<RPPMIter,RPPMIter> findOverlap2nd(RPPMIter i, RPPMIter end) ;
void dump_map(ostream& os, RPPMIter i, RPPMIter end) ;
}
RooProduct::RooProduct() :
_compRIter( _compRSet.createIterator() ),
_compCIter( _compCSet.createIterator() )
{
TRACE_CREATE
}
RooProduct::~RooProduct()
{
if (_compRIter) {
delete _compRIter ;
}
if (_compCIter) {
delete _compCIter ;
}
TRACE_DESTROY
}
RooProduct::RooProduct(const char* name, const char* title, const RooArgList& prodSet) :
RooAbsReal(name, title),
_compRSet("!compRSet","Set of real product components",this),
_compCSet("!compCSet","Set of category product components",this),
_compRIter( _compRSet.createIterator() ),
_compCIter( _compCSet.createIterator() ),
_cacheMgr(this,10)
{
TIterator* compIter = prodSet.createIterator() ;
RooAbsArg* comp ;
while((comp = (RooAbsArg*)compIter->Next())) {
if (dynamic_cast<RooAbsReal*>(comp)) {
_compRSet.add(*comp) ;
} else if (dynamic_cast<RooAbsCategory*>(comp)) {
_compCSet.add(*comp) ;
} else {
coutE(InputArguments) << "RooProduct::ctor(" << GetName() << ") ERROR: component " << comp->GetName()
<< " is not of type RooAbsReal or RooAbsCategory" << endl ;
RooErrorHandler::softAbort() ;
}
}
delete compIter ;
TRACE_CREATE
}
RooProduct::RooProduct(const RooProduct& other, const char* name) :
RooAbsReal(other, name),
_compRSet("!compRSet",this,other._compRSet),
_compCSet("!compCSet",this,other._compCSet),
_compRIter(_compRSet.createIterator()),
_compCIter(_compCSet.createIterator()),
_cacheMgr(other._cacheMgr,this)
{
TRACE_CREATE
}
Bool_t RooProduct::forceAnalyticalInt(const RooAbsArg& dep) const
{
_compRIter->Reset() ;
RooAbsReal* rcomp ;
Bool_t depends(kFALSE);
while((rcomp=(RooAbsReal*)_compRIter->Next())&&!depends) {
depends = rcomp->dependsOn(dep);
}
return depends ;
}
RooProduct::ProdMap* RooProduct::groupProductTerms(const RooArgSet& allVars) const
{
ProdMap* map = new ProdMap ;
RooAbsReal* rcomp ; _compRIter->Reset() ;
RooArgList *indep = new RooArgList();
while((rcomp=(RooAbsReal*)_compRIter->Next())) {
if( !rcomp->dependsOn(allVars) ) indep->add(*rcomp);
}
if (indep->getSize()!=0) {
map->push_back( std::make_pair(new RooArgSet(),indep) );
}
TIterator *allVarsIter = allVars.createIterator() ;
RooAbsReal* var ;
while((var=(RooAbsReal*)allVarsIter->Next())) {
RooArgSet *vars = new RooArgSet(); vars->add(*var);
RooArgList *comps = new RooArgList();
RooAbsReal* rcomp2 ;
_compRIter->Reset() ;
while((rcomp2=(RooAbsReal*)_compRIter->Next())) {
if( rcomp2->dependsOn(*var) ) comps->add(*rcomp2);
}
map->push_back( std::make_pair(vars,comps) );
}
delete allVarsIter ;
Bool_t overlap;
do {
std::pair<ProdMap::iterator,ProdMap::iterator> i = findOverlap2nd(map->begin(),map->end());
overlap = (i.first!=i.second);
if (overlap) {
i.first->first->add(*i.second->first);
RooFIter it = i.second->second->fwdIterator() ;
RooAbsArg* targ ;
while ((targ = it.next())) {
if (!i.first->second->find(*targ)) {
i.first->second->add(*targ) ;
}
}
delete i.second->first;
delete i.second->second;
map->erase(i.second);
}
} while (overlap);
int nVar=0; int nFunc=0;
for (ProdMap::iterator i = map->begin();i!=map->end();++i) {
nVar+=i->first->getSize();
nFunc+=i->second->getSize();
}
assert(nVar==allVars.getSize());
assert(nFunc==_compRSet.getSize());
return map;
}
Int_t RooProduct::getPartIntList(const RooArgSet* iset, const char *isetRange) const
{
Int_t sterileIndex(-1);
CacheElem* cache = (CacheElem*) _cacheMgr.getObj(iset,iset,&sterileIndex,RooNameReg::ptr(isetRange));
if (cache!=0) {
Int_t code = _cacheMgr.lastIndex();
return code;
}
ProdMap* map = groupProductTerms(*iset);
cxcoutD(Integration) << "RooProduct::getPartIntList(" << GetName() << ") groupProductTerms returned map" ;
if (dologD(Integration)) {
dump_map(ccoutD(Integration),map->begin(),map->end());
ccoutD(Integration) << endl;
}
if (map->size()<2) {
for (ProdMap::iterator iter = map->begin() ; iter != map->end() ; ++iter) {
delete iter->first ;
delete iter->second ;
}
delete map ;
return -1;
}
cache = new CacheElem();
for (ProdMap::const_iterator i = map->begin();i!=map->end();++i) {
RooAbsReal *term(0);
if (i->second->getSize()>1) {
const char *name = makeFPName("SUBPROD_",*i->second);
term = new RooProduct(name,name,*i->second);
cache->_ownedList.addOwned(*term);
cxcoutD(Integration) << "RooProduct::getPartIntList(" << GetName() << ") created subexpression " << term->GetName() << endl;
} else {
assert(i->second->getSize()==1);
auto_ptr<TIterator> j( i->second->createIterator() );
term = (RooAbsReal*)j->Next();
}
assert(term!=0);
if (i->first->getSize()==0) {
cache->_prodList.add(*term);
cxcoutD(Integration) << "RooProduct::getPartIntList(" << GetName() << ") adding simple factor " << term->GetName() << endl;
} else {
RooAbsReal *integral = term->createIntegral(*i->first,isetRange);
cache->_prodList.add(*integral);
cache->_ownedList.addOwned(*integral);
cxcoutD(Integration) << "RooProduct::getPartIntList(" << GetName() << ") adding integral for " << term->GetName() << " : " << integral->GetName() << endl;
}
}
Int_t code = _cacheMgr.setObj(iset,iset,(RooAbsCacheElement*)cache,RooNameReg::ptr(isetRange));
cxcoutD(Integration) << "RooProduct::getPartIntList(" << GetName() << ") created list " << cache->_prodList << " with code " << code+1 << endl
<< " for iset=" << *iset << " @" << iset << " range: " << (isetRange?isetRange:"<none>") << endl ;
for (ProdMap::iterator iter = map->begin() ; iter != map->end() ; ++iter) {
delete iter->first ;
delete iter->second ;
}
delete map ;
return code;
}
Int_t RooProduct::getAnalyticalIntegralWN(RooArgSet& allVars, RooArgSet& analVars,
const RooArgSet* ,
const char* rangeName) const
{
if (_forceNumInt) return 0 ;
assert(analVars.getSize()==0);
analVars.add(allVars) ;
Int_t code = getPartIntList(&analVars,rangeName)+1;
return code ;
}
Double_t RooProduct::analyticalIntegral(Int_t code, const char* rangeName) const
{
CacheElem *cache = (CacheElem*) _cacheMgr.getObjByIndex(code-1);
if (cache==0) {
std::auto_ptr<RooArgSet> vars( getParameters(RooArgSet()) );
std::auto_ptr<RooArgSet> iset( _cacheMgr.nameSet2ByIndex(code-1)->select(*vars) );
Int_t code2 = getPartIntList(iset.get(),rangeName)+1;
assert(code==code2);
return analyticalIntegral(code2,rangeName);
}
assert(cache!=0);
return calculate(cache->_prodList);
}
Double_t RooProduct::calculate(const RooArgList& partIntList) const
{
RooAbsReal *term(0);
Double_t val=1;
RooFIter i = partIntList.fwdIterator() ;
while((term=(RooAbsReal*)i.next())) {
double x = term->getVal();
val*= x;
}
return val;
}
const char* RooProduct::makeFPName(const char *pfx,const RooArgSet& terms) const
{
static TString pname;
pname = pfx;
std::auto_ptr<TIterator> i( terms.createIterator() );
RooAbsArg *arg;
Bool_t first(kTRUE);
while((arg=(RooAbsArg*)i->Next())) {
if (first) { first=kFALSE;}
else pname.Append("_X_");
pname.Append(arg->GetName());
}
return pname.Data();
}
Double_t RooProduct::evaluate() const
{
Double_t prod(1) ;
RooFIter compRIter = _compRSet.fwdIterator() ;
RooAbsReal* rcomp ;
const RooArgSet* nset = _compRSet.nset() ;
while((rcomp=(RooAbsReal*)compRIter.next())) {
prod *= rcomp->getVal(nset) ;
}
RooFIter compCIter = _compCSet.fwdIterator() ;
RooAbsCategory* ccomp ;
while((ccomp=(RooAbsCategory*)compCIter.next())) {
prod *= ccomp->getIndex() ;
}
return prod ;
}
std::list<Double_t>* RooProduct::binBoundaries(RooAbsRealLValue& obs, Double_t xlo, Double_t xhi) const
{
RooFIter iter = _compRSet.fwdIterator() ;
RooAbsReal* func ;
while((func=(RooAbsReal*)iter.next())) {
list<Double_t>* binb = func->binBoundaries(obs,xlo,xhi) ;
if (binb) {
return binb ;
}
}
return 0 ;
}
Bool_t RooProduct::isBinnedDistribution(const RooArgSet& obs) const
{
RooFIter iter = _compRSet.fwdIterator() ;
RooAbsReal* func ;
while((func=(RooAbsReal*)iter.next())) {
if (func->dependsOn(obs) && !func->isBinnedDistribution(obs)) {
return kFALSE ;
}
}
return kTRUE ;
}
std::list<Double_t>* RooProduct::plotSamplingHint(RooAbsRealLValue& obs, Double_t xlo, Double_t xhi) const
{
RooFIter iter = _compRSet.fwdIterator() ;
RooAbsReal* func ;
while((func=(RooAbsReal*)iter.next())) {
list<Double_t>* hint = func->plotSamplingHint(obs,xlo,xhi) ;
if (hint) {
return hint ;
}
}
return 0 ;
}
RooProduct::CacheElem::~CacheElem()
{
}
RooArgList RooProduct::CacheElem::containedArgs(Action)
{
RooArgList ret(_ownedList) ;
return ret ;
}
void RooProduct::setCacheAndTrackHints(RooArgSet& trackNodes)
{
RooArgSet comp(components()) ;
RooFIter piter = comp.fwdIterator() ;
RooAbsArg* parg ;
while ((parg=piter.next())) {
if (parg->isDerived()) {
if (parg->canNodeBeCached()==Always) {
trackNodes.add(*parg) ;
}
}
}
}
void RooProduct::printMetaArgs(ostream& os) const
{
Bool_t first(kTRUE) ;
_compRIter->Reset() ;
RooAbsReal* rcomp ;
while((rcomp=(RooAbsReal*)_compRIter->Next())) {
if (!first) { os << " * " ; } else { first = kFALSE ; }
os << rcomp->GetName() ;
}
_compCIter->Reset() ;
RooAbsCategory* ccomp ;
while((ccomp=(RooAbsCategory*)_compCIter->Next())) {
if (!first) { os << " * " ; } else { first = kFALSE ; }
os << ccomp->GetName() ;
}
os << " " ;
}
namespace {
std::pair<RPPMIter,RPPMIter> findOverlap2nd(RPPMIter i, RPPMIter end)
{
for (; i!=end; ++i) for ( RPPMIter j(i+1); j!=end; ++j) {
if (i->second->overlaps(*j->second)) {
return std::make_pair(i,j);
}
}
return std::make_pair(end,end);
}
void dump_map(ostream& os, RPPMIter i, RPPMIter end)
{
Bool_t first(kTRUE);
os << " [ " ;
for(; i!=end;++i) {
if (first) { first=kFALSE; }
else { os << " , " ; }
os << *(i->first) << " -> " << *(i->second) ;
}
os << " ] " ;
}
}