Logo ROOT   6.18/05
Reference Guide
MethodRXGB.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRXGB *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * R eXtreme Gradient Boosting *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (http://tmva.sourceforge.net/LICENSE) *
17 * *
18 **********************************************************************************/
19
20#include <iomanip>
21
22#include "TMath.h"
23#include "Riostream.h"
24#include "TMatrix.h"
25#include "TMatrixD.h"
26#include "TVectorD.h"
27
29#include "TMVA/MethodRXGB.h"
30#include "TMVA/Tools.h"
31#include "TMVA/Config.h"
32#include "TMVA/Ranking.h"
33#include "TMVA/Types.h"
34#include "TMVA/PDF.h"
36
37#include "TMVA/Results.h"
38#include "TMVA/Timer.h"
39
40using namespace TMVA;
41
43
45
46//creating an Instance
47Bool_t MethodRXGB::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("xgboost");
48
49//_______________________________________________________________________
50MethodRXGB::MethodRXGB(const TString &jobName,
51 const TString &methodTitle,
52 DataSetInfo &dsi,
53 const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
54 fNRounds(10),
55 fEta(0.3),
56 fMaxDepth(6),
57 predict("predict", "xgboost"),
58 xgbtrain("xgboost"),
59 xgbdmatrix("xgb.DMatrix"),
60 xgbsave("xgb.save"),
61 xgbload("xgb.load"),
62 asfactor("as.factor"),
63 asmatrix("as.matrix"),
64 fModel(NULL)
65{
66 // standard constructor for the RXGB
67
68}
69
70//_______________________________________________________________________
71MethodRXGB::MethodRXGB(DataSetInfo &theData, const TString &theWeightFile)
72 : RMethodBase(Types::kRXGB, theData, theWeightFile),
73 fNRounds(10),
74 fEta(0.3),
75 fMaxDepth(6),
76 predict("predict", "xgboost"),
77 xgbtrain("xgboost"),
78 xgbdmatrix("xgb.DMatrix"),
79 xgbsave("xgb.save"),
80 xgbload("xgb.load"),
81 asfactor("as.factor"),
82 asmatrix("as.matrix"),
83 fModel(NULL)
84{
85
86}
87
88
89//_______________________________________________________________________
91{
92 if (fModel) delete fModel;
93}
94
95//_______________________________________________________________________
97{
98 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
99 return kFALSE;
100}
101
102
103//_______________________________________________________________________
105{
106
107 if (!IsModuleLoaded) {
108 Error("Init", "R's package xgboost can not be loaded.");
109 Log() << kFATAL << " R's package xgboost can not be loaded."
110 << Endl;
111 return;
112 }
113 //factors creations
114 //xgboost require a numeric factor then background=0 signal=1 from fFactorTrain
115 UInt_t size = fFactorTrain.size();
116 fFactorNumeric.resize(size);
117
118 for (UInt_t i = 0; i < size; i++) {
119 if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
120 else fFactorNumeric[i] = 0;
121 }
122
123
124
125}
126
128{
129 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
132 params["eta"] = fEta;
133 params["max.depth"] = fMaxDepth;
134
135 SEXP Model = xgbtrain(ROOT::R::Label["data"] = dmatrix,
137 ROOT::R::Label["weight"] = fWeightTrain,
138 ROOT::R::Label["nrounds"] = fNRounds,
139 ROOT::R::Label["params"] = params);
140
141 fModel = new ROOT::R::TRObject(Model);
142 if (IsModelPersistence())
143 {
144 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
145 Log() << Endl;
146 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
147 Log() << Endl;
148 xgbsave(Model, path);
149 }
150}
151
152//_______________________________________________________________________
154{
155 DeclareOptionRef(fNRounds, "NRounds", "The max number of iterations");
156 DeclareOptionRef(fEta, "Eta", "Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
157 DeclareOptionRef(fMaxDepth, "MaxDepth", "Maximum depth of the tree");
158}
159
160//_______________________________________________________________________
162{
163}
164
165//_______________________________________________________________________
167{
168 Log() << kINFO << "Testing Classification RXGB METHOD " << Endl;
170}
171
172
173//_______________________________________________________________________
175{
176 NoErrorCalc(errLower, errUpper);
177 Double_t mvaValue;
178 const TMVA::Event *ev = GetEvent();
179 const UInt_t nvar = DataInfo().GetNVariables();
180 ROOT::R::TRDataFrame fDfEvent;
181 for (UInt_t i = 0; i < nvar; i++) {
182 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
183 }
184 //if using persistence model
186
187 mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(fDfEvent)));
188 return mvaValue;
189}
190
191////////////////////////////////////////////////////////////////////////////////
192/// get all the MVA values for the events of the current Data type
193std::vector<Double_t> MethodRXGB::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
194{
195 Long64_t nEvents = Data()->GetNEvents();
196 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
197 if (firstEvt < 0) firstEvt = 0;
198
199 nEvents = lastEvt-firstEvt;
200
201 UInt_t nvars = Data()->GetNVariables();
202
203 // use timer
204 Timer timer( nEvents, GetName(), kTRUE );
205 if (logProgress)
206 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
207 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
208
209
210 // fill R DATA FRAME with events data
211 std::vector<std::vector<Float_t> > inputData(nvars);
212 for (UInt_t i = 0; i < nvars; i++) {
213 inputData[i] = std::vector<Float_t>(nEvents);
214 }
215
216 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
217 Data()->SetCurrentEvent(ievt);
218 const TMVA::Event *e = Data()->GetEvent();
219 assert(nvars == e->GetNVariables());
220 for (UInt_t i = 0; i < nvars; i++) {
221 inputData[i][ievt] = e->GetValue(i);
222 }
223 // if (ievt%100 == 0)
224 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
225 }
226
227 ROOT::R::TRDataFrame evtData;
228 for (UInt_t i = 0; i < nvars; i++) {
229 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
230 }
231 //if using persistence model
233
234 std::vector<Double_t> mvaValues(nEvents);
236 mvaValues = pred.As<std::vector<Double_t>>();
237
238 if (logProgress) {
239 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
240 << timer.GetElapsedTime() << " " << Endl;
241 }
242
243 return mvaValues;
244
245}
246//_______________________________________________________________________
248{
249// get help message text
250//
251// typical length of text line:
252// "|--------------------------------------------------------------|"
253 Log() << Endl;
254 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
255 Log() << Endl;
256 Log() << "Decision Trees and Rule-Based Models " << Endl;
257 Log() << Endl;
258 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
259 Log() << Endl;
260 Log() << Endl;
261 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
262 Log() << Endl;
263 Log() << "<None>" << Endl;
264}
265
266//_______________________________________________________________________
268{
270 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
271 Log() << Endl;
272 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
273 Log() << Endl;
274
275 SEXP Model = xgbload(path);
276 fModel = new ROOT::R::TRObject(Model);
277
278}
279
280//_______________________________________________________________________
281void TMVA::MethodRXGB::MakeClass(const TString &/*theClassFileName*/) const
282{
283}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition: RSha256.hxx:103
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
long long Long64_t
Definition: RtypesCore.h:69
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:365
int type
Definition: TGX11.cxx:120
char * Form(const char *fmt,...)
This is a class to create DataFrames from ROOT to R.
Definition: TRDataFrame.h:177
static TRInterface & Instance()
static method to get an TRInterface instance reference
Bool_t Require(TString pkg)
Method to load an R's package.
This is a class to get ROOT's objects from R's objects.
Definition: TRObject.h:71
T As()
Some datatypes of ROOT or c++ can be wrapped in to a TRObject, this method lets you unwrap those data...
Definition: TRObject.h:153
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Definition: Configurable.h:122
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t GetNVariables() const
Definition: DataSetInfo.h:110
std::vector< TString > GetListOfVariables() const
returns list of variables
const Event * GetEvent() const
Definition: DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:205
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:217
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:99
std::vector< Float_t > & GetValues()
Definition: Event.h:95
const char * GetName() const
Definition: MethodBase.h:325
const TString & GetWeightFileDir() const
Definition: MethodBase.h:481
const TString & GetMethodName() const
Definition: MethodBase.h:322
const Event * GetEvent() const
Definition: MethodBase.h:740
DataSetInfo & DataInfo() const
Definition: MethodBase.h:401
virtual void TestClassification()
initialization
void ReadStateFromFile()
Function to write options and weights to file.
Bool_t IsModelPersistence()
Definition: MethodBase.h:374
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:841
DataSet * Data() const
Definition: MethodBase.h:400
Double_t fEta
Definition: MethodRXGB.h:88
std::vector< UInt_t > fFactorNumeric
Definition: MethodRXGB.h:92
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodRXGB.cxx:193
UInt_t fMaxDepth
Definition: MethodRXGB.h:89
void GetHelpMessage() const
Definition: MethodRXGB.cxx:247
ROOT::R::TRFunctionImport xgbtrain
Definition: MethodRXGB.h:96
MethodRXGB(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition: MethodRXGB.cxx:50
static Bool_t IsModuleLoaded
Definition: MethodRXGB.h:90
ROOT::R::TRFunctionImport asmatrix
Definition: MethodRXGB.h:101
virtual void TestClassification()
initialization
Definition: MethodRXGB.cxx:166
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodRXGB.cxx:96
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
Definition: MethodRXGB.cxx:281
void ReadModelFromFile()
Definition: MethodRXGB.cxx:267
ROOT::R::TRFunctionImport xgbsave
Definition: MethodRXGB.h:98
ROOT::R::TRObject * fModel
Definition: MethodRXGB.h:102
ROOT::R::TRFunctionImport xgbdmatrix
Definition: MethodRXGB.h:97
ROOT::R::TRFunctionImport predict
Definition: MethodRXGB.h:95
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Definition: MethodRXGB.cxx:174
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:92
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:88
TVectorD fWeightTrain
Definition: RMethodBase.h:90
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:134
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
Singleton class for Global types used by TMVA.
Definition: Types.h:73
EAnalysisType
Definition: Types.h:127
@ kClassification
Definition: Types.h:128
@ kTraining
Definition: Types.h:144
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:880
Basic string class.
Definition: TString.h:131
std::string GetName(const std::string &scope_name)
Definition: Cppyy.cxx:146
Rcpp::internal::NamedPlaceHolder Label
Definition: RExports.cxx:14
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:748