Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodRXGB.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRXGB *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * R eXtreme Gradient Boosting *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (http://tmva.sourceforge.net/LICENSE) *
17 * *
18 **********************************************************************************/
19
20#include <iomanip>
21
22#include "TMath.h"
23#include "Riostream.h"
24#include "TMatrix.h"
25#include "TMatrixD.h"
26#include "TVectorD.h"
27
29#include "TMVA/MethodRXGB.h"
30#include "TMVA/Tools.h"
31#include "TMVA/Config.h"
32#include "TMVA/Ranking.h"
33#include "TMVA/Types.h"
34#include "TMVA/PDF.h"
36
37#include "TMVA/Results.h"
38#include "TMVA/Timer.h"
39
40using namespace TMVA;
41
43
45
46//creating an Instance
48
49//_______________________________________________________________________
51 const TString &methodTitle,
52 DataSetInfo &dsi,
53 const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
54 fNRounds(10),
55 fEta(0.3),
56 fMaxDepth(6),
57 predict("predict", "xgboost"),
58 xgbtrain("xgboost"),
59 xgbdmatrix("xgb.DMatrix"),
60 xgbsave("xgb.save"),
61 xgbload("xgb.load"),
62 asfactor("as.factor"),
63 asmatrix("as.matrix"),
64 fModel(NULL)
65{
66 // standard constructor for the RXGB
67
68}
69
70//_______________________________________________________________________
71MethodRXGB::MethodRXGB(DataSetInfo &theData, const TString &theWeightFile)
72 : RMethodBase(Types::kRXGB, theData, theWeightFile),
73 fNRounds(10),
74 fEta(0.3),
75 fMaxDepth(6),
76 predict("predict", "xgboost"),
77 xgbtrain("xgboost"),
78 xgbdmatrix("xgb.DMatrix"),
79 xgbsave("xgb.save"),
80 xgbload("xgb.load"),
81 asfactor("as.factor"),
82 asmatrix("as.matrix"),
83 fModel(NULL)
84{
85
86}
87
88
89//_______________________________________________________________________
91{
92 if (fModel) delete fModel;
93}
94
95//_______________________________________________________________________
97{
98 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
99 return kFALSE;
100}
101
102
103//_______________________________________________________________________
105{
106
107 if (!IsModuleLoaded) {
108 Error("Init", "R's package xgboost can not be loaded.");
109 Log() << kFATAL << " R's package xgboost can not be loaded."
110 << Endl;
111 return;
112 }
113 //factors creations
114 //xgboost require a numeric factor then background=0 signal=1 from fFactorTrain
115 UInt_t size = fFactorTrain.size();
116 fFactorNumeric.resize(size);
117
118 for (UInt_t i = 0; i < size; i++) {
119 if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
120 else fFactorNumeric[i] = 0;
121 }
122
123
124
125}
126
128{
129 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
132 params["eta"] = fEta;
133 params["max.depth"] = fMaxDepth;
134
135 SEXP Model = xgbtrain(ROOT::R::Label["data"] = dmatrix,
137 ROOT::R::Label["weight"] = fWeightTrain,
138 ROOT::R::Label["nrounds"] = fNRounds,
139 ROOT::R::Label["params"] = params);
140
141 fModel = new ROOT::R::TRObject(Model);
142 if (IsModelPersistence())
143 {
144 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
145 Log() << Endl;
146 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
147 Log() << Endl;
148 xgbsave(Model, path);
149 }
150}
151
152//_______________________________________________________________________
154{
155 DeclareOptionRef(fNRounds, "NRounds", "The max number of iterations");
156 DeclareOptionRef(fEta, "Eta", "Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
157 DeclareOptionRef(fMaxDepth, "MaxDepth", "Maximum depth of the tree");
158}
159
160//_______________________________________________________________________
162{
163}
164
165//_______________________________________________________________________
167{
168 Log() << kINFO << "Testing Classification RXGB METHOD " << Endl;
170}
171
172
173//_______________________________________________________________________
175{
176 NoErrorCalc(errLower, errUpper);
177 Double_t mvaValue;
178 const TMVA::Event *ev = GetEvent();
179 const UInt_t nvar = DataInfo().GetNVariables();
180 ROOT::R::TRDataFrame fDfEvent;
181 for (UInt_t i = 0; i < nvar; i++) {
182 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
183 }
184 //if using persistence model
186
187 mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(fDfEvent)));
188 return mvaValue;
189}
190
191////////////////////////////////////////////////////////////////////////////////
192/// get all the MVA values for the events of the current Data type
193std::vector<Double_t> MethodRXGB::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
194{
195 Long64_t nEvents = Data()->GetNEvents();
196 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
197 if (firstEvt < 0) firstEvt = 0;
198
199 nEvents = lastEvt-firstEvt;
200
201 UInt_t nvars = Data()->GetNVariables();
202
203 // use timer
204 Timer timer( nEvents, GetName(), kTRUE );
205 if (logProgress)
206 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
207 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
208
209
210 // fill R DATA FRAME with events data
211 std::vector<std::vector<Float_t> > inputData(nvars);
212 for (UInt_t i = 0; i < nvars; i++) {
213 inputData[i] = std::vector<Float_t>(nEvents);
214 }
215
216 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
217 Data()->SetCurrentEvent(ievt);
218 const TMVA::Event *e = Data()->GetEvent();
219 assert(nvars == e->GetNVariables());
220 for (UInt_t i = 0; i < nvars; i++) {
221 inputData[i][ievt] = e->GetValue(i);
222 }
223 // if (ievt%100 == 0)
224 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
225 }
226
227 ROOT::R::TRDataFrame evtData;
228 for (UInt_t i = 0; i < nvars; i++) {
229 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
230 }
231 //if using persistence model
233
234 std::vector<Double_t> mvaValues(nEvents);
236 mvaValues = pred.As<std::vector<Double_t>>();
237
238 if (logProgress) {
239 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
240 << timer.GetElapsedTime() << " " << Endl;
241 }
242
243 return mvaValues;
244
245}
246//_______________________________________________________________________
248{
249// get help message text
250//
251// typical length of text line:
252// "|--------------------------------------------------------------|"
253 Log() << Endl;
254 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
255 Log() << Endl;
256 Log() << "Decision Trees and Rule-Based Models " << Endl;
257 Log() << Endl;
258 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
259 Log() << Endl;
260 Log() << Endl;
261 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
262 Log() << Endl;
263 Log() << "<None>" << Endl;
264}
265
266//_______________________________________________________________________
268{
270 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
271 Log() << Endl;
272 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
273 Log() << Endl;
274
275 SEXP Model = xgbload(path);
276 fModel = new ROOT::R::TRObject(Model);
277
278}
279
280//_______________________________________________________________________
281void TMVA::MethodRXGB::MakeClass(const TString &/*theClassFileName*/) const
282{
283}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
const Bool_t kFALSE
Definition RtypesCore.h:101
bool Bool_t
Definition RtypesCore.h:63
long long Long64_t
Definition RtypesCore.h:80
const Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassImp(name)
Definition Rtypes.h:364
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:187
int type
Definition TGX11.cxx:121
char * Form(const char *fmt,...)
This is a class to create DataFrames from ROOT to R.
static TRInterface & Instance()
static method to get an TRInterface instance reference
Bool_t Require(TString pkg)
Method to load an R's package.
This is a class to get ROOT's objects from R's objects.
Definition TRObject.h:70
T As()
Some datatypes of ROOT or c++ can be wrapped in to a TRObject, this method lets you unwrap those data...
Definition TRObject.h:152
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
std::vector< TString > GetListOfVariables() const
returns list of variables
const Event * GetEvent() const
Definition DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
std::vector< Float_t > & GetValues()
Definition Event.h:94
const char * GetName() const
Definition MethodBase.h:334
Bool_t IsModelPersistence() const
Definition MethodBase.h:383
const TString & GetWeightFileDir() const
Definition MethodBase.h:492
const TString & GetMethodName() const
Definition MethodBase.h:331
const Event * GetEvent() const
Definition MethodBase.h:751
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void TestClassification()
initialization
void ReadStateFromFile()
Function to write options and weights to file.
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
DataSet * Data() const
Definition MethodBase.h:409
std::vector< UInt_t > fFactorNumeric
Definition MethodRXGB.h:93
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
void GetHelpMessage() const
ROOT::R::TRFunctionImport xgbtrain
Definition MethodRXGB.h:97
MethodRXGB(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
static Bool_t IsModuleLoaded
Definition MethodRXGB.h:91
ROOT::R::TRFunctionImport asmatrix
Definition MethodRXGB.h:102
virtual void TestClassification()
initialization
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
ROOT::R::TRFunctionImport xgbsave
Definition MethodRXGB.h:99
ROOT::R::TRObject * fModel
Definition MethodRXGB.h:103
ROOT::R::TRFunctionImport xgbdmatrix
Definition MethodRXGB.h:98
ROOT::R::TRFunctionImport predict
Definition MethodRXGB.h:96
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
std::vector< std::string > fFactorTrain
Definition RMethodBase.h:95
ROOT::R::TRDataFrame fDfTrain
Definition RMethodBase.h:91
TVectorD fWeightTrain
Definition RMethodBase.h:93
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition Timer.cxx:146
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:136
const Rcpp::internal::NamedPlaceHolder & Label
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148