#include "TMVA/Reader.h"
#include "TTree.h"
#include "TLeaf.h"
#include "TString.h"
#include "TClass.h"
#include "TH1D.h"
#include "TKey.h"
#include "TVector.h"
#include "TXMLEngine.h"
#include "TMath.h"
#include <cstdlib>
#include <string>
#include <vector>
#include <fstream>
#include <iostream>
#ifndef ROOT_TMVA_Tools
#include "TMVA/Tools.h"
#endif
#include "TMVA/Config.h"
#include "TMVA/ClassifierFactory.h"
#include "TMVA/IMethod.h"
#include "TMVA/MethodCuts.h"
#include "TMVA/MethodCategory.h"
#include "TMVA/DataSetManager.h"
ClassImp(TMVA::Reader)
TMVA::Reader::Reader( const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetManager( NULL ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fCalculateError(kFALSE),
fMvaEventError( 0 ),
fMvaEventErrorUpper( 0 ),
fLogger ( 0 )
{
fDataSetManager = new DataSetManager( fDataInputHandler );
fDataSetManager->AddDataSetInfo(fDataSetInfo);
fLogger = new MsgLogger(this);
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
Init();
}
TMVA::Reader::Reader( std::vector<TString>& inputVars, const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetManager( NULL ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fCalculateError(kFALSE),
fMvaEventError( 0 ),
fMvaEventErrorUpper( 0 ),
fLogger ( 0 )
{
fDataSetManager = new DataSetManager( fDataInputHandler );
fDataSetManager->AddDataSetInfo(fDataSetInfo);
fLogger = new MsgLogger(this);
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
for (std::vector<TString>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ivar++)
DataInfo().AddVariable( *ivar );
Init();
}
TMVA::Reader::Reader( std::vector<std::string>& inputVars, const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetManager( NULL ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fCalculateError(kFALSE),
fMvaEventError( 0 ),
fMvaEventErrorUpper( 0 ),
fLogger ( 0 )
{
fDataSetManager = new DataSetManager( fDataInputHandler );
fDataSetManager->AddDataSetInfo(fDataSetInfo);
fLogger = new MsgLogger(this);
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
for (std::vector<std::string>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ivar++)
DataInfo().AddVariable( ivar->c_str() );
Init();
}
TMVA::Reader::Reader( const std::string& varNames, const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetManager( NULL ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fCalculateError(kFALSE),
fMvaEventError( 0 ),
fMvaEventErrorUpper( 0 ),
fLogger ( 0 )
{
fDataSetManager = new DataSetManager( fDataInputHandler );
fDataSetManager->AddDataSetInfo(fDataSetInfo);
fLogger = new MsgLogger(this);
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
DecodeVarNames(varNames);
Init();
}
TMVA::Reader::Reader( const TString& varNames, const TString& theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSetManager( NULL ),
fDataSetInfo(),
fVerbose( verbose ),
fSilent ( kFALSE ),
fColor ( kFALSE ),
fCalculateError(kFALSE),
fMvaEventError( 0 ),
fMvaEventErrorUpper( 0 ),
fLogger ( 0 )
{
fDataSetManager = new DataSetManager( fDataInputHandler );
fDataSetManager->AddDataSetInfo(fDataSetInfo);
fLogger = new MsgLogger(this);
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
DecodeVarNames(varNames);
Init();
}
void TMVA::Reader::DeclareOptions()
{
if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput();
DeclareOptionRef( fVerbose, "V", "Verbose flag" );
DeclareOptionRef( fColor, "Color", "Color flag (default True)" );
DeclareOptionRef( fSilent, "Silent", "Boolean silent flag (default False)" );
DeclareOptionRef( fCalculateError, "Error", "Calculates errors (default False)" );
}
TMVA::Reader::~Reader( void )
{
std::map<TString, IMethod* >::iterator itr;
for( itr = fMethodMap.begin(); itr != fMethodMap.end(); itr++) {
delete itr->second;
}
fMethodMap.clear();
delete fDataSetManager;
delete fLogger;
}
void TMVA::Reader::Init( void )
{
if (Verbose()) fLogger->SetMinType( kVERBOSE );
gConfig().SetUseColor( fColor );
gConfig().SetSilent ( fSilent );
}
void TMVA::Reader::AddVariable( const TString& expression, Float_t* datalink )
{
DataInfo().AddVariable( expression, "", "", 0, 0, 'F', kFALSE ,(void*)datalink );
}
void TMVA::Reader::AddVariable( const TString& expression, Int_t* datalink )
{
Log() << kFATAL << "Reader::AddVariable( const TString& expression, Int_t* datalink ), this function is deprecated, please provide all variables to the reader as floats" << Endl;
Log() << kFATAL << "Reader::AddVariable( const TString& expression, Int_t* datalink ), this function is deprecated, please provide all variables to the reader as floats" << Endl;
DataInfo().AddVariable(expression, "", "", 0, 0, 'I', kFALSE, (void*)datalink );
}
void TMVA::Reader::AddSpectator( const TString& expression, Float_t* datalink )
{
DataInfo().AddSpectator( expression, "", "", 0, 0, 'F', kFALSE ,(void*)datalink );
}
void TMVA::Reader::AddSpectator( const TString& expression, Int_t* datalink )
{
DataInfo().AddSpectator(expression, "", "", 0, 0, 'I', kFALSE, (void*)datalink );
}
TString TMVA::Reader::GetMethodTypeFromFile( const TString& filename )
{
std::ifstream fin( filename );
if (!fin.good()) {
Log() << kFATAL << "<BookMVA> fatal error: "
<< "unable to open input weight file: " << filename << Endl;
}
TString fullMethodName("");
if (filename.EndsWith(".xml")) {
fin.close();
#if ROOT_VERSION_CODE >= ROOT_VERSION(5,29,0)
void* doc = gTools().xmlengine().ParseFile(filename,gTools().xmlenginebuffersize());
#else
void* doc = gTools().xmlengine().ParseFile(filename);
#endif
void* rootnode = gTools().xmlengine().DocGetRootElement(doc);
gTools().ReadAttr(rootnode, "Method", fullMethodName);
gTools().xmlengine().FreeDoc(doc);
}
else {
char buf[512];
fin.getline(buf,512);
while (!TString(buf).BeginsWith("Method")) fin.getline(buf,512);
fullMethodName = TString(buf);
fin.close();
}
TString methodType = fullMethodName(0,fullMethodName.Index("::"));
if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
return methodType;
}
TMVA::IMethod* TMVA::Reader::BookMVA( const TString& methodTag, const TString& weightfile )
{
if (fMethodMap.find( methodTag ) != fMethodMap.end())
Log() << kFATAL << "<BookMVA> method tag \"" << methodTag << "\" already exists!" << Endl;
TString methodType(GetMethodTypeFromFile(weightfile));
Log() << kINFO << "Booking \"" << methodTag << "\" of type \"" << methodType << "\" from " << weightfile << "." << Endl;
MethodBase* method = dynamic_cast<MethodBase*>(this->BookMVA( Types::Instance().GetMethodType(methodType),
weightfile ) );
if( method && method->GetMethodType() == Types::kCategory ){
MethodCategory *methCat = (dynamic_cast<MethodCategory*>(method));
if( !methCat )
Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Reader" << Endl;
methCat->fDataSetManager = fDataSetManager;
}
return fMethodMap[methodTag] = method;
}
TMVA::IMethod* TMVA::Reader::BookMVA( TMVA::Types::EMVA methodType, const TString& weightfile )
{
IMethod* im = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( methodType )),
DataInfo(), weightfile );
MethodBase *method = (dynamic_cast<MethodBase*>(im));
if (method==0) return im;
if( method->GetMethodType() == Types::kCategory ){
MethodCategory *methCat = (dynamic_cast<MethodCategory*>(method));
if( !methCat )
Log() << kERROR << "Method with type kCategory cannot be casted to MethodCategory. /Reader" << Endl;
methCat->fDataSetManager = fDataSetManager;
}
method->SetupMethod();
method->DeclareCompatibilityOptions();
method->ReadStateFromFile();
method->CheckSetup();
Log() << kINFO << "Booked classifier \"" << method->GetMethodName()
<< "\" of type: \"" << method->GetMethodTypeName() << "\"" << Endl;
return method;
}
TMVA::IMethod* TMVA::Reader::BookMVA( TMVA::Types::EMVA methodType, const char* xmlstr )
{
#if (ROOT_SVN_REVISION >= 32259) && (ROOT_VERSION_CODE >= 334336) // 5.26/00
IMethod* im = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( methodType )),
DataInfo(), "" );
MethodBase *method = (dynamic_cast<MethodBase*>(im));
if(!method) return 0;
if( method->GetMethodType() == Types::kCategory ){
MethodCategory *methCat = (dynamic_cast<MethodCategory*>(method));
if( !methCat )
Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Reader" << Endl;
methCat->fDataSetManager = fDataSetManager;
}
method->SetupMethod();
method->DeclareCompatibilityOptions();
method->ReadStateFromXMLString( xmlstr );
method->CheckSetup();
Log() << kINFO << "Booked classifier \"" << method->GetMethodName()
<< "\" of type: \"" << method->GetMethodTypeName() << "\"" << Endl;
return method;
#else
Log() << kFATAL << "Method Reader::BookMVA(TMVA::Types::EMVA methodType = " << methodType
<< ", const char* xmlstr = " << xmlstr
<< " ) is not available for ROOT versions prior to 5.26/00." << Endl;
return 0;
#endif
}
Double_t TMVA::Reader::EvaluateMVA( const std::vector<Float_t>& inputVec, const TString& methodTag, Double_t aux )
{
IMethod* imeth = FindMVA( methodTag );
MethodBase* meth = dynamic_cast<TMVA::MethodBase*>(imeth);
if(meth==0) return 0;
Event* tmpEvent=new Event(inputVec, DataInfo().GetNVariables());
for (UInt_t i=0; i<inputVec.size(); i++){
if (TMath::IsNaN(inputVec[i])) {
Log() << kERROR << i << "-th variable of the event is NaN --> return MVA value -999, \n that's all I can do, please fix or remove this event." << Endl;
delete tmpEvent;
return -999;
}
}
if (meth->GetMethodType() == TMVA::Types::kCuts) {
TMVA::MethodCuts* mc = dynamic_cast<TMVA::MethodCuts*>(meth);
if(mc)
mc->SetTestSignalEfficiency( aux );
}
Double_t val = meth->GetMvaValue( tmpEvent, (fCalculateError?&fMvaEventError:0));
delete tmpEvent;
return val;
}
Double_t TMVA::Reader::EvaluateMVA( const std::vector<Double_t>& inputVec, const TString& methodTag, Double_t aux )
{
if(fTmpEvalVec.size() != inputVec.size())
fTmpEvalVec.resize(inputVec.size());
for (UInt_t idx=0; idx!=inputVec.size(); idx++ )
fTmpEvalVec[idx]=inputVec[idx];
return EvaluateMVA( fTmpEvalVec, methodTag, aux );
}
Double_t TMVA::Reader::EvaluateMVA( const TString& methodTag, Double_t aux )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
Log() << kINFO << "<EvaluateMVA> unknown classifier in map; "
<< "you looked for \"" << methodTag << "\" within available methods: " << Endl;
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << " --> " << it->first << Endl;
Log() << "Check calling string" << kFATAL << Endl;
}
else method = it->second;
MethodBase * kl = dynamic_cast<TMVA::MethodBase*>(method);
if(kl==0)
Log() << kFATAL << methodTag << " is not a method" << Endl;
const Event* ev = kl->GetEvent();
for (UInt_t i=0; i<ev->GetNVariables(); i++){
if (TMath::IsNaN(ev->GetValue(i))) {
Log() << kERROR << i << "-th variable of the event is NaN --> return MVA value -999, \n that's all I can do, please fix or remove this event." << Endl;
return -999;
}
}
return this->EvaluateMVA( kl, aux );
}
Double_t TMVA::Reader::EvaluateMVA( MethodBase* method, Double_t aux )
{
if (method->GetMethodType() == TMVA::Types::kCuts) {
TMVA::MethodCuts* mc = dynamic_cast<TMVA::MethodCuts*>(method);
if(mc)
mc->SetTestSignalEfficiency( aux );
}
return method->GetMvaValue( (fCalculateError?&fMvaEventError:0),
(fCalculateError?&fMvaEventErrorUpper:0) );
}
const std::vector< Float_t >& TMVA::Reader::EvaluateRegression( const TString& methodTag, Double_t aux )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
Log() << kINFO << "<EvaluateMVA> unknown method in map; "
<< "you looked for \"" << methodTag << "\" within available methods: " << Endl;
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << " --> " << it->first << Endl;
Log() << "Check calling string" << kFATAL << Endl;
}
else method = it->second;
MethodBase * kl = dynamic_cast<TMVA::MethodBase*>(method);
if(kl==0)
Log() << kFATAL << methodTag << " is not a method" << Endl;
const Event* ev = kl->GetEvent();
for (UInt_t i=0; i<ev->GetNVariables(); i++){
if (TMath::IsNaN(ev->GetValue(i))) {
Log() << kERROR << i << "-th variable of the event is NaN, \n regression values might evaluate to .. what do I know. \n sorry this warning is all I can do, please fix or remove this event." << Endl;
}
}
return this->EvaluateRegression( kl, aux );
}
const std::vector< Float_t >& TMVA::Reader::EvaluateRegression( MethodBase* method, Double_t )
{
const Event* ev = method->GetEvent();
for (UInt_t i=0; i<ev->GetNVariables(); i++){
if (TMath::IsNaN(ev->GetValue(i))) {
Log() << kERROR << i << "-th variable of the event is NaN, \n regression values might evaluate to .. what do I know. \n sorry this warning is all I can do, please fix or remove this event." << Endl;
}
}
return method->GetRegressionValues();
}
Float_t TMVA::Reader::EvaluateRegression( UInt_t tgtNumber, const TString& methodTag, Double_t aux )
{
try {
return EvaluateRegression(methodTag, aux).at(tgtNumber);
}
catch (std::out_of_range e) {
Log() << kWARNING << "Regression could not be evaluated for target-number " << tgtNumber << Endl;
return 0;
}
}
const std::vector< Float_t >& TMVA::Reader::EvaluateMulticlass( const TString& methodTag, Double_t aux )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
Log() << kINFO << "<EvaluateMVA> unknown method in map; "
<< "you looked for \"" << methodTag << "\" within available methods: " << Endl;
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << " --> " << it->first << Endl;
Log() << "Check calling string" << kFATAL << Endl;
}
else method = it->second;
MethodBase * kl = dynamic_cast<TMVA::MethodBase*>(method);
if(kl==0)
Log() << kFATAL << methodTag << " is not a method" << Endl;
const Event* ev = kl->GetEvent();
for (UInt_t i=0; i<ev->GetNVariables(); i++){
if (TMath::IsNaN(ev->GetValue(i))) {
Log() << kERROR << i << "-th variable of the event is NaN, \n regression values might evaluate to .. what do I know. \n sorry this warning is all I can do, please fix or remove this event." << Endl;
}
}
return this->EvaluateMulticlass( kl, aux );
}
const std::vector< Float_t >& TMVA::Reader::EvaluateMulticlass( MethodBase* method, Double_t )
{
const Event* ev = method->GetEvent();
for (UInt_t i=0; i<ev->GetNVariables(); i++){
if (TMath::IsNaN(ev->GetValue(i))) {
Log() << kERROR << i << "-th variable of the event is NaN, \n regression values might evaluate to .. what do I know. \n sorry this warning is all I can do, please fix or remove this event." << Endl;
}
}
return method->GetMulticlassValues();
}
Float_t TMVA::Reader::EvaluateMulticlass( UInt_t clsNumber, const TString& methodTag, Double_t aux )
{
try {
return EvaluateMulticlass(methodTag, aux).at(clsNumber);
}
catch (std::out_of_range e) {
Log() << kWARNING << "Multiclass could not be evaluated for class-number " << clsNumber << Endl;
return 0;
}
}
TMVA::IMethod* TMVA::Reader::FindMVA( const TString& methodTag )
{
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it != fMethodMap.end()) return it->second;
Log() << kERROR << "Method " << methodTag << " not found!" << Endl;
return 0;
}
TMVA::MethodCuts* TMVA::Reader::FindCutsMVA( const TString& methodTag )
{
return dynamic_cast<MethodCuts*>(FindMVA(methodTag));
}
Double_t TMVA::Reader::GetProba( const TString& methodTag, Double_t ap_sig, Double_t mvaVal )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << "M" << it->first << Endl;
Log() << kFATAL << "<EvaluateMVA> unknown classifier in map: " << method << "; "
<< "you looked for " << methodTag<< " while the available methods are : " << Endl;
}
else method = it->second;
MethodBase* kl = dynamic_cast<MethodBase*>(method);
if(kl==0) return -1;
const Event* ev = kl->GetEvent();
for (UInt_t i=0; i<ev->GetNVariables(); i++){
if (TMath::IsNaN(ev->GetValue(i))) {
Log() << kERROR << i << "-th variable of the event is NaN --> return MVA value -999, \n that's all I can do, please fix or remove this event." << Endl;
return -999;
}
}
if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
return kl->GetProba( mvaVal, ap_sig );
}
Double_t TMVA::Reader::GetRarity( const TString& methodTag, Double_t mvaVal )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << "M" << it->first << Endl;
Log() << kFATAL << "<EvaluateMVA> unknown classifier in map: \"" << method << "\"; "
<< "you looked for \"" << methodTag<< "\" while the available methods are : " << Endl;
}
else method = it->second;
MethodBase* kl = dynamic_cast<MethodBase*>(method);
if(kl==0) return -1;
const Event* ev = kl->GetEvent();
for (UInt_t i=0; i<ev->GetNVariables(); i++){
if (TMath::IsNaN(ev->GetValue(i))) {
Log() << kERROR << i << "-th variable of the event is NaN --> return MVA value -999, \n that's all I can do, please fix or remove this event." << Endl;
return -999;
}
}
if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
return kl->GetRarity( mvaVal );
}
void TMVA::Reader::DecodeVarNames( const std::string& varNames )
{
size_t ipos = 0, f = 0;
while (f != varNames.length()) {
f = varNames.find( ':', ipos );
if (f > varNames.length()) f = varNames.length();
std::string subs = varNames.substr( ipos, f-ipos ); ipos = f+1;
DataInfo().AddVariable( subs.c_str() );
}
}
void TMVA::Reader::DecodeVarNames( const TString& varNames )
{
TString format;
Int_t n = varNames.Length();
TString format_obj;
for (int i=0; i< n+1 ; i++) {
format.Append(varNames(i));
if (varNames(i) == ':' || i == n) {
format.Chop();
format_obj = format;
format_obj.ReplaceAll("@","");
DataInfo().AddVariable( format_obj );
format.Resize(0);
}
}
}