Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodRXGB.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRXGB *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * R eXtreme Gradient Boosting *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (see tmva/doc/LICENSE) *
17 * *
18 **********************************************************************************/
19
20#include <iomanip>
21
22#include "TMath.h"
23#include "Riostream.h"
24#include "TMatrix.h"
25#include "TMatrixD.h"
26#include "TVectorD.h"
27
29#include "TMVA/MethodRXGB.h"
30#include "TMVA/Tools.h"
31#include "TMVA/Config.h"
32#include "TMVA/Ranking.h"
33#include "TMVA/Types.h"
34#include "TMVA/PDF.h"
36
37#include "TMVA/Results.h"
38#include "TMVA/Timer.h"
39
40using namespace TMVA;
41
43
44
45//creating an Instance
47
48//_______________________________________________________________________
50 const TString &methodTitle,
52 const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
53 fNRounds(10),
54 fEta(0.3),
55 fMaxDepth(6),
56 predict("predict", "xgboost"),
57 xgbtrain("xgboost"),
58 xgbdmatrix("xgb.DMatrix"),
59 xgbsave("xgb.save"),
60 xgbload("xgb.load"),
61 asfactor("as.factor"),
62 asmatrix("as.matrix"),
63 fModel(NULL)
64{
65 // standard constructor for the RXGB
66
67}
68
69//_______________________________________________________________________
72 fNRounds(10),
73 fEta(0.3),
74 fMaxDepth(6),
75 predict("predict", "xgboost"),
76 xgbtrain("xgboost"),
77 xgbdmatrix("xgb.DMatrix"),
78 xgbsave("xgb.save"),
79 xgbload("xgb.load"),
80 asfactor("as.factor"),
81 asmatrix("as.matrix"),
82 fModel(NULL)
83{
84
85}
86
87
88//_______________________________________________________________________
90{
91 if (fModel) delete fModel;
92}
93
94//_______________________________________________________________________
100
101
102//_______________________________________________________________________
104{
105
106 if (!IsModuleLoaded) {
107 Error("Init", "R's package xgboost can not be loaded.");
108 Log() << kFATAL << " R's package xgboost can not be loaded."
109 << Endl;
110 return;
111 }
112 //factors creations
113 //xgboost require a numeric factor then background=0 signal=1 from fFactorTrain
114 UInt_t size = fFactorTrain.size();
115 fFactorNumeric.resize(size);
116
117 for (UInt_t i = 0; i < size; i++) {
118 if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
119 else fFactorNumeric[i] = 0;
120 }
121
122
123
124}
125
127{
128 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
131 params["eta"] = fEta;
132 params["max.depth"] = fMaxDepth;
133
136 ROOT::R::Label["weight"] = fWeightTrain,
137 ROOT::R::Label["nrounds"] = fNRounds,
138 ROOT::R::Label["params"] = params);
139
141 if (IsModelPersistence())
142 {
143 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
144 Log() << Endl;
145 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
146 Log() << Endl;
147 xgbsave(Model, path);
148 }
149}
150
151//_______________________________________________________________________
153{
154 DeclareOptionRef(fNRounds, "NRounds", "The max number of iterations");
155 DeclareOptionRef(fEta, "Eta", "Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
156 DeclareOptionRef(fMaxDepth, "MaxDepth", "Maximum depth of the tree");
157}
158
159//_______________________________________________________________________
163
164//_______________________________________________________________________
166{
167 Log() << kINFO << "Testing Classification RXGB METHOD " << Endl;
169}
170
171
172//_______________________________________________________________________
174{
177 const TMVA::Event *ev = GetEvent();
178 const UInt_t nvar = DataInfo().GetNVariables();
180 for (UInt_t i = 0; i < nvar; i++) {
181 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
182 }
183 //if using persistence model
185
187 return mvaValue;
188}
189
190////////////////////////////////////////////////////////////////////////////////
191/// get all the MVA values for the events of the current Data type
193{
194 Long64_t nEvents = Data()->GetNEvents();
195 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
196 if (firstEvt < 0) firstEvt = 0;
197
198 nEvents = lastEvt-firstEvt;
199
200 UInt_t nvars = Data()->GetNVariables();
201
202 // use timer
203 Timer timer( nEvents, GetName(), kTRUE );
204 if (logProgress)
205 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
206 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
207
208
209 // fill R DATA FRAME with events data
210 std::vector<std::vector<Float_t> > inputData(nvars);
211 for (UInt_t i = 0; i < nvars; i++) {
212 inputData[i] = std::vector<Float_t>(nEvents);
213 }
214
215 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
217 const TMVA::Event *e = Data()->GetEvent();
218 assert(nvars == e->GetNVariables());
219 for (UInt_t i = 0; i < nvars; i++) {
220 inputData[i][ievt] = e->GetValue(i);
221 }
222 // if (ievt%100 == 0)
223 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
224 }
225
227 for (UInt_t i = 0; i < nvars; i++) {
228 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
229 }
230 //if using persistence model
232
233 std::vector<Double_t> mvaValues(nEvents);
235 mvaValues = pred.As<std::vector<Double_t>>();
236
237 if (logProgress) {
238 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
239 << timer.GetElapsedTime() << " " << Endl;
240 }
241
242 return mvaValues;
243
244}
245//_______________________________________________________________________
247{
248// get help message text
249//
250// typical length of text line:
251// "|--------------------------------------------------------------|"
252 Log() << Endl;
253 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
254 Log() << Endl;
255 Log() << "Decision Trees and Rule-Based Models " << Endl;
256 Log() << Endl;
257 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
258 Log() << Endl;
259 Log() << Endl;
260 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
261 Log() << Endl;
262 Log() << "<None>" << Endl;
263}
264
265//_______________________________________________________________________
267{
268 ROOT::R::TRInterface::Instance().Require("RXGB");
269 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
270 Log() << Endl;
271 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
272 Log() << Endl;
273
274 SEXP Model = xgbload(path);
275 fModel = new ROOT::R::TRObject(Model);
276
277}
278
279//_______________________________________________________________________
280void TMVA::MethodRXGB::MakeClass(const TString &/*theClassFileName*/) const
281{
282}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
bool Bool_t
Boolean (0=false, 1=true) (bool)
Definition RtypesCore.h:77
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
double Double_t
Double 8 bytes.
Definition RtypesCore.h:73
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:208
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2495
This is a class to create DataFrames from ROOT to R.
static TRInterface & Instance()
static method to get an TRInterface instance reference
This is a class to get ROOT's objects from R's objects.
Definition TRObject.h:70
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
std::vector< TString > GetListOfVariables() const
returns list of variables
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
const char * GetName() const override
Definition MethodBase.h:334
Bool_t IsModelPersistence() const
Definition MethodBase.h:383
const TString & GetWeightFileDir() const
Definition MethodBase.h:492
const TString & GetMethodName() const
Definition MethodBase.h:331
const Event * GetEvent() const
Definition MethodBase.h:751
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void TestClassification()
initialization
void ReadStateFromFile()
Function to write options and weights to file.
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
DataSet * Data() const
Definition MethodBase.h:409
std::vector< UInt_t > fFactorNumeric
Definition MethodRXGB.h:93
Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr)
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
void GetHelpMessage() const
ROOT::R::TRFunctionImport xgbtrain
Definition MethodRXGB.h:97
MethodRXGB(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
static Bool_t IsModuleLoaded
Definition MethodRXGB.h:91
ROOT::R::TRFunctionImport asmatrix
Definition MethodRXGB.h:102
virtual void TestClassification()
initialization
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
ROOT::R::TRFunctionImport xgbsave
Definition MethodRXGB.h:99
ROOT::R::TRObject * fModel
Definition MethodRXGB.h:103
ROOT::R::TRFunctionImport xgbdmatrix
Definition MethodRXGB.h:98
ROOT::R::TRFunctionImport predict
Definition MethodRXGB.h:96
std::vector< std::string > fFactorTrain
Definition RMethodBase.h:95
ROOT::R::TRDataFrame fDfTrain
Definition RMethodBase.h:91
TVectorD fWeightTrain
Definition RMethodBase.h:93
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:138
const Rcpp::internal::NamedPlaceHolder & Label
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148