Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodC50.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodC50 *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * Decision Trees and Rule-Based Models *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (see tmva/doc/LICENSE) *
17 * *
18 **********************************************************************************/
19
20#include <iomanip>
21
22#include "TMath.h"
23#include "Riostream.h"
24#include "TMatrix.h"
25#include "TMatrixD.h"
26#include "TVectorD.h"
27
29#include "TMVA/MethodC50.h"
30#include "TMVA/Tools.h"
31#include "TMVA/Config.h"
32#include "TMVA/Ranking.h"
33#include "TMVA/Types.h"
34#include "TMVA/PDF.h"
36
37#include "TMVA/Results.h"
38#include "TMVA/Timer.h"
39
40using namespace TMVA;
41
43
44
45//creating an Instance
47
48//_______________________________________________________________________
50 const TString &methodTitle,
52 const TString &theOption) : RMethodBase(jobName, Types::kC50, methodTitle, dsi, theOption),
53 fNTrials(1),
54 fRules(kFALSE),
55 fMvaCounter(0),
56 predict("predict.C5.0"),
57 //predict("predict"),
58 C50("C5.0"),
59 C50Control("C5.0Control"),
60 asfactor("as.factor"),
61 fModel(NULL)
62{
63 // standard constructor for the C50
64
65 //C5.0Control options
67 fControlBands = 0;
70 fControlCF = 0.25;
74 r["sample.int(4096, size = 1) - 1L"] >> fControlSeed;
76
78}
79
80//_______________________________________________________________________
83 fNTrials(1),
84 fRules(kFALSE),
85 fMvaCounter(0),
86 predict("predict.C5.0"),
87 C50("C5.0"),
88 C50Control("C5.0Control"),
89 asfactor("as.factor"),
90 fModel(NULL)
91{
92
93 // constructor from weight file
95 fControlBands = 0;
98 fControlCF = 0.25;
101 fControlSample = 0;
102 r["sample.int(4096, size = 1) - 1L"] >> fControlSeed;
104}
105
106
107//_______________________________________________________________________
109{
110 if (fModel) delete fModel;
111}
112
113//_______________________________________________________________________
119
120
121//_______________________________________________________________________
123{
124
125 if (!IsModuleLoaded) {
126 Error("Init", "R's package C50 can not be loaded.");
127 Log() << kFATAL << " R's package C50 can not be loaded."
128 << Endl;
129 return;
130 }
131}
132
134{
135 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
138 ROOT::R::Label["trials"] = fNTrials, \
139 ROOT::R::Label["rules"] = fRules, \
140 ROOT::R::Label["weights"] = fWeightTrain, \
141 ROOT::R::Label["control"] = fModelControl);
143 if (IsModelPersistence())
144 {
145 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
146 Log() << Endl;
147 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
148 Log() << Endl;
149 r["C50Model"] << Model;
150 r << "save(C50Model,file='" + path + "')";
151 }
152}
153
154//_______________________________________________________________________
156{
157 //
158 DeclareOptionRef(fNTrials, "NTrials", "An integer specifying the number of boosting iterations");
159 DeclareOptionRef(fRules, "Rules", "A logical: should the tree be decomposed into a rule-basedmodel?");
160
161 //C5.0Control Options
162 DeclareOptionRef(fControlSubset, "ControlSubset", "A logical: should the model evaluate groups of discrete \
163 predictors for splits? Note: the C5.0 command line version defaults this \
164 parameter to ‘FALSE’, meaning no attempted gropings will be evaluated \
165 during the tree growing stage.");
166 DeclareOptionRef(fControlBands, "ControlBands", "An integer between 2 and 1000. If ‘TRUE’, the model orders \
167 the rules by their affect on the error rate and groups the \
168 rules into the specified number of bands. This modifies the \
169 output so that the effect on the error rate can be seen for \
170 the groups of rules within a band. If this options is \
171 selected and ‘rules = kFALSE’, a warning is issued and ‘rules’ \
172 is changed to ‘kTRUE’.");
173 DeclareOptionRef(fControlWinnow, "ControlWinnow", "A logical: should predictor winnowing (i.e feature selection) be used?");
174 DeclareOptionRef(fControlNoGlobalPruning, "ControlNoGlobalPruning", "A logical to toggle whether the final, global pruning \
175 step to simplify the tree.");
176 DeclareOptionRef(fControlCF, "ControlCF", "A number in (0, 1) for the confidence factor.");
177 DeclareOptionRef(fControlMinCases, "ControlMinCases", "an integer for the smallest number of samples that must be \
178 put in at least two of the splits.");
179
180 DeclareOptionRef(fControlFuzzyThreshold, "ControlFuzzyThreshold", "A logical toggle to evaluate possible advanced splits \
181 of the data. See Quinlan (1993) for details and examples.");
182 DeclareOptionRef(fControlSample, "ControlSample", "A value between (0, .999) that specifies the random \
183 proportion of the data should be used to train the model. By \
184 default, all the samples are used for model training. Samples \
185 not used for training are used to evaluate the accuracy of \
186 the model in the printed output.");
187 DeclareOptionRef(fControlSeed, "ControlSeed", " An integer for the random number seed within the C code.");
188 DeclareOptionRef(fControlEarlyStopping, "ControlEarlyStopping", " A logical to toggle whether the internal method for \
189 stopping boosting should be used.");
190
191
192}
193
194//_______________________________________________________________________
196{
197 if (fNTrials <= 0) {
198 Log() << kERROR << " fNTrials <=0... that does not work !! "
199 << " I set it to 1 .. just so that the program does not crash"
200 << Endl;
201 fNTrials = 1;
202 }
204 ROOT::R::Label["bands"] = fControlBands, \
205 ROOT::R::Label["winnow"] = fControlWinnow, \
206 ROOT::R::Label["noGlobalPruning"] = fControlNoGlobalPruning, \
207 ROOT::R::Label["CF"] = fControlCF, \
208 ROOT::R::Label["minCases"] = fControlMinCases, \
209 ROOT::R::Label["fuzzyThreshold"] = fControlFuzzyThreshold, \
210 ROOT::R::Label["sample"] = fControlSample, \
211 ROOT::R::Label["seed"] = fControlSeed, \
212 ROOT::R::Label["earlyStopping"] = fControlEarlyStopping);
213}
214
215//_______________________________________________________________________
217{
218 Log() << kINFO << "Testing Classification C50 METHOD " << Endl;
220}
221
222
223//_______________________________________________________________________
225{
228 const TMVA::Event *ev = GetEvent();
229 const UInt_t nvar = DataInfo().GetNVariables();
231 for (UInt_t i = 0; i < nvar; i++) {
232 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
233 }
234 //if using persistence model
236
237 TVectorD result = predict(*fModel, fDfEvent, ROOT::R::Label["type"] = "prob");
238 mvaValue = result[1]; //returning signal prob
239 return mvaValue;
240}
241
242
243////////////////////////////////////////////////////////////////////////////////
244/// get all the MVA values for the events of the current Data type
246{
247 Long64_t nEvents = Data()->GetNEvents();
248 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
249 if (firstEvt < 0) firstEvt = 0;
250
251 nEvents = lastEvt-firstEvt;
252
253 UInt_t nvars = Data()->GetNVariables();
254
255 // use timer
256 Timer timer( nEvents, GetName(), kTRUE );
257 if (logProgress)
258 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
259 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
260
261
262 // fill R DATA FRAME with events data
263 std::vector<std::vector<Float_t> > inputData(nvars);
264 for (UInt_t i = 0; i < nvars; i++) {
265 inputData[i] = std::vector<Float_t>(nEvents);
266 }
267
268 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
270 const TMVA::Event *e = Data()->GetEvent();
271 assert(nvars == e->GetNVariables());
272 for (UInt_t i = 0; i < nvars; i++) {
273 inputData[i][ievt] = e->GetValue(i);
274 }
275 // if (ievt%100 == 0)
276 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
277 }
278
280 for (UInt_t i = 0; i < nvars; i++) {
281 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
282 }
283 //if using persistence model
285
286 std::vector<Double_t> mvaValues(nEvents);
288 std::vector<Double_t> probValues(2*nEvents);
289 probValues = result.As<std::vector<Double_t>>();
290 assert(probValues.size() == 2*mvaValues.size());
291 std::copy(probValues.begin()+nEvents, probValues.end(), mvaValues.begin() );
292
293 if (logProgress) {
294 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
295 << timer.GetElapsedTime() << " " << Endl;
296 }
297
298 return mvaValues;
299
300}
301
302//_______________________________________________________________________
304{
305// get help message text
306//
307// typical length of text line:
308// "|--------------------------------------------------------------|"
309 Log() << Endl;
310 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
311 Log() << Endl;
312 Log() << "Decision Trees and Rule-Based Models " << Endl;
313 Log() << Endl;
314 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
315 Log() << Endl;
316 Log() << Endl;
317 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
318 Log() << Endl;
319 Log() << "<None>" << Endl;
320}
321
322//_______________________________________________________________________
324{
325 ROOT::R::TRInterface::Instance().Require("C50");
326 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
327 Log() << Endl;
328 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
329 Log() << Endl;
330 r << "load('" + path + "')";
331 SEXP Model;
332 r["C50Model"] >> Model;
333 fModel = new ROOT::R::TRObject(Model);
334
335}
336
337//_______________________________________________________________________
338void TMVA::MethodC50::MakeClass(const TString &/*theClassFileName*/) const
339{
340}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
bool Bool_t
Boolean (0=false, 1=true) (bool)
Definition RtypesCore.h:77
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:208
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2495
const_iterator begin() const
const_iterator end() const
This is a class to create DataFrames from ROOT to R.
static TRInterface & Instance()
static method to get an TRInterface instance reference
This is a class to get ROOT's objects from R's objects.
Definition TRObject.h:70
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
std::vector< TString > GetListOfVariables() const
returns list of variables
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
const char * GetName() const override
Definition MethodBase.h:334
Bool_t IsModelPersistence() const
Definition MethodBase.h:383
const TString & GetWeightFileDir() const
Definition MethodBase.h:492
const TString & GetMethodName() const
Definition MethodBase.h:331
const Event * GetEvent() const
Definition MethodBase.h:751
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void TestClassification()
initialization
void ReadStateFromFile()
Function to write options and weights to file.
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
DataSet * Data() const
Definition MethodBase.h:409
void GetHelpMessage() const
Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr)
static Bool_t IsModuleLoaded
Definition MethodC50.h:100
virtual void TestClassification()
initialization
Bool_t fControlSubset
Definition MethodC50.h:88
ROOT::R::TRFunctionImport asfactor
Definition MethodC50.h:105
ROOT::R::TRObject * fModel
Definition MethodC50.h:106
void DeclareOptions()
Double_t fControlCF
Definition MethodC50.h:92
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Int_t fControlSeed
Definition MethodC50.h:96
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
Bool_t fControlNoGlobalPruning
Definition MethodC50.h:91
Bool_t fControlFuzzyThreshold
Definition MethodC50.h:94
Double_t fControlSample
Definition MethodC50.h:95
UInt_t fNTrials
Definition MethodC50.h:84
std::vector< TString > ListOfVariables
Definition MethodC50.h:108
Bool_t fControlEarlyStopping
Definition MethodC50.h:97
ROOT::R::TRFunctionImport C50Control
Definition MethodC50.h:104
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
void ProcessOptions()
ROOT::R::TRObject fModelControl
Definition MethodC50.h:107
ROOT::R::TRFunctionImport predict
Definition MethodC50.h:102
Bool_t fControlWinnow
Definition MethodC50.h:90
ROOT::R::TRFunctionImport C50
Definition MethodC50.h:103
MethodC50(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition MethodC50.cxx:49
UInt_t fControlMinCases
Definition MethodC50.h:93
void ReadModelFromFile()
UInt_t fControlBands
Definition MethodC50.h:89
std::vector< std::string > fFactorTrain
Definition RMethodBase.h:95
ROOT::R::TRInterface & r
Definition RMethodBase.h:52
ROOT::R::TRDataFrame fDfTrain
Definition RMethodBase.h:91
TVectorD fWeightTrain
Definition RMethodBase.h:93
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:138
const Rcpp::internal::NamedPlaceHolder & Label
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148