Logo ROOT   6.08/07
Reference Guide
MethodRSNNS.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 
5 /**********************************************************************************
6  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
7  * Package: TMVA *
8  * Class : MethodRSNNS *
9  * Web : http://oproject.org *
10  * *
11  * Description: *
12  * Neural Networks in R using the Stuttgart Neural Network Simulator *
13  * *
14  * *
15  * Redistribution and use in source and binary forms, with or without *
16  * modification, are permitted according to the terms listed in LICENSE *
17  * (http://tmva.sourceforge.net/LICENSE) *
18  * *
19  **********************************************************************************/
20 
21 #include <iomanip>
22 
23 #include "TMath.h"
24 #include "Riostream.h"
25 #include "TMatrix.h"
26 #include "TMatrixD.h"
27 #include "TVectorD.h"
28 
30 #include "TMVA/MethodRSNNS.h"
31 #include "TMVA/Tools.h"
32 #include "TMVA/Config.h"
33 #include "TMVA/Ranking.h"
34 #include "TMVA/Types.h"
35 #include "TMVA/PDF.h"
36 #include "TMVA/ClassifierFactory.h"
37 
38 #include "TMVA/Results.h"
39 #include "TMVA/Timer.h"
40 
41 using namespace TMVA;
42 
43 REGISTER_METHOD(RSNNS)
44 
46 
47 //creating an Instance
48 Bool_t MethodRSNNS::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("RSNNS");
49 
50 //_______________________________________________________________________
51 MethodRSNNS::MethodRSNNS(const TString &jobName,
52  const TString &methodTitle,
53  DataSetInfo &dsi,
54  const TString &theOption) :
55  RMethodBase(jobName, Types::kRSNNS, methodTitle, dsi, theOption),
56  fMvaCounter(0),
57  predict("predict"),
58  mlp("mlp"),
59  asfactor("as.factor"),
60  fModel(NULL)
61 {
62  fNetType = methodTitle;
63  if (fNetType != "RMLP") {
64  Log() << kFATAL << " Unknow Method" + fNetType
65  << Endl;
66  return;
67  }
68 
69  // standard constructor for the RSNNS
70  //RSNNS Options for all NN methods
71  fSize = "c(5)";
72  fMaxit = 100;
73 
74  fInitFunc = "Randomize_Weights";
75  fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
76 
77  fLearnFunc = "Std_Backpropagation"; //
78  fLearnFuncParams = "c(0.2,0)";
79 
80  fUpdateFunc = "Topological_Order";
81  fUpdateFuncParams = "c(0)";
82 
83  fHiddenActFunc = "Act_Logistic";
85  fLinOut = kFALSE;
86  fPruneFunc = "NULL";
87  fPruneFuncParams = "NULL";
88 
89 }
90 
91 //_______________________________________________________________________
92 MethodRSNNS::MethodRSNNS(DataSetInfo &theData, const TString &theWeightFile)
93  : RMethodBase(Types::kRSNNS, theData, theWeightFile),
94  fMvaCounter(0),
95  predict("predict"),
96  mlp("mlp"),
97  asfactor("as.factor"),
98  fModel(NULL)
99 
100 {
101  fNetType = "RMLP"; //GetMethodName();//GetMethodName() is not returning RMLP is reting MethodBase why?
102  if (fNetType != "RMLP") {
103  Log() << kFATAL << " Unknow Method = " + fNetType
104  << Endl;
105  return;
106  }
107 
108  // standard constructor for the RSNNS
109  //RSNNS Options for all NN methods
110  fSize = "c(5)";
111  fMaxit = 100;
112 
113  fInitFunc = "Randomize_Weights";
114  fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
115 
116  fLearnFunc = "Std_Backpropagation"; //
117  fLearnFuncParams = "c(0.2,0)";
118 
119  fUpdateFunc = "Topological_Order";
120  fUpdateFuncParams = "c(0)";
121 
122  fHiddenActFunc = "Act_Logistic";
124  fLinOut = kFALSE;
125  fPruneFunc = "NULL";
126  fPruneFuncParams = "NULL";
127 }
128 
129 
130 //_______________________________________________________________________
132 {
133  if (fModel) delete fModel;
134 }
135 
136 //_______________________________________________________________________
138 {
139  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
140  return kFALSE;
141 }
142 
143 
144 //_______________________________________________________________________
146 {
147  if (!IsModuleLoaded) {
148  Error("Init", "R's package RSNNS can not be loaded.");
149  Log() << kFATAL << " R's package RSNNS can not be loaded."
150  << Endl;
151  return;
152  }
153  //factors creations
154  //RSNNS mlp require a numeric factor then background=0 signal=1 from fFactorTrain/fFactorTest
155  UInt_t size = fFactorTrain.size();
156  fFactorNumeric.resize(size);
157 
158  for (UInt_t i = 0; i < size; i++) {
159  if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
160  else fFactorNumeric[i] = 0;
161  }
162 }
163 
165 {
166  if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
167  if (fNetType == "RMLP") {
168  ROOT::R::TRObject PruneFunc;
169  if (fPruneFunc == "NULL") PruneFunc = r.Eval("NULL");
170  else PruneFunc = r.Eval(Form("'%s'", fPruneFunc.Data()));
171 
172  SEXP Model = mlp(ROOT::R::Label["x"] = fDfTrain,
174  ROOT::R::Label["size"] = r.Eval(fSize),
175  ROOT::R::Label["maxit"] = fMaxit,
176  ROOT::R::Label["initFunc"] = fInitFunc,
177  ROOT::R::Label["initFuncParams"] = r.Eval(fInitFuncParams),
178  ROOT::R::Label["learnFunc"] = fLearnFunc,
179  ROOT::R::Label["learnFuncParams"] = r.Eval(fLearnFuncParams),
180  ROOT::R::Label["updateFunc"] = fUpdateFunc,
181  ROOT::R::Label["updateFuncParams"] = r.Eval(fUpdateFuncParams),
182  ROOT::R::Label["hiddenActFunc"] = fHiddenActFunc,
183  ROOT::R::Label["shufflePatterns"] = fShufflePatterns,
184  ROOT::R::Label["libOut"] = fLinOut,
185  ROOT::R::Label["pruneFunc"] = PruneFunc,
186  ROOT::R::Label["pruneFuncParams"] = r.Eval(fPruneFuncParams));
187  fModel = new ROOT::R::TRObject(Model);
188  //if model persistence is enabled saving it is R serialziation.
189  if (IsModelPersistence())
190  {
191  TString path = GetWeightFileDir() + "/RMLPModel.RData";
192  Log() << Endl;
193  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
194  Log() << Endl;
195  r["RMLPModel"] << Model;
196  r << "save(RMLPModel,file='" + path + "')";
197  }
198  }
199 }
200 
201 //_______________________________________________________________________
203 {
204  //RSNNS Options for all NN methods
205 // TVectorF fSize;//number of units in the hidden layer(s)
206  DeclareOptionRef(fSize, "Size", "number of units in the hidden layer(s)");
207  DeclareOptionRef(fMaxit, "Maxit", "Maximum of iterations to learn");
208 
209  DeclareOptionRef(fInitFunc, "InitFunc", "the initialization function to use");
210  DeclareOptionRef(fInitFuncParams, "InitFuncParams", "the parameters for the initialization function");
211 
212  DeclareOptionRef(fLearnFunc, "LearnFunc", "the learning function to use");
213  DeclareOptionRef(fLearnFuncParams, "LearnFuncParams", "the parameters for the learning function");
214 
215  DeclareOptionRef(fUpdateFunc, "UpdateFunc", "the update function to use");
216  DeclareOptionRef(fUpdateFuncParams, "UpdateFuncParams", "the parameters for the update function");
217 
218  DeclareOptionRef(fHiddenActFunc, "HiddenActFunc", "the activation function of all hidden units");
219  DeclareOptionRef(fShufflePatterns, "ShufflePatterns", "should the patterns be shuffled?");
220  DeclareOptionRef(fLinOut, "LinOut", "sets the activation function of the output units to linear or logistic");
221 
222  DeclareOptionRef(fPruneFunc, "PruneFunc", "the prune function to use");
223  DeclareOptionRef(fPruneFuncParams, "PruneFuncParams", "the parameters for the pruning function. Unlike the\
224  other functions, these have to be given in a named list. See\
225  the pruning demos for further explanation.the update function to use");
226 
227 }
228 
229 //_______________________________________________________________________
231 {
232  if (fMaxit <= 0) {
233  Log() << kERROR << " fMaxit <=0... that does not work !! "
234  << " I set it to 50 .. just so that the program does not crash"
235  << Endl;
236  fMaxit = 1;
237  }
238  // standard constructor for the RSNNS
239  //RSNNS Options for all NN methods
240 
241 }
242 
243 //_______________________________________________________________________
245 {
246  Log() << kINFO << "Testing Classification " << fNetType << " METHOD " << Endl;
247 
249 }
250 
251 
252 //_______________________________________________________________________
254 {
255  NoErrorCalc(errLower, errUpper);
256  Double_t mvaValue;
257  const TMVA::Event *ev = GetEvent();
258  const UInt_t nvar = DataInfo().GetNVariables();
259  ROOT::R::TRDataFrame fDfEvent;
260  for (UInt_t i = 0; i < nvar; i++) {
261  fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
262  }
263  //if using persistence model
265 
266  TVectorD result = predict(*fModel, fDfEvent, ROOT::R::Label["type"] = "prob");
267  mvaValue = result[0]; //returning signal prob
268  return mvaValue;
269 }
270 
271 ////////////////////////////////////////////////////////////////////////////////
272 /// get all the MVA values for the events of the current Data type
273 std::vector<Double_t> MethodRSNNS::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
274 {
276  if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
277  if (firstEvt < 0) firstEvt = 0;
278 
279  nEvents = lastEvt-firstEvt;
280 
281  UInt_t nvars = Data()->GetNVariables();
282 
283  // use timer
284  Timer timer( nEvents, GetName(), kTRUE );
285  if (logProgress)
286  Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
287  << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
288 
289 
290  // fill R DATA FRAME with events data
291  std::vector<std::vector<Float_t> > inputData(nvars);
292  for (UInt_t i = 0; i < nvars; i++) {
293  inputData[i] = std::vector<Float_t>(nEvents);
294  }
295 
296  for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
297  Data()->SetCurrentEvent(ievt);
298  const TMVA::Event *e = Data()->GetEvent();
299  assert(nvars == e->GetNVariables());
300  for (UInt_t i = 0; i < nvars; i++) {
301  inputData[i][ievt] = e->GetValue(i);
302  }
303  // if (ievt%100 == 0)
304  // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
305  }
306 
307  ROOT::R::TRDataFrame evtData;
308  for (UInt_t i = 0; i < nvars; i++) {
309  evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
310  }
311  //if using persistence model
313 
314  std::vector<Double_t> mvaValues(nEvents);
315  ROOT::R::TRObject result = predict(*fModel, evtData, ROOT::R::Label["type"] = "prob");
316  //std::vector<Double_t> probValues(2*nEvents);
317  mvaValues = result.As<std::vector<Double_t>>();
318  // assert(probValues.size() == 2*mvaValues.size());
319  // std::copy(probValues.begin()+nEvents, probValues.end(), mvaValues.begin() );
320 
321  if (logProgress) {
322  Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
323  << timer.GetElapsedTime() << " " << Endl;
324  }
325 
326  return mvaValues;
327 
328 }
329 
330 
331 //_______________________________________________________________________
333 {
334  ROOT::R::TRInterface::Instance().Require("RSNNS");
335  TString path = GetWeightFileDir() + "/RMLPModel.RData";
336  Log() << Endl;
337  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
338  Log() << Endl;
339  r << "load('" + path + "')";
340  SEXP Model;
341  r["RMLPModel"] >> Model;
342  fModel = new ROOT::R::TRObject(Model);
343 
344 }
345 
346 
347 //_______________________________________________________________________
349 {
350 // get help message text
351 //
352 // typical length of text line:
353 // "|--------------------------------------------------------------|"
354  Log() << Endl;
355  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
356  Log() << Endl;
357  Log() << "Decision Trees and Rule-Based Models " << Endl;
358  Log() << Endl;
359  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
360  Log() << Endl;
361  Log() << Endl;
362  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
363  Log() << Endl;
364  Log() << "<None>" << Endl;
365 }
366 
Bool_t fShufflePatterns
Definition: MethodRSNNS.h:104
UInt_t GetNVariables() const
Definition: DataSetInfo.h:128
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:113
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
long long Long64_t
Definition: RtypesCore.h:69
RooCmdArg Label(const char *str)
MsgLogger & Log() const
Definition: Configurable.h:128
std::vector< TString > GetListOfVariables() const
returns list of variables
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
ROOT::R::TRFunctionImport predict
Definition: MethodRSNNS.h:114
EAnalysisType
Definition: Types.h:129
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:225
void GetHelpMessage() const
std::vector< UInt_t > fFactorNumeric
Definition: MethodRSNNS.h:111
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
ROOT::R::TRObject * fModel
Definition: MethodRSNNS.h:117
TString GetElapsedTime(Bool_t Scientific=kTRUE)
Definition: Timer.cxx:129
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
virtual void TestClassification()
initialization
const TString & GetWeightFileDir() const
Definition: MethodBase.h:486
Tools & gTools()
Definition: Tools.cxx:79
TStopwatch timer
Definition: pirndm.C:37
const Event * GetEvent() const
Definition: MethodBase.h:745
DataSet * Data() const
Definition: MethodBase.h:405
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:217
DataSetInfo & DataInfo() const
Definition: MethodBase.h:406
TString fUpdateFuncParams
Definition: MethodRSNNS.h:101
const int nEvents
Definition: testRooFit.cxx:42
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:96
const char * GetName() const
Definition: MethodBase.h:330
unsigned int UInt_t
Definition: RtypesCore.h:42
ROOT::R::TRInterface & r
Definition: RMethodBase.h:53
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:925
char * Form(const char *fmt,...)
ROOT::R::TRFunctionImport mlp
Definition: MethodRSNNS.h:115
const TString & GetMethodName() const
Definition: MethodBase.h:327
UInt_t GetNVariables() const
accessor to the number of variables
Definition: Event.cxx:305
Float_t GetValue(UInt_t ivar) const
return value of i&#39;th variable
Definition: Event.cxx:233
TString fInitFuncParams
Definition: MethodRSNNS.h:95
ROOT::R::TRFunctionImport asfactor
Definition: MethodRSNNS.h:116
MethodRSNNS(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition: MethodRSNNS.cxx:51
#define ClassImp(name)
Definition: Rtypes.h:279
double Double_t
Definition: RtypesCore.h:55
TString fLearnFuncParams
Definition: MethodRSNNS.h:98
int type
Definition: TGX11.cxx:120
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
TString fPruneFuncParams
Definition: MethodRSNNS.h:108
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
std::vector< Float_t > & GetValues()
Definition: Event.h:96
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:92
#define NULL
Definition: Rtypes.h:82
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:229
double result[121]
TString fLearnFunc
Definition: MethodRSNNS.h:97
TString fHiddenActFunc
Definition: MethodRSNNS.h:103
const Bool_t kTRUE
Definition: Rtypes.h:91
static Bool_t IsModuleLoaded
Definition: MethodRSNNS.h:113
virtual void TestClassification()
initialization
const Event * GetEvent() const
Definition: DataSet.cxx:211
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:819
Bool_t IsModelPersistence()
Definition: MethodBase.h:379