ROOT  6.06/09
Reference Guide
MethodRSNNS.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 
5 /**********************************************************************************
6  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
7  * Package: TMVA *
8  * Class : MethodRSNNS *
9  * Web : http://oproject.org *
10  * *
11  * Description: *
12  * Neural Networks in R using the Stuttgart Neural Network Simulator *
13  * *
14  * *
15  * Redistribution and use in source and binary forms, with or without *
16  * modification, are permitted according to the terms listed in LICENSE *
17  * (http://tmva.sourceforge.net/LICENSE) *
18  * *
19  **********************************************************************************/
20 
21 #include <iomanip>
22 
23 #include "TMath.h"
24 #include "Riostream.h"
25 #include "TMatrix.h"
26 #include "TMatrixD.h"
27 #include "TVectorD.h"
28 
30 #include "TMVA/MethodRSNNS.h"
31 #include "TMVA/Tools.h"
32 #include "TMVA/Config.h"
33 #include "TMVA/Ranking.h"
34 #include "TMVA/Types.h"
35 #include "TMVA/PDF.h"
36 #include "TMVA/ClassifierFactory.h"
37 
38 #include "TMVA/Results.h"
39 
40 using namespace TMVA;
41 
42 REGISTER_METHOD(RSNNS)
43 
45 
46 //creating an Instance
47 Bool_t MethodRSNNS::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("RSNNS");
48 
49 //_______________________________________________________________________
50 MethodRSNNS::MethodRSNNS(const TString &jobName,
51  const TString &methodTitle,
52  DataSetInfo &dsi,
53  const TString &theOption,
54  TDirectory *theTargetDir) :
55  RMethodBase(jobName, Types::kRSNNS, methodTitle, dsi, theOption, theTargetDir),
56  fMvaCounter(0),
57  predict("predict"),
58  mlp("mlp"),
59  asfactor("as.factor"),
60  fModel(NULL)
61 {
62  fNetType = methodTitle;
63  if (fNetType != "RMLP") {
64  Log() << kFATAL << " Unknow Method" + fNetType
65  << Endl;
66  return;
67  }
68 
69  // standard constructor for the RSNNS
70  //RSNNS Options for all NN methods
71  fSize = "c(5)";
72  fMaxit = 100;
73 
74  fInitFunc = "Randomize_Weights";
75  fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
76 
77  fLearnFunc = "Std_Backpropagation"; //
78  fLearnFuncParams = "c(0.2,0)";
79 
80  fUpdateFunc = "Topological_Order";
81  fUpdateFuncParams = "c(0)";
82 
83  fHiddenActFunc = "Act_Logistic";
84  fShufflePatterns = kTRUE;
85  fLinOut = kFALSE;
86  fPruneFunc = "NULL";
87  fPruneFuncParams = "NULL";
88 
89  SetWeightFileDir(gConfig().GetIONames().fWeightFileDir);
90 }
91 
92 //_______________________________________________________________________
93 MethodRSNNS::MethodRSNNS(DataSetInfo &theData, const TString &theWeightFile, TDirectory *theTargetDir)
94  : RMethodBase(Types::kRSNNS, theData, theWeightFile, theTargetDir),
95  fMvaCounter(0),
96  predict("predict"),
97  mlp("mlp"),
98  asfactor("as.factor"),
99  fModel(NULL)
100 
101 {
102  fNetType = "RMLP"; //GetMethodName();//GetMethodName() is not returning RMLP is reting MethodBase why?
103  if (fNetType != "RMLP") {
104  Log() << kFATAL << " Unknow Method = " + fNetType
105  << Endl;
106  return;
107  }
108 
109  // standard constructor for the RSNNS
110  //RSNNS Options for all NN methods
111  fSize = "c(5)";
112  fMaxit = 100;
113 
114  fInitFunc = "Randomize_Weights";
115  fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
116 
117  fLearnFunc = "Std_Backpropagation"; //
118  fLearnFuncParams = "c(0.2,0)";
119 
120  fUpdateFunc = "Topological_Order";
121  fUpdateFuncParams = "c(0)";
122 
123  fHiddenActFunc = "Act_Logistic";
125  fLinOut = kFALSE;
126  fPruneFunc = "NULL";
127  fPruneFuncParams = "NULL";
128 
129  SetWeightFileDir(gConfig().GetIONames().fWeightFileDir);
130 }
131 
132 
133 //_______________________________________________________________________
135 {
136  if (fModel) delete fModel;
137 }
138 
139 //_______________________________________________________________________
141 {
142  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
143  return kFALSE;
144 }
145 
146 
147 //_______________________________________________________________________
149 {
150  if (!IsModuleLoaded) {
151  Error("Init", "R's package RSNNS can not be loaded.");
152  Log() << kFATAL << " R's package RSNNS can not be loaded."
153  << Endl;
154  return;
155  }
156  //factors creations
157  //RSNNS mlp require a numeric factor then background=0 signal=1 from fFactorTrain/fFactorTest
158  UInt_t size = fFactorTrain.size();
159  fFactorNumeric.resize(size);
160 
161  for (UInt_t i = 0; i < size; i++) {
162  if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
163  else fFactorNumeric[i] = 0;
164  }
165 }
166 
168 {
169  if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
170  if (fNetType == "RMLP") {
171  ROOT::R::TRObject PruneFunc;
172  if (fPruneFunc == "NULL") PruneFunc = r.Eval("NULL");
173  else PruneFunc = r.Eval(Form("'%s'", fPruneFunc.Data()));
174 
175  SEXP Model = mlp(ROOT::R::Label["x"] = fDfTrain,
177  ROOT::R::Label["size"] = r.Eval(fSize),
178  ROOT::R::Label["maxit"] = fMaxit,
179  ROOT::R::Label["initFunc"] = fInitFunc,
180  ROOT::R::Label["initFuncParams"] = r.Eval(fInitFuncParams),
181  ROOT::R::Label["learnFunc"] = fLearnFunc,
182  ROOT::R::Label["learnFuncParams"] = r.Eval(fLearnFuncParams),
183  ROOT::R::Label["updateFunc"] = fUpdateFunc,
184  ROOT::R::Label["updateFuncParams"] = r.Eval(fUpdateFuncParams),
185  ROOT::R::Label["hiddenActFunc"] = fHiddenActFunc,
186  ROOT::R::Label["shufflePatterns"] = fShufflePatterns,
187  ROOT::R::Label["libOut"] = fLinOut,
188  ROOT::R::Label["pruneFunc"] = PruneFunc,
189  ROOT::R::Label["pruneFuncParams"] = r.Eval(fPruneFuncParams));
190  fModel = new ROOT::R::TRObject(Model);
191  TString path = GetWeightFileDir() + "/RMLPModel.RData";
192  Log() << Endl;
193  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
194  Log() << Endl;
195  r["RMLPModel"] << Model;
196  r << "save(RMLPModel,file='" + path + "')";
197  }
198 }
199 
200 //_______________________________________________________________________
202 {
203  //RSNNS Options for all NN methods
204 // TVectorF fSize;//number of units in the hidden layer(s)
205  DeclareOptionRef(fSize, "Size", "number of units in the hidden layer(s)");
206  DeclareOptionRef(fMaxit, "Maxit", "Maximum of iterations to learn");
207 
208  DeclareOptionRef(fInitFunc, "InitFunc", "the initialization function to use");
209  DeclareOptionRef(fInitFuncParams, "InitFuncParams", "the parameters for the initialization function");
210 
211  DeclareOptionRef(fLearnFunc, "LearnFunc", "the learning function to use");
212  DeclareOptionRef(fLearnFuncParams, "LearnFuncParams", "the parameters for the learning function");
213 
214  DeclareOptionRef(fUpdateFunc, "UpdateFunc", "the update function to use");
215  DeclareOptionRef(fUpdateFuncParams, "UpdateFuncParams", "the parameters for the update function");
216 
217  DeclareOptionRef(fHiddenActFunc, "HiddenActFunc", "the activation function of all hidden units");
218  DeclareOptionRef(fShufflePatterns, "ShufflePatterns", "should the patterns be shuffled?");
219  DeclareOptionRef(fLinOut, "LinOut", "sets the activation function of the output units to linear or logistic");
220 
221  DeclareOptionRef(fPruneFunc, "PruneFunc", "the prune function to use");
222  DeclareOptionRef(fPruneFuncParams, "PruneFuncParams", "the parameters for the pruning function. Unlike the\
223  other functions, these have to be given in a named list. See\
224  the pruning demos for further explanation.the update function to use");
225 
226 }
227 
228 //_______________________________________________________________________
230 {
231  if (fMaxit <= 0) {
232  Log() << kERROR << " fMaxit <=0... that does not work !! "
233  << " I set it to 50 .. just so that the program does not crash"
234  << Endl;
235  fMaxit = 1;
236  }
237  // standard constructor for the RSNNS
238  //RSNNS Options for all NN methods
239 
240 }
241 
242 //_______________________________________________________________________
244 {
245  Log() << kINFO << "Testing Classification " << fNetType << " METHOD " << Endl;
246 
248 }
249 
250 
251 //_______________________________________________________________________
253 {
254  NoErrorCalc(errLower, errUpper);
255  Double_t mvaValue;
256  const TMVA::Event *ev = GetEvent();
257  const UInt_t nvar = DataInfo().GetNVariables();
258  ROOT::R::TRDataFrame fDfEvent;
259  for (UInt_t i = 0; i < nvar; i++) {
260  fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
261  }
262  //if using persistence model
263  if (!fModel) {
265  }
266  TVectorD result = predict(*fModel, fDfEvent, ROOT::R::Label["type"] = "prob");
267  mvaValue = result[0]; //returning signal prob
268  return mvaValue;
269 }
270 
271 //_______________________________________________________________________
273 {
275  TString path = GetWeightFileDir() + "/RMLPModel.RData";
276  Log() << Endl;
277  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
278  Log() << Endl;
279  r << "load('" + path + "')";
280  SEXP Model;
281  r["RMLPModel"] >> Model;
282  fModel = new ROOT::R::TRObject(Model);
283 
284 }
285 
286 
287 //_______________________________________________________________________
289 {
290 // get help message text
291 //
292 // typical length of text line:
293 // "|--------------------------------------------------------------|"
294  Log() << Endl;
295  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
296  Log() << Endl;
297  Log() << "Decision Trees and Rule-Based Models " << Endl;
298  Log() << Endl;
299  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
300  Log() << Endl;
301  Log() << Endl;
302  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
303  Log() << Endl;
304  Log() << "<None>" << Endl;
305 }
306 
Bool_t fShufflePatterns
Definition: MethodRSNNS.h:101
const TString & GetWeightFileDir() const
Definition: MethodBase.h:407
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
Namespace for new ROOT classes and functions.
Definition: ROOT.py:1
void GetHelpMessage() const
Config & gConfig()
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
ROOT::R::TRFunctionImport predict
Definition: MethodRSNNS.h:111
DataSet * Data() const
Definition: MethodBase.h:363
EAnalysisType
Definition: Types.h:124
Basic string class.
Definition: TString.h:137
TString as(SEXP s)
Definition: RExports.h:85
std::vector< UInt_t > fFactorNumeric
Definition: MethodRSNNS.h:108
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
ROOT::R::TRObject * fModel
Definition: MethodRSNNS.h:114
MethodRSNNS(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="", TDirectory *theTargetDir=NULL)
UInt_t GetNVariables() const
Definition: DataSetInfo.h:128
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
virtual void TestClassification()
initialization
const char * Data() const
Definition: TString.h:349
Tools & gTools()
Definition: Tools.cxx:79
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:918
TString fUpdateFuncParams
Definition: MethodRSNNS.h:98
Bool_t Require(TString pkg)
Method to load an R's package.
This is a class to get ROOT's objects from R's objects
Definition: TRObject.h:73
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:98
ROOT::R::TRInterface & r
Definition: Object.C:4
unsigned int UInt_t
Definition: RtypesCore.h:42
ROOT::R::TRInterface & r
Definition: RMethodBase.h:53
const Event * GetEvent() const
Definition: MethodBase.h:667
char * Form(const char *fmt,...)
ROOT::R::TRFunctionImport mlp
Definition: MethodRSNNS.h:112
TString fUpdateFunc
Definition: MethodRSNNS.h:97
TString fInitFuncParams
Definition: MethodRSNNS.h:92
#define ClassImp(name)
Definition: Rtypes.h:279
double Double_t
Definition: RtypesCore.h:55
static Bool_t IsModuleLoaded
Definition: MethodRSNNS.h:110
Describe directory structure in memory.
Definition: TDirectory.h:41
TString fLearnFuncParams
Definition: MethodRSNNS.h:95
int type
Definition: TGX11.cxx:120
static TRInterface & Instance()
static method to get an TRInterface instance reference
MsgLogger & Log() const
Definition: Configurable.h:130
TString fPruneFuncParams
Definition: MethodRSNNS.h:105
DataSetInfo & DataInfo() const
Definition: MethodBase.h:364
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
std::vector< Float_t > & GetValues()
Definition: Event.h:93
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:94
void SetWeightFileDir(TString fileDir)
set directory of weight file
std::vector< TString > GetListOfVariables() const
returns list of variables
double result[121]
TString fLearnFunc
Definition: MethodRSNNS.h:94
TString fHiddenActFunc
Definition: MethodRSNNS.h:100
Rcpp::internal::NamedPlaceHolder Label
Definition: RExports.cxx:14
const Bool_t kTRUE
Definition: Rtypes.h:91
Int_t Eval(const TString &code, TRObject &ans)
Method to eval R code and you get the result in a reference to TRObject.
Definition: TRInterface.cxx:58
virtual void TestClassification()
initialization
TRandom3 R
a TMatrixD.
Definition: testIO.cxx:28
Definition: math.cpp:60
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:820
This is a class to create DataFrames from ROOT to R
Definition: TRDataFrame.h:183