Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodC50.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodC50 *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * Decision Trees and Rule-Based Models *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (http://tmva.sourceforge.net/LICENSE) *
17 * *
18 **********************************************************************************/
19
20#include <iomanip>
21
22#include "TMath.h"
23#include "Riostream.h"
24#include "TMatrix.h"
25#include "TMatrixD.h"
26#include "TVectorD.h"
27
29#include "TMVA/MethodC50.h"
30#include "TMVA/Tools.h"
31#include "TMVA/Config.h"
32#include "TMVA/Ranking.h"
33#include "TMVA/Types.h"
34#include "TMVA/PDF.h"
36
37#include "TMVA/Results.h"
38#include "TMVA/Timer.h"
39
40using namespace TMVA;
41
43
45
46//creating an Instance
48
49//_______________________________________________________________________
51 const TString &methodTitle,
52 DataSetInfo &dsi,
53 const TString &theOption) : RMethodBase(jobName, Types::kC50, methodTitle, dsi, theOption),
54 fNTrials(1),
55 fRules(kFALSE),
56 fMvaCounter(0),
57 predict("predict.C5.0"),
58 //predict("predict"),
59 C50("C5.0"),
60 C50Control("C5.0Control"),
61 asfactor("as.factor"),
62 fModel(NULL)
63{
64 // standard constructor for the C50
65
66 //C5.0Control options
68 fControlBands = 0;
71 fControlCF = 0.25;
75 r["sample.int(4096, size = 1) - 1L"] >> fControlSeed;
77
79}
80
81//_______________________________________________________________________
82MethodC50::MethodC50(DataSetInfo &theData, const TString &theWeightFile)
83 : RMethodBase(Types::kC50, theData, theWeightFile),
84 fNTrials(1),
85 fRules(kFALSE),
86 fMvaCounter(0),
87 predict("predict.C5.0"),
88 C50("C5.0"),
89 C50Control("C5.0Control"),
90 asfactor("as.factor"),
91 fModel(NULL)
92{
93
94 // constructor from weight file
96 fControlBands = 0;
99 fControlCF = 0.25;
102 fControlSample = 0;
103 r["sample.int(4096, size = 1) - 1L"] >> fControlSeed;
105}
106
107
108//_______________________________________________________________________
110{
111 if (fModel) delete fModel;
112}
113
114//_______________________________________________________________________
116{
117 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
118 return kFALSE;
119}
120
121
122//_______________________________________________________________________
124{
125
126 if (!IsModuleLoaded) {
127 Error("Init", "R's package C50 can not be loaded.");
128 Log() << kFATAL << " R's package C50 can not be loaded."
129 << Endl;
130 return;
131 }
132}
133
135{
136 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
137 SEXP Model = C50(ROOT::R::Label["x"] = fDfTrain, \
139 ROOT::R::Label["trials"] = fNTrials, \
140 ROOT::R::Label["rules"] = fRules, \
141 ROOT::R::Label["weights"] = fWeightTrain, \
142 ROOT::R::Label["control"] = fModelControl);
143 fModel = new ROOT::R::TRObject(Model);
144 if (IsModelPersistence())
145 {
146 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
147 Log() << Endl;
148 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
149 Log() << Endl;
150 r["C50Model"] << Model;
151 r << "save(C50Model,file='" + path + "')";
152 }
153}
154
155//_______________________________________________________________________
157{
158 //
159 DeclareOptionRef(fNTrials, "NTrials", "An integer specifying the number of boosting iterations");
160 DeclareOptionRef(fRules, "Rules", "A logical: should the tree be decomposed into a rule-basedmodel?");
161
162 //C5.0Control Options
163 DeclareOptionRef(fControlSubset, "ControlSubset", "A logical: should the model evaluate groups of discrete \
164 predictors for splits? Note: the C5.0 command line version defaults this \
165 parameter to ‘FALSE’, meaning no attempted gropings will be evaluated \
166 during the tree growing stage.");
167 DeclareOptionRef(fControlBands, "ControlBands", "An integer between 2 and 1000. If ‘TRUE’, the model orders \
168 the rules by their affect on the error rate and groups the \
169 rules into the specified number of bands. This modifies the \
170 output so that the effect on the error rate can be seen for \
171 the groups of rules within a band. If this options is \
172 selected and ‘rules = kFALSE’, a warning is issued and ‘rules’ \
173 is changed to ‘kTRUE’.");
174 DeclareOptionRef(fControlWinnow, "ControlWinnow", "A logical: should predictor winnowing (i.e feature selection) be used?");
175 DeclareOptionRef(fControlNoGlobalPruning, "ControlNoGlobalPruning", "A logical to toggle whether the final, global pruning \
176 step to simplify the tree.");
177 DeclareOptionRef(fControlCF, "ControlCF", "A number in (0, 1) for the confidence factor.");
178 DeclareOptionRef(fControlMinCases, "ControlMinCases", "an integer for the smallest number of samples that must be \
179 put in at least two of the splits.");
180
181 DeclareOptionRef(fControlFuzzyThreshold, "ControlFuzzyThreshold", "A logical toggle to evaluate possible advanced splits \
182 of the data. See Quinlan (1993) for details and examples.");
183 DeclareOptionRef(fControlSample, "ControlSample", "A value between (0, .999) that specifies the random \
184 proportion of the data should be used to train the model. By \
185 default, all the samples are used for model training. Samples \
186 not used for training are used to evaluate the accuracy of \
187 the model in the printed output.");
188 DeclareOptionRef(fControlSeed, "ControlSeed", " An integer for the random number seed within the C code.");
189 DeclareOptionRef(fControlEarlyStopping, "ControlEarlyStopping", " A logical to toggle whether the internal method for \
190 stopping boosting should be used.");
191
192
193}
194
195//_______________________________________________________________________
197{
198 if (fNTrials <= 0) {
199 Log() << kERROR << " fNTrials <=0... that does not work !! "
200 << " I set it to 1 .. just so that the program does not crash"
201 << Endl;
202 fNTrials = 1;
203 }
205 ROOT::R::Label["bands"] = fControlBands, \
206 ROOT::R::Label["winnow"] = fControlWinnow, \
207 ROOT::R::Label["noGlobalPruning"] = fControlNoGlobalPruning, \
208 ROOT::R::Label["CF"] = fControlCF, \
209 ROOT::R::Label["minCases"] = fControlMinCases, \
210 ROOT::R::Label["fuzzyThreshold"] = fControlFuzzyThreshold, \
211 ROOT::R::Label["sample"] = fControlSample, \
212 ROOT::R::Label["seed"] = fControlSeed, \
213 ROOT::R::Label["earlyStopping"] = fControlEarlyStopping);
214}
215
216//_______________________________________________________________________
218{
219 Log() << kINFO << "Testing Classification C50 METHOD " << Endl;
221}
222
223
224//_______________________________________________________________________
226{
227 NoErrorCalc(errLower, errUpper);
228 Double_t mvaValue;
229 const TMVA::Event *ev = GetEvent();
230 const UInt_t nvar = DataInfo().GetNVariables();
231 ROOT::R::TRDataFrame fDfEvent;
232 for (UInt_t i = 0; i < nvar; i++) {
233 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
234 }
235 //if using persistence model
237
238 TVectorD result = predict(*fModel, fDfEvent, ROOT::R::Label["type"] = "prob");
239 mvaValue = result[1]; //returning signal prob
240 return mvaValue;
241}
242
243
244////////////////////////////////////////////////////////////////////////////////
245/// get all the MVA values for the events of the current Data type
246std::vector<Double_t> MethodC50::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
247{
248 Long64_t nEvents = Data()->GetNEvents();
249 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
250 if (firstEvt < 0) firstEvt = 0;
251
252 nEvents = lastEvt-firstEvt;
253
254 UInt_t nvars = Data()->GetNVariables();
255
256 // use timer
257 Timer timer( nEvents, GetName(), kTRUE );
258 if (logProgress)
259 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
260 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
261
262
263 // fill R DATA FRAME with events data
264 std::vector<std::vector<Float_t> > inputData(nvars);
265 for (UInt_t i = 0; i < nvars; i++) {
266 inputData[i] = std::vector<Float_t>(nEvents);
267 }
268
269 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
270 Data()->SetCurrentEvent(ievt);
271 const TMVA::Event *e = Data()->GetEvent();
272 assert(nvars == e->GetNVariables());
273 for (UInt_t i = 0; i < nvars; i++) {
274 inputData[i][ievt] = e->GetValue(i);
275 }
276 // if (ievt%100 == 0)
277 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
278 }
279
280 ROOT::R::TRDataFrame evtData;
281 for (UInt_t i = 0; i < nvars; i++) {
282 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
283 }
284 //if using persistence model
286
287 std::vector<Double_t> mvaValues(nEvents);
288 ROOT::R::TRObject result = predict(*fModel, evtData, ROOT::R::Label["type"] = "prob");
289 std::vector<Double_t> probValues(2*nEvents);
290 probValues = result.As<std::vector<Double_t>>();
291 assert(probValues.size() == 2*mvaValues.size());
292 std::copy(probValues.begin()+nEvents, probValues.end(), mvaValues.begin() );
293
294 if (logProgress) {
295 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
296 << timer.GetElapsedTime() << " " << Endl;
297 }
298
299 return mvaValues;
300
301}
302
303//_______________________________________________________________________
305{
306// get help message text
307//
308// typical length of text line:
309// "|--------------------------------------------------------------|"
310 Log() << Endl;
311 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
312 Log() << Endl;
313 Log() << "Decision Trees and Rule-Based Models " << Endl;
314 Log() << Endl;
315 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
316 Log() << Endl;
317 Log() << Endl;
318 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
319 Log() << Endl;
320 Log() << "<None>" << Endl;
321}
322
323//_______________________________________________________________________
325{
327 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
328 Log() << Endl;
329 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
330 Log() << Endl;
331 r << "load('" + path + "')";
332 SEXP Model;
333 r["C50Model"] >> Model;
334 fModel = new ROOT::R::TRObject(Model);
335
336}
337
338//_______________________________________________________________________
339void TMVA::MethodC50::MakeClass(const TString &/*theClassFileName*/) const
340{
341}
#define REGISTER_METHOD(CLASS)
for example
ROOT::R::TRInterface & r
Definition Object.C:4
#define e(i)
Definition RSha256.hxx:103
const Bool_t kFALSE
Definition RtypesCore.h:101
bool Bool_t
Definition RtypesCore.h:63
long long Long64_t
Definition RtypesCore.h:80
const Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassImp(name)
Definition Rtypes.h:364
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:187
int type
Definition TGX11.cxx:121
char * Form(const char *fmt,...)
This is a class to create DataFrames from ROOT to R.
static TRInterface & Instance()
static method to get an TRInterface instance reference
Bool_t Require(TString pkg)
Method to load an R's package.
This is a class to get ROOT's objects from R's objects.
Definition TRObject.h:70
T As()
Some datatypes of ROOT or c++ can be wrapped in to a TRObject, this method lets you unwrap those data...
Definition TRObject.h:152
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
std::vector< TString > GetListOfVariables() const
returns list of variables
const Event * GetEvent() const
Definition DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
std::vector< Float_t > & GetValues()
Definition Event.h:94
const char * GetName() const
Definition MethodBase.h:334
Bool_t IsModelPersistence() const
Definition MethodBase.h:383
const TString & GetWeightFileDir() const
Definition MethodBase.h:492
const TString & GetMethodName() const
Definition MethodBase.h:331
const Event * GetEvent() const
Definition MethodBase.h:751
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void TestClassification()
initialization
void ReadStateFromFile()
Function to write options and weights to file.
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
DataSet * Data() const
Definition MethodBase.h:409
void GetHelpMessage() const
static Bool_t IsModuleLoaded
Definition MethodC50.h:100
virtual void TestClassification()
initialization
Bool_t fControlSubset
Definition MethodC50.h:88
ROOT::R::TRFunctionImport asfactor
Definition MethodC50.h:105
ROOT::R::TRObject * fModel
Definition MethodC50.h:106
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
void DeclareOptions()
Double_t fControlCF
Definition MethodC50.h:92
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Int_t fControlSeed
Definition MethodC50.h:96
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
Bool_t fControlNoGlobalPruning
Definition MethodC50.h:91
Bool_t fControlFuzzyThreshold
Definition MethodC50.h:94
Double_t fControlSample
Definition MethodC50.h:95
UInt_t fNTrials
Definition MethodC50.h:84
std::vector< TString > ListOfVariables
Definition MethodC50.h:108
Bool_t fControlEarlyStopping
Definition MethodC50.h:97
ROOT::R::TRFunctionImport C50Control
Definition MethodC50.h:104
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
void ProcessOptions()
ROOT::R::TRObject fModelControl
Definition MethodC50.h:107
ROOT::R::TRFunctionImport predict
Definition MethodC50.h:102
Bool_t fControlWinnow
Definition MethodC50.h:90
ROOT::R::TRFunctionImport C50
Definition MethodC50.h:103
MethodC50(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition MethodC50.cxx:50
UInt_t fControlMinCases
Definition MethodC50.h:93
void ReadModelFromFile()
UInt_t fControlBands
Definition MethodC50.h:89
std::vector< std::string > fFactorTrain
Definition RMethodBase.h:95
ROOT::R::TRInterface & r
Definition RMethodBase.h:52
ROOT::R::TRDataFrame fDfTrain
Definition RMethodBase.h:91
TVectorD fWeightTrain
Definition RMethodBase.h:93
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition Timer.cxx:146
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:136
const Rcpp::internal::NamedPlaceHolder & Label
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148