Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ActivationFunctions.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 10/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 //////////////////////////////////////////////////////////////////
13 // Implementation of the activation functions for the reference //
14 // implementation. //
15 //////////////////////////////////////////////////////////////////
16
18#include <math.h>
19
20namespace TMVA
21{
22namespace DNN
23{
24
25//______________________________________________________________________________
26template<typename Real_t>
28 const TMatrixT<Real_t> &/*A*/)
29{
30 size_t m,n;
31 m = B.GetNrows();
32 n = B.GetNcols();
33
34 for (size_t i = 0; i < m; i++) {
35 for (size_t j = 0; j < n; j++) {
36 B(i,j) = 1.0;
37 }
38 }
39}
40
41//______________________________________________________________________________
42template<typename Real_t>
44{
45 size_t m,n;
46 m = A.GetNrows();
47 n = A.GetNcols();
48
49 for (size_t i = 0; i < m; i++) {
50 for (size_t j = 0; j < n; j++) {
51 A(i,j) = std::max((Real_t) 0.0, A(i,j));
52 }
53 }
54}
55
56//______________________________________________________________________________
57template<typename Real_t>
59 const TMatrixT<Real_t> & A)
60{
61 size_t m,n;
62 m = A.GetNrows();
63 n = A.GetNcols();
64
65 for (size_t i = 0; i < m; i++)
66 {
67 for (size_t j = 0; j < n; j++)
68 {
69 B(i,j) = (A(i,j) < 0) ? 0.0 : 1.0;
70 }
71 }
72}
73
74//______________________________________________________________________________
75template<typename Real_t>
77{
78 size_t m,n;
79 m = B.GetNrows();
80 n = B.GetNcols();
81
82 for (size_t i = 0; i < m; i++) {
83 for (size_t j = 0; j < n; j++) {
84 Real_t sig = 1.0 / (1.0 + std::exp(-B(i,j)));
85 B(i,j) = sig;
86 }
87 }
88}
89
90//______________________________________________________________________________
91template<typename Real_t>
93 const TMatrixT<Real_t> & A)
94{
95 size_t m,n;
96 m = A.GetNrows();
97 n = A.GetNcols();
98
99 for (size_t i = 0; i < m; i++) {
100 for (size_t j = 0; j < n; j++) {
101 Real_t sig = 1.0 / (1.0 + std::exp(-A(i,j)));
102 B(i,j) = sig * (1.0 - sig);
103 }
104 }
105}
106
107//______________________________________________________________________________
108template<typename Real_t>
110{
111 size_t m,n;
112 m = B.GetNrows();
113 n = B.GetNcols();
114
115 for (size_t i = 0; i < m; i++) {
116 for (size_t j = 0; j < n; j++) {
117 Real_t t = tanh(B(i,j));
118 B(i,j) = t;
119 }
120 }
121}
122
123//______________________________________________________________________________
124template<typename Real_t>
126 const TMatrixT<Real_t> & A)
127{
128 size_t m,n;
129 m = A.GetNrows();
130 n = A.GetNcols();
131
132 for (size_t i = 0; i < m; i++) {
133 for (size_t j = 0; j < n; j++) {
134 Real_t t = tanh(A(i,j));
135 B(i,j) = 1 - t * t;
136 }
137 }
138}
139
140//______________________________________________________________________________
141template<typename Real_t>
143{
144 size_t m,n;
145 m = B.GetNrows();
146 n = B.GetNcols();
147
148 for (size_t i = 0; i < m; i++) {
149 for (size_t j = 0; j < n; j++) {
150 B(i,j) = fabs(B(i,j));
151 }
152 }
153}
154
155//______________________________________________________________________________
156template<typename Real_t>
158 const TMatrixT<Real_t> & A)
159{
160 size_t m,n;
161 m = A.GetNrows();
162 n = A.GetNcols();
163
164 for (size_t i = 0; i < m; i++) {
165 for (size_t j = 0; j < n; j++) {
166 B(i,j) = (A(i,j) < 0.0) ? -1.0 : 1.0;
167 }
168 }
169}
170
171//______________________________________________________________________________
172template<typename Real_t>
174{
175 size_t m,n;
176 m = A.GetNrows();
177 n = A.GetNcols();
178
179 for (size_t i = 0; i < m; i++) {
180 for (size_t j = 0; j < n; j++) {
181 Real_t x = A(i,j);
182 A(i,j) = x / (1 + fabs(x));
183 }
184 }
185}
186
187//______________________________________________________________________________
188template<typename Real_t>
190 const TMatrixT<Real_t> & A)
191{
192 size_t m,n;
193 m = A.GetNrows();
194 n = A.GetNcols();
195
196 for (size_t i = 0; i < m; i++) {
197 for (size_t j = 0; j < n; j++) {
198 Real_t x = 1.0 + fabs(A(i,j));
199 B(i,j) = 1.0 / (x * x);
200 }
201 }
202}
203
204//______________________________________________________________________________
205template<typename Real_t>
207{
208 size_t m,n;
209 m = A.GetNrows();
210 n = A.GetNcols();
211
212 for (size_t i = 0; i < m; i++) {
213 for (size_t j = 0; j < n; j++) {
214 Real_t x = A(i,j);
215 A(i,j) = exp(- x * x);
216 }
217 }
218}
219
220//______________________________________________________________________________
221template<typename Real_t>
223 const TMatrixT<Real_t> & A)
224{
225 size_t m,n;
226 m = A.GetNrows();
227 n = A.GetNcols();
228
229 for (size_t i = 0; i < m; i++) {
230 for (size_t j = 0; j < n; j++) {
231 Real_t x = A(i,j);
232 B(i,j) = - 2.0 * x * exp(- x * x);
233 }
234 }
235}
236} // namespace DNN
237} // namespace TMVA
float Real_t
Definition RtypesCore.h:68
static void SymmetricRelu(TMatrixT< AReal > &B)
static void Relu(TMatrixT< AReal > &B)
static void GaussDerivative(TMatrixT< AReal > &B, const TMatrixT< AReal > &A)
static void TanhDerivative(TMatrixT< AReal > &B, const TMatrixT< AReal > &A)
static void IdentityDerivative(TMatrixT< AReal > &B, const TMatrixT< AReal > &A)
static void Gauss(TMatrixT< AReal > &B)
static void SigmoidDerivative(TMatrixT< AReal > &B, const TMatrixT< AReal > &A)
static void SoftSignDerivative(TMatrixT< AReal > &B, const TMatrixT< AReal > &A)
static void Tanh(TMatrixT< AReal > &B)
static void SoftSign(TMatrixT< AReal > &B)
static void Sigmoid(TMatrixT< AReal > &B)
static void SymmetricReluDerivative(TMatrixT< AReal > &B, const TMatrixT< AReal > &A)
static void ReluDerivative(TMatrixT< AReal > &B, const TMatrixT< AReal > &A)
Int_t GetNrows() const
Int_t GetNcols() const
TMatrixT.
Definition TMatrixT.h:39
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
create variable transformations
TMarker m
Definition textangle.C:8