Logo ROOT  
Reference Guide
MethodRXGB.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRXGB *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * R eXtreme Gradient Boosting *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (http://tmva.sourceforge.net/LICENSE) *
17 * *
18 **********************************************************************************/
19
20#include <iomanip>
21
22#include "TMath.h"
23#include "Riostream.h"
24#include "TMatrix.h"
25#include "TMatrixD.h"
26#include "TVectorD.h"
27
29#include "TMVA/MethodRXGB.h"
30#include "TMVA/Tools.h"
31#include "TMVA/Config.h"
32#include "TMVA/Ranking.h"
33#include "TMVA/Types.h"
34#include "TMVA/PDF.h"
36
37#include "TMVA/Results.h"
38#include "TMVA/Timer.h"
39
40using namespace TMVA;
41
43
45
46//creating an Instance
47Bool_t MethodRXGB::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("xgboost");
48
49//_______________________________________________________________________
50MethodRXGB::MethodRXGB(const TString &jobName,
51 const TString &methodTitle,
52 DataSetInfo &dsi,
53 const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
54 fNRounds(10),
55 fEta(0.3),
56 fMaxDepth(6),
57 predict("predict", "xgboost"),
58 xgbtrain("xgboost"),
59 xgbdmatrix("xgb.DMatrix"),
60 xgbsave("xgb.save"),
61 xgbload("xgb.load"),
62 asfactor("as.factor"),
63 asmatrix("as.matrix"),
64 fModel(NULL)
65{
66 // standard constructor for the RXGB
67
68}
69
70//_______________________________________________________________________
71MethodRXGB::MethodRXGB(DataSetInfo &theData, const TString &theWeightFile)
72 : RMethodBase(Types::kRXGB, theData, theWeightFile),
73 fNRounds(10),
74 fEta(0.3),
75 fMaxDepth(6),
76 predict("predict", "xgboost"),
77 xgbtrain("xgboost"),
78 xgbdmatrix("xgb.DMatrix"),
79 xgbsave("xgb.save"),
80 xgbload("xgb.load"),
81 asfactor("as.factor"),
82 asmatrix("as.matrix"),
83 fModel(NULL)
84{
85
86}
87
88
89//_______________________________________________________________________
91{
92 if (fModel) delete fModel;
93}
94
95//_______________________________________________________________________
97{
98 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
99 return kFALSE;
100}
101
102
103//_______________________________________________________________________
105{
106
107 if (!IsModuleLoaded) {
108 Error("Init", "R's package xgboost can not be loaded.");
109 Log() << kFATAL << " R's package xgboost can not be loaded."
110 << Endl;
111 return;
112 }
113 //factors creations
114 //xgboost require a numeric factor then background=0 signal=1 from fFactorTrain
115 UInt_t size = fFactorTrain.size();
116 fFactorNumeric.resize(size);
117
118 for (UInt_t i = 0; i < size; i++) {
119 if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
120 else fFactorNumeric[i] = 0;
121 }
122
123
124
125}
126
128{
129 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
132 params["eta"] = fEta;
133 params["max.depth"] = fMaxDepth;
134
135 SEXP Model = xgbtrain(ROOT::R::Label["data"] = dmatrix,
137 ROOT::R::Label["weight"] = fWeightTrain,
138 ROOT::R::Label["nrounds"] = fNRounds,
139 ROOT::R::Label["params"] = params);
140
141 fModel = new ROOT::R::TRObject(Model);
142 if (IsModelPersistence())
143 {
144 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
145 Log() << Endl;
146 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
147 Log() << Endl;
148 xgbsave(Model, path);
149 }
150}
151
152//_______________________________________________________________________
154{
155 DeclareOptionRef(fNRounds, "NRounds", "The max number of iterations");
156 DeclareOptionRef(fEta, "Eta", "Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
157 DeclareOptionRef(fMaxDepth, "MaxDepth", "Maximum depth of the tree");
158}
159
160//_______________________________________________________________________
162{
163}
164
165//_______________________________________________________________________
167{
168 Log() << kINFO << "Testing Classification RXGB METHOD " << Endl;
170}
171
172
173//_______________________________________________________________________
175{
176 NoErrorCalc(errLower, errUpper);
177 Double_t mvaValue;
178 const TMVA::Event *ev = GetEvent();
179 const UInt_t nvar = DataInfo().GetNVariables();
180 ROOT::R::TRDataFrame fDfEvent;
181 for (UInt_t i = 0; i < nvar; i++) {
182 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
183 }
184 //if using persistence model
186
187 mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(fDfEvent)));
188 return mvaValue;
189}
190
191////////////////////////////////////////////////////////////////////////////////
192/// get all the MVA values for the events of the current Data type
193std::vector<Double_t> MethodRXGB::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
194{
195 Long64_t nEvents = Data()->GetNEvents();
196 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
197 if (firstEvt < 0) firstEvt = 0;
198
199 nEvents = lastEvt-firstEvt;
200
201 UInt_t nvars = Data()->GetNVariables();
202
203 // use timer
204 Timer timer( nEvents, GetName(), kTRUE );
205 if (logProgress)
206 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
207 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
208
209
210 // fill R DATA FRAME with events data
211 std::vector<std::vector<Float_t> > inputData(nvars);
212 for (UInt_t i = 0; i < nvars; i++) {
213 inputData[i] = std::vector<Float_t>(nEvents);
214 }
215
216 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
217 Data()->SetCurrentEvent(ievt);
218 const TMVA::Event *e = Data()->GetEvent();
219 assert(nvars == e->GetNVariables());
220 for (UInt_t i = 0; i < nvars; i++) {
221 inputData[i][ievt] = e->GetValue(i);
222 }
223 // if (ievt%100 == 0)
224 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
225 }
226
227 ROOT::R::TRDataFrame evtData;
228 for (UInt_t i = 0; i < nvars; i++) {
229 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
230 }
231 //if using persistence model
233
234 std::vector<Double_t> mvaValues(nEvents);
236 mvaValues = pred.As<std::vector<Double_t>>();
237
238 if (logProgress) {
239 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
240 << timer.GetElapsedTime() << " " << Endl;
241 }
242
243 return mvaValues;
244
245}
246//_______________________________________________________________________
248{
249// get help message text
250//
251// typical length of text line:
252// "|--------------------------------------------------------------|"
253 Log() << Endl;
254 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
255 Log() << Endl;
256 Log() << "Decision Trees and Rule-Based Models " << Endl;
257 Log() << Endl;
258 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
259 Log() << Endl;
260 Log() << Endl;
261 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
262 Log() << Endl;
263 Log() << "<None>" << Endl;
264}
265
266//_______________________________________________________________________
268{
270 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
271 Log() << Endl;
272 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
273 Log() << Endl;
274
275 SEXP Model = xgbload(path);
276 fModel = new ROOT::R::TRObject(Model);
277
278}
279
280//_______________________________________________________________________
281void TMVA::MethodRXGB::MakeClass(const TString &/*theClassFileName*/) const
282{
283}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition: RSha256.hxx:103
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
int Int_t
Definition: RtypesCore.h:45
unsigned int UInt_t
Definition: RtypesCore.h:46
const Bool_t kFALSE
Definition: RtypesCore.h:101
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
long long Long64_t
Definition: RtypesCore.h:80
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define ClassImp(name)
Definition: Rtypes.h:364
int type
Definition: TGX11.cxx:121
char * Form(const char *fmt,...)
This is a class to create DataFrames from ROOT to R.
Definition: TRDataFrame.h:176
static TRInterface & Instance()
static method to get an TRInterface instance reference
Bool_t Require(TString pkg)
Method to load an R's package.
This is a class to get ROOT's objects from R's objects.
Definition: TRObject.h:70
T As()
Some datatypes of ROOT or c++ can be wrapped in to a TRObject, this method lets you unwrap those data...
Definition: TRObject.h:152
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Definition: Configurable.h:122
Class that contains all the data information.
Definition: DataSetInfo.h:62
UInt_t GetNVariables() const
Definition: DataSetInfo.h:127
std::vector< TString > GetListOfVariables() const
returns list of variables
const Event * GetEvent() const
Definition: DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:206
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:88
std::vector< Float_t > & GetValues()
Definition: Event.h:94
const char * GetName() const
Definition: MethodBase.h:334
Bool_t IsModelPersistence() const
Definition: MethodBase.h:383
const TString & GetWeightFileDir() const
Definition: MethodBase.h:492
const TString & GetMethodName() const
Definition: MethodBase.h:331
const Event * GetEvent() const
Definition: MethodBase.h:751
DataSetInfo & DataInfo() const
Definition: MethodBase.h:410
virtual void TestClassification()
initialization
void ReadStateFromFile()
Function to write options and weights to file.
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:836
DataSet * Data() const
Definition: MethodBase.h:409
Double_t fEta
Definition: MethodRXGB.h:89
std::vector< UInt_t > fFactorNumeric
Definition: MethodRXGB.h:93
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodRXGB.cxx:193
UInt_t fMaxDepth
Definition: MethodRXGB.h:90
void GetHelpMessage() const
Definition: MethodRXGB.cxx:247
ROOT::R::TRFunctionImport xgbtrain
Definition: MethodRXGB.h:97
MethodRXGB(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition: MethodRXGB.cxx:50
static Bool_t IsModuleLoaded
Definition: MethodRXGB.h:91
ROOT::R::TRFunctionImport asmatrix
Definition: MethodRXGB.h:102
virtual void TestClassification()
initialization
Definition: MethodRXGB.cxx:166
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodRXGB.cxx:96
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
Definition: MethodRXGB.cxx:281
void ReadModelFromFile()
Definition: MethodRXGB.cxx:267
ROOT::R::TRFunctionImport xgbsave
Definition: MethodRXGB.h:99
ROOT::R::TRObject * fModel
Definition: MethodRXGB.h:103
ROOT::R::TRFunctionImport xgbdmatrix
Definition: MethodRXGB.h:98
ROOT::R::TRFunctionImport predict
Definition: MethodRXGB.h:96
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Definition: MethodRXGB.cxx:174
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:95
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:91
TVectorD fWeightTrain
Definition: RMethodBase.h:93
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:146
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
Singleton class for Global types used by TMVA.
Definition: Types.h:73
EAnalysisType
Definition: Types.h:128
@ kClassification
Definition: Types.h:129
@ kTraining
Definition: Types.h:145
@ kINFO
Definition: Types.h:60
@ kFATAL
Definition: Types.h:63
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:893
Basic string class.
Definition: TString.h:136
def predict(model, test_X, batch_size=100)
const Rcpp::internal::NamedPlaceHolder & Label
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:760