#ifndef ROOT_TMVA_MethodBase
#define ROOT_TMVA_MethodBase
#include <iosfwd>
#include <vector>
#include <map>
#include "assert.h"
#ifndef ROOT_TString
#include "TString.h"
#endif
#ifndef ROOT_TMVA_IMethod
#include "TMVA/IMethod.h"
#endif
#ifndef ROOT_TMVA_Configurable
#include "TMVA/Configurable.h"
#endif
#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"
#endif
#ifndef ROOT_TMVA_DataSet
#include "TMVA/DataSet.h"
#endif
#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif
#ifndef ROOT_TMVA_TransformationHandler
#include "TMVA/TransformationHandler.h"
#endif
#ifndef ROOT_TMVA_OptimizeConfigParameters
#include "TMVA/OptimizeConfigParameters.h"
#endif
class TGraph;
class TTree;
class TDirectory;
class TSpline;
class TH1F;
class TH1D;
namespace TMVA {
class Ranking;
class PDF;
class TSpline1;
class MethodCuts;
class MethodBoost;
class DataSetInfo;
class MethodBase : virtual public IMethod, public Configurable {
friend class Factory;
public:
enum EWeightFileType { kROOT=0, kTEXT };
MethodBase( const TString& jobName,
Types::EMVA methodType,
const TString& methodTitle,
DataSetInfo& dsi,
const TString& theOption = "",
TDirectory* theBaseDir = 0 );
MethodBase( Types::EMVA methodType,
DataSetInfo& dsi,
const TString& weightFile,
TDirectory* theBaseDir = 0 );
virtual ~MethodBase();
void SetupMethod();
void ProcessSetup();
virtual void CheckSetup();
void AddOutput( Types::ETreeType type, Types::EAnalysisType analysisType );
void TrainMethod();
virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA");
virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
virtual void Train() = 0;
void SetTrainTime( Double_t trainTime ) { fTrainTime = trainTime; }
Double_t GetTrainTime() const { return fTrainTime; }
void SetTestTime ( Double_t testTime ) { fTestTime = testTime; }
Double_t GetTestTime () const { return fTestTime; }
virtual void TestClassification();
virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X");
virtual void TestMulticlass();
virtual void TestRegression( Double_t& bias, Double_t& biasT,
Double_t& dev, Double_t& devT,
Double_t& rms, Double_t& rmsT,
Double_t& mInf, Double_t& mInfT,
Double_t& corr,
Types::ETreeType type );
virtual void Init() = 0;
virtual void DeclareOptions() = 0;
virtual void ProcessOptions() = 0;
virtual void DeclareCompatibilityOptions();
virtual void Reset(){return;}
virtual Double_t GetMvaValue( Double_t* errLower = 0, Double_t* errUpper = 0) = 0;
Double_t GetMvaValue( const TMVA::Event* const ev, Double_t* err = 0, Double_t* errUpper = 0 );
protected:
void NoErrorCalc(Double_t* const err, Double_t* const errUpper);
public:
const std::vector<Float_t>& GetRegressionValues(const TMVA::Event* const ev){
fTmpEvent = ev;
const std::vector<Float_t>* ptr = &GetRegressionValues();
fTmpEvent = 0;
return (*ptr);
}
virtual const std::vector<Float_t>& GetRegressionValues() {
std::vector<Float_t>* ptr = new std::vector<Float_t>(0);
return (*ptr);
}
virtual const std::vector<Float_t>& GetMulticlassValues() {
std::vector<Float_t>* ptr = new std::vector<Float_t>(0);
return (*ptr);
}
virtual Double_t GetProba( const Event *ev);
virtual Double_t GetProba( Double_t mvaVal, Double_t ap_sig );
virtual Double_t GetRarity( Double_t mvaVal, Types::ESBType reftype = Types::kBackground ) const;
virtual const Ranking* CreateRanking() = 0;
virtual void MakeClass( const TString& classFileName = TString("") ) const;
void PrintHelpMessage() const;
public:
void WriteStateToFile () const;
void ReadStateFromFile ();
protected:
virtual void AddWeightsXMLTo ( void* parent ) const = 0;
virtual void ReadWeightsFromXML ( void* wghtnode ) = 0;
virtual void ReadWeightsFromStream( std::istream& ) = 0;
virtual void ReadWeightsFromStream( TFile& ) {}
private:
friend class MethodCategory;
friend class MethodCompositeBase;
void WriteStateToXML ( void* parent ) const;
void ReadStateFromXML ( void* parent );
void WriteStateToStream ( std::ostream& tf ) const;
void WriteVarsToStream ( std::ostream& tf, const TString& prefix = "" ) const;
public:
void ReadStateFromStream ( std::istream& tf );
void ReadStateFromStream ( TFile& rf );
void ReadStateFromXMLString( const char* xmlstr );
private:
void AddVarsXMLTo ( void* parent ) const;
void AddSpectatorsXMLTo ( void* parent ) const;
void AddTargetsXMLTo ( void* parent ) const;
void AddClassesXMLTo ( void* parent ) const;
void ReadVariablesFromXML ( void* varnode );
void ReadSpectatorsFromXML( void* specnode);
void ReadTargetsFromXML ( void* tarnode );
void ReadClassesFromXML ( void* clsnode );
void ReadVarsFromStream ( std::istream& istr );
public:
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype);
virtual void WriteMonitoringHistosToFile() const;
virtual Double_t GetEfficiency( const TString&, Types::ETreeType, Double_t& err );
virtual Double_t GetTrainingEfficiency(const TString& );
virtual std::vector<Float_t> GetMulticlassEfficiency( std::vector<std::vector<Float_t> >& purity );
virtual std::vector<Float_t> GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity );
virtual Double_t GetSignificance() const;
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const;
virtual Double_t GetROCIntegral(PDF *pdfS=0, PDF *pdfB=0) const;
virtual Double_t GetMaximumSignificance( Double_t SignalEvents, Double_t BackgroundEvents,
Double_t& optimal_significance_value ) const;
virtual Double_t GetSeparation( TH1*, TH1* ) const;
virtual Double_t GetSeparation( PDF* pdfS = 0, PDF* pdfB = 0 ) const;
virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t& stddev,Double_t& stddev90Percent ) const;
const TString& GetJobName () const { return fJobName; }
const TString& GetMethodName () const { return fMethodName; }
TString GetMethodTypeName() const { return Types::Instance().GetMethodName(fMethodType); }
Types::EMVA GetMethodType () const { return fMethodType; }
const char* GetName () const { return fMethodName.Data(); }
const TString& GetTestvarName () const { return fTestvar; }
const TString GetProbaName () const { return fTestvar + "_Proba"; }
TString GetWeightFileName() const;
void SetTestvarName ( const TString & v="" ) { fTestvar = (v=="") ? ("MVA_" + GetMethodName()) : v; }
UInt_t GetNvar() const { return DataInfo().GetNVariables(); }
UInt_t GetNVariables() const { return DataInfo().GetNVariables(); }
UInt_t GetNTargets() const { return DataInfo().GetNTargets(); };
const TString& GetInputVar ( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetInternalName(); }
const TString& GetInputLabel( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetLabel(); }
const TString& GetInputTitle( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetTitle(); }
Double_t GetMean( Int_t ivar ) const { return GetTransformationHandler().GetMean(ivar); }
Double_t GetRMS ( Int_t ivar ) const { return GetTransformationHandler().GetRMS(ivar); }
Double_t GetXmin( Int_t ivar ) const { return GetTransformationHandler().GetMin(ivar); }
Double_t GetXmax( Int_t ivar ) const { return GetTransformationHandler().GetMax(ivar); }
Double_t GetSignalReferenceCut() const { return fSignalReferenceCut; }
Double_t GetSignalReferenceCutOrientation() const { return fSignalReferenceCutOrientation; }
void SetSignalReferenceCut( Double_t cut ) { fSignalReferenceCut = cut; }
void SetSignalReferenceCutOrientation( Double_t cutOrientation ) { fSignalReferenceCutOrientation = cutOrientation; }
TDirectory* BaseDir() const;
TDirectory* MethodBaseDir() const;
void SetMethodDir ( TDirectory* methodDir ) { fBaseDir = fMethodBaseDir = methodDir; }
void SetBaseDir( TDirectory* methodDir ){ fBaseDir = methodDir; }
void SetMethodBaseDir( TDirectory* methodDir ){ fMethodBaseDir = methodDir; }
UInt_t GetTrainingTMVAVersionCode() const { return fTMVATrainingVersion; }
UInt_t GetTrainingROOTVersionCode() const { return fROOTTrainingVersion; }
TString GetTrainingTMVAVersionString() const;
TString GetTrainingROOTVersionString() const;
TransformationHandler& GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
{
if(fTransformationPointer && takeReroutedIfAvailable) return *fTransformationPointer; else return fTransformation;
}
const TransformationHandler& GetTransformationHandler(Bool_t takeReroutedIfAvailable=true) const
{
if(fTransformationPointer && takeReroutedIfAvailable) return *fTransformationPointer; else return fTransformation;
}
void RerouteTransformationHandler (TransformationHandler* fTargetTransformation) { fTransformationPointer=fTargetTransformation; }
DataSet* Data() const { return DataInfo().GetDataSet(); }
DataSetInfo& DataInfo() const { return fDataSetInfo; }
mutable const Event* fTmpEvent;
UInt_t GetNEvents () const { return Data()->GetNEvents(); }
const Event* GetEvent () const;
const Event* GetEvent ( const TMVA::Event* ev ) const;
const Event* GetEvent ( Long64_t ievt ) const;
const Event* GetEvent ( Long64_t ievt , Types::ETreeType type ) const;
const Event* GetTrainingEvent( Long64_t ievt ) const;
const Event* GetTestingEvent ( Long64_t ievt ) const;
const std::vector<TMVA::Event*>& GetEventCollection( Types::ETreeType type );
virtual Bool_t IsSignalLike();
virtual Bool_t IsSignalLike(Double_t mvaVal);
Bool_t HasMVAPdfs() const { return fHasMVAPdfs; }
virtual void SetAnalysisType( Types::EAnalysisType type ) { fAnalysisType = type; }
Types::EAnalysisType GetAnalysisType() const { return fAnalysisType; }
Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
Bool_t DoMulticlass() const { return fAnalysisType == Types::kMulticlass; }
void DisableWriting(Bool_t setter){ fDisableWriting = setter; }
protected:
void SetWeightFileName( TString );
const TString& GetWeightFileDir() const { return fFileDir; }
void SetWeightFileDir( TString fileDir );
Bool_t IsNormalised() const { return fNormalise; }
void SetNormalised( Bool_t norm ) { fNormalise = norm; }
Bool_t Verbose() const { return fVerbose; }
Bool_t Help () const { return fHelp; }
const TString& GetInternalVarName( Int_t ivar ) const { return (*fInputVars)[ivar]; }
const TString& GetOriginalVarName( Int_t ivar ) const { return DataInfo().GetVariableInfo(ivar).GetExpression(); }
Bool_t HasTrainingTree() const { return Data()->GetNTrainingEvents() != 0; }
protected:
virtual void MakeClassSpecific( std::ostream&, const TString& = "" ) const {}
virtual void MakeClassSpecificHeader( std::ostream&, const TString& = "" ) const {}
static MethodBase* GetThisBase();
void Statistics( Types::ETreeType treeType, const TString& theVarName,
Double_t&, Double_t&, Double_t&,
Double_t&, Double_t&, Double_t& );
Bool_t TxtWeightsOnly() const { return kTRUE; }
protected:
Bool_t IsConstructedFromWeightFile() const { return fConstructedFromWeightFile; }
private:
void InitBase();
void DeclareBaseOptions();
void ProcessBaseOptions();
enum ECutOrientation { kNegative = -1, kPositive = +1 };
ECutOrientation GetCutOrientation() const { return fCutOrientation; }
void ResetThisBase();
void CreateMVAPdfs();
static Double_t IGetEffForRoot( Double_t );
Double_t GetEffForRoot ( Double_t );
Bool_t GetLine( std::istream& fin, char * buf );
virtual void AddClassifierOutput ( Types::ETreeType type );
virtual void AddClassifierOutputProb( Types::ETreeType type );
virtual void AddRegressionOutput ( Types::ETreeType type );
virtual void AddMulticlassOutput ( Types::ETreeType type );
private:
void AddInfoItem( void* gi, const TString& name,
const TString& value) const;
static void CreateVariableTransforms(const TString& trafoDefinition,
TMVA::DataSetInfo& dataInfo,
TMVA::TransformationHandler& transformationHandler,
TMVA::MsgLogger& log );
protected:
Ranking* fRanking;
std::vector<TString>* fInputVars;
Int_t fNbins;
Int_t fNbinsMVAoutput;
Int_t fNbinsH;
Types::EAnalysisType fAnalysisType;
std::vector<Float_t>* fRegressionReturnVal;
std::vector<Float_t>* fMulticlassReturnVal;
private:
friend class MethodCuts;
Bool_t fDisableWriting;
DataSetInfo& fDataSetInfo;
Double_t fSignalReferenceCut;
Double_t fSignalReferenceCutOrientation;
Types::ESBType fVariableTransformType;
TString fJobName;
TString fMethodName;
Types::EMVA fMethodType;
TString fTestvar;
UInt_t fTMVATrainingVersion;
UInt_t fROOTTrainingVersion;
Bool_t fConstructedFromWeightFile;
TDirectory* fBaseDir;
mutable TDirectory* fMethodBaseDir;
TString fParentDir;
TString fFileDir;
TString fWeightFile;
private:
TH1* fEffS;
PDF* fDefaultPDF;
PDF* fMVAPdfS;
PDF* fMVAPdfB;
PDF* fSplS;
PDF* fSplB;
TSpline* fSpleffBvsS;
PDF* fSplTrainS;
PDF* fSplTrainB;
TSpline* fSplTrainEffBvsS;
private:
Double_t fMeanS;
Double_t fMeanB;
Double_t fRmsS;
Double_t fRmsB;
Double_t fXmin;
Double_t fXmax;
TString fVarTransformString;
TransformationHandler* fTransformationPointer;
TransformationHandler fTransformation;
Bool_t fVerbose;
TString fVerbosityLevelString;
EMsgType fVerbosityLevel;
Bool_t fHelp;
Bool_t fHasMVAPdfs;
Bool_t fIgnoreNegWeightsInTraining;
protected:
Bool_t IgnoreEventsWithNegWeightsInTraining() const { return fIgnoreNegWeightsInTraining; }
UInt_t fSignalClass;
UInt_t fBackgroundClass;
private:
Double_t fTrainTime;
Double_t fTestTime;
ECutOrientation fCutOrientation;
TSpline1* fSplRefS;
TSpline1* fSplRefB;
TSpline1* fSplTrainRefS;
TSpline1* fSplTrainRefB;
mutable std::vector<const std::vector<TMVA::Event*>*> fEventCollections;
public:
Bool_t fSetupCompleted;
private:
static MethodBase* fgThisBase;
private:
Bool_t fNormalise;
Bool_t fUseDecorr;
TString fVariableTransformTypeString;
Bool_t fTxtWeightsOnly;
Int_t fNbinsMVAPdf;
Int_t fNsmoothMVAPdf;
protected:
ClassDef(MethodBase,0)
};
}
inline const TMVA::Event* TMVA::MethodBase::GetEvent( const TMVA::Event* ev ) const
{
return GetTransformationHandler().Transform(ev);
}
inline const TMVA::Event* TMVA::MethodBase::GetEvent() const
{
if(fTmpEvent)
return GetTransformationHandler().Transform(fTmpEvent);
else
return GetTransformationHandler().Transform(Data()->GetEvent());
}
inline const TMVA::Event* TMVA::MethodBase::GetEvent( Long64_t ievt ) const
{
assert(fTmpEvent==0);
return GetTransformationHandler().Transform(Data()->GetEvent(ievt));
}
inline const TMVA::Event* TMVA::MethodBase::GetEvent( Long64_t ievt, Types::ETreeType type ) const
{
assert(fTmpEvent==0);
return GetTransformationHandler().Transform(Data()->GetEvent(ievt, type));
}
inline const TMVA::Event* TMVA::MethodBase::GetTrainingEvent( Long64_t ievt ) const
{
assert(fTmpEvent==0);
return GetEvent(ievt, Types::kTraining);
}
inline const TMVA::Event* TMVA::MethodBase::GetTestingEvent( Long64_t ievt ) const
{
assert(fTmpEvent==0);
return GetEvent(ievt, Types::kTesting);
}
#endif