Logo ROOT   6.07/09
Reference Guide
TestLossFunctionsCpu.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Simon Pfreundschuh
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 //////////////////////////////////////////////////////////////////
13 // Test for the loss function implementatoins for the //
14 // multi-threaded CPU version using the generic test defined in //
15 // TestLossFunctions.h. //
16 //////////////////////////////////////////////////////////////////
17 
18 #include <iostream>
20 #include "TestLossFunctions.h"
21 
22 using namespace TMVA::DNN;
23 
24 int main()
25 {
26  using Scalar_t = Double_t;
27 
28  std::cout << "Testing Loss Functions:" << std::endl << std::endl;
29 
30  double error;
31 
32  //
33  // Mean Squared Error.
34  //
35 
36  error = testMeanSquaredError<TCpu<Scalar_t>>(10);
37  std::cout << "Testing mean squared error loss: ";
38  std::cout << "maximum relative error = " << print_error(error) << std::endl;
39  if (error > 1e-3)
40  return 1;
41 
42  error = testMeanSquaredErrorGradients<TCpu<Scalar_t>>(10);
43  std::cout << "Testing mean squared error gradient: ";
44  std::cout << "maximum relative error = " << print_error(error) << std::endl;
45  if (error > 1e-3)
46  return 1;
47 
48  //
49  // Cross Entropy.
50  //
51 
52  error = testCrossEntropy<TCpu<Scalar_t>>(10);
53  std::cout << "Testing cross entropy loss: ";
54  std::cout << "maximum relative error = " << print_error(error) << std::endl;
55  if (error > 1e-3)
56  return 1;
57 
58  error = testCrossEntropyGradients<TCpu<Scalar_t>>(10);
59  std::cout << "Testing mean squared error gradient: ";
60  std::cout << "maximum relative error = " << print_error(error) << std::endl;
61  if (error > 1e-3)
62  return 1;
63 }
Definition: Blas.h:58
std::string print_error(AFloat &e)
Color code error.
Definition: Utility.h:247
double Double_t
Definition: RtypesCore.h:55
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
int main()