Loading [MathJax]/extensions/tex2jax.js
Logo ROOT   6.18/05
Reference Guide
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Dropout.cxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 21/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
13#include "TRandom.h"
14
15/////////////////////////////////////////////////////////////////////
16// Implementation of Dropout for multi-threaded CPU architectures. //
17/////////////////////////////////////////////////////////////////////
18
19namespace TMVA {
20namespace DNN {
21//#if 0
22//____________________________________________________________________________
23template<typename AFloat>
25 AFloat dropoutProbability)
26{
27 AFloat *data = A.GetRawDataPointer();
28
30 size_t seed = dlRand.Integer(4294967295); // use 2^32-1
31
32 size_t nElements = A.GetNoElements();
33 const size_t nSteps = TCpuMatrix<AFloat>::GetNWorkItems(nElements);
34
35 // apply droput. The probability is actually the probability to keep the node
36 // (i.e. 1 - dropout_prob)
37 auto f = [&data, dropoutProbability, &nSteps, &nElements, &seed](UInt_t workerID)
38 {
39 TRandom rand(seed+workerID);
40 size_t iMax = std::min(workerID+nSteps,nElements);
41 for (size_t i = workerID; i < iMax; ++i) {
42 AFloat r = rand.Uniform();
43 data[i] = (r > dropoutProbability) ? 0.0 : data[i] / dropoutProbability;
44 }
45 return 0;
46 };
47
48#ifdef DL_USE_MTE
49 A.GetThreadExecutor().Foreach(f, ROOT::TSeqI(0,nElements,nSteps));
50#else
51 for (size_t i = 0; i < nElements; i+=nSteps)
52 f(i);
53#endif
54}
55 // old impl (to be removed)
56#if 0
57//____________________________________________________________________________
58template<typename AFloat>
60 AFloat dropoutProbability)
61{
62 AFloat *data = A.GetRawDataPointer();
63
64 auto f = [&data, dropoutProbability](UInt_t workerID)
65 {
66 TRandom rand(time(nullptr) + workerID);
67 AFloat r = rand.Uniform();
68 data[workerID] = (r > dropoutProbability) ? 0.0 : data[workerID] / dropoutProbability;
69 return 0;
70 };
71
72 A.GetThreadExecutor().Map(f, ROOT::TSeqI(A.GetNoElements()));
73}
74#endif
75
76
77
78} // namespace DNN
79} // namespace TMVA
ROOT::R::TRInterface & r
Definition: Object.C:4
#define f(i)
Definition: RSha256.hxx:104
unsigned int UInt_t
Definition: RtypesCore.h:42
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
The TCpuMatrix class.
Definition: CpuMatrix.h:89
static size_t GetNWorkItems(size_t nelements)
Definition: CpuMatrix.h:180
static void Dropout(TCpuMatrix< Scalar_t > &A, Scalar_t p)
Apply dropout with activation probability p to the given matrix A and scale the result by reciprocal ...
Definition: Dropout.cxx:24
static TRandom & GetRandomGenerator()
This is the base class for the ROOT Random number generators.
Definition: TRandom.h:27
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition: TRandom.cxx:635
virtual UInt_t Integer(UInt_t imax)
Returns a random integer uniformly distributed on the interval [ 0, imax-1 ].
Definition: TRandom.cxx:349
static double A[]
create variable transformations