Logo ROOT   6.10/09
Reference Guide
MethodRXGB.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodRXGB *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * R eXtreme Gradient Boosting *
12  * *
13  * *
14  * Redistribution and use in source and binary forms, with or without *
15  * modification, are permitted according to the terms listed in LICENSE *
16  * (http://tmva.sourceforge.net/LICENSE) *
17  * *
18  **********************************************************************************/
19 
20 #include <iomanip>
21 
22 #include "TMath.h"
23 #include "Riostream.h"
24 #include "TMatrix.h"
25 #include "TMatrixD.h"
26 #include "TVectorD.h"
27 
29 #include "TMVA/MethodRXGB.h"
30 #include "TMVA/Tools.h"
31 #include "TMVA/Config.h"
32 #include "TMVA/Ranking.h"
33 #include "TMVA/Types.h"
34 #include "TMVA/PDF.h"
35 #include "TMVA/ClassifierFactory.h"
36 
37 #include "TMVA/Results.h"
38 #include "TMVA/Timer.h"
39 
40 using namespace TMVA;
41 
42 REGISTER_METHOD(RXGB)
43 
45 
46 //creating an Instance
48 
49 //_______________________________________________________________________
50 MethodRXGB::MethodRXGB(const TString &jobName,
51  const TString &methodTitle,
52  DataSetInfo &dsi,
53  const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
54  fNRounds(10),
55  fEta(0.3),
56  fMaxDepth(6),
57  predict("predict", "xgboost"),
58  xgbtrain("xgboost"),
59  xgbdmatrix("xgb.DMatrix"),
60  xgbsave("xgb.save"),
61  xgbload("xgb.load"),
62  asfactor("as.factor"),
63  asmatrix("as.matrix"),
64  fModel(NULL)
65 {
66  // standard constructor for the RXGB
67 
68 }
69 
70 //_______________________________________________________________________
71 MethodRXGB::MethodRXGB(DataSetInfo &theData, const TString &theWeightFile)
72  : RMethodBase(Types::kRXGB, theData, theWeightFile),
73  fNRounds(10),
74  fEta(0.3),
75  fMaxDepth(6),
76  predict("predict", "xgboost"),
77  xgbtrain("xgboost"),
78  xgbdmatrix("xgb.DMatrix"),
79  xgbsave("xgb.save"),
80  xgbload("xgb.load"),
81  asfactor("as.factor"),
82  asmatrix("as.matrix"),
83  fModel(NULL)
84 {
85 
86 }
87 
88 
89 //_______________________________________________________________________
91 {
92  if (fModel) delete fModel;
93 }
94 
95 //_______________________________________________________________________
97 {
98  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
99  return kFALSE;
100 }
101 
102 
103 //_______________________________________________________________________
105 {
106 
107  if (!IsModuleLoaded) {
108  Error("Init", "R's package xgboost can not be loaded.");
109  Log() << kFATAL << " R's package xgboost can not be loaded."
110  << Endl;
111  return;
112  }
113  //factors creations
114  //xgboost require a numeric factor then background=0 signal=1 from fFactorTrain
115  UInt_t size = fFactorTrain.size();
116  fFactorNumeric.resize(size);
117 
118  for (UInt_t i = 0; i < size; i++) {
119  if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
120  else fFactorNumeric[i] = 0;
121  }
122 
123 
124 
125 }
126 
128 {
129  if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
131  ROOT::R::TRDataFrame params;
132  params["eta"] = fEta;
133  params["max.depth"] = fMaxDepth;
134 
135  SEXP Model = xgbtrain(ROOT::R::Label["data"] = dmatrix,
136  ROOT::R::Label["label"] = fFactorNumeric,
137  ROOT::R::Label["weight"] = fWeightTrain,
138  ROOT::R::Label["nrounds"] = fNRounds,
139  ROOT::R::Label["params"] = params);
140 
141  fModel = new ROOT::R::TRObject(Model);
142  if (IsModelPersistence())
143  {
144  TString path = GetWeightFileDir() + "/RXGBModel.RData";
145  Log() << Endl;
146  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
147  Log() << Endl;
148  xgbsave(Model, path);
149  }
150 }
151 
152 //_______________________________________________________________________
154 {
155  DeclareOptionRef(fNRounds, "NRounds", "The max number of iterations");
156  DeclareOptionRef(fEta, "Eta", "Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
157  DeclareOptionRef(fMaxDepth, "MaxDepth", "Maximum depth of the tree");
158 }
159 
160 //_______________________________________________________________________
162 {
163 }
164 
165 //_______________________________________________________________________
167 {
168  Log() << kINFO << "Testing Classification RXGB METHOD " << Endl;
170 }
171 
172 
173 //_______________________________________________________________________
175 {
176  NoErrorCalc(errLower, errUpper);
177  Double_t mvaValue;
178  const TMVA::Event *ev = GetEvent();
179  const UInt_t nvar = DataInfo().GetNVariables();
180  ROOT::R::TRDataFrame fDfEvent;
181  for (UInt_t i = 0; i < nvar; i++) {
182  fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
183  }
184  //if using persistence model
186 
187  mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(fDfEvent)));
188  return mvaValue;
189 }
190 
191 ////////////////////////////////////////////////////////////////////////////////
192 /// get all the MVA values for the events of the current Data type
193 std::vector<Double_t> MethodRXGB::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
194 {
196  if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
197  if (firstEvt < 0) firstEvt = 0;
198 
199  nEvents = lastEvt-firstEvt;
200 
201  UInt_t nvars = Data()->GetNVariables();
202 
203  // use timer
204  Timer timer( nEvents, GetName(), kTRUE );
205  if (logProgress)
206  Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
207  << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
208 
209 
210  // fill R DATA FRAME with events data
211  std::vector<std::vector<Float_t> > inputData(nvars);
212  for (UInt_t i = 0; i < nvars; i++) {
213  inputData[i] = std::vector<Float_t>(nEvents);
214  }
215 
216  for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
217  Data()->SetCurrentEvent(ievt);
218  const TMVA::Event *e = Data()->GetEvent();
219  assert(nvars == e->GetNVariables());
220  for (UInt_t i = 0; i < nvars; i++) {
221  inputData[i][ievt] = e->GetValue(i);
222  }
223  // if (ievt%100 == 0)
224  // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
225  }
226 
227  ROOT::R::TRDataFrame evtData;
228  for (UInt_t i = 0; i < nvars; i++) {
229  evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
230  }
231  //if using persistence model
233 
234  std::vector<Double_t> mvaValues(nEvents);
235  ROOT::R::TRObject pred = predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(evtData)));
236  mvaValues = pred.As<std::vector<Double_t>>();
237 
238  if (logProgress) {
239  Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
240  << timer.GetElapsedTime() << " " << Endl;
241  }
242 
243  return mvaValues;
244 
245 }
246 //_______________________________________________________________________
248 {
249 // get help message text
250 //
251 // typical length of text line:
252 // "|--------------------------------------------------------------|"
253  Log() << Endl;
254  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
255  Log() << Endl;
256  Log() << "Decision Trees and Rule-Based Models " << Endl;
257  Log() << Endl;
258  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
259  Log() << Endl;
260  Log() << Endl;
261  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
262  Log() << Endl;
263  Log() << "<None>" << Endl;
264 }
265 
266 //_______________________________________________________________________
268 {
270  TString path = GetWeightFileDir() + "/RXGBModel.RData";
271  Log() << Endl;
272  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
273  Log() << Endl;
274 
275  SEXP Model = xgbload(path);
276  fModel = new ROOT::R::TRObject(Model);
277 
278 }
279 
280 //_______________________________________________________________________
281 void TMVA::MethodRXGB::MakeClass(const TString &/*theClassFileName*/) const
282 {
283 }
ROOT::R::TRFunctionImport xgbdmatrix
Definition: MethodRXGB.h:97
UInt_t GetNVariables() const
Definition: DataSetInfo.h:110
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:99
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Singleton class for Global types used by TMVA.
Definition: Types.h:73
long long Long64_t
Definition: RtypesCore.h:69
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodRXGB.cxx:96
MsgLogger & Log() const
Definition: Configurable.h:122
std::vector< TString > GetListOfVariables() const
returns list of variables
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
EAnalysisType
Definition: Types.h:125
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:216
ROOT::R::TRFunctionImport asfactor
Definition: MethodRXGB.h:100
ROOT::R::TRObject * fModel
Definition: MethodRXGB.h:102
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
#define NULL
Definition: RtypesCore.h:88
void ReadModelFromFile()
Definition: MethodRXGB.cxx:267
const TString & GetWeightFileDir() const
Definition: MethodBase.h:474
TVectorD fWeightTrain
Definition: RMethodBase.h:90
TStopwatch timer
Definition: pirndm.C:37
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodRXGB.cxx:193
const Event * GetEvent() const
Definition: MethodBase.h:733
DataSet * Data() const
Definition: MethodBase.h:393
void ReadStateFromFile()
Function to write options and weights to file.
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:203
DataSetInfo & DataInfo() const
Definition: MethodBase.h:394
Class that contains all the data information.
Definition: DataSetInfo.h:60
T As()
Some datatypes of ROOT or c++ can be wrapped in to a TRObject, this method lets you unwrap those data...
Definition: TRObject.h:153
Bool_t Require(TString pkg)
Method to load an R&#39;s package.
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:134
const int nEvents
Definition: testRooFit.cxx:42
This is a class to get ROOT&#39;s objects from R&#39;s objects
Definition: TRObject.h:71
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:92
ROOT::R::TRFunctionImport xgbsave
Definition: MethodRXGB.h:98
const char * GetName() const
Definition: MethodBase.h:318
ROOT::R::TRFunctionImport xgbload
Definition: MethodRXGB.h:99
unsigned int UInt_t
Definition: RtypesCore.h:42
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:873
char * Form(const char *fmt,...)
const TString & GetMethodName() const
Definition: MethodBase.h:315
ROOT::R::TRFunctionImport asmatrix
Definition: MethodRXGB.h:101
Double_t fEta
Definition: MethodRXGB.h:88
Tools & gTools()
UInt_t GetNVariables() const
accessor to the number of variables
Definition: Event.cxx:309
const Bool_t kFALSE
Definition: RtypesCore.h:92
Float_t GetValue(UInt_t ivar) const
return value of i&#39;th variable
Definition: Event.cxx:237
#define ClassImp(name)
Definition: Rtypes.h:336
double Double_t
Definition: RtypesCore.h:55
virtual void TestClassification()
initialization
Definition: MethodRXGB.cxx:166
int type
Definition: TGX11.cxx:120
static TRInterface & Instance()
static method to get an TRInterface instance reference
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
MethodRXGB(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition: MethodRXGB.cxx:50
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:839
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
UInt_t fMaxDepth
Definition: MethodRXGB.h:89
std::vector< Float_t > & GetValues()
Definition: Event.h:89
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:88
ROOT::R::TRFunctionImport predict
Definition: MethodRXGB.h:95
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:215
Rcpp::internal::NamedPlaceHolder Label
Definition: RExports.cxx:14
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
Definition: MethodRXGB.cxx:281
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Definition: MethodRXGB.cxx:174
static Bool_t IsModuleLoaded
Definition: MethodRXGB.h:90
void GetHelpMessage() const
Definition: MethodRXGB.cxx:247
const Bool_t kTRUE
Definition: RtypesCore.h:91
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
virtual void TestClassification()
initialization
ROOT::R::TRFunctionImport xgbtrain
Definition: MethodRXGB.h:96
const Event * GetEvent() const
Definition: DataSet.cxx:202
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:829
This is a class to create DataFrames from ROOT to R
Definition: TRDataFrame.h:177
std::vector< UInt_t > fFactorNumeric
Definition: MethodRXGB.h:92
Bool_t IsModelPersistence()
Definition: MethodBase.h:367