ROOT logo
// @(#)root/tmva $Id: RuleFitAPI.h 20882 2007-11-19 11:31:26Z rdm $
// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : RuleFitAPI                                                            *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Interface to Friedman's RuleFit method                                    * 
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker    <Andreas.Hocker@cern.ch>     - CERN, Switzerland       *
 *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
 *      Helge Voss         <Helge.Voss@cern.ch>         - MPI-KP Heidelberg, Ger. *
 *      Kai Voss           <Kai.Voss@cern.ch>           - U. of Victoria, Canada  *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland                                                         * 
 *      U. of Victoria, Canada                                                    * 
 *      MPI-KP Heidelberg, Germany                                                * 
 *      LAPP, Annecy, France                                                      *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 *                                                                                *
 **********************************************************************************/

#ifndef ROOT_TMVA_RuleFitAPI
#define ROOT_TMVA_RuleFitAPI

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// RuleFitAPI                                                           //
//                                                                      //
// J Friedman's RuleFit method                                          //
//                                                                      //
//////////////////////////////////////////////////////////////////////////

#include <fstream>

namespace TMVA {

   class MsgLogger;
   class MethodRuleFit;

   class RuleFitAPI {

   public:

      RuleFitAPI( const TMVA::MethodRuleFit *rfbase, TMVA::RuleFit *rulefit, EMsgType minType );

      virtual ~RuleFitAPI();

      // welcome message
      void WelcomeMessage();

      // message on howto get the binary
      void HowtoSetupRF();

      // Set RuleFit working directory
      void SetRFWorkDir(const char * wdir);

      // Check RF work dir - aborts if it fails
      void CheckRFWorkDir();

      // run rf_go.exe in various modes
      inline void TrainRuleFit();
      inline void TestRuleFit();
      inline void VarImp();

      // read result into MethodRuleFit
      Bool_t ReadModelSum();

      // Get working directory
      const TString GetRFWorkDir() const { return fRFWorkDir; }

   protected:

      enum ERFMode    { kRfRegress=1, kRfClass=2 };          // RuleFit modes, default=Class
      enum EModel     { kRfLinear=0, kRfRules=1, kRfBoth=2 }; // models, default=Both (rules+linear)
      enum ERFProgram { kRfTrain=0, kRfPredict, kRfVarimp };    // rf_go.exe running mode
  
      // integer parameters
      typedef struct {
         Int_t mode;
         Int_t lmode;
         Int_t n;
         Int_t p;
         Int_t max_rules;
         Int_t tree_size;
         Int_t path_speed;
         Int_t path_xval;
         Int_t path_steps;
         Int_t path_testfreq;
         Int_t tree_store;
         Int_t cat_store;
      } IntParms;

      // float parameters
      typedef struct {
         Float_t  xmiss;
         Float_t  trim_qntl;
         Float_t  huber;
         Float_t  inter_supp;
         Float_t  memory_par;
         Float_t  samp_fract;
         Float_t  path_inc;
         Float_t  conv_fac;
      } RealParms;

      // setup
      void InitRuleFit();
      void FillRealParmsDef();
      void FillIntParmsDef();
      void ImportSetup();
      void SetTrainParms();
      void SetTestParms();

      // run
      Int_t  RunRuleFit();

      // set rf_go.exe running mode
      void SetRFTrain()   { fRFProgram = kRfTrain; }
      void SetRFPredict() { fRFProgram = kRfPredict; }
      void SetRFVarimp()  { fRFProgram = kRfVarimp; }

      // handle rulefit files
      inline TString GetRFName(TString name);
      inline Bool_t  OpenRFile(TString name, std::ofstream & f);
      inline Bool_t  OpenRFile(TString name, std::ifstream & f);

      // read/write binary files
      inline Bool_t WriteInt(ofstream &   f, const Int_t   *v, Int_t n=1);
      inline Bool_t WriteFloat(ofstream & f, const Float_t *v, Int_t n=1);
      inline Int_t  ReadInt(ifstream & f,   Int_t *v, Int_t n=1) const;
      inline Int_t  ReadFloat(ifstream & f, Float_t *v, Int_t n=1) const;
  
