Logo ROOT   6.07/09
Reference Guide
TestLossFunctionsCuda.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Simon Pfreundschuh
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 ///////////////////////////////////////////////////////////////////
13 // Test for the loss function reference implementation using the //
14 // generic test defined in TestLossFunctions.h. //
15 ///////////////////////////////////////////////////////////////////
16 
17 #include <iostream>
19 #include "TestLossFunctions.h"
20 
21 using namespace TMVA::DNN;
22 
23 int main()
24 {
25  using Scalar_t = Double_t;
26  std::cout << "Testing Loss Functions:" << std::endl << std::endl;
27 
28  double error;
29 
30  //
31  // Mean Squared Error.
32  //
33 
34  error = testMeanSquaredError<TCuda<Scalar_t>>(10);
35  std::cout << "Testing mean squared error loss: ";
36  std::cout << "maximum relative error = " << print_error(error) << std::endl;
37  if (error > 1e-3)
38  return 1;
39 
40  error = testMeanSquaredErrorGradients<TCuda<Scalar_t>>(10);
41  std::cout << "Testing mean squared error gradient: ";
42  std::cout << "maximum relative error = " << print_error(error) << std::endl;
43  if (error > 1e-3)
44  return 1;
45 
46  //
47  // Cross Entropy.
48  //
49 
50  error = testCrossEntropy<TCuda<Scalar_t>>(10);
51  std::cout << "Testing cross entropy loss: ";
52  std::cout << "maximum relative error = " << print_error(error) << std::endl;
53  if (error > 1e-3)
54  return 1;
55 
56  error = testCrossEntropyGradients<TCuda<Scalar_t>>(10);
57  std::cout << "Testing mean squared error gradient: ";
58  std::cout << "maximum relative error = " << print_error(error) << std::endl;
59  if (error > 1e-3)
60  return 1;
61 }
int main()
Definition: Blas.h:58
std::string print_error(AFloat &e)
Color code error.
Definition: Utility.h:247
double Double_t
Definition: RtypesCore.h:55
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630