Logo ROOT   6.18/05
Reference Guide
MethodRSNNS.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4
5/**********************************************************************************
6 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
7 * Package: TMVA *
8 * Class : MethodRSNNS *
9 * Web : http://oproject.org *
10 * *
11 * Description: *
12 * Neural Networks in R using the Stuttgart Neural Network Simulator *
13 * *
14 * *
15 * Redistribution and use in source and binary forms, with or without *
16 * modification, are permitted according to the terms listed in LICENSE *
17 * (http://tmva.sourceforge.net/LICENSE) *
18 * *
19 **********************************************************************************/
20
21#include <iomanip>
22
23#include "TMath.h"
24#include "Riostream.h"
25#include "TMatrix.h"
26#include "TMatrixD.h"
27#include "TVectorD.h"
28
30#include "TMVA/MethodRSNNS.h"
31#include "TMVA/Tools.h"
32#include "TMVA/Config.h"
33#include "TMVA/Ranking.h"
34#include "TMVA/Types.h"
35#include "TMVA/PDF.h"
37
38#include "TMVA/Results.h"
39#include "TMVA/Timer.h"
40
41using namespace TMVA;
42
43REGISTER_METHOD(RSNNS)
44
46
47//creating an Instance
48Bool_t MethodRSNNS::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("RSNNS");
49
50//_______________________________________________________________________
51MethodRSNNS::MethodRSNNS(const TString &jobName,
52 const TString &methodTitle,
53 DataSetInfo &dsi,
54 const TString &theOption) :
55 RMethodBase(jobName, Types::kRSNNS, methodTitle, dsi, theOption),
56 fMvaCounter(0),
57 predict("predict"),
58 mlp("mlp"),
59 asfactor("as.factor"),
60 fModel(NULL)
61{
62 fNetType = methodTitle;
63 if (fNetType != "RMLP") {
64 Log() << kFATAL << " Unknow Method" + fNetType
65 << Endl;
66 return;
67 }
68
69 // standard constructor for the RSNNS
70 //RSNNS Options for all NN methods
71 fSize = "c(5)";
72 fMaxit = 100;
73
74 fInitFunc = "Randomize_Weights";
75 fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
76
77 fLearnFunc = "Std_Backpropagation"; //
78 fLearnFuncParams = "c(0.2,0)";
79
80 fUpdateFunc = "Topological_Order";
81 fUpdateFuncParams = "c(0)";
82
83 fHiddenActFunc = "Act_Logistic";
86 fPruneFunc = "NULL";
87 fPruneFuncParams = "NULL";
88
89}
90
91//_______________________________________________________________________
92MethodRSNNS::MethodRSNNS(DataSetInfo &theData, const TString &theWeightFile)
93 : RMethodBase(Types::kRSNNS, theData, theWeightFile),
94 fMvaCounter(0),
95 predict("predict"),
96 mlp("mlp"),
97 asfactor("as.factor"),
98 fModel(NULL)
99
100{
101 fNetType = "RMLP"; //GetMethodName();//GetMethodName() is not returning RMLP is reting MethodBase why?
102 if (fNetType != "RMLP") {
103 Log() << kFATAL << " Unknow Method = " + fNetType
104 << Endl;
105 return;
106 }
107
108 // standard constructor for the RSNNS
109 //RSNNS Options for all NN methods
110 fSize = "c(5)";
111 fMaxit = 100;
112
113 fInitFunc = "Randomize_Weights";
114 fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
115
116 fLearnFunc = "Std_Backpropagation"; //
117 fLearnFuncParams = "c(0.2,0)";
118
119 fUpdateFunc = "Topological_Order";
120 fUpdateFuncParams = "c(0)";
121
122 fHiddenActFunc = "Act_Logistic";
124 fLinOut = kFALSE;
125 fPruneFunc = "NULL";
126 fPruneFuncParams = "NULL";
127}
128
129
130//_______________________________________________________________________
132{
133 if (fModel) delete fModel;
134}
135
136//_______________________________________________________________________
138{
139 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
140 return kFALSE;
141}
142
143
144//_______________________________________________________________________
146{
147 if (!IsModuleLoaded) {
148 Error("Init", "R's package RSNNS can not be loaded.");
149 Log() << kFATAL << " R's package RSNNS can not be loaded."
150 << Endl;
151 return;
152 }
153 //factors creations
154 //RSNNS mlp require a numeric factor then background=0 signal=1 from fFactorTrain/fFactorTest
155 UInt_t size = fFactorTrain.size();
156 fFactorNumeric.resize(size);
157
158 for (UInt_t i = 0; i < size; i++) {
159 if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
160 else fFactorNumeric[i] = 0;
161 }
162}
163
165{
166 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
167 if (fNetType == "RMLP") {
168 ROOT::R::TRObject PruneFunc;
169 if (fPruneFunc == "NULL") PruneFunc = r.Eval("NULL");
170 else PruneFunc = r.Eval(Form("'%s'", fPruneFunc.Data()));
171
172 SEXP Model = mlp(ROOT::R::Label["x"] = fDfTrain,
174 ROOT::R::Label["size"] = r.Eval(fSize),
175 ROOT::R::Label["maxit"] = fMaxit,
176 ROOT::R::Label["initFunc"] = fInitFunc,
177 ROOT::R::Label["initFuncParams"] = r.Eval(fInitFuncParams),
178 ROOT::R::Label["learnFunc"] = fLearnFunc,
179 ROOT::R::Label["learnFuncParams"] = r.Eval(fLearnFuncParams),
180 ROOT::R::Label["updateFunc"] = fUpdateFunc,
181 ROOT::R::Label["updateFuncParams"] = r.Eval(fUpdateFuncParams),
182 ROOT::R::Label["hiddenActFunc"] = fHiddenActFunc,
183 ROOT::R::Label["shufflePatterns"] = fShufflePatterns,
184 ROOT::R::Label["libOut"] = fLinOut,
185 ROOT::R::Label["pruneFunc"] = PruneFunc,
186 ROOT::R::Label["pruneFuncParams"] = r.Eval(fPruneFuncParams));
187 fModel = new ROOT::R::TRObject(Model);
188 //if model persistence is enabled saving it is R serialziation.
189 if (IsModelPersistence())
190 {
191 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
192 Log() << Endl;
193 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
194 Log() << Endl;
195 r["RMLPModel"] << Model;
196 r << "save(RMLPModel,file='" + path + "')";
197 }
198 }
199}
200
201//_______________________________________________________________________
203{
204 //RSNNS Options for all NN methods
205// TVectorF fSize;//number of units in the hidden layer(s)
206 DeclareOptionRef(fSize, "Size", "number of units in the hidden layer(s)");
207 DeclareOptionRef(fMaxit, "Maxit", "Maximum of iterations to learn");
208
209 DeclareOptionRef(fInitFunc, "InitFunc", "the initialization function to use");
210 DeclareOptionRef(fInitFuncParams, "InitFuncParams", "the parameters for the initialization function");
211
212 DeclareOptionRef(fLearnFunc, "LearnFunc", "the learning function to use");
213 DeclareOptionRef(fLearnFuncParams, "LearnFuncParams", "the parameters for the learning function");
214
215 DeclareOptionRef(fUpdateFunc, "UpdateFunc", "the update function to use");
216 DeclareOptionRef(fUpdateFuncParams, "UpdateFuncParams", "the parameters for the update function");
217
218 DeclareOptionRef(fHiddenActFunc, "HiddenActFunc", "the activation function of all hidden units");
219 DeclareOptionRef(fShufflePatterns, "ShufflePatterns", "should the patterns be shuffled?");
220 DeclareOptionRef(fLinOut, "LinOut", "sets the activation function of the output units to linear or logistic");
221
222 DeclareOptionRef(fPruneFunc, "PruneFunc", "the prune function to use");
223 DeclareOptionRef(fPruneFuncParams, "PruneFuncParams", "the parameters for the pruning function. Unlike the\
224 other functions, these have to be given in a named list. See\
225 the pruning demos for further explanation.the update function to use");
226
227}
228
229//_______________________________________________________________________
231{
232 if (fMaxit <= 0) {
233 Log() << kERROR << " fMaxit <=0... that does not work !! "
234 << " I set it to 50 .. just so that the program does not crash"
235 << Endl;
236 fMaxit = 1;
237 }
238 // standard constructor for the RSNNS
239 //RSNNS Options for all NN methods
240
241}
242
243//_______________________________________________________________________
245{
246 Log() << kINFO << "Testing Classification " << fNetType << " METHOD " << Endl;
247
249}
250
251
252//_______________________________________________________________________
254{
255 NoErrorCalc(errLower, errUpper);
256 Double_t mvaValue;
257 const TMVA::Event *ev = GetEvent();
258 const UInt_t nvar = DataInfo().GetNVariables();
259 ROOT::R::TRDataFrame fDfEvent;
260 for (UInt_t i = 0; i < nvar; i++) {
261 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
262 }
263 //if using persistence model
265
266 TVectorD result = predict(*fModel, fDfEvent, ROOT::R::Label["type"] = "prob");
267 mvaValue = result[0]; //returning signal prob
268 return mvaValue;
269}
270
271////////////////////////////////////////////////////////////////////////////////
272/// get all the MVA values for the events of the current Data type
273std::vector<Double_t> MethodRSNNS::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
274{
275 Long64_t nEvents = Data()->GetNEvents();
276 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
277 if (firstEvt < 0) firstEvt = 0;
278
279 nEvents = lastEvt-firstEvt;
280
281 UInt_t nvars = Data()->GetNVariables();
282
283 // use timer
284 Timer timer( nEvents, GetName(), kTRUE );
285 if (logProgress)
286 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
287 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
288
289
290 // fill R DATA FRAME with events data
291 std::vector<std::vector<Float_t> > inputData(nvars);
292 for (UInt_t i = 0; i < nvars; i++) {
293 inputData[i] = std::vector<Float_t>(nEvents);
294 }
295
296 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
297 Data()->SetCurrentEvent(ievt);
298 const TMVA::Event *e = Data()->GetEvent();
299 assert(nvars == e->GetNVariables());
300 for (UInt_t i = 0; i < nvars; i++) {
301 inputData[i][ievt] = e->GetValue(i);
302 }
303 // if (ievt%100 == 0)
304 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
305 }
306
307 ROOT::R::TRDataFrame evtData;
308 for (UInt_t i = 0; i < nvars; i++) {
309 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
310 }
311 //if using persistence model
313
314 std::vector<Double_t> mvaValues(nEvents);
315 ROOT::R::TRObject result = predict(*fModel, evtData, ROOT::R::Label["type"] = "prob");
316 //std::vector<Double_t> probValues(2*nEvents);
317 mvaValues = result.As<std::vector<Double_t>>();
318 // assert(probValues.size() == 2*mvaValues.size());
319 // std::copy(probValues.begin()+nEvents, probValues.end(), mvaValues.begin() );
320
321 if (logProgress) {
322 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
323 << timer.GetElapsedTime() << " " << Endl;
324 }
325
326 return mvaValues;
327
328}
329
330
331//_______________________________________________________________________
333{
335 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
336 Log() << Endl;
337 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
338 Log() << Endl;
339 r << "load('" + path + "')";
340 SEXP Model;
341 r["RMLPModel"] >> Model;
342 fModel = new ROOT::R::TRObject(Model);
343
344}
345
346
347//_______________________________________________________________________
349{
350// get help message text
351//
352// typical length of text line:
353// "|--------------------------------------------------------------|"
354 Log() << Endl;
355 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
356 Log() << Endl;
357 Log() << "Decision Trees and Rule-Based Models " << Endl;
358 Log() << Endl;
359 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
360 Log() << Endl;
361 Log() << Endl;
362 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
363 Log() << Endl;
364 Log() << "<None>" << Endl;
365}
366
#define REGISTER_METHOD(CLASS)
for example
ROOT::R::TRInterface & r
Definition: Object.C:4
#define e(i)
Definition: RSha256.hxx:103
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
long long Long64_t
Definition: RtypesCore.h:69
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:365
int type
Definition: TGX11.cxx:120
char * Form(const char *fmt,...)
This is a class to create DataFrames from ROOT to R.
Definition: TRDataFrame.h:177
static TRInterface & Instance()
static method to get an TRInterface instance reference
Bool_t Require(TString pkg)
Method to load an R's package.
Int_t Eval(const TString &code, TRObject &ans)
Method to eval R code and you get the result in a reference to TRObject.
Definition: TRInterface.cxx:63
This is a class to get ROOT's objects from R's objects.
Definition: TRObject.h:71
T As()
Some datatypes of ROOT or c++ can be wrapped in to a TRObject, this method lets you unwrap those data...
Definition: TRObject.h:153
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Definition: Configurable.h:122
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t GetNVariables() const
Definition: DataSetInfo.h:110
std::vector< TString > GetListOfVariables() const
returns list of variables
const Event * GetEvent() const
Definition: DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:205
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:217
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:99
std::vector< Float_t > & GetValues()
Definition: Event.h:95
const char * GetName() const
Definition: MethodBase.h:325
const TString & GetWeightFileDir() const
Definition: MethodBase.h:481
const TString & GetMethodName() const
Definition: MethodBase.h:322
const Event * GetEvent() const
Definition: MethodBase.h:740
DataSetInfo & DataInfo() const
Definition: MethodBase.h:401
virtual void TestClassification()
initialization
Bool_t IsModelPersistence()
Definition: MethodBase.h:374
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:841
DataSet * Data() const
Definition: MethodBase.h:400
static Bool_t IsModuleLoaded
Definition: MethodRSNNS.h:111
Bool_t fShufflePatterns
Definition: MethodRSNNS.h:102
TString fLearnFunc
Definition: MethodRSNNS.h:95
void GetHelpMessage() const
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
ROOT::R::TRFunctionImport predict
Definition: MethodRSNNS.h:112
TString fHiddenActFunc
Definition: MethodRSNNS.h:101
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
TString fUpdateFunc
Definition: MethodRSNNS.h:98
TString fUpdateFuncParams
Definition: MethodRSNNS.h:99
TString fLearnFuncParams
Definition: MethodRSNNS.h:96
virtual void TestClassification()
initialization
TString fPruneFuncParams
Definition: MethodRSNNS.h:106
MethodRSNNS(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Definition: MethodRSNNS.cxx:51
TString fInitFuncParams
Definition: MethodRSNNS.h:93
std::vector< UInt_t > fFactorNumeric
Definition: MethodRSNNS.h:109
ROOT::R::TRFunctionImport mlp
Definition: MethodRSNNS.h:113
ROOT::R::TRObject * fModel
Definition: MethodRSNNS.h:115
std::vector< std::string > fFactorTrain
Definition: RMethodBase.h:92
ROOT::R::TRInterface & r
Definition: RMethodBase.h:49
ROOT::R::TRDataFrame fDfTrain
Definition: RMethodBase.h:88
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:134
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:840
Singleton class for Global types used by TMVA.
Definition: Types.h:73
EAnalysisType
Definition: Types.h:127
@ kClassification
Definition: Types.h:128
@ kTraining
Definition: Types.h:144
virtual void Error(const char *method, const char *msgfmt,...) const
Issue error message.
Definition: TObject.cxx:880
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
std::string GetName(const std::string &scope_name)
Definition: Cppyy.cxx:146
Rcpp::internal::NamedPlaceHolder Label
Definition: RExports.cxx:14
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:748