Logo ROOT   6.12/07
Reference Guide
LossFunctions.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 10/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12  ////////////////////////////////////////////////////////////
13  // Implementation of the loss functions for the reference //
14  // implementation. //
15  ////////////////////////////////////////////////////////////
16 
18 
19 namespace TMVA
20 {
21 namespace DNN
22 {
23 //______________________________________________________________________________
24 template <typename AReal>
26  const TMatrixT<AReal> &weights)
27 {
28  size_t m,n;
29  m = Y.GetNrows();
30  n = Y.GetNcols();
31  AReal result = 0.0;
32 
33  for (size_t i = 0; i < m; i++) {
34  for (size_t j = 0; j < n; j++) {
35  AReal dY = (Y(i,j) - output(i,j));
36  result += weights(i, 0) * dY * dY;
37  }
38  }
39  result /= static_cast<AReal>(m * n);
40  return result;
41 }
42 
43 //______________________________________________________________________________
44 template <typename AReal>
46  const TMatrixT<AReal> &output, const TMatrixT<AReal> &weights)
47 {
48  size_t m,n;
49  m = Y.GetNrows();
50  n = Y.GetNcols();
51 
52  dY.Minus(Y, output);
53  dY *= -2.0 / static_cast<AReal>(m * n);
54 
55  for (size_t i = 0; i < m; i++) {
56  for (size_t j = 0; j < n; j++) {
57  dY(i, j) *= weights(i, 0);
58  }
59  }
60 }
61 
62 //______________________________________________________________________________
63 template <typename AReal>
65  const TMatrixT<AReal> &weights)
66 {
67  size_t m,n;
68  m = Y.GetNrows();
69  n = Y.GetNcols();
70  AReal result = 0.0;
71 
72  for (size_t i = 0; i < m; i++) {
73  AReal w = weights(i, 0);
74  for (size_t j = 0; j < n; j++) {
75  AReal sig = 1.0 / (1.0 + std::exp(-output(i,j)));
76  result += w * (Y(i, j) * std::log(sig) + (1.0 - Y(i, j)) * std::log(1.0 - sig));
77  }
78  }
79  result /= -static_cast<AReal>(m * n);
80  return result;
81 }
82 
83 //______________________________________________________________________________
84 template <typename AReal>
86  const TMatrixT<AReal> &output, const TMatrixT<AReal> &weights)
87 {
88  size_t m,n;
89  m = Y.GetNrows();
90  n = Y.GetNcols();
91 
92  AReal norm = 1.0 / static_cast<AReal>(m * n);
93  for (size_t i = 0; i < m; i++)
94  {
95  AReal w = weights(i, 0);
96  for (size_t j = 0; j < n; j++)
97  {
98  AReal y = Y(i,j);
99  AReal sig = 1.0 / (1.0 + std::exp(-output(i,j)));
100  dY(i, j) = norm * w * (sig - y);
101  }
102  }
103 }
104 
105 //______________________________________________________________________________
106 template <typename AReal>
108  const TMatrixT<AReal> &weights)
109 {
110  size_t m,n;
111  m = Y.GetNrows();
112  n = Y.GetNcols();
113  AReal result = 0.0;
114 
115  for (size_t i = 0; i < m; i++) {
116  AReal sum = 0.0;
117  AReal w = weights(i, 0);
118  for (size_t j = 0; j < n; j++) {
119  sum += exp(output(i,j));
120  }
121  for (size_t j = 0; j < n; j++) {
122  result += w * Y(i, j) * log(exp(output(i, j)) / sum);
123  }
124  }
125  result /= -static_cast<AReal>(m);
126  return result;
127 }
128 
129 //______________________________________________________________________________
130 template <typename AReal>
132  const TMatrixT<AReal> &output, const TMatrixT<AReal> &weights)
133 {
134  size_t m,n;
135  m = Y.GetNrows();
136  n = Y.GetNcols();
137  AReal norm = 1.0 / m ;
138 
139  for (size_t i = 0; i < m; i++)
140  {
141  AReal sum = 0.0;
142  AReal sumY = 0.0;
143  AReal w = weights(i, 0);
144  for (size_t j = 0; j < n; j++) {
145  sum += exp(output(i,j));
146  sumY += Y(i,j);
147  }
148  for (size_t j = 0; j < n; j++) {
149  dY(i, j) = w * norm * (exp(output(i, j)) / sum * sumY - Y(i, j));
150  }
151  }
152 }
153 
154 } // namespace DNN
155 } // namespace TMVA
static long int sum(long int i)
Definition: Factory.cxx:2173
auto * m
Definition: textangle.C:8
Int_t GetNcols() const
Definition: TMatrixTBase.h:125
static AReal SoftmaxCrossEntropy(const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
Softmax transformation is implicitly applied, thus output should hold the linear activations of the l...
static AReal CrossEntropy(const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
Sigmoid transformation is implicitly applied, thus output should hold the linear activations of the l...
void Minus(const TMatrixT< Element > &a, const TMatrixT< Element > &b)
General matrix summation. Create a matrix C such that C = A - B.
Definition: TMatrixT.cxx:580
static void SoftmaxCrossEntropyGradients(TMatrixT< AReal > &dY, const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
Int_t GetNrows() const
Definition: TMatrixTBase.h:122
static void CrossEntropyGradients(TMatrixT< AReal > &dY, const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
Double_t y[n]
Definition: legend1.C:17
Abstract ClassifierFactory template that handles arbitrary types.
double exp(double)
static void MeanSquaredErrorGradients(TMatrixT< AReal > &dY, const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
const Int_t n
Definition: legend1.C:16
static AReal MeanSquaredError(const TMatrixT< AReal > &Y, const TMatrixT< AReal > &output, const TMatrixT< AReal > &weights)
double log(double)