Logo ROOT  
Reference Guide
MethodC50.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodC50 *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * Decision Trees and Rule-Based Models *
12  * *
13  * *
14  * Redistribution and use in source and binary forms, with or without *
15  * modification, are permitted according to the terms listed in LICENSE *
16  * (http://tmva.sourceforge.net/LICENSE) *
17  * *
18  **********************************************************************************/
19 
20 #include <iomanip>
21 
22 #include "TMath.h"
23 #include "Riostream.h"
24 #include "TMatrix.h"
25 #include "TMatrixD.h"
26 #include "TVectorD.h"
27 
29 #include "TMVA/MethodC50.h"
30 #include "TMVA/Tools.h"
31 #include "TMVA/Config.h"
32 #include "TMVA/Ranking.h"
33 #include "TMVA/Types.h"
34 #include "TMVA/PDF.h"
35 #include "TMVA/ClassifierFactory.h"
36 
37 #include "TMVA/Results.h"
38 #include "TMVA/Timer.h"
39 
40 using namespace TMVA;
41 
42 REGISTER_METHOD(C50)
43 
45 
46 //creating an Instance
48 
49 //_______________________________________________________________________
51  const TString &methodTitle,
52  DataSetInfo &dsi,
53  const TString &theOption) : RMethodBase(jobName, Types::kC50, methodTitle, dsi, theOption),
54  fNTrials(1),
55  fRules(kFALSE),
56  fMvaCounter(0),
57  predict("predict.C5.0"),
58  //predict("predict"),
59  C50("C5.0"),
60  C50Control("C5.0Control"),
61  asfactor("as.factor"),
62  fModel(NULL)
63 {
64  // standard constructor for the C50
65 
66  //C5.0Control options
68  fControlBands = 0;
71  fControlCF = 0.25;
72  fControlMinCases = 2;
74  fControlSample = 0;
75  r["sample.int(4096, size = 1) - 1L"] >> fControlSeed;
77 
79 }
80 
81 //_______________________________________________________________________
82 MethodC50::MethodC50(DataSetInfo &theData, const TString &theWeightFile)
83  : RMethodBase(Types::kC50, theData, theWeightFile),
84  fNTrials(1),
85  fRules(kFALSE),
86  fMvaCounter(0),
87  predict("predict.C5.0"),
88  C50("C5.0"),
89  C50Control("C5.0Control"),
90  asfactor("as.factor"),
91  fModel(NULL)
92 {
93 
94  // constructor from weight file
96  fControlBands = 0;
99  fControlCF = 0.25;
100  fControlMinCases = 2;
102  fControlSample = 0;
103  r["sample.int(4096, size = 1) - 1L"] >> fControlSeed;
105 }
106 
107 
108 //_______________________________________________________________________
110 {
111  if (fModel) delete fModel;
112 }
113 
114 //_______________________________________________________________________
116 {
117  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
118  return kFALSE;
119 }
120 
121 
122 //_______________________________________________________________________
124 {
125 
126  if (!IsModuleLoaded) {
127  Error("Init", "R's package C50 can not be loaded.");
128  Log() << kFATAL << " R's package C50 can not be loaded."
129  << Endl;
130  return;
131  }
132 }
133 
135 {
136  if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
137  SEXP Model = C50(ROOT::R::Label["x"] = fDfTrain, \
139  ROOT::R::Label["trials"] = fNTrials, \
140  ROOT::R::Label["rules"] = fRules, \
141  ROOT::R::Label["weights"] = fWeightTrain, \
142  ROOT::R::Label["control"] = fModelControl);
143  fModel = new ROOT::R::TRObject(Model);
144  if (IsModelPersistence())
145  {
146  TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
147  Log() << Endl;
148  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
149  Log() << Endl;
150  r["C50Model"] << Model;
151  r << "save(C50Model,file='" + path + "')";
152  }
153 }
154 
155 //_______________________________________________________________________
157 {
158  //
159  DeclareOptionRef(fNTrials, "NTrials", "An integer specifying the number of boosting iterations");
160  DeclareOptionRef(fRules, "Rules", "A logical: should the tree be decomposed into a rule-basedmodel?");
161 
162  //C5.0Control Options
163  DeclareOptionRef(fControlSubset, "ControlSubset", "A logical: should the model evaluate groups of discrete \
164  predictors for splits? Note: the C5.0 command line version defaults this \
165  parameter to ‘FALSE’, meaning no attempted gropings will be evaluated \
166  during the tree growing stage.");
167  DeclareOptionRef(fControlBands, "ControlBands", "An integer between 2 and 1000. If ‘TRUE’, the model orders \
168  the rules by their affect on the error rate and groups the \
169  rules into the specified number of bands. This modifies the \
170  output so that the effect on the error rate can be seen for \
171  the groups of rules within a band. If this options is \
172  selected and ‘rules = kFALSE’, a warning is issued and ‘rules’ \
173  is changed to ‘kTRUE’.");
174  DeclareOptionRef(fControlWinnow, "ControlWinnow", "A logical: should predictor winnowing (i.e feature selection) be used?");
175  DeclareOptionRef(fControlNoGlobalPruning, "ControlNoGlobalPruning", "A logical to toggle whether the final, global pruning \
176  step to simplify the tree.");
177  DeclareOptionRef(fControlCF, "ControlCF", "A number in (0, 1) for the confidence factor.");
178  DeclareOptionRef(fControlMinCases, "ControlMinCases", "an integer for the smallest number of samples that must be \
179  put in at least two of the splits.");
180 
181  DeclareOptionRef(fControlFuzzyThreshold, "ControlFuzzyThreshold", "A logical toggle to evaluate possible advanced splits \
182  of the data. See Quinlan (1993) for details and examples.");
183  DeclareOptionRef(fControlSample, "ControlSample", "A value between (0, .999) that specifies the random \
184  proportion of the data should be used to train the model. By \
185  default, all the samples are used for model training. Samples \
186  not used for training are used to evaluate the accuracy of \
187  the model in the printed output.");
188  DeclareOptionRef(fControlSeed, "ControlSeed", " An integer for the random number seed within the C code.");
189  DeclareOptionRef(fControlEarlyStopping, "ControlEarlyStopping", " A logical to toggle whether the internal method for \
190  stopping boosting should be used.");
191 
192 
193 }
194 
195 //_______________________________________________________________________
197 {
198  if (fNTrials <= 0) {
199  Log() << kERROR << " fNTrials <=0... that does not work !! "
200  << " I set it to 1 .. just so that the program does not crash"
201  << Endl;
202  fNTrials = 1;
203  }
205  ROOT::R::Label["bands"] = fControlBands, \
206  ROOT::R::Label["winnow"] = fControlWinnow, \
207  ROOT::R::Label["noGlobalPruning"] = fControlNoGlobalPruning, \
208  ROOT::R::Label["CF"] = fControlCF, \
209  ROOT::R::Label["minCases"] = fControlMinCases, \
210  ROOT::R::Label["fuzzyThreshold"] = fControlFuzzyThreshold, \
211  ROOT::R::Label["sample"] = fControlSample, \
212  ROOT::R::Label["seed"] = fControlSeed, \
213  ROOT::R::Label["earlyStopping"] = fControlEarlyStopping);
214 }
215 
216 //_______________________________________________________________________
218 {
219  Log() << kINFO << "Testing Classification C50 METHOD " << Endl;
221 }
222 
223 
224 //_______________________________________________________________________
226 {
227  NoErrorCalc(errLower, errUpper);
228  Double_t mvaValue;
229  const TMVA::Event *ev = GetEvent();
230  const UInt_t nvar = DataInfo().GetNVariables();
231  ROOT::R::TRDataFrame fDfEvent;
232  for (UInt_t i = 0; i < nvar; i++) {
233  fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
234  }
235  //if using persistence model
237 
238  TVectorD result = predict(*fModel, fDfEvent, ROOT::R::Label["type"] = "prob");
239  mvaValue = result[1]; //returning signal prob
240  return mvaValue;
241 }
242 
243 
244 ////////////////////////////////////////////////////////////////////////////////
245 /// get all the MVA values for the events of the current Data type
246 std::vector<Double_t> MethodC50::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
247 {
248  Long64_t nEvents = Data()->GetNEvents();
249  if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
250  if (firstEvt < 0) firstEvt = 0;
251 
252  nEvents = lastEvt-firstEvt;
253 
254  UInt_t nvars = Data()->GetNVariables();
255 
256  // use timer
257  Timer timer( nEvents, GetName(), kTRUE );
258  if (logProgress)
259  Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
260  << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
261 
262 
263  // fill R DATA FRAME with events data
264  std::vector<std::vector<Float_t> > inputData(nvars);
265  for (UInt_t i = 0; i < nvars; i++) {
266  inputData[i] = std::vector<Float_t>(nEvents);
267  }
268 
269  for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
270  Data()->SetCurrentEvent(ievt);
271  const TMVA::Event *e = Data()->GetEvent();
272  assert(nvars == e->GetNVariables());
273  for (UInt_t i = 0; i < nvars; i++) {
274  inputData[i][ievt] = e->GetValue(i);
275  }
276  // if (ievt%100 == 0)
277  // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
278  }
279 
280  ROOT::R::TRDataFrame evtData;
281  for (UInt_t i = 0; i < nvars; i++) {
282  evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
283  }
284  //if using persistence model
286 
287  std::vector<Double_t> mvaValues(nEvents);
288  ROOT::R::TRObject result = predict(*fModel, evtData, ROOT::R::Label["type"] = "prob");
289  std::vector<Double_t> probValues(2*nEvents);
290  probValues = result.As<std::vector<Double_t>>();
291  assert(probValues.size() == 2*mvaValues.size());
292  std::copy(probValues.begin()+nEvents, probValues.end(), mvaValues.begin() );
293 
294  if (logProgress) {
295  Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
296  << timer.GetElapsedTime() << " " << Endl;
297  }
298 
299  return mvaValues;
300 
301 }
302 
303 //_______________________________________________________________________
305 {
306 // get help message text
307 //
308 // typical length of text line:
309 // "|--------------------------------------------------------------|"
310  Log() << Endl;
311  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
312  Log() << Endl;
313  Log() << "Decision Trees and Rule-Based Models " << Endl;
314  Log() << Endl;
315  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
316  Log() << Endl;
317  Log() << Endl;
318  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
319  Log() << Endl;
320  Log() << "<None>" << Endl;
321 }
322 
323 //_______________________________________________________________________
325 {
327  TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
328  Log() << Endl;
329  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
330  Log() << Endl;
331  r << "load('" + path + "')";
332  SEXP Model;
333  r["C50Model"] >> Model;
334  fModel = new ROOT::R::TRObject(Model);
335 
336 }
337 
338 //_______________________________________________________________________
339 void TMVA::MethodC50::MakeClass(const TString &/*theClassFileName*/) const
340 {
341 }
TMVA::MethodBase::TestClassification
virtual void TestClassification()
initialization
Definition: MethodBase.cxx:1125
TMVA::DataSet::GetNVariables
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:216
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:100
TMVA::Configurable::Log
MsgLogger & Log() const
Definition: Configurable.h:122
TMVA::DataSet::GetCurrentType
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:194
e
#define e(i)
Definition: RSha256.hxx:103
TVectorD.h
TMVA::MethodBase::Data
DataSet * Data() const
Definition: MethodBase.h:409
TMVA::kERROR
@ kERROR
Definition: Types.h:62
TMVA::MethodC50::MakeClass
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
Definition: MethodC50.cxx:339
TMVA::MethodBase::IsModelPersistence
Bool_t IsModelPersistence() const
Definition: MethodBase.h:383
ClassImp
#define ClassImp(name)
Definition: Rtypes.h:364
Form
char * Form(const char *fmt,...)
r
ROOT::R::TRInterface & r
Definition: Object.C:4
Long64_t
long long Long64_t
Definition: RtypesCore.h:80
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TObject::Error
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:893
TMVA::MethodC50::fControlSeed
Int_t fControlSeed
Definition: MethodC50.h:96
Ranking.h
TMVA::DataSetInfo::GetNVariables
UInt_t GetNVariables() const
Definition: DataSetInfo.h:127
TMVA::MethodC50::fControlSubset
Bool_t fControlSubset
Definition: MethodC50.h:88
TMVA::MethodC50::fControlFuzzyThreshold
Bool_t fControlFuzzyThreshold
Definition: MethodC50.h:94
TMVA::MethodC50::ReadModelFromFile
void ReadModelFromFile()
Definition: MethodC50.cxx:324
TMVA::MethodC50::C50Control
ROOT::R::TRFunctionImport C50Control
Definition: MethodC50.h:104
VariableTransformBase.h
TString
Basic string class.
Definition: TString.h:136
TMVA::MethodC50::Init
void Init()
Definition: MethodC50.cxx:123
Bool_t
bool Bool_t
Definition: RtypesCore.h:63
ROOT::R::TRObject
This is a class to get ROOT's objects from R's objects.
Definition: TRObject.h:70
TMVA::MethodC50::Train
void Train()
Definition: MethodC50.cxx:134
ROOT::R::TRInterface::Require
Bool_t Require(TString pkg)
Method to load an R's package.
Definition: TRInterface.cxx:200
REGISTER_METHOD
#define REGISTER_METHOD(CLASS)
for example
Definition: ClassifierFactory.h:124
TMVA::RMethodBase::r
ROOT::R::TRInterface & r
Definition: RMethodBase.h:52
bool
TMatrix.h
PDF.h
TMVA::MethodBase::DataInfo
DataSetInfo & DataInfo() const
Definition: MethodBase.h:410
PyTorch_Generate_CNN_Model.predict
def predict(model, test_X, batch_size=100)
Definition: PyTorch_Generate_CNN_Model.py:91
TMVA::MethodC50::asfactor
ROOT::R::TRFunctionImport asfactor
Definition: MethodC50.h:105
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::MethodC50::MethodC50
MethodC50(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition: MethodC50.cxx:50
TMVA::MethodC50::fControlCF
Double_t fControlCF
Definition: MethodC50.h:92
TMVA::Timer::GetElapsedTime
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:146
TMVA::MethodBase::GetMethodName
const TString & GetMethodName() const
Definition: MethodBase.h:331
Timer.h
TMVA::Event::GetValues
std::vector< Float_t > & GetValues()
Definition: Event.h:94
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:128
TMVA::MethodC50::fControlSample
Double_t fControlSample
Definition: MethodC50.h:95
TMVA::DataSet::GetEvent
const Event * GetEvent() const
Definition: DataSet.cxx:202
TMVA::DataSet::GetNEvents
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:206
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:101
TMVA::MethodC50::fRules
Bool_t fRules
Definition: MethodC50.h:85
TMVA::MethodC50::predict
ROOT::R::TRFunctionImport predict
Definition: MethodC50.h:102
TMVA::Types::kClassification
@ kClassification
Definition: Types.h:129
TMVA::MethodC50::fModelControl
ROOT::R::TRObject fModelControl
Definition: MethodC50.h:107
TMVA::RMethodBase
Definition: RMethodBase.h:48
TMVA::MethodBase::NoErrorCalc
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:836
TMVA::RMethodBase::fDfTrain
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:91
TMVA::MethodBase::ReadStateFromFile
void ReadStateFromFile()
Function to write options and weights to file.
Definition: MethodBase.cxx:1426
TMVA::MethodC50::fControlMinCases
UInt_t fControlMinCases
Definition: MethodC50.h:93
TMVA::MethodBase::GetWeightFileDir
const TString & GetWeightFileDir() const
Definition: MethodBase.h:492
TMVA::Types
Singleton class for Global types used by TMVA.
Definition: Types.h:73
Types.h
TMVA::MethodC50::HasAnalysisType
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodC50.cxx:115
MethodC50.h
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Config.h
TMVA::MethodC50::GetMvaValues
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodC50.cxx:246
unsigned int
TMVA::Timer
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
TMVA::Tools::Color
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
TMVA::Types::kTraining
@ kTraining
Definition: Types.h:145
TMVA::MethodC50::C50
ROOT::R::TRFunctionImport C50
Definition: MethodC50.h:103
TMVA::MethodC50::fNTrials
UInt_t fNTrials
Definition: MethodC50.h:84
TMVA::MethodC50::TestClassification
virtual void TestClassification()
initialization
Definition: MethodC50.cxx:217
TVectorT< Double_t >
TMVA::RMethodBase::fWeightTrain
TVectorD fWeightTrain
Definition: RMethodBase.h:93
TMVA::DataSet::SetCurrentEvent
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:88
Double_t
double Double_t
Definition: RtypesCore.h:59
ROOT::R::TRObject::As
T As()
Some datatypes of ROOT or c++ can be wrapped in to a TRObject, this method lets you unwrap those data...
Definition: TRObject.h:152
TMVA::kFATAL
@ kFATAL
Definition: Types.h:63
TMVA::MethodC50::fControlEarlyStopping
Bool_t fControlEarlyStopping
Definition: MethodC50.h:97
TMVA::MethodBase::GetName
const char * GetName() const
Definition: MethodBase.h:334
TMVA::MethodC50::ListOfVariables
std::vector< TString > ListOfVariables
Definition: MethodC50.h:108
TMVA::Event
Definition: Event.h:51
TMVA::MethodBase::GetEvent
const Event * GetEvent() const
Definition: MethodBase.h:751
TMVA::MethodC50::fControlBands
UInt_t fControlBands
Definition: MethodC50.h:89
TMVA::MethodC50::fControlNoGlobalPruning
Bool_t fControlNoGlobalPruning
Definition: MethodC50.h:91
ROOT::R::TRInterface::Instance
static TRInterface & Instance()
static method to get an TRInterface instance reference
Definition: TRInterface.cxx:187
ROOT::R::Label
const Rcpp::internal::NamedPlaceHolder & Label
TMVA::kINFO
@ kINFO
Definition: Types.h:60
TMVA::MethodC50::IsModuleLoaded
static Bool_t IsModuleLoaded
Definition: MethodC50.h:100
Tools.h
ClassifierFactory.h
type
int type
Definition: TGX11.cxx:121
TMVA::MethodC50::fModel
ROOT::R::TRObject * fModel
Definition: MethodC50.h:106
TMVA::MethodC50::ProcessOptions
void ProcessOptions()
Definition: MethodC50.cxx:196
TMatrixD.h
Results.h
TMVA::MethodC50::GetHelpMessage
void GetHelpMessage() const
Definition: MethodC50.cxx:304
Riostream.h
TMVA::gTools
Tools & gTools()
TMVA::Configurable::DeclareOptionRef
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
TMVA::RMethodBase::fFactorTrain
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:95
TMVA::MethodC50
Definition: MethodC50.h:33
TMVA::MethodC50::DeclareOptions
void DeclareOptions()
Definition: MethodC50.cxx:156
ROOT::R::TRDataFrame
This is a class to create DataFrames from ROOT to R.
Definition: TRDataFrame.h:176
TMVA::MethodC50::fControlWinnow
Bool_t fControlWinnow
Definition: MethodC50.h:90
TMath.h
TMVA::DataSetInfo::GetListOfVariables
std::vector< TString > GetListOfVariables() const
returns list of variables
Definition: DataSetInfo.cxx:393
TMVA::MethodC50::GetMvaValue
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Definition: MethodC50.cxx:225
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int
TMVA::MethodC50::~MethodC50
~MethodC50(void)
Definition: MethodC50.cxx:109