Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Arithmetic.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Ravi Kiran S
3
4/*************************************************************************
5 * Copyright (C) 2018, Ravi Kiran S *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12//////////////////////////////////////////////////////////////////
13// Implementation of the Helper arithmetic functions for the //
14// reference implementation. //
15//////////////////////////////////////////////////////////////////
16
18#include <math.h>
19
20namespace TMVA {
21namespace DNN {
22
23//______________________________________________________________________________
24template <typename AReal>
26{
27 B = 0.0;
28 for (Int_t i = 0; i < A.GetNrows(); i++) {
29 for (Int_t j = 0; j < A.GetNcols(); j++) {
30 B(0, j) += A(i, j);
31 }
32 }
33}
34
35//______________________________________________________________________________
36template <typename AReal>
38{
39 for (Int_t i = 0; i < A.GetNrows(); i++) {
40 for (Int_t j = 0; j < A.GetNcols(); j++) {
41 A(i, j) *= B(i, j);
42 }
43 }
44}
45
46//______________________________________________________________________________
47template <typename AReal>
49{
50 for (Int_t i = 0; i < A.GetNrows(); i++) {
51 for (Int_t j = 0; j < A.GetNcols(); j++) {
52 A(i, j) += beta;
53 }
54 }
55}
56
57//______________________________________________________________________________
58template <typename AReal>
60{
61 for (Int_t i = 0; i < A.GetNrows(); i++) {
62 for (Int_t j = 0; j < A.GetNcols(); j++) {
63 A(i, j) *= beta;
64 }
65 }
66}
67
68//______________________________________________________________________________
69template <typename AReal>
71{
72 for (Int_t i = 0; i < A.GetNrows(); i++) {
73 for (Int_t j = 0; j < A.GetNcols(); j++) {
74 A(i, j) = 1.0 / A(i, j);
75 }
76 }
77}
78
79//______________________________________________________________________________
80template <typename AReal>
82{
83 for (Int_t i = 0; i < A.GetNrows(); i++) {
84 for (Int_t j = 0; j < A.GetNcols(); j++) {
85 A(i, j) *= A(i, j);
86 }
87 }
88}
89
90//______________________________________________________________________________
91template <typename AReal>
93{
94 for (Int_t i = 0; i < A.GetNrows(); i++) {
95 for (Int_t j = 0; j < A.GetNcols(); j++) {
96 A(i, j) = sqrt(A(i, j));
97 }
98 }
99}
100/// Adam updates
101//____________________________________________________________________________
102template<typename AReal>
104{
105 // ADAM update the weights.
106 // Weight = Weight - alpha * M / (sqrt(V) + epsilon)
107 AReal * a = A.GetMatrixArray();
108 const AReal * m = M.GetMatrixArray();
109 const AReal * v = V.GetMatrixArray();
110 for (int index = 0; index < A.GetNoElements() ; ++index) {
111 a[index] = a[index] - alpha * m[index]/( sqrt(v[index]) + eps);
112 }
113}
114
115//____________________________________________________________________________
116template<typename AReal>
118{
119 // First momentum weight gradient update for ADAM
120 // Mt = beta1 * Mt-1 + (1-beta1) * WeightGradients
121 AReal * a = A.GetMatrixArray();
122 const AReal * b = B.GetMatrixArray();
123 for (int index = 0; index < A.GetNoElements() ; ++index) {
124 a[index] = beta * a[index] + (1.-beta) * b[index];
125 }
126}
127//____________________________________________________________________________
128template<typename AReal>
130{
131 // Second momentum weight gradient update for ADAM
132 // Vt = beta2 * Vt-1 + (1-beta2) * WeightGradients^2
133 AReal * a = A.GetMatrixArray();
134 const AReal * b = B.GetMatrixArray();
135 for (int index = 0; index < A.GetNoElements() ; ++index) {
136 a[index] = beta * a[index] + (1.-beta) * b[index] * b[index];
137 }
138}
139
140} // namespace DNN
141} // namespace TMVA
#define b(i)
Definition RSha256.hxx:100
#define a(i)
Definition RSha256.hxx:99
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
static void AdamUpdate(TMatrixT< AReal > &A, const TMatrixT< AReal > &M, const TMatrixT< AReal > &V, AReal alpha, AReal eps)
Update functions for ADAM optimizer.
static void AdamUpdateSecondMom(TMatrixT< AReal > &A, const TMatrixT< AReal > &B, AReal beta)
static void AdamUpdateFirstMom(TMatrixT< AReal > &A, const TMatrixT< AReal > &B, AReal beta)
static void ConstAdd(TMatrixT< AReal > &A, AReal beta)
Add the constant beta to all the elements of matrix A and write the result into A.
static void ReciprocalElementWise(TMatrixT< AReal > &A)
Reciprocal each element of the matrix A and write the result into A.
static void SquareElementWise(TMatrixT< AReal > &A)
Square each element of the matrix A and write the result into A.
static void Hadamard(TMatrixT< AReal > &A, const TMatrixT< AReal > &B)
In-place Hadamard (element-wise) product of matrices A and B with the result being written into A.
static void SqrtElementWise(TMatrixT< AReal > &A)
Square root each element of the matrix A and write the result into A.
static void SumColumns(TMatrixT< AReal > &B, const TMatrixT< AReal > &A)
Sum columns of (m x n) matrix A and write the results into the first m elements in A.
static void ConstMult(TMatrixT< AReal > &A, AReal beta)
Multiply the constant beta to all the elements of matrix A and write the result into A.
Int_t GetNrows() const
Int_t GetNoElements() const
Int_t GetNcols() const
TMatrixT.
Definition TMatrixT.h:39
const Element * GetMatrixArray() const override
Definition TMatrixT.h:225
create variable transformations
TMarker m
Definition textangle.C:8