Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodRSNNS.h
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : RMethodRSNNS *
8 * *
9 * Description: *
10 * R´s Package RSNNS method based on ROOTR *
11 * *
12 **********************************************************************************/
13
14#ifndef ROOT_TMVA_RMethodRSNNS
15#define ROOT_TMVA_RMethodRSNNS
16
17//////////////////////////////////////////////////////////////////////////
18// //
19// RMethodRSNNS //
20// //
21// //
22//////////////////////////////////////////////////////////////////////////
23
24#include "TMVA/RMethodBase.h"
25#include <vector>
26
27namespace TMVA {
28
29 class Factory; // DSMTEST
30 class Reader; // DSMTEST
31 class DataSetManager; // DSMTEST
32 class Types;
33 class MethodRSNNS : public RMethodBase {
34
35 public :
36
37 // constructors
38 MethodRSNNS(const TString &jobName,
39 const TString &methodTitle,
40 DataSetInfo &theData,
41 const TString &theOption = "");
42
44 const TString &theWeightFile);
45
46
47 ~MethodRSNNS(void);
48 void Train();
49 // options treatment
50 void Init();
51 void DeclareOptions();
52 void ProcessOptions();
53 // create ranking
55 {
56 return NULL; // = 0;
57 }
58
59
60 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
61
62 // performs classifier testing
63 virtual void TestClassification();
64
65
66 Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
67
69 // the actual "weights"
70 virtual void AddWeightsXMLTo(void * /*parent*/) const {} // = 0;
71 virtual void ReadWeightsFromXML(void * /*wghtnode*/) {} // = 0;
72 virtual void ReadWeightsFromStream(std::istream &) {} //= 0; // backward compatibility
73
74 void ReadModelFromFile();
75
76 // signal/background classification response for all current set of data
77 virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
78
79 private :
81 friend class Factory; // DSMTEST
82 friend class Reader; // DSMTEST
83 protected:
85 std::vector<Float_t> fProbResultForTrainSig;
86 std::vector<Float_t> fProbResultForTestSig;
87
88 TString fNetType;//default RMPL
89 //RSNNS Options for all NN methods
90 TString fSize;//number of units in the hidden layer(s)
91 UInt_t fMaxit;//maximum of iterations to learn
92
93 TString fInitFunc;//the initialization function to use
94 TString fInitFuncParams;//the parameters for the initialization function (type 6 see getSnnsRFunctionTable() in RSNNS package)
95
96 TString fLearnFunc;//the learning function to use
97 TString fLearnFuncParams;//the parameters for the learning function
98
99 TString fUpdateFunc;//the update function to use
100 TString fUpdateFuncParams;//the parameters for the update function
101
102 TString fHiddenActFunc;//the activation function of all hidden units
103 Bool_t fShufflePatterns;//should the patterns be shuffled?
104 Bool_t fLinOut;//sets the activation function of the output units to linear or logistic
105
106 TString fPruneFunc;//the pruning function to use
107 TString fPruneFuncParams;//the parameters for the pruning function. Unlike the
108 //other functions, these have to be given in a named list. See
109 //the pruning demos for further explanation.
110 std::vector<UInt_t> fFactorNumeric; //factors creations
111 //RSNNS mlp require a numeric factor then background=0 signal=1 from fFactorTrain
117 // get help message text
118 void GetHelpMessage() const;
119
121 };
122} // namespace TMVA
123#endif
unsigned int UInt_t
Definition RtypesCore.h:46
bool Bool_t
Definition RtypesCore.h:63
double Double_t
Definition RtypesCore.h:59
long long Long64_t
Definition RtypesCore.h:73
#define ClassDef(name, id)
Definition Rtypes.h:325
int type
Definition TGX11.cxx:121
This is a class to pass functions from ROOT to R.
This is a class to get ROOT's objects from R's objects.
Definition TRObject.h:70
Class that contains all the data information.
Definition DataSetInfo.h:62
Class that contains all the data information.
This is the main MVA steering class.
Definition Factory.h:80
virtual void ReadWeightsFromStream(std::istream &)=0
static Bool_t IsModuleLoaded
virtual void ReadWeightsFromStream(std::istream &)
Definition MethodRSNNS.h:72
virtual void AddWeightsXMLTo(void *) const
Definition MethodRSNNS.h:70
std::vector< Float_t > fProbResultForTrainSig
Definition MethodRSNNS.h:85
ROOT::R::TRFunctionImport asfactor
void GetHelpMessage() const
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
DataSetManager * fDataSetManager
Definition MethodRSNNS.h:80
const Ranking * CreateRanking()
Definition MethodRSNNS.h:54
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
ROOT::R::TRFunctionImport predict
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
TString fUpdateFuncParams
TString fLearnFuncParams
Definition MethodRSNNS.h:97
virtual void TestClassification()
initialization
TString fPruneFuncParams
virtual void ReadWeightsFromXML(void *)
Definition MethodRSNNS.h:71
TString fInitFuncParams
Definition MethodRSNNS.h:94
std::vector< UInt_t > fFactorNumeric
ROOT::R::TRFunctionImport mlp
ROOT::R::TRObject * fModel
std::vector< Float_t > fProbResultForTestSig
Definition MethodRSNNS.h:86
Ranking for variables in method (implementation)
Definition Ranking.h:48
The Reader class serves to use the MVAs in a specific analysis context.
Definition Reader.h:64
Basic string class.
Definition TString.h:136
create variable transformations