#ifndef ROOT_TMVA_RuleFit
#define ROOT_TMVA_RuleFit
#include <algorithm>
#ifndef ROOT_TMVA_DecisionTree
#include "TMVA/DecisionTree.h"
#endif
#ifndef ROOT_TMVA_RuleEnsemble
#include "TMVA/RuleEnsemble.h"
#endif
#ifndef ROOT_TMVA_RuleFitParams
#include "TMVA/RuleFitParams.h"
#endif
#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif
namespace TMVA {
class MethodBase;
class MethodRuleFit;
class MsgLogger;
class RuleFit {
public:
RuleFit( const TMVA::MethodBase *rfbase );
RuleFit( void );
virtual ~RuleFit( void );
void InitNEveEff();
void InitPtrs( const TMVA::MethodBase *rfbase );
void Initialize( const TMVA::MethodBase *rfbase );
void SetMsgType( EMsgType t );
void SetTrainingEvents( const std::vector<const TMVA::Event *> & el );
void ReshuffleEvents() { std::random_shuffle(fTrainingEventsRndm.begin(),fTrainingEventsRndm.end()); }
void SetMethodBase( const MethodBase *rfbase );
void MakeForest();
void BuildTree( TMVA::DecisionTree *dt );
void SaveEventWeights();
void RestoreEventWeights();
void Boost( TMVA::DecisionTree *dt );
void ForestStatistics();
Double_t EvalEvent( const Event& e );
Double_t CalcWeightSum( const std::vector<const TMVA::Event *> *events, UInt_t neve=0 );
void FitCoefficients();
void CalcImportance();
void SetModelLinear() { fRuleEnsemble.SetModelLinear(); }
void SetModelRules() { fRuleEnsemble.SetModelRules(); }
void SetModelFull() { fRuleEnsemble.SetModelFull(); }
void SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
void SetRuleMinDist( Double_t d ) { fRuleEnsemble.SetRuleMinDist(d); }
void SetGDTau( Double_t t=0.0 ) { fRuleFitParams.SetGDTau(t); }
void SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
void SetGDNPathSteps( Int_t n=100 ) { fRuleFitParams.SetGDNPathSteps(n); }
void SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
void UseImportanceVisHists() { fVisHistsUseImp = kTRUE; }
void UseCoefficientsVisHists() { fVisHistsUseImp = kFALSE; }
void MakeVisHists();
void FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
void FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
void FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
void FillLin(TH2F* h2,Int_t vind);
void FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
void NormVisHists(std::vector<TH2F *> & hlist);
void MakeDebugHists();
Bool_t GetCorrVars(TString & title, TString & var1, TString & var2);
UInt_t GetNTreeSample() const { return fNTreeSample; }
Double_t GetNEveEff() const { return fNEveEffTrain; }
const Event* GetTrainingEvent(UInt_t i) const { return static_cast< const Event *>(fTrainingEvents[i]); }
Double_t GetTrainingEventWeight(UInt_t i) const { return fTrainingEvents[i]->GetWeight(); }
const std::vector< const TMVA::Event * > & GetTrainingEvents() const { return fTrainingEvents; }
void GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
const std::vector< const TMVA::DecisionTree *> & GetForest() const { return fForest; }
const RuleEnsemble & GetRuleEnsemble() const { return fRuleEnsemble; }
RuleEnsemble * GetRuleEnsemblePtr() { return &fRuleEnsemble; }
const RuleFitParams & GetRuleFitParams() const { return fRuleFitParams; }
RuleFitParams * GetRuleFitParamsPtr() { return &fRuleFitParams; }
const MethodRuleFit * GetMethodRuleFit() const { return fMethodRuleFit; }
const MethodBase * GetMethodBase() const { return fMethodBase; }
private:
RuleFit( const RuleFit & other );
void Copy( const RuleFit & other );
std::vector<const TMVA::Event *> fTrainingEvents;
std::vector<const TMVA::Event *> fTrainingEventsRndm;
std::vector<Double_t> fEventWeights;
UInt_t fNTreeSample;
Double_t fNEveEffTrain;
std::vector< const TMVA::DecisionTree *> fForest;
RuleEnsemble fRuleEnsemble;
RuleFitParams fRuleFitParams;
const MethodRuleFit *fMethodRuleFit;
const MethodBase *fMethodBase;
Bool_t fVisHistsUseImp;
mutable MsgLogger* fLogger;
MsgLogger& Log() const { return *fLogger; }
static const Int_t randSEED = 0;
ClassDef(RuleFit,0)
};
}
#endif