Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodDNN.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Peter Speckmayer
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodDNN *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * NeuralNetwork *
12 * *
13 * Authors (alphabetical): *
14 * Peter Speckmayer <peter.speckmayer@gmx.at> - CERN, Switzerland *
15 * Simon Pfreundschuh <s.pfreundschuh@gmail.com> - CERN, Switzerland *
16 * *
17 * Copyright (c) 2005-2015: *
18 * CERN, Switzerland *
19 * U. of Victoria, Canada *
20 * MPI-K Heidelberg, Germany *
21 * U. of Bonn, Germany *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (http://tmva.sourceforge.net/LICENSE) *
26 **********************************************************************************/
27
28//#pragma once
29
30#ifndef ROOT_TMVA_MethodDNN
31#define ROOT_TMVA_MethodDNN
32
33//////////////////////////////////////////////////////////////////////////
34// //
35// MethodDNN //
36// //
37// Neural Network implementation //
38// //
39//////////////////////////////////////////////////////////////////////////
40
41#include <vector>
42#include <map>
43#include <string>
44#include <sstream>
45
46#include "TString.h"
47#include "TTree.h"
48#include "TRandom3.h"
49#include "TH1F.h"
50#include "TMVA/MethodBase.h"
51#include "TMVA/NeuralNet.h"
52
53#include "TMVA/Tools.h"
54
55#include "TMVA/DNN/Net.h"
56#include "TMVA/DNN/Minimizers.h"
58
59#ifdef R__HAS_TMVACPU
60#define DNNCPU
61#endif
62#ifdef R__HAS_TMVAGPU
63//#define DNNCUDA
64#endif
65
66#ifdef DNNCPU
68#endif
69
70#ifdef DNNCUDA
72#endif
73
74namespace TMVA {
75
76class MethodDNN : public MethodBase
77{
79
84
85private:
86 using LayoutVector_t = std::vector<std::pair<int, DNN::EActivationFunction>>;
87 using KeyValueVector_t = std::vector<std::map<TString, TString>>;
88
90 {
91 size_t batchSize;
98 std::vector<Double_t> dropoutProbabilities;
100 };
101
102 // the option handling methods
104 void ProcessOptions();
105
107
108 // general helper functions
109 void Init();
110
114
122 std::vector<TTrainingSettings> fTrainingSettings;
124
126
127 ClassDef(MethodDNN,0); // neural network
128
129 static inline void WriteMatrixXML(void *parent, const char *name,
130 const TMatrixT<Double_t> &X);
131 static inline void ReadMatrixXML(void *xml, const char *name,
133protected:
134
135 void MakeClassSpecific( std::ostream&, const TString& ) const;
136 void GetHelpMessage() const;
137
138public:
139
140 // Standard Constructors
141 MethodDNN(const TString& jobName,
142 const TString& methodTitle,
143 DataSetInfo& theData,
144 const TString& theOption);
146 const TString& theWeightFile);
147 virtual ~MethodDNN();
148
150 UInt_t numberClasses,
151 UInt_t numberTargets );
154 TString blockDelim,
155 TString tokenDelim);
156 void Train();
157 void TrainGpu();
158 void TrainCpu();
159
160 virtual Double_t GetMvaValue( Double_t* err=0, Double_t* errUpper=0 );
161 virtual const std::vector<Float_t>& GetRegressionValues();
162 virtual const std::vector<Float_t>& GetMulticlassValues();
163
165
166 // write weights to stream
167 void AddWeightsXMLTo ( void* parent ) const;
168
169 // read weights from stream
170 void ReadWeightsFromStream( std::istream & i );
171 void ReadWeightsFromXML ( void* wghtnode );
172
173 // ranking of input variables
174 const Ranking* CreateRanking();
175
176};
177
178inline void MethodDNN::WriteMatrixXML(void *parent,
179 const char *name,
180 const TMatrixT<Double_t> &X)
181{
182 std::stringstream matrixStringStream("");
183 matrixStringStream.precision( 16 );
184
185 for (size_t i = 0; i < (size_t) X.GetNrows(); i++)
186 {
187 for (size_t j = 0; j < (size_t) X.GetNcols(); j++)
188 {
189 matrixStringStream << std::scientific << X(i,j) << " ";
190 }
191 }
192 std::string s = matrixStringStream.str();
193 void* matxml = gTools().xmlengine().NewChild(parent, 0, name);
194 gTools().xmlengine().NewAttr(matxml, 0, "rows",
195 gTools().StringFromInt((int)X.GetNrows()));
196 gTools().xmlengine().NewAttr(matxml, 0, "cols",
197 gTools().StringFromInt((int)X.GetNcols()));
198 gTools().xmlengine().AddRawLine (matxml, s.c_str());
199}
200
201inline void MethodDNN::ReadMatrixXML(void *xml,
202 const char *name,
204{
205 void *matrixXML = gTools().GetChild(xml, name);
206 size_t rows, cols;
207 gTools().ReadAttr(matrixXML, "rows", rows);
208 gTools().ReadAttr(matrixXML, "cols", cols);
209
210 const char * matrixString = gTools().xmlengine().GetNodeContent(matrixXML);
211 std::stringstream matrixStringStream(matrixString);
212
213 for (size_t i = 0; i < rows; i++)
214 {
215 for (size_t j = 0; j < cols; j++)
216 {
217 matrixStringStream >> X(i,j);
218 }
219 }
220}
221} // namespace TMVA
222
223#endif
double Double_t
Definition RtypesCore.h:59
#define ClassDef(name, id)
Definition Rtypes.h:325
char name[80]
Definition TGX11.cxx:110
int type
Definition TGX11.cxx:121
Generic neural network class.
Definition Net.h:49
The reference architecture class.
Definition Reference.h:53
TMatrixT< AReal > Matrix_t
Definition Reference.h:58
Class that contains all the data information.
Definition DataSetInfo.h:62
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual void ReadWeightsFromStream(std::istream &)=0
Deep Neural Network Implementation.
Definition MethodDNN.h:77
TString fLayoutString
Definition MethodDNN.h:115
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
typename Architecture_t::Scalar_t Scalar_t
Definition MethodDNN.h:83
virtual const std::vector< Float_t > & GetMulticlassValues()
UInt_t GetNumValidationSamples()
void ReadWeightsFromXML(void *wghtnode)
std::vector< std::map< TString, TString > > KeyValueVector_t
Definition MethodDNN.h:87
typename Architecture_t::Matrix_t Matrix_t
Definition MethodDNN.h:82
TString fTrainingStrategyString
Definition MethodDNN.h:117
KeyValueVector_t fSettings
Definition MethodDNN.h:125
void ReadWeightsFromStream(std::istream &i)
LayoutVector_t ParseLayoutString(TString layerSpec)
static void WriteMatrixXML(void *parent, const char *name, const TMatrixT< Double_t > &X)
Definition MethodDNN.h:178
MethodDNN(DataSetInfo &theData, const TString &theWeightFile)
void MakeClassSpecific(std::ostream &, const TString &) const
MethodDNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption)
void ProcessOptions()
virtual ~MethodDNN()
LayoutVector_t fLayout
Definition MethodDNN.h:121
TString fValidationSize
Definition MethodDNN.h:120
TString fWeightInitializationString
Definition MethodDNN.h:118
std::vector< std::pair< int, DNN::EActivationFunction > > LayoutVector_t
Definition MethodDNN.h:86
DNN::EInitialization fWeightInitialization
Definition MethodDNN.h:112
friend struct TestMethodDNNValidationSize
Definition MethodDNN.h:78
TString fErrorStrategy
Definition MethodDNN.h:116
std::vector< TTrainingSettings > fTrainingSettings
Definition MethodDNN.h:122
void DeclareOptions()
TString fArchitectureString
Definition MethodDNN.h:119
virtual Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
const Ranking * CreateRanking()
KeyValueVector_t ParseKeyValueString(TString parseString, TString blockDelim, TString tokenDelim)
DNN::EOutputFunction fOutputFunction
Definition MethodDNN.h:113
void AddWeightsXMLTo(void *parent) const
void GetHelpMessage() const
static void ReadMatrixXML(void *xml, const char *name, TMatrixT< Double_t > &X)
Definition MethodDNN.h:201
virtual const std::vector< Float_t > & GetRegressionValues()
Ranking for variables in method (implementation)
Definition Ranking.h:48
void * GetChild(void *parent, const char *childname=0)
get child node
Definition Tools.cxx:1150
TXMLEngine & xmlengine()
Definition Tools.h:262
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
Int_t GetNrows() const
Int_t GetNcols() const
TMatrixT.
Definition TMatrixT.h:39
Basic string class.
Definition TString.h:136
Bool_t AddRawLine(XMLNodePointer_t parent, const char *line)
Add just line into xml file Line should has correct xml syntax that later it can be decoded by xml pa...
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
const char * GetNodeContent(XMLNodePointer_t xmlnode)
get contents (if any) of xmlnode
EOutputFunction
Enum that represents output functions.
Definition Functions.h:46
ERegularization
Enum representing the regularization type applied for a given layer.
Definition Functions.h:65
create variable transformations
Tools & gTools()
DNN::ERegularization regularization
Definition MethodDNN.h:94
std::vector< Double_t > dropoutProbabilities
Definition MethodDNN.h:98