#ifndef ROOT_TMVA_RuleFitAPI
#define ROOT_TMVA_RuleFitAPI
#include <fstream>
namespace TMVA {
class MsgLogger;
class MethodRuleFit;
class RuleFitAPI {
public:
RuleFitAPI( const TMVA::MethodRuleFit *rfbase, TMVA::RuleFit *rulefit, EMsgType minType );
virtual ~RuleFitAPI();
void WelcomeMessage();
void HowtoSetupRF();
void SetRFWorkDir(const char * wdir);
void CheckRFWorkDir();
inline void TrainRuleFit();
inline void TestRuleFit();
inline void VarImp();
Bool_t ReadModelSum();
const TString GetRFWorkDir() const { return fRFWorkDir; }
protected:
enum ERFMode { kRfRegress=1, kRfClass=2 };
enum EModel { kRfLinear=0, kRfRules=1, kRfBoth=2 };
enum ERFProgram { kRfTrain=0, kRfPredict, kRfVarimp };
typedef struct {
Int_t mode;
Int_t lmode;
Int_t n;
Int_t p;
Int_t max_rules;
Int_t tree_size;
Int_t path_speed;
Int_t path_xval;
Int_t path_steps;
Int_t path_testfreq;
Int_t tree_store;
Int_t cat_store;
} IntParms;
typedef struct {
Float_t xmiss;
Float_t trim_qntl;
Float_t huber;
Float_t inter_supp;
Float_t memory_par;
Float_t samp_fract;
Float_t path_inc;
Float_t conv_fac;
} RealParms;
void InitRuleFit();
void FillRealParmsDef();
void FillIntParmsDef();
void ImportSetup();
void SetTrainParms();
void SetTestParms();
Int_t RunRuleFit();
void SetRFTrain() { fRFProgram = kRfTrain; }
void SetRFPredict() { fRFProgram = kRfPredict; }
void SetRFVarimp() { fRFProgram = kRfVarimp; }
inline TString GetRFName(TString name);
inline Bool_t OpenRFile(TString name, std::ofstream & f);
inline Bool_t OpenRFile(TString name, std::ifstream & f);
inline Bool_t WriteInt(ofstream & f, const Int_t *v, Int_t n=1);
inline Bool_t WriteFloat(ofstream & f, const Float_t *v, Int_t n=1);
inline Int_t ReadInt(ifstream & f, Int_t *v, Int_t n=1) const;
inline Int_t ReadFloat(ifstream & f, Float_t *v, Int_t n=1) const;
Bool_t WriteAll();
Bool_t WriteIntParms();
Bool_t WriteRealParms();
Bool_t WriteLx();
Bool_t WriteProgram();
Bool_t WriteRealVarImp();
Bool_t WriteRfOut();
Bool_t WriteRfStatus();
Bool_t WriteRuleFitMod();
Bool_t WriteRuleFitSum();
Bool_t WriteTrain();
Bool_t WriteVarNames();
Bool_t WriteVarImp();
Bool_t WriteYhat();
Bool_t WriteTest();
Bool_t ReadYhat();
Bool_t ReadIntParms();
Bool_t ReadRealParms();
Bool_t ReadLx();
Bool_t ReadProgram();
Bool_t ReadRealVarImp();
Bool_t ReadRfOut();
Bool_t ReadRfStatus();
Bool_t ReadRuleFitMod();
Bool_t ReadRuleFitSum();
Bool_t ReadTrainX();
Bool_t ReadTrainY();
Bool_t ReadTrainW();
Bool_t ReadVarNames();
Bool_t ReadVarImp();
private:
RuleFitAPI();
const MethodRuleFit *fMethodRuleFit;
RuleFit *fRuleFit;
std::vector<Float_t> fRFYhat;
std::vector<Float_t> fRFVarImp;
std::vector<Int_t> fRFVarImpInd;
TString fRFWorkDir;
IntParms fRFIntParms;
RealParms fRFRealParms;
std::vector<int> fRFLx;
ERFProgram fRFProgram;
TString fModelType;
mutable MsgLogger fLogger;
ClassDef(RuleFitAPI,0)
};
}
void TMVA::RuleFitAPI::TrainRuleFit()
{
SetTrainParms();
WriteAll();
RunRuleFit();
}
void TMVA::RuleFitAPI::TestRuleFit()
{
SetTestParms();
WriteAll();
RunRuleFit();
ReadYhat();
}
void TMVA::RuleFitAPI::VarImp()
{
SetRFVarimp();
WriteAll();
RunRuleFit();
ReadVarImp();
}
TString TMVA::RuleFitAPI::GetRFName(TString name)
{
return fRFWorkDir+"/"+name;
}
Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ofstream & f)
{
TString fullName = GetRFName(name);
f.open(fullName);
if (!f.is_open()) {
fLogger << kERROR << "Error opening RuleFit file for output: "
<< fullName << Endl;
return kFALSE;
}
return kTRUE;
}
Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ifstream & f)
{
TString fullName = GetRFName(name);
f.open(fullName);
if (!f.is_open()) {
fLogger << kERROR << "Error opening RuleFit file for input: "
<< fullName << Endl;
return kFALSE;
}
return kTRUE;
}
Bool_t TMVA::RuleFitAPI::WriteInt(ofstream & f, const Int_t *v, Int_t n)
{
if (!f.is_open()) return kFALSE;
return f.write(reinterpret_cast<char const *>(v), n*sizeof(Int_t));
}
Bool_t TMVA::RuleFitAPI::WriteFloat(ofstream & f, const Float_t *v, Int_t n)
{
if (!f.is_open()) return kFALSE;
return f.write(reinterpret_cast<char const *>(v), n*sizeof(Float_t));
}
Int_t TMVA::RuleFitAPI::ReadInt(ifstream & f, Int_t *v, Int_t n) const
{
if (!f.is_open()) return 0;
if (f.read(reinterpret_cast<char *>(v), n*sizeof(Int_t))) return 1;
return 0;
}
Int_t TMVA::RuleFitAPI::ReadFloat(ifstream & f, Float_t *v, Int_t n) const
{
if (!f.is_open()) return 0;
if (f.read(reinterpret_cast<char *>(v), n*sizeof(Float_t))) return 1;
return 0;
}
#endif // RuleFitAPI_H