Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
MethodRXGB.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRXGB *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * R eXtreme Gradient Boosting *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (see tmva/doc/LICENSE) *
17 * *
18 **********************************************************************************/
19
20#include <iomanip>
21
22#include "TMath.h"
23#include "Riostream.h"
24#include "TMatrix.h"
25#include "TMatrixD.h"
26#include "TVectorD.h"
27
29#include "TMVA/MethodRXGB.h"
30#include "TMVA/Tools.h"
31#include "TMVA/Config.h"
32#include "TMVA/Ranking.h"
33#include "TMVA/Types.h"
34#include "TMVA/PDF.h"
36
37#include "TMVA/Results.h"
38#include "TMVA/Timer.h"
39
40using namespace TMVA;
41
43
44
45//creating an Instance
47
48//_______________________________________________________________________
50 const TString &methodTitle,
51 DataSetInfo &dsi,
52 const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
53 fNRounds(10),
54 fEta(0.3),
55 fMaxDepth(6),
56 predict("predict", "xgboost"),
57 xgbtrain("xgboost"),
58 xgbdmatrix("xgb.DMatrix"),
59 xgbsave("xgb.save"),
60 xgbload("xgb.load"),
61 asfactor("as.factor"),
62 asmatrix("as.matrix"),
63 fModel(NULL)
64{
65 // standard constructor for the RXGB
66
67}
68
69//_______________________________________________________________________
70MethodRXGB::MethodRXGB(DataSetInfo &theData, const TString &theWeightFile)
71 : RMethodBase(Types::kRXGB, theData, theWeightFile),
72 fNRounds(10),
73 fEta(0.3),
74 fMaxDepth(6),
75 predict("predict", "xgboost"),
76 xgbtrain("xgboost"),
77 xgbdmatrix("xgb.DMatrix"),
78 xgbsave("xgb.save"),
79 xgbload("xgb.load"),
80 asfactor("as.factor"),
81 asmatrix("as.matrix"),
82 fModel(NULL)
83{
84
85}
86
87
88//_______________________________________________________________________
90{
91 if (fModel) delete fModel;
92}
93
94//_______________________________________________________________________
95Bool_t MethodRXGB::HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/)
96{
97 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
98 return kFALSE;
99}
100
101
102//_______________________________________________________________________
104{
105
106 if (!IsModuleLoaded) {
107 Error("Init", "R's package xgboost can not be loaded.");
108 Log() << kFATAL << " R's package xgboost can not be loaded."
109 << Endl;
110 return;
111 }
112 //factors creations
113 //xgboost require a numeric factor then background=0 signal=1 from fFactorTrain
114 UInt_t size = fFactorTrain.size();
115 fFactorNumeric.resize(size);
116
117 for (UInt_t i = 0; i < size; i++) {
118 if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
119 else fFactorNumeric[i] = 0;
120 }
121
122
123
124}
125
127{
128 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
131 params["eta"] = fEta;
132 params["max.depth"] = fMaxDepth;
133
134 SEXP Model = xgbtrain(ROOT::R::Label["data"] = dmatrix,
136 ROOT::R::Label["weight"] = fWeightTrain,
137 ROOT::R::Label["nrounds"] = fNRounds,
138 ROOT::R::Label["params"] = params);
139
140 fModel = new ROOT::R::TRObject(Model);
141 if (IsModelPersistence())
142 {
143 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
144 Log() << Endl;
145 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
146 Log() << Endl;
147 xgbsave(Model, path);
148 }
149}
150
151//_______________________________________________________________________
153{
154 DeclareOptionRef(fNRounds, "NRounds", "The max number of iterations");
155 DeclareOptionRef(fEta, "Eta", "Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
156 DeclareOptionRef(fMaxDepth, "MaxDepth", "Maximum depth of the tree");
157}
158
159//_______________________________________________________________________
163
164//_______________________________________________________________________
166{
167 Log() << kINFO << "Testing Classification RXGB METHOD " << Endl;
169}
170
171
172//_______________________________________________________________________
174{
175 NoErrorCalc(errLower, errUpper);
176 Double_t mvaValue;
177 const TMVA::Event *ev = GetEvent();
178 const UInt_t nvar = DataInfo().GetNVariables();
179 ROOT::R::TRDataFrame fDfEvent;
180 for (UInt_t i = 0; i < nvar; i++) {
181 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
182 }
183 //if using persistence model
185
186 mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(fDfEvent)));
187 return mvaValue;
188}
189
190////////////////////////////////////////////////////////////////////////////////
191/// get all the MVA values for the events of the current Data type
192std::vector<Double_t> MethodRXGB::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
193{
194 Long64_t nEvents = Data()->GetNEvents();
195 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
196 if (firstEvt < 0) firstEvt = 0;
197
198 nEvents = lastEvt-firstEvt;
199
200 UInt_t nvars = Data()->GetNVariables();
201
202 // use timer
203 Timer timer( nEvents, GetName(), kTRUE );
204 if (logProgress)
205 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
206 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
207
208
209 // fill R DATA FRAME with events data
210 std::vector<std::vector<Float_t> > inputData(nvars);
211 for (UInt_t i = 0; i < nvars; i++) {
212 inputData[i] = std::vector<Float_t>(nEvents);
213 }
214
215 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
216 Data()->SetCurrentEvent(ievt);
217 const TMVA::Event *e = Data()->GetEvent();
218 assert(nvars == e->GetNVariables());
219 for (UInt_t i = 0; i < nvars; i++) {
220 inputData[i][ievt] = e->GetValue(i);
221 }
222 // if (ievt%100 == 0)
223 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
224 }
225
226 ROOT::R::TRDataFrame evtData;
227 for (UInt_t i = 0; i < nvars; i++) {
228 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
229 }
230 //if using persistence model
232
233 std::vector<Double_t> mvaValues(nEvents);
235 mvaValues = pred.As<std::vector<Double_t>>();
236
237 if (logProgress) {
238 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
239 << timer.GetElapsedTime() << " " << Endl;
240 }
241
242 return mvaValues;
243
244}
245//_______________________________________________________________________
247{
248// get help message text
249//
250// typical length of text line:
251// "|--------------------------------------------------------------|"
252 Log() << Endl;
253 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
254 Log() << Endl;
255 Log() << "Decision Trees and Rule-Based Models " << Endl;
256 Log() << Endl;
257 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
258 Log() << Endl;
259 Log() << Endl;
260 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
261 Log() << Endl;
262 Log() << "<None>" << Endl;
263}
264
265//_______________________________________________________________________
267{
269 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
270 Log() << Endl;
271 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
272 Log() << Endl;
273
274 SEXP Model = xgbload(path);
275 fModel = new ROOT::R::TRObject(Model);
276
277}
278
279//_______________________________________________________________________
280void TMVA::MethodRXGB::MakeClass(const TString &/*theClassFileName*/) const
281{
282}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
int Int_t
Signed integer 4 bytes (int).
Definition RtypesCore.h:59
bool Bool_t
Boolean (0=false, 1=true) (bool).
Definition RtypesCore.h:77
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
Error("WriteTObject","The current directory (%s) is not associated with a file. The object (%s) has not been written.", GetName(), objname)
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2496
This is a class to create DataFrames from ROOT to R.
static TRInterface & Instance()
static method to get an TRInterface instance reference
Bool_t Require(TString pkg)
Method to load an R's package.
This is a class to get ROOT's objects from R's objects.
Definition TRObject.h:69
T As()
Some datatypes of ROOT or c++ can be wrapped in to a TRObject, this method lets you unwrap those data...
Definition TRObject.h:151
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
std::vector< TString > GetListOfVariables() const
returns list of variables
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
std::vector< Float_t > & GetValues()
Definition Event.h:94
const char * GetName() const override
Definition MethodBase.h:337
Bool_t IsModelPersistence() const
Definition MethodBase.h:386
const TString & GetWeightFileDir() const
Definition MethodBase.h:495
const TString & GetMethodName() const
Definition MethodBase.h:334
const Event * GetEvent() const
Definition MethodBase.h:754
DataSetInfo & DataInfo() const
Definition MethodBase.h:413
virtual void TestClassification()
initialization
void ReadStateFromFile()
Function to write options and weights to file.
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
DataSet * Data() const
Definition MethodBase.h:412
std::vector< UInt_t > fFactorNumeric
Definition MethodRXGB.h:93
void Init() override
ROOT::R::TRFunctionImport xgbtrain
Definition MethodRXGB.h:97
MethodRXGB(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
static Bool_t IsModuleLoaded
Definition MethodRXGB.h:91
ROOT::R::TRFunctionImport asmatrix
Definition MethodRXGB.h:102
void DeclareOptions() override
void Train() override
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false) override
get all the MVA values for the events of the current Data type
ROOT::R::TRFunctionImport xgbload
Definition MethodRXGB.h:100
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) override
Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr) override
ROOT::R::TRFunctionImport asfactor
Definition MethodRXGB.h:101
void GetHelpMessage() const override
ROOT::R::TRFunctionImport xgbsave
Definition MethodRXGB.h:99
ROOT::R::TRObject * fModel
Definition MethodRXGB.h:103
ROOT::R::TRFunctionImport xgbdmatrix
Definition MethodRXGB.h:98
void MakeClass(const TString &classFileName=TString("")) const override
create reader class for method (classification only at present)
void ProcessOptions() override
ROOT::R::TRFunctionImport predict
Definition MethodRXGB.h:96
void TestClassification() override
initialization
std::vector< std::string > fFactorTrain
Definition RMethodBase.h:75
ROOT::R::TRDataFrame fDfTrain
Definition RMethodBase.h:71
RMethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="", ROOT::R::TRInterface &_r=ROOT::R::TRInterface::Instance())
TVectorD fWeightTrain
Definition RMethodBase.h:73
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition Timer.cxx:145
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:803
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:138
const Rcpp::internal::NamedPlaceHolder & Label
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148