ROOT  6.06/09
Reference Guide
MethodRXGB.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodRXGB *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * R eXtreme Gradient Boosting *
12  * *
13  * *
14  * Redistribution and use in source and binary forms, with or without *
15  * modification, are permitted according to the terms listed in LICENSE *
16  * (http://tmva.sourceforge.net/LICENSE) *
17  * *
18  **********************************************************************************/
19 
20 #include <iomanip>
21 
22 #include "TMath.h"
23 #include "Riostream.h"
24 #include "TMatrix.h"
25 #include "TMatrixD.h"
26 #include "TVectorD.h"
27 
29 #include "TMVA/MethodRXGB.h"
30 #include "TMVA/Tools.h"
31 #include "TMVA/Config.h"
32 #include "TMVA/Ranking.h"
33 #include "TMVA/Types.h"
34 #include "TMVA/PDF.h"
35 #include "TMVA/ClassifierFactory.h"
36 
37 #include "TMVA/Results.h"
38 
39 using namespace TMVA;
40 
41 REGISTER_METHOD(RXGB)
42 
44 
45 //creating an Instance
46 Bool_t MethodRXGB::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("xgboost");
47 
48 //_______________________________________________________________________
49 MethodRXGB::MethodRXGB(const TString &jobName,
50  const TString &methodTitle,
51  DataSetInfo &dsi,
52  const TString &theOption,
53  TDirectory *theTargetDir) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption, theTargetDir),
54  fNRounds(10),
55  fEta(0.3),
56  fMaxDepth(6),
57  predict("predict", "xgboost"),
58  xgbtrain("xgboost"),
59  xgbdmatrix("xgb.DMatrix"),
60  xgbsave("xgb.save"),
61  xgbload("xgb.load"),
62  asfactor("as.factor"),
63  asmatrix("as.matrix"),
64  fModel(NULL)
65 {
66  // standard constructor for the RXGB
67 
68 // default extension for weight files
69  SetWeightFileDir(gConfig().GetIONames().fWeightFileDir);
70 }
71 
72 //_______________________________________________________________________
73 MethodRXGB::MethodRXGB(DataSetInfo &theData, const TString &theWeightFile, TDirectory *theTargetDir)
74  : RMethodBase(Types::kRXGB, theData, theWeightFile, theTargetDir),
75  fNRounds(10),
76  fEta(0.3),
77  fMaxDepth(6),
78  predict("predict", "xgboost"),
79  xgbtrain("xgboost"),
80  xgbdmatrix("xgb.DMatrix"),
81  xgbsave("xgb.save"),
82  xgbload("xgb.load"),
83  asfactor("as.factor"),
84  asmatrix("as.matrix"),
85  fModel(NULL)
86 {
87 
88 // default extension for weight files
89  SetWeightFileDir(gConfig().GetIONames().fWeightFileDir);
90 }
91 
92 
93 //_______________________________________________________________________
95 {
96  if (fModel) delete fModel;
97 }
98 
99 //_______________________________________________________________________
101 {
102  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
103  return kFALSE;
104 }
105 
106 
107 //_______________________________________________________________________
109 {
110 
111  if (!IsModuleLoaded) {
112  Error("Init", "R's package xgboost can not be loaded.");
113  Log() << kFATAL << " R's package xgboost can not be loaded."
114  << Endl;
115  return;
116  }
117  //factors creations
118  //xgboost require a numeric factor then background=0 signal=1 from fFactorTrain
119  UInt_t size = fFactorTrain.size();
120  fFactorNumeric.resize(size);
121 
122  for (UInt_t i = 0; i < size; i++) {
123  if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
124  else fFactorNumeric[i] = 0;
125  }
126 
127 
128 
129 }
130 
132 {
133  if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
135  ROOT::R::TRDataFrame params;
136  params["eta"] = fEta;
137  params["max.depth"] = fMaxDepth;
138 
139  SEXP Model = xgbtrain(ROOT::R::Label["data"] = dmatrix,
140  ROOT::R::Label["label"] = fFactorNumeric,
141  ROOT::R::Label["weight"] = fWeightTrain,
142  ROOT::R::Label["nrounds"] = fNRounds,
143  ROOT::R::Label["params"] = params);
144 
145  fModel = new ROOT::R::TRObject(Model);
146  TString path = GetWeightFileDir() + "/RXGBModel.RData";
147  Log() << Endl;
148  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
149  Log() << Endl;
150  xgbsave(Model, path);
151 }
152 
153 //_______________________________________________________________________
155 {
156  DeclareOptionRef(fNRounds, "NRounds", "The max number of iterations");
157  DeclareOptionRef(fEta, "Eta", "Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
158  DeclareOptionRef(fMaxDepth, "MaxDepth", "Maximum depth of the tree");
159 }
160 
161 //_______________________________________________________________________
163 {
164 }
165 
166 //_______________________________________________________________________
168 {
169  Log() << kINFO << "Testing Classification RXGB METHOD " << Endl;
171 }
172 
173 
174 //_______________________________________________________________________
176 {
177  NoErrorCalc(errLower, errUpper);
178  Double_t mvaValue;
179  const TMVA::Event *ev = GetEvent();
180  const UInt_t nvar = DataInfo().GetNVariables();
181  ROOT::R::TRDataFrame fDfEvent;
182  for (UInt_t i = 0; i < nvar; i++) {
183  fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
184  }
185  //if using persistence model
186  if (!fModel) {
188  }
189  mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(fDfEvent)));
190  return mvaValue;
191 }
192 
193 //_______________________________________________________________________
195 {
196 // get help message text
197 //
198 // typical length of text line:
199 // "|--------------------------------------------------------------|"
200  Log() << Endl;
201  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
202  Log() << Endl;
203  Log() << "Decision Trees and Rule-Based Models " << Endl;
204  Log() << Endl;
205  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
206  Log() << Endl;
207  Log() << Endl;
208  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
209  Log() << Endl;
210  Log() << "<None>" << Endl;
211 }
212 
213 //_______________________________________________________________________
215 {
217  TString path = GetWeightFileDir() + "/RXGBModel.RData";
218  Log() << Endl;
219  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
220  Log() << Endl;
221 
222  SEXP Model = xgbload(path);
223  fModel = new ROOT::R::TRObject(Model);
224 
225 }
226 
227 //_______________________________________________________________________
228 void TMVA::MethodRXGB::MakeClass(const TString &theClassFileName) const
229 {
230 }
ROOT::R::TRFunctionImport xgbdmatrix
Definition: MethodRXGB.h:95
const TString & GetWeightFileDir() const
Definition: MethodBase.h:407
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
Namespace for new ROOT classes and functions.
Definition: ROOT.py:1
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodRXGB.cxx:100
Config & gConfig()
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
DataSet * Data() const
Definition: MethodBase.h:363
EAnalysisType
Definition: Types.h:124
ROOT::R::TRObject * fModel
Definition: MethodRXGB.h:100
Basic string class.
Definition: TString.h:137
TString as(SEXP s)
Definition: RExports.h:85
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
UInt_t GetNVariables() const
Definition: DataSetInfo.h:128
MethodRXGB(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=NULL)
void ReadStateFromFile()
Definition: MethodRXGB.cxx:214
TVectorD fWeightTrain
Definition: RMethodBase.h:96
Tools & gTools()
Definition: Tools.cxx:79
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:918
Bool_t Require(TString pkg)
Method to load an R's package.
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
Definition: MethodRXGB.cxx:228
This is a class to get ROOT's objects from R's objects
Definition: TRObject.h:73
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:98
ROOT::R::TRFunctionImport xgbsave
Definition: MethodRXGB.h:96
void GetHelpMessage() const
Definition: MethodRXGB.cxx:194
unsigned int UInt_t
Definition: RtypesCore.h:42
const Event * GetEvent() const
Definition: MethodBase.h:667
ROOT::R::TRFunctionImport asmatrix
Definition: MethodRXGB.h:99
Double_t fEta
Definition: MethodRXGB.h:86
static Bool_t IsModuleLoaded
Definition: MethodRXGB.h:88
#define ClassImp(name)
Definition: Rtypes.h:279
double Double_t
Definition: RtypesCore.h:55
Describe directory structure in memory.
Definition: TDirectory.h:41
virtual void TestClassification()
initialization
Definition: MethodRXGB.cxx:167
int type
Definition: TGX11.cxx:120
static TRInterface & Instance()
static method to get an TRInterface instance reference
MsgLogger & Log() const
Definition: Configurable.h:130
DataSetInfo & DataInfo() const
Definition: MethodBase.h:364
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
UInt_t fMaxDepth
Definition: MethodRXGB.h:87
std::vector< Float_t > & GetValues()
Definition: Event.h:93
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:94
void SetWeightFileDir(TString fileDir)
set directory of weight file
ROOT::R::TRFunctionImport predict
Definition: MethodRXGB.h:93
std::vector< TString > GetListOfVariables() const
returns list of variables
Rcpp::internal::NamedPlaceHolder Label
Definition: RExports.cxx:14
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Definition: MethodRXGB.cxx:175
const Bool_t kTRUE
Definition: Rtypes.h:91
virtual void TestClassification()
initialization
ROOT::R::TRFunctionImport xgbtrain
Definition: MethodRXGB.h:94
TRandom3 R
a TMatrixD.
Definition: testIO.cxx:28
Definition: math.cpp:60
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:820
This is a class to create DataFrames from ROOT to R
Definition: TRDataFrame.h:183
std::vector< UInt_t > fFactorNumeric
Definition: MethodRXGB.h:90