#include "TMVA/ClassifierFactory.h"
#include "TMVA/MethodCommittee.h"
#include "TMVA/Tools.h"
#include "TMVA/Timer.h"
#include "Riostream.h"
#include "TMath.h"
#include "TRandom3.h"
#include <algorithm>
#include "TObjString.h"
#include "TDirectory.h"
#include "TMVA/Ranking.h"
#include "TMVA/IMethod.h"
using std::vector;
REGISTER_METHOD(Committee)
ClassImp(TMVA::MethodCommittee)
TMVA::MethodCommittee::MethodCommittee( const TString& jobName,
const TString& methodTitle,
DataSetInfo& dsi,
const TString& theOption,
TDirectory* theTargetDir ) :
TMVA::MethodBase( jobName, Types::kCommittee, methodTitle, dsi, theOption, theTargetDir ),
fNMembers(100),
fBoostType("AdaBoost")
{
}
TMVA::MethodCommittee::MethodCommittee( DataSetInfo& theData,
const TString& theWeightFile,
TDirectory* theTargetDir ) :
TMVA::MethodBase( Types::kCommittee, theData, theWeightFile, theTargetDir ),
fNMembers(100),
fBoostType("AdaBoost")
{
}
Bool_t TMVA::MethodCommittee::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets )
{
if( type == Types::kClassification && numberClasses == 2 ) return kTRUE;
if( type == Types::kRegression && numberTargets == 1 ) return kTRUE;
return kFALSE;
}
void TMVA::MethodCommittee::DeclareOptions()
{
DeclareOptionRef(fNMembers, "NMembers", "number of members in the committee");
DeclareOptionRef(fUseMemberDecision=kFALSE, "UseMemberDecision", "use binary information from IsSignal");
DeclareOptionRef(fUseWeightedMembers=kTRUE, "UseWeightedMembers", "use weighted trees or simple average in classification from the forest");
DeclareOptionRef(fBoostType, "BoostType", "boosting type");
AddPreDefVal(TString("AdaBoost"));
AddPreDefVal(TString("Bagging"));
}
void TMVA::MethodCommittee::ProcessOptions()
{
fBoostFactorHist = new TH1F("fBoostFactor","Ada Boost weights",100,1,100);
fErrFractHist = new TH2F("fErrFractHist","error fraction vs tree number",
fNMembers,0,fNMembers,50,0,0.5);
fMonitorNtuple = new TTree("fMonitorNtuple","Committee variables");
fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
fMonitorNtuple->Branch("boostFactor",&fBoostFactor,"boostFactor/D");
fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
}
void TMVA::MethodCommittee::Init( void )
{
fNMembers = 100;
fBoostType = "AdaBoost";
fCommittee.clear();
fBoostWeights.clear();
}
TMVA::MethodCommittee::~MethodCommittee( void )
{
for (UInt_t i=0; i<GetCommittee().size(); i++) delete fCommittee[i];
fCommittee.clear();
}
void TMVA::MethodCommittee::WriteStateToFile() const
{
TString fname(GetWeightFileName());
Log() << kINFO << "creating weight file: " << fname << Endl;
std::ofstream* fout = new std::ofstream( fname );
if (!fout->good()) {
Log() << kFATAL << "<WriteStateToFile> "
<< "unable to open output weight file: " << fname << endl;
}
WriteStateToStream( *fout );
}
void TMVA::MethodCommittee::Train( void )
{
Log() << kINFO << "will train "<< fNMembers << " committee members ... patience please" << Endl;
Timer timer( fNMembers, GetName() );
for (UInt_t imember=0; imember<fNMembers; imember++){
timer.DrawProgressBar( imember );
IMethod* method = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( fMemberType )),
GetJobName(),
GetMethodName(),
DataInfo(),
fMemberOption );
method->Train();
GetBoostWeights().push_back( this->Boost( dynamic_cast<MethodBase*>(method), imember ) );
GetCommittee().push_back( method );
fMonitorNtuple->Fill();
}
Log() << kINFO << "elapsed time: " << timer.GetElapsedTime()
<< " " << Endl;
}
Double_t TMVA::MethodCommittee::Boost( TMVA::MethodBase* method, UInt_t imember )
{
if (fBoostType=="AdaBoost") return this->AdaBoost( method );
else if (fBoostType=="Bagging") return this->Bagging( imember );
else {
Log() << kINFO << GetOptions() << Endl;
Log() << kFATAL << "<Boost> unknown boost option called" << Endl;
}
return 1.0;
}
Double_t TMVA::MethodCommittee::AdaBoost( TMVA::MethodBase* method )
{
Double_t adaBoostBeta = 1.;
if (Data()->GetNTrainingEvents()) Log() << kFATAL << "<AdaBoost> Data().TrainingTree() is zero pointer" << Endl;
Double_t err=0, sumw=0, sumwfalse=0, count=0;
vector<Char_t> correctSelected;
MethodBase* mbase = (MethodBase*)method;
for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
Event* ev = Data()->GetEvent(ievt);
sumw += ev->GetBoostWeight();
Bool_t isSignalType = mbase->IsSignalLike();
if (isSignalType == ev->IsSignal())
correctSelected.push_back( kTRUE );
else {
sumwfalse += ev->GetBoostWeight();
count += 1;
correctSelected.push_back( kFALSE );
}
}
if (0 == sumw) {
Log() << kFATAL << "<AdaBoost> fatal error sum of event boostweights is zero" << Endl;
}
err = sumwfalse/sumw;
Double_t newSumw=0;
Int_t i=0;
Double_t boostFactor = 1;
if (err>0){
if (adaBoostBeta == 1){
boostFactor = (1-err)/err ;
}
else {
boostFactor = TMath::Power((1-err)/err,adaBoostBeta) ;
}
}
else {
boostFactor = 1000;
}
for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
Event *ev = Data()->GetEvent(ievt);
if (!correctSelected[ievt]) ev->SetBoostWeight( ev->GetBoostWeight() * boostFactor);
newSumw += ev->GetBoostWeight();
i++;
}
for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
Event *ev = Data()->GetEvent(ievt);
ev->SetBoostWeight( ev->GetBoostWeight() * sumw / newSumw );
}
fBoostFactorHist->Fill(boostFactor);
fErrFractHist->Fill(GetCommittee().size(),err);
fBoostFactor = boostFactor;
fErrorFraction = err;
return TMath::Log(boostFactor);
}
Double_t TMVA::MethodCommittee::Bagging( UInt_t imember )
{
Double_t newSumw = 0;
TRandom3* trandom = new TRandom3( imember );
for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
Event* ev = Data()->GetEvent(ievt);
Double_t newWeight = trandom->Rndm();
ev->SetBoostWeight( newWeight );
newSumw += newWeight;
}
for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
Event* ev = Data()->GetEvent(ievt);
ev->SetBoostWeight( ev->GetBoostWeight() * Data()->GetNTrainingEvents() / newSumw );
}
return 1.0;
}
void TMVA::MethodCommittee::WriteWeightsToStream( ostream& o ) const
{
for (UInt_t imember=0; imember<GetCommittee().size(); imember++) {
o << endl;
o << "------------------------------ new member: " << imember << " ---------------" << endl;
o << "boost weight: " << GetBoostWeights()[imember] << endl;
(dynamic_cast<MethodBase*>(GetCommittee()[imember]))->WriteStateToStream( o );
}
}
void TMVA::MethodCommittee::AddWeightsXMLTo( void* ) const {
Log() << kFATAL << "Please implement writing of weights as XML" << Endl;
}
void TMVA::MethodCommittee::ReadWeightsFromStream( istream& istr )
{
std::vector<IMethod*>::iterator member = GetCommittee().begin();
for (; member != GetCommittee().end(); member++) delete *member;
GetCommittee().clear();
GetBoostWeights().clear();
TString dummy;
UInt_t imember;
Double_t boostWeight;
DataSetInfo & dsi = DataInfo();
for (UInt_t i=0; i<fNMembers; i++) {
istr >> dummy >> dummy >> dummy >> imember;
istr >> dummy >> dummy >> boostWeight;
if (imember != i) {
Log() << kFATAL << "<ReadWeightsFromStream> fatal error while reading Weight file \n "
<< ": mismatch imember: " << imember << " != i: " << i << Endl;
}
IMethod* method = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( fMemberType )), dsi, "" );
(dynamic_cast<MethodBase*>(method))->ReadStateFromStream(istr);
GetCommittee().push_back(method);
GetBoostWeights().push_back(boostWeight);
}
}
Double_t TMVA::MethodCommittee::GetMvaValue( Double_t* err )
{
if (err != 0) *err = -1;
Double_t myMVA = 0;
Double_t norm = 0;
for (UInt_t itree=0; itree<GetCommittee().size(); itree++) {
Double_t tmpMVA = ( fUseMemberDecision ? ( (dynamic_cast<MethodBase*>(GetCommittee()[itree]))->IsSignalLike() ? 1.0 : -1.0 )
: GetCommittee()[itree]->GetMvaValue() );
if (fUseWeightedMembers){
myMVA += GetBoostWeights()[itree] * tmpMVA;
norm += GetBoostWeights()[itree];
}
else {
myMVA += tmpMVA;
norm += 1;
}
}
return (norm != 0) ? myMVA /= Double_t(norm) : -999;
}
void TMVA::MethodCommittee::WriteMonitoringHistosToFile( void ) const
{
Log() << kINFO << "Write monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
fBoostFactorHist->Write();
fErrFractHist->Write();
fMonitorNtuple->Write();
BaseDir()->cd();
}
vector< Double_t > TMVA::MethodCommittee::GetVariableImportance()
{
fVariableImportance.resize(GetNvar());
return fVariableImportance;
}
Double_t TMVA::MethodCommittee::GetVariableImportance(UInt_t ivar)
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar < (UInt_t)relativeImportance.size()) return relativeImportance[ivar];
else Log() << kFATAL << "<GetVariableImportance> ivar = " << ivar << " is out of range " << Endl;
return -1;
}
const TMVA::Ranking* TMVA::MethodCommittee::CreateRanking()
{
fRanking = new Ranking( GetName(), "Variable Importance" );
vector< Double_t> importance(this->GetVariableImportance());
for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
fRanking->AddRank( Rank( GetInputLabel(ivar), importance[ivar] ) );
}
return fRanking;
}
void TMVA::MethodCommittee::MakeClassSpecific( std::ostream& fout, const TString& className ) const
{
fout << " // not implemented for class: \"" << className << "\"" << endl;
fout << "};" << endl;
}
void TMVA::MethodCommittee::GetHelpMessage() const
{
Log() << Endl;
Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "<None>" << Endl;
Log() << Endl;
Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "<None>" << Endl;
Log() << Endl;
Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "<None>" << Endl;
}