ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
MethodRSVM.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodRSVM- *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * Support Vector Machines *
12  * *
13  * *
14  * Redistribution and use in source and binary forms, with or without *
15  * modification, are permitted according to the terms listed in LICENSE *
16  * (http://tmva.sourceforge.net/LICENSE) *
17  * *
18  **********************************************************************************/
19 
20 #include <iomanip>
21 
22 #include "TMath.h"
23 #include "Riostream.h"
24 #include "TMatrix.h"
25 #include "TMatrixD.h"
26 #include "TVectorD.h"
27 
29 #include "TMVA/MethodRSVM.h"
30 #include "TMVA/Tools.h"
31 #include "TMVA/Config.h"
32 #include "TMVA/Ranking.h"
33 #include "TMVA/Types.h"
34 #include "TMVA/PDF.h"
35 #include "TMVA/ClassifierFactory.h"
36 
37 #include "TMVA/Results.h"
38 
39 using namespace TMVA;
40 
41 REGISTER_METHOD(RSVM)
42 
44 //creating an Instance
45 Bool_t MethodRSVM::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("e1071");
46 
47 
48 //_______________________________________________________________________
49 MethodRSVM::MethodRSVM(const TString &jobName,
50  const TString &methodTitle,
51  DataSetInfo &dsi,
52  const TString &theOption,
53  TDirectory *theTargetDir) :
54  RMethodBase(jobName, Types::kRSVM, methodTitle, dsi, theOption, theTargetDir),
55  fMvaCounter(0),
56  svm("svm"),
57  predict("predict"),
58  asfactor("as.factor"),
59  fModel(NULL)
60 {
61  // standard constructor for the RSVM
62  //Booking options
63  fScale = kTRUE;
64  fType = "C-classification";
65  fKernel = "radial";
66  fDegree = 3;
67 
68  fGamma = (fDfTrain.GetNcols() == 1) ? 1.0 : (1.0 / fDfTrain.GetNcols());
69  fCoef0 = 0;
70  fCost = 1;
71  fNu = 0.5;
72  fCacheSize = 40;
73  fTolerance = 0.001;
74  fEpsilon = 0.1;
75  fShrinking = kTRUE;
76  fCross = 0;
77  fProbability = kTRUE;
78  fFitted = kTRUE;
79  SetWeightFileDir(gConfig().GetIONames().fWeightFileDir);
80 }
81 
82 //_______________________________________________________________________
83 MethodRSVM::MethodRSVM(DataSetInfo &theData, const TString &theWeightFile, TDirectory *theTargetDir)
84  : RMethodBase(Types::kRSVM, theData, theWeightFile, theTargetDir),
85  fMvaCounter(0),
86  svm("svm"),
87  predict("predict"),
88  asfactor("as.factor"),
89  fModel(NULL)
90 {
91  // standard constructor for the RSVM
92  //Booking options
93  fScale = kTRUE;
94  fType = "C-classification";
95  fKernel = "radial";
96  fDegree = 3;
97 
98  fGamma = (fDfTrain.GetNcols() == 1) ? 1.0 : (1.0 / fDfTrain.GetNcols());
99  fCoef0 = 0;
100  fCost = 1;
101  fNu = 0.5;
102  fCacheSize = 40;
103  fTolerance = 0.001;
104  fEpsilon = 0.1;
105  fShrinking = kTRUE;
106  fCross = 0;
108  fFitted = kTRUE;
109  SetWeightFileDir(gConfig().GetIONames().fWeightFileDir);
110 }
111 
112 
113 //_______________________________________________________________________
115 {
116  if (fModel) delete fModel;
117 }
118 
119 //_______________________________________________________________________
121 {
122  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
123  return kFALSE;
124 }
125 
126 
127 //_______________________________________________________________________
129 {
130  if (!IsModuleLoaded) {
131  Error("Init", "R's package e1071 can not be loaded.");
132  Log() << kFATAL << " R's package e1071 can not be loaded."
133  << Endl;
134  return;
135  }
136 }
137 
139 {
140  if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
141  //SVM require a named vector
142  ROOT::R::TRDataFrame ClassWeightsTrain;
143  ClassWeightsTrain["background"] = Data()->GetNEvtBkgdTrain();
144  ClassWeightsTrain["signal"] = Data()->GetNEvtSigTrain();
145 
146 
147  SEXP Model = svm(ROOT::R::Label["x"] = fDfTrain, \
149  ROOT::R::Label["scale"] = fScale, \
150  ROOT::R::Label["type"] = fType, \
151  ROOT::R::Label["kernel"] = fKernel, \
152  ROOT::R::Label["degree"] = fDegree, \
153  ROOT::R::Label["gamma"] = fGamma, \
154  ROOT::R::Label["coef0"] = fCoef0, \
155  ROOT::R::Label["cost"] = fCost, \
156  ROOT::R::Label["nu"] = fNu, \
157  ROOT::R::Label["class.weights"] = ClassWeightsTrain, \
158  ROOT::R::Label["cachesize"] = fCacheSize, \
159  ROOT::R::Label["tolerance"] = fTolerance, \
160  ROOT::R::Label["epsilon"] = fEpsilon, \
161  ROOT::R::Label["shrinking"] = fShrinking, \
162  ROOT::R::Label["cross"] = fCross, \
163  ROOT::R::Label["probability"] = fProbability, \
164  ROOT::R::Label["fitted"] = fFitted);
165  fModel = new ROOT::R::TRObject(Model);
166  TString path = GetWeightFileDir() + "/RSVMModel.RData";
167  Log() << Endl;
168  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
169  Log() << Endl;
170  r["RSVMModel"] << Model;
171  r << "save(RSVMModel,file='" + path + "')";
172 
173 }
174 
175 //_______________________________________________________________________
177 {
178  DeclareOptionRef(fScale, "Scale", "A logical vector indicating the variables to be scaled. If\
179  ‘scale’ is of length 1, the value is recycled as many times \
180  as needed. Per default, data are scaled internally (both ‘x’\
181  and ‘y’ variables) to zero mean and unit variance. The center \
182  and scale values are returned and used for later predictions.");
183  DeclareOptionRef(fType, "Type", "‘svm’ can be used as a classification machine, as a \
184  regression machine, or for novelty detection. Depending of\
185  whether ‘y’ is a factor or not, the default setting for\
186  ‘type’ is ‘C-classification’ or ‘eps-regression’,\
187  respectively, but may be overwritten by setting an explicit value.\
188  Valid options are:\
189  - ‘C-classification’\
190  - ‘nu-classification’\
191  - ‘one-classification’ (for novelty detection)\
192  - ‘eps-regression’\
193  - ‘nu-regression’");
194  DeclareOptionRef(fKernel, "Kernel", "the kernel used in training and predicting. You might\
195  consider changing some of the following parameters, depending on the kernel type.\
196  linear: u'*v\
197  polynomial: (gamma*u'*v + coef0)^degree\
198  radial basis: exp(-gamma*|u-v|^2)\
199  sigmoid: tanh(gamma*u'*v + coef0)");
200  DeclareOptionRef(fDegree, "Degree", "parameter needed for kernel of type ‘polynomial’ (default: 3)");
201  DeclareOptionRef(fGamma, "Gamma", "parameter needed for all kernels except ‘linear’ (default:1/(data dimension))");
202  DeclareOptionRef(fCoef0, "Coef0", "parameter needed for kernels of type ‘polynomial’ and ‘sigmoid’ (default: 0)");
203  DeclareOptionRef(fCost, "Cost", "cost of constraints violation (default: 1)-it is the ‘C’-constant of the regularization term in the Lagrange formulation.");
204  DeclareOptionRef(fNu, "Nu", "parameter needed for ‘nu-classification’, ‘nu-regression’,and ‘one-classification’");
205  DeclareOptionRef(fCacheSize, "CacheSize", "cache memory in MB (default 40)");
206  DeclareOptionRef(fTolerance, "Tolerance", "tolerance of termination criterion (default: 0.001)");
207  DeclareOptionRef(fEpsilon, "Epsilon", "epsilon in the insensitive-loss function (default: 0.1)");
208  DeclareOptionRef(fShrinking, "Shrinking", "option whether to use the shrinking-heuristics (default:‘TRUE’)");
209  DeclareOptionRef(fCross, "Cross", "if a integer value k>0 is specified, a k-fold cross validation on the training data is performed to assess the\
210  quality of the model: the accuracy rate for classification and the Mean Squared Error for regression");
211  DeclareOptionRef(fProbability, "Probability", "logical indicating whether the model should allow for probability predictions.");
212  DeclareOptionRef(fFitted, "Fitted", "logical indicating whether the fitted values should be computed and included in the model or not (default: ‘TRUE’)");
213 
214 }
215 
216 //_______________________________________________________________________
218 {
219  r["RMVA.RSVM.Scale"] = fScale;
220  r["RMVA.RSVM.Type"] = fType;
221  r["RMVA.RSVM.Kernel"] = fKernel;
222  r["RMVA.RSVM.Degree"] = fDegree;
223  r["RMVA.RSVM.Gamma"] = fGamma;
224  r["RMVA.RSVM.Coef0"] = fCoef0;
225  r["RMVA.RSVM.Cost"] = fCost;
226  r["RMVA.RSVM.Nu"] = fNu;
227  r["RMVA.RSVM.CacheSize"] = fCacheSize;
228  r["RMVA.RSVM.Tolerance"] = fTolerance;
229  r["RMVA.RSVM.Epsilon"] = fEpsilon;
230  r["RMVA.RSVM.Shrinking"] = fShrinking;
231  r["RMVA.RSVM.Cross"] = fCross;
232  r["RMVA.RSVM.Probability"] = fProbability;
233  r["RMVA.RSVM.Fitted"] = fFitted;
234 
235 }
236 
237 //_______________________________________________________________________
239 {
240  Log() << kINFO << "Testing Classification RSVM METHOD " << Endl;
241 
243 }
244 
245 
246 //_______________________________________________________________________
248 {
249  NoErrorCalc(errLower, errUpper);
250  Double_t mvaValue;
251  const TMVA::Event *ev = GetEvent();
252  const UInt_t nvar = DataInfo().GetNVariables();
253  ROOT::R::TRDataFrame fDfEvent;
254  for (UInt_t i = 0; i < nvar; i++) {
255  fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
256  }
257  //if using persistence model
258  if (!fModel) {
260  }
261  ROOT::R::TRObject result = predict(*fModel, fDfEvent, ROOT::R::Label["decision.values"] = kTRUE, ROOT::R::Label["probability"] = kTRUE);
262  TVectorD values = result.GetAttribute("decision.values");
263  mvaValue = values[0]; //returning signal prob
264  return mvaValue;
265 }
266 
267 
268 //_______________________________________________________________________
270 {
272  TString path = GetWeightFileDir() + "/RSVMModel.RData";
273  Log() << Endl;
274  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
275  Log() << Endl;
276  r << "load('" + path + "')";
277  SEXP Model;
278  r["RSVMModel"] >> Model;
279  fModel = new ROOT::R::TRObject(Model);
280 
281 }
282 
283 //_______________________________________________________________________
285 {
286 // get help message text
287 //
288 // typical length of text line:
289 // "|--------------------------------------------------------------|"
290  Log() << Endl;
291  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
292  Log() << Endl;
293  Log() << "Decision Trees and Rule-Based Models " << Endl;
294  Log() << Endl;
295  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
296  Log() << Endl;
297  Log() << Endl;
298  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
299  Log() << Endl;
300  Log() << "<None>" << Endl;
301 }
302 
MethodRSVM(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=NULL)
Float_t fEpsilon
Definition: MethodRSVM.h:117
const TString & GetWeightFileDir() const
Definition: MethodBase.h:407
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
Config & gConfig()
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
std::vector< double > values
Definition: TwoHistoFit2D.C:32
DataSet * Data() const
Definition: MethodBase.h:363
EAnalysisType
Definition: Types.h:124
ROOT::R::TRObject * fModel
Definition: MethodRSVM.h:130
Basic string class.
Definition: TString.h:137
TString as(SEXP s)
Definition: RExports.h:85
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
UInt_t GetNVariables() const
Definition: DataSetInfo.h:128
ROOT::R::TRFunctionImport asfactor
Definition: MethodRSVM.h:129
Long64_t GetNEvtBkgdTrain()
return number of background training events in dataset
Definition: DataSet.cxx:420
ROOT::R::TRFunctionImport svm
Definition: MethodRSVM.h:127
Tools & gTools()
Definition: Tools.cxx:79
TRObject GetAttribute(const TString name)
The R objects can to have associate attributes with this method you can added attribute to TRObject g...
Definition: TRObject.h:130
ROOT::R::TRFunctionImport predict
Definition: MethodRSVM.h:128
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:918
Float_t fProbability
Definition: MethodRSVM.h:123
Bool_t Require(TString pkg)
Method to load an R's package.
This is a class to get ROOT's objects from R's objects
Definition: TRObject.h:73
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:98
ROOT::R::TRInterface & r
Definition: Object.C:4
unsigned int UInt_t
Definition: RtypesCore.h:42
ROOT::R::TRInterface & r
Definition: RMethodBase.h:53
int GetNcols()
Method to get the number of colunms.
Definition: TRDataFrame.h:401
const Event * GetEvent() const
Definition: MethodBase.h:667
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Definition: MethodRSVM.cxx:247
PyObject * fType
static Bool_t IsModuleLoaded
Definition: MethodRSVM.h:126
#define ClassImp(name)
Definition: Rtypes.h:279
double Double_t
Definition: RtypesCore.h:55
Describe directory structure in memory.
Definition: TDirectory.h:44
Long64_t GetNEvtSigTrain()
return number of signal training events in dataset
Definition: DataSet.cxx:412
int type
Definition: TGX11.cxx:120
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodRSVM.cxx:120
static TRInterface & Instance()
static method to get an TRInterface instance reference
MsgLogger & Log() const
Definition: Configurable.h:130
DataSetInfo & DataInfo() const
Definition: MethodBase.h:364
void GetHelpMessage() const
Definition: MethodRSVM.cxx:284
Float_t fCacheSize
Definition: MethodRSVM.h:115
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
#define REGISTER_METHOD(CLASS)
for example
std::vector< Float_t > & GetValues()
Definition: Event.h:93
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:94
virtual void TestClassification()
initialization
Definition: MethodRSVM.cxx:238
void SetWeightFileDir(TString fileDir)
set directory of weight file
std::vector< TString > GetListOfVariables() const
returns list of variables
double result[121]
Rcpp::internal::NamedPlaceHolder Label
Definition: RExports.cxx:14
void ReadStateFromFile()
Definition: MethodRSVM.cxx:269
const Bool_t kTRUE
Definition: Rtypes.h:91
virtual void TestClassification()
initialization
TRandom3 R
a TMatrixD.
Definition: testIO.cxx:28
Definition: math.cpp:60
Float_t fTolerance
Definition: MethodRSVM.h:116
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:827
This is a class to create DataFrames from ROOT to R
Definition: TRDataFrame.h:183