      // write rf_go.exe i/o files
      Bool_t WriteAll();
      Bool_t WriteIntParms();
      Bool_t WriteRealParms();
      Bool_t WriteLx();
      Bool_t WriteProgram();
      Bool_t WriteRealVarImp();
      Bool_t WriteRfOut();
      Bool_t WriteRfStatus();
      Bool_t WriteRuleFitMod();
      Bool_t WriteRuleFitSum();
      Bool_t WriteTrain();
      Bool_t WriteVarNames();
      Bool_t WriteVarImp();
      Bool_t WriteYhat();
      Bool_t WriteTest();

      // read rf_go.exe i/o files
      Bool_t ReadYhat();
      Bool_t ReadIntParms();
      Bool_t ReadRealParms();
      Bool_t ReadLx();
      Bool_t ReadProgram();
      Bool_t ReadRealVarImp();
      Bool_t ReadRfOut();
      Bool_t ReadRfStatus();
      Bool_t ReadRuleFitMod();
      Bool_t ReadRuleFitSum();
      Bool_t ReadTrainX();
      Bool_t ReadTrainY();
      Bool_t ReadTrainW();
      Bool_t ReadVarNames();
      Bool_t ReadVarImp();

   private:
      // prevent empty constructor from being used
      RuleFitAPI();
      const MethodRuleFit *fMethodRuleFit; // parent method - set in constructor
      RuleFit             *fRuleFit;       // non const ptr to RuleFit class in MethodRuleFit
      //
      std::vector<Float_t> fRFYhat;      // score results from test sample
      std::vector<Float_t> fRFVarImp;    // variable importances
      std::vector<Int_t>   fRFVarImpInd; // variable index
      TString              fRFWorkDir;   // working directory
      IntParms             fRFIntParms;  // integer parameters
      RealParms            fRFRealParms; // real parameters
      std::vector<int>     fRFLx;        // variable selector
      ERFProgram           fRFProgram;   // what to run
      TString              fModelType;   // model type string

      mutable MsgLogger    fLogger;          // message logger

      ClassDef(RuleFitAPI,0)        // Friedman's RuleFit method

   };

} // namespace TMVA

//_______________________________________________________________________
void TMVA::RuleFitAPI::TrainRuleFit()
{
   // run rf_go.exe to train the model
   SetTrainParms();
   WriteAll();
   RunRuleFit();
}

//_______________________________________________________________________
void TMVA::RuleFitAPI::TestRuleFit()
{
   // run rf_go.exe with the test data
   SetTestParms();
   WriteAll();
   RunRuleFit();
   ReadYhat(); // read in the scores
}

//_______________________________________________________________________
void TMVA::RuleFitAPI::VarImp()
{
   // run rf_go.exe to get the variable importance
   SetRFVarimp();
   WriteAll();
   RunRuleFit();
   ReadVarImp(); // read in the variable importances
}

//_______________________________________________________________________
TString TMVA::RuleFitAPI::GetRFName(TString name)
{
   // get the name inluding the rulefit directory
   return fRFWorkDir+"/"+name;
}

//_______________________________________________________________________
Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ofstream & f)
{
   // open a file for writing in the rulefit directory
   TString fullName = GetRFName(name);
   f.open(fullName);
   if (!f.is_open()) {
      fLogger << kERROR << "Error opening RuleFit file for output: "
              << fullName << Endl;
      return kFALSE;
   }
   return kTRUE;
}

//_______________________________________________________________________
Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ifstream & f)
{
   // open a file for reading in the rulefit directory
   TString fullName = GetRFName(name);
   f.open(fullName);
   if (!f.is_open()) {
      fLogger << kERROR << "Error opening RuleFit file for input: "
              << fullName << Endl;
      return kFALSE;
   }
   return kTRUE;
}

//_______________________________________________________________________
Bool_t TMVA::RuleFitAPI::WriteInt(ofstream &   f, const Int_t   *v, Int_t n)
{
   // write an int
   if (!f.is_open()) return kFALSE;
   return f.write(reinterpret_cast<char const *>(v), n*sizeof(Int_t));
}

//_______________________________________________________________________
Bool_t TMVA::RuleFitAPI::WriteFloat(ofstream & f, const Float_t *v, Int_t n)
{
   // write a float
   if (!f.is_open()) return kFALSE;
   return f.write(reinterpret_cast<char const *>(v), n*sizeof(Float_t));
}

//_______________________________________________________________________
Int_t TMVA::RuleFitAPI::ReadInt(ifstream & f,   Int_t *v, Int_t n) const
{
   // read an int
   if (!f.is_open()) return 0;
   if (f.read(reinterpret_cast<char *>(v), n*sizeof(Int_t))) return 1;
   return 0;
}

//_______________________________________________________________________
Int_t TMVA::RuleFitAPI::ReadFloat(ifstream & f, Float_t *v, Int_t n) const
{
   // read a float
   if (!f.is_open()) return 0;
   if (f.read(reinterpret_cast<char *>(v), n*sizeof(Float_t))) return 1;
   return 0;
}

#endif // RuleFitAPI_H
 RuleFitAPI.h:1
 RuleFitAPI.h:2
 RuleFitAPI.h:3
 RuleFitAPI.h:4
 RuleFitAPI.h:5
 RuleFitAPI.h:6
 RuleFitAPI.h:7
 RuleFitAPI.h:8
 RuleFitAPI.h:9
 RuleFitAPI.h:10
 RuleFitAPI.h:11
 RuleFitAPI.h:12
 RuleFitAPI.h:13
 RuleFitAPI.h:14
 RuleFitAPI.h:15
 RuleFitAPI.h:16
 RuleFitAPI.h:17
 RuleFitAPI.h:18
 RuleFitAPI.h:19
 RuleFitAPI.h:20
 RuleFitAPI.h:21
 RuleFitAPI.h:22
 RuleFitAPI.h:23
 RuleFitAPI.h:24
 RuleFitAPI.h:25
 RuleFitAPI.h:26
 RuleFitAPI.h:27
 RuleFitAPI.h:28
 RuleFitAPI.h:29
 RuleFitAPI.h:30
 RuleFitAPI.h:31
 RuleFitAPI.h:32
 RuleFitAPI.h:33
 RuleFitAPI.h:34
 RuleFitAPI.h:35
 RuleFitAPI.h:36
 RuleFitAPI.h:37
 RuleFitAPI.h:38
 RuleFitAPI.h:39
 RuleFitAPI.h:40
 RuleFitAPI.h:41
 RuleFitAPI.h:42
 RuleFitAPI.h:43
 RuleFitAPI.h:44
 RuleFitAPI.h:45
 RuleFitAPI.h:46
 RuleFitAPI.h:47
 RuleFitAPI.h:48
 RuleFitAPI.h:49
 RuleFitAPI.h:50
 RuleFitAPI.h:51
 RuleFitAPI.h:52
 RuleFitAPI.h:53
 RuleFitAPI.h:54
 RuleFitAPI.h:55
 RuleFitAPI.h:56
 RuleFitAPI.h:57
 RuleFitAPI.h:58
 RuleFitAPI.h:59
 RuleFitAPI.h:60
 RuleFitAPI.h:61
 RuleFitAPI.h:62
 RuleFitAPI.h:63
 RuleFitAPI.h:64
 RuleFitAPI.h:65
 RuleFitAPI.h:66
 RuleFitAPI.h:67
 RuleFitAPI.h:68
 RuleFitAPI.h:69
 RuleFitAPI.h:70
 RuleFitAPI.h:71
 RuleFitAPI.h:72
 RuleFitAPI.h:73
 RuleFitAPI.h:74
 RuleFitAPI.h:75
 RuleFitAPI.h:76
 RuleFitAPI.h:77
 RuleFitAPI.h:78
 RuleFitAPI.h:79
 RuleFitAPI.h:80
 RuleFitAPI.h:81
 RuleFitAPI.h:82
 RuleFitAPI.h:83
 RuleFitAPI.h:84
 RuleFitAPI.h:85
 RuleFitAPI.h:86
 RuleFitAPI.h:87
 RuleFitAPI.h:88
 RuleFitAPI.h:89
 RuleFitAPI.h:90
 RuleFitAPI.h:91
 RuleFitAPI.h:92
 RuleFitAPI.h:93
 RuleFitAPI.h:94
 RuleFitAPI.h:95
 RuleFitAPI.h:96
 RuleFitAPI.h:97
 RuleFitAPI.h:98
 RuleFitAPI.h:99
 RuleFitAPI.h:100
 RuleFitAPI.h:101
 RuleFitAPI.h:102
 RuleFitAPI.h:103
 RuleFitAPI.h:104
 RuleFitAPI.h:105
 RuleFitAPI.h:106
 RuleFitAPI.h:107
 RuleFitAPI.h:108
 RuleFitAPI.h:109
 RuleFitAPI.h:110
 RuleFitAPI.h:111
 RuleFitAPI.h:112
 RuleFitAPI.h:113
 RuleFitAPI.h:114
 RuleFitAPI.h:115
 RuleFitAPI.h:116
 RuleFitAPI.h:117
 RuleFitAPI.h:118
 RuleFitAPI.h:119
 RuleFitAPI.h:120
 RuleFitAPI.h:121
 RuleFitAPI.h:122
 RuleFitAPI.h:123
 RuleFitAPI.h:124
 RuleFitAPI.h:125
 RuleFitAPI.h:126
 RuleFitAPI.h:127
 RuleFitAPI.h:128
 RuleFitAPI.h:129
 RuleFitAPI.h:130
 RuleFitAPI.h:131
 RuleFitAPI.h:132
 RuleFitAPI.h:133
 RuleFitAPI.h:134
 RuleFitAPI.h:135
 RuleFitAPI.h:136
 RuleFitAPI.h:137
 RuleFitAPI.h:138
 RuleFitAPI.h:139
 RuleFitAPI.h:140
 RuleFitAPI.h:141
 RuleFitAPI.h:142
 RuleFitAPI.h:143
 RuleFitAPI.h:144
 RuleFitAPI.h:145
 RuleFitAPI.h:146
 RuleFitAPI.h:147
 RuleFitAPI.h:148
 RuleFitAPI.h:149
 RuleFitAPI.h:150
 RuleFitAPI.h:151
 RuleFitAPI.h:152
 RuleFitAPI.h:153
 RuleFitAPI.h:154
 RuleFitAPI.h:155
 RuleFitAPI.h:156
 RuleFitAPI.h:157
 RuleFitAPI.h:158
 RuleFitAPI.h:159
 RuleFitAPI.h:160
 RuleFitAPI.h:161
 RuleFitAPI.h:162
 RuleFitAPI.h:163
 RuleFitAPI.h:164
 RuleFitAPI.h:165
 RuleFitAPI.h:166
 RuleFitAPI.h:167
 RuleFitAPI.h:168
 RuleFitAPI.h:169
 RuleFitAPI.h:170
 RuleFitAPI.h:171
 RuleFitAPI.h:172
 RuleFitAPI.h:173
 RuleFitAPI.h:174
 RuleFitAPI.h:175
 RuleFitAPI.h:176
 RuleFitAPI.h:177
 RuleFitAPI.h:178
 RuleFitAPI.h:179
 RuleFitAPI.h:180
 RuleFitAPI.h:181
 RuleFitAPI.h:182
 RuleFitAPI.h:183
 RuleFitAPI.h:184
 RuleFitAPI.h:185
 RuleFitAPI.h:186
 RuleFitAPI.h:187
 RuleFitAPI.h:188
 RuleFitAPI.h:189
 RuleFitAPI.h:190
 RuleFitAPI.h:191
 RuleFitAPI.h:192
 RuleFitAPI.h:193
 RuleFitAPI.h:194
 RuleFitAPI.h:195
 RuleFitAPI.h:196
 RuleFitAPI.h:197
 RuleFitAPI.h:198
 RuleFitAPI.h:199
 RuleFitAPI.h:200
 RuleFitAPI.h:201
 RuleFitAPI.h:202
 RuleFitAPI.h:203
 RuleFitAPI.h:204
 RuleFitAPI.h:205
 RuleFitAPI.h:206
 RuleFitAPI.h:207
 RuleFitAPI.h:208
 RuleFitAPI.h:209
 RuleFitAPI.h:210
 RuleFitAPI.h:211
 RuleFitAPI.h:212
 RuleFitAPI.h:213
 RuleFitAPI.h:214
 RuleFitAPI.h:215
 RuleFitAPI.h:216
 RuleFitAPI.h:217
 RuleFitAPI.h:218
 RuleFitAPI.h:219
 RuleFitAPI.h:220
 RuleFitAPI.h:221
 RuleFitAPI.h:222
 RuleFitAPI.h:223
 RuleFitAPI.h:224
 RuleFitAPI.h:225
 RuleFitAPI.h:226
 RuleFitAPI.h:227
 RuleFitAPI.h:228
 RuleFitAPI.h:229
 RuleFitAPI.h:230
 RuleFitAPI.h:231
 RuleFitAPI.h:232
 RuleFitAPI.h:233
 RuleFitAPI.h:234
 RuleFitAPI.h:235
 RuleFitAPI.h:236
 RuleFitAPI.h:237
 RuleFitAPI.h:238
 RuleFitAPI.h:239
 RuleFitAPI.h:240
 RuleFitAPI.h:241
 RuleFitAPI.h:242
 RuleFitAPI.h:243
 RuleFitAPI.h:244
 RuleFitAPI.h:245
 RuleFitAPI.h:246
 RuleFitAPI.h:247
 RuleFitAPI.h:248
 RuleFitAPI.h:249
 RuleFitAPI.h:250
 RuleFitAPI.h:251
 RuleFitAPI.h:252
 RuleFitAPI.h:253
 RuleFitAPI.h:254
 RuleFitAPI.h:255
 RuleFitAPI.h:256
 RuleFitAPI.h:257
 RuleFitAPI.h:258
 RuleFitAPI.h:259
 RuleFitAPI.h:260
 RuleFitAPI.h:261
 RuleFitAPI.h:262
 RuleFitAPI.h:263
 RuleFitAPI.h:264
 RuleFitAPI.h:265
 RuleFitAPI.h:266
 RuleFitAPI.h:267
 RuleFitAPI.h:268
 RuleFitAPI.h:269
 RuleFitAPI.h:270
 RuleFitAPI.h:271
 RuleFitAPI.h:272
 RuleFitAPI.h:273
 RuleFitAPI.h:274
 RuleFitAPI.h:275
 RuleFitAPI.h:276
 RuleFitAPI.h:277
 RuleFitAPI.h:278
 RuleFitAPI.h:279
 RuleFitAPI.h:280
 RuleFitAPI.h:281
 RuleFitAPI.h:282
 RuleFitAPI.h:283
 RuleFitAPI.h:284
 RuleFitAPI.h:285
 RuleFitAPI.h:286
 RuleFitAPI.h:287
 RuleFitAPI.h:288
 RuleFitAPI.h:289
 RuleFitAPI.h:290
 RuleFitAPI.h:291
 RuleFitAPI.h:292
 RuleFitAPI.h:293
 RuleFitAPI.h:294
 RuleFitAPI.h:295
 RuleFitAPI.h:296