Logo ROOT   6.07/09
Reference Guide
Dropout.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 21/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
13 #include "TRandom.h"
14 
15 /////////////////////////////////////////////////////////////////////
16 // Implementation of Dropout for multi-threaded CPU architectures. //
17 /////////////////////////////////////////////////////////////////////
18 
19 namespace TMVA {
20 namespace DNN {
21 
22 //____________________________________________________________________________
23 template<typename AFloat>
25  AFloat dropoutProbability)
26 {
27  AFloat *data = A.GetRawDataPointer();
28 
29  auto f = [&data, dropoutProbability](UInt_t workerID)
30  {
31  TRandom rand(time(nullptr) + workerID);
32  AFloat r = rand.Uniform();
33  data[workerID] = (r > dropoutProbability) ? 0.0 : data[workerID] / dropoutProbability;
34  return 0;
35  };
36 
38 }
39 
40 } // namespace DNN
41 } // namespace TMVA
The TCpuMatrix class.
Definition: CpuMatrix.h:46
ROOT::TThreadExecutor & GetThreadExecutor() const
Definition: CpuMatrix.h:106
static double A[]
This is the base class for the ROOT Random number generators.
Definition: TRandom.h:31
static void Dropout(TCpuMatrix< Scalar_t > &A, Scalar_t p)
Apply dropout with activation probability p to the given matrix A and scale the result by reciprocal ...
Definition: Dropout.cxx:24
TRandom2 r(17)
unsigned int UInt_t
Definition: RtypesCore.h:42
AFloat * GetRawDataPointer()
Return raw pointer to the elements stored contiguously in column-major order.
Definition: CpuMatrix.h:103
size_t GetNElements() const
Definition: CpuMatrix.h:95
double f(double x)
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition: TRandom.cxx:606
Abstract ClassifierFactory template that handles arbitrary types.