Logo ROOT   6.12/07
Reference Guide
Dropout.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 10/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12  //////////////////////////////////////////////////////////////////
13  // Implementation of the activation functions for the reference //
14  // implementation. //
15  //////////////////////////////////////////////////////////////////
16 
17 
19 #include "TRandom.h"
20 
21 namespace TMVA
22 {
23 namespace DNN
24 {
25 
26 //______________________________________________________________________________
27 
28 template<typename Real_t>
30 {
31  size_t m,n;
32  m = B.GetNrows();
33  n = B.GetNcols();
34 
35  TRandom rand(time(nullptr));
36 
37  for (size_t i = 0; i < m; i++) {
38  for (size_t j = 0; j < n; j++) {
39  Real_t r = rand.Uniform();
40  if (r >= dropoutProbability) {
41  B(i,j) = 0.0;
42  } else {
43  B(i,j) /= dropoutProbability;
44  }
45  }
46  }
47 }
48 
49 }
50 }
static double B[]
auto * m
Definition: textangle.C:8
Int_t GetNcols() const
Definition: TMatrixTBase.h:125
TMatrixT.
Definition: TMatrixDfwd.h:22
This is the base class for the ROOT Random number generators.
Definition: TRandom.h:27
ROOT::R::TRInterface & r
Definition: Object.C:4
Int_t GetNrows() const
Definition: TMatrixTBase.h:122
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition: TRandom.cxx:627
float Real_t
Definition: RtypesCore.h:64
Abstract ClassifierFactory template that handles arbitrary types.
const Int_t n
Definition: legend1.C:16
static void Dropout(TMatrixT< AReal > &A, AReal dropoutProbability)
Apply dropout with activation probability p to the given matrix A and scale the result by reciprocal ...
Definition: Dropout.cxx:29