#include "TMVA/Reader.h"
#include "TTree.h"
#include "TLeaf.h"
#include "TString.h"
#include "TClass.h"
#include "TH1D.h"
#include "TKey.h"
#include "TVector.h"
#include "TXMLEngine.h"
#include <cstdlib>
#include <string>
#include <vector>
#include <fstream>
#include <iostream>
#ifndef ROOT_TMVA_Tools
#include "TMVA/Tools.h"
#endif
#include "TMVA/Config.h"
#include "TMVA/ClassifierFactory.h"
#include "TMVA/IMethod.h"
#include "TMVA/MethodCuts.h"
#define TMVA_Reader_TestIO__
#undef TMVA_Reader_TestIO__
ClassImp(TMVA::Reader)
TMVA::Reader::Reader( const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fMvaEventError( -1 ),
fLogger ( new MsgLogger(this) )
{
DataSetManager::CreateInstance(fDataInputHandler);
DataSetManager::Instance().AddDataSetInfo(fDataSetInfo);
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
Init();
}
TMVA::Reader::Reader( std::vector<TString>& inputVars, const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fMvaEventError( -1 ),
fLogger ( new MsgLogger(this) )
{
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
for (std::vector<TString>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ivar++) {
DataInfo().AddVariable( *ivar );
}
Init();
}
TMVA::Reader::Reader( std::vector<std::string>& inputVars, const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fMvaEventError( -1 ),
fLogger ( new MsgLogger(this) )
{
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
for (std::vector<std::string>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ivar++) {
DataInfo().AddVariable( ivar->c_str() );
}
Init();
}
TMVA::Reader::Reader( const std::string& varNames, const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fMvaEventError( -1 ),
fLogger ( new MsgLogger(this) )
{
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
DecodeVarNames(varNames);
Init();
}
TMVA::Reader::Reader( const TString& varNames, const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fMvaEventError( -1 ),
fLogger ( new MsgLogger(this) )
{
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
DecodeVarNames(varNames);
Init();
}
void TMVA::Reader::DeclareOptions()
{
if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput();
DeclareOptionRef( fVerbose, "V", "Verbose flag" );
DeclareOptionRef( fColor, "Color", "Color flag (default True)" );
DeclareOptionRef( fSilent, "Silent", "Boolean silent flag (default False)" );
}
TMVA::Reader::~Reader( void )
{
delete fLogger;
}
void TMVA::Reader::Init( void )
{
if (Verbose()) fLogger->SetMinType( kVERBOSE );
gConfig().SetUseColor( fColor );
gConfig().SetSilent ( fSilent );
}
void TMVA::Reader::AddVariable( const TString& expression, Float_t* datalink )
{
DataInfo().AddVariable( expression, "", "", 0, 0, 'F', kFALSE ,(void*)datalink );
}
void TMVA::Reader::AddVariable( const TString& expression, Int_t* datalink )
{
DataInfo().AddVariable(expression, "", "", 0, 0, 'I', kFALSE, (void*)datalink );
}
TMVA::IMethod* TMVA::Reader::BookMVA( const TString& methodTag, const TString& weightfile )
{
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it != fMethodMap.end()) {
Log() << kFATAL << "<BookMVA> method tag \"" << methodTag << "\" already exists!" << Endl;
}
Log() << kINFO << "Booking \"" << methodTag << "\" [" << weightfile << "]" << Endl;
TString fullMethodName("");
if (weightfile.EndsWith(".xml")) {
void* doc = gTools().xmlengine().ParseFile(weightfile);
void* rootnode = gTools().xmlengine().DocGetRootElement(doc);
gTools().ReadAttr(rootnode, "Method", fullMethodName);
}
else {
ifstream fin( weightfile );
if (!fin.good()) {
Log() << kFATAL << "<BookMVA> fatal error: "
<< "unable to open input weight file: " << weightfile << Endl;
}
char buf[512];
fin.getline(buf,512);
while (!TString(buf).BeginsWith("Method")) fin.getline(buf,512);
fullMethodName = TString(buf);
fin.close();
}
TString methodType = fullMethodName(0,fullMethodName.Index("::"));
if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
MethodBase* method = dynamic_cast<MethodBase*>(this->BookMVA( Types::Instance().GetMethodType(methodType),
weightfile ) );
return fMethodMap[methodTag] = method;
}
TMVA::IMethod* TMVA::Reader::BookMVA( TMVA::Types::EMVA methodType, const TString& weightfile )
{
IMethod* im = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( methodType )),
DataInfo(), weightfile );
MethodBase *method = (dynamic_cast<MethodBase*>(im));
method->SetupMethod();
method->ReadStateFromFile();
Log() << kINFO << "Booked classifier \"" << method->GetMethodName()
<< "\" of type: \"" << method->GetMethodTypeName() << "\"" << Endl;
#ifdef TMVA_Reader_TestIO__
std::ofstream tfile( weightfile+".control" );
method->WriteStateToStream(tfile);
tfile.close();
#endif
return method;
}
Double_t TMVA::Reader::EvaluateMVA( const std::vector<Float_t>& , const TString& methodTag, Double_t aux )
{
return EvaluateMVA( methodTag, aux );
}
Double_t TMVA::Reader::EvaluateMVA( const std::vector<Double_t>& , const TString& methodTag, Double_t aux )
{
return EvaluateMVA( methodTag, aux );
}
Double_t TMVA::Reader::EvaluateMVA( const TString& methodTag, Double_t aux )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
Log() << kINFO << "<EvaluateMVA> unknown classifier in map; "
<< "you looked for \"" << methodTag << "\" within available methods: " << Endl;
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << " --> " << it->first << Endl;
Log() << "Check calling string" << kFATAL << Endl;
}
else method = it->second;
return this->EvaluateMVA( dynamic_cast<TMVA::MethodBase*>(method), aux );
}
Double_t TMVA::Reader::EvaluateMVA( MethodBase* method, Double_t aux )
{
if (method->GetMethodType() == TMVA::Types::kCuts)
dynamic_cast<TMVA::MethodCuts*>(method)->SetTestSignalEfficiency( aux );
return method->GetMvaValue( &fMvaEventError );
}
const std::vector< Float_t >& TMVA::Reader::EvaluateRegression( const TString& methodTag, Double_t aux )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
Log() << kINFO << "<EvaluateMVA> unknown method in map; "
<< "you looked for \"" << methodTag << "\" within available methods: " << Endl;
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << " --> " << it->first << Endl;
Log() << "Check calling string" << kFATAL << Endl;
}
else method = it->second;
return this->EvaluateRegression( dynamic_cast<TMVA::MethodBase*>(method), aux );
}
const std::vector< Float_t >& TMVA::Reader::EvaluateRegression( MethodBase* method, Double_t )
{
return method->GetRegressionValues();
}
Float_t TMVA::Reader::EvaluateRegression( UInt_t tgtNumber, const TString& methodTag, Double_t aux )
{
try {
return EvaluateRegression(methodTag, aux).at(tgtNumber);
}
catch (std::out_of_range e) {
Log() << kWARNING << "Regression could not be evaluated for target-number " << tgtNumber << Endl;
return 0;
}
}
TMVA::IMethod* TMVA::Reader::FindMVA( const TString& methodTag )
{
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it != fMethodMap.end()) return it->second;
Log() << kERROR << "Method " << methodTag << " not found!" << Endl;
return 0;
}
TMVA::MethodCuts* TMVA::Reader::FindCutsMVA( const TString& methodTag )
{
return dynamic_cast<MethodCuts*>(FindMVA(methodTag));
}
Double_t TMVA::Reader::GetProba( const TString& methodTag, Double_t ap_sig, Double_t mvaVal )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << "M" << it->first << Endl;
Log() << kFATAL << "<EvaluateMVA> unknown classifier in map: " << method << "; "
<< "you looked for " << methodTag<< " while the available methods are : " << Endl;
}
else method = it->second;
MethodBase* kl = dynamic_cast<MethodBase*>(method);
if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
return kl->GetProba( mvaVal, ap_sig );
}
Double_t TMVA::Reader::GetRarity( const TString& methodTag, Double_t mvaVal )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << "M" << it->first << Endl;
Log() << kFATAL << "<EvaluateMVA> unknown classifier in map: \"" << method << "\"; "
<< "you looked for \"" << methodTag<< "\" while the available methods are : " << Endl;
}
else method = it->second;
MethodBase* kl = dynamic_cast<MethodBase*>(method);
if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
return kl->GetRarity( mvaVal );
}
void TMVA::Reader::DecodeVarNames( const std::string& varNames )
{
size_t ipos = 0, f = 0;
while (f != varNames.length()) {
f = varNames.find( ':', ipos );
if (f > varNames.length()) f = varNames.length();
std::string subs = varNames.substr( ipos, f-ipos ); ipos = f+1;
DataInfo().AddVariable( subs.c_str() );
}
}
void TMVA::Reader::DecodeVarNames( const TString& varNames )
{
TString format;
Int_t n = varNames.Length();
TString format_obj;
for (int i=0; i< n+1 ; i++) {
format.Append(varNames(i));
if (varNames(i) == ':' || i == n) {
format.Chop();
format_obj = format;
format_obj.ReplaceAll("@","");
DataInfo().AddVariable( format_obj );
format.Resize(0);
}
}
}