Logo ROOT   6.07/09
Reference Guide
TestActivationFunctions.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Simon Pfreundschuh
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 //////////////////////////////////////////////////////////////////////
13 // Concrete instantiation of the generic activation function test //
14 // for the reference architecture. //
15 //////////////////////////////////////////////////////////////////////
16 
17 #include <iostream>
19 
20 using namespace TMVA::DNN;
21 
22 int main()
23 {
24  using Scalar_t = Double_t;
25  std::cout << "Testing Activation Functions:" << std::endl;
26 
27  Scalar_t error;
28 
29  // Identity.
30 
31  error = testIdentity<TReference<Scalar_t>>(10);
32  std::cout << "Testing identity activation: ";
33  std::cout << "maximum relative error = " << print_error(error) << std::endl;
34  if (error > 1e-10)
35  return 1;
36 
37  error = testIdentityDerivative<TReference<Scalar_t>>(10);
38  std::cout << "Testing identity activation derivative: ";
39  std::cout << "maximum relative error = " << print_error(error) << std::endl;
40  if (error > 1e-10)
41  return 1;
42 
43  // ReLU.
44 
45  error = testRelu<TReference<Scalar_t>>(10);
46  std::cout << "Testing ReLU activation: ";
47  std::cout << "maximum relative error = " << print_error(error) << std::endl;
48  if (error > 1e-10)
49  return 1;
50 
51  error = testReluDerivative<TReference<Scalar_t>>(10);
52  std::cout << "Testing ReLU activation derivative: ";
53  std::cout << "maximum relative error = " << print_error(error) << std::endl;
54  if (error > 1e-10)
55  return 1;
56 
57  // Sigmoid.
58 
59  error = testSigmoid<TReference<Scalar_t>>(10);
60  std::cout << "Testing Sigmoid activation: ";
61  std::cout << "maximum relative error = " << print_error(error) << std::endl;
62  if (error > 1e-10)
63  return 1;
64 
65  error = testSigmoidDerivative<TReference<Scalar_t>>(10);
66  std::cout << "Testing Sigmoid activation derivative: ";
67  std::cout << "maximum relative error = " << print_error(error) << std::endl;
68  if (error > 1e-10)
69  return 1;
70 
71  // TanH.
72 
73  error = testTanh<TReference<Scalar_t>>(10);
74  std::cout << "Testing TanH activation: ";
75  std::cout << "maximum relative error = " << print_error(error) << std::endl;
76  if (error > 1e-10)
77  return 1;
78 
79  error = testTanhDerivative<TReference<Scalar_t>>(10);
80  std::cout << "Testing TanH activation derivative: ";
81  std::cout << "maximum relative error = " << print_error(error) << std::endl;
82  if (error > 1e-10)
83  return 1;
84 
85  // Symmetric ReLU.
86 
87  error = testSymmetricReluDerivative<TReference<Scalar_t>>(10);
88  std::cout << "Testing Symm. ReLU activation: ";
89  std::cout << "maximum relative error = " << print_error(error) << std::endl;
90  if (error > 1e-10)
91  return 1;
92 
93  error = testSymmetricReluDerivative<TReference<Scalar_t>>(10);
94  std::cout << "Testing Symm. ReLU activation derivative: ";
95  std::cout << "maximum relative error = " << print_error(error) << std::endl;
96  if (error > 1e-10)
97  return 1;
98 
99  // Soft Sign.
100 
101  error = testSoftSign<TReference<Scalar_t>>(10);
102  std::cout << "Testing Soft Sign activation: ";
103  std::cout << "maximum relative error = " << print_error(error) << std::endl;
104  if (error > 1e-10)
105  return 1;
106 
107  error = testSoftSignDerivative<TReference<Scalar_t>>(10);
108  std::cout << "Testing Soft Sign activation derivative: ";
109  std::cout << "maximum relative error = " << print_error(error) << std::endl;
110  if (error > 1e-10)
111  return 1;
112 
113  // Gauss.
114 
115  error = testGauss<TReference<Scalar_t>>(10);
116  std::cout << "Testing Gauss activation: ";
117  std::cout << "maximum relative error = " << print_error(error) << std::endl;
118  if (error > 1e-10)
119  return 1;
120 
121  error = testGaussDerivative<TReference<Scalar_t>>(10);
122  std::cout << "Testing Gauss activation derivative: ";
123  std::cout << "maximum relative error = " << print_error(error) << std::endl;
124  if (error > 1e-10)
125  return 1;
126 
127  return 0;
128 }
Definition: Blas.h:58
std::string print_error(AFloat &e)
Color code error.
Definition: Utility.h:247
double Double_t
Definition: RtypesCore.h:55
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630