27#ifndef TMVA_DNN_BatchNormLayer
28#define TMVA_DNN_BatchNormLayer
63template <
typename Architecture_t>
67 using Scalar_t =
typename Architecture_t::Scalar_t;
68 using Matrix_t =
typename Architecture_t::Matrix_t;
69 using Tensor_t =
typename Architecture_t::Tensor_t;
101 TBatchNormLayer(
size_t batchSize,
size_t inputDepth,
size_t inputHeight,
size_t inputWidth,
186 std::vector<Matrix_t> params(2);
218template <
typename Architecture_t>
220 size_t inputWidth,
const std::vector<size_t> &shape,
int axis,
222 :
VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth,
223 inputDepth, inputHeight, inputWidth,
225 CalculateNormDim(axis, inputDepth, inputHeight, inputWidth),
227 shape[2], shape[0], shape[1],
229 fNormAxis(axis), fMomentum(momentum), fEpsilon(
epsilon),
230 fMu(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
231 fVar(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
232 fIVar(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
233 fMu_Training(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
234 fVar_Training(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
241template <
typename Architecture_t>
246 printf(
"Error - copy ctor not implemented\n");
250template <
typename Architecture_t>
254 printf(
"Error - copy ctor not implemented\n");
258template <
typename Architecture_t>
263 Architecture_t::ReleaseBNormDescriptors(fDescriptors);
268template <
typename Architecture_t>
271 Matrix_t &gamma = this->GetWeightsAt(0);
272 Matrix_t &beta = this->GetWeightsAt(1);
273 size_t bndim = gamma.GetNcols();
276 for (
size_t i = 0; i < bndim; ++i) {
279 fMu_Training(0,i) = 0;
280 fVar_Training(0,i) = 1;
283 Matrix_t &dgamma = this->GetWeightGradientsAt(0);
284 Matrix_t &dbeta = this->GetWeightGradientsAt(1);
290 Architecture_t::InitializeBNormDescriptors(fDescriptors,
this);
294template <
typename Architecture_t>
299 if (
x.GetLayout() != fReshapedData.GetLayout()) {
300 x2 =
Tensor_t(
x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
301 y2 =
Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
305 y2 = this->GetOutput();
310 Architecture_t::BatchNormLayerForwardTraining(fNormAxis,
x2,
y2,
311 this->GetWeightsAt(0), this->GetWeightsAt(1),
312 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
314 this->GetVarVector(), this->GetNTrainedBatches(),
315 this->GetMomentum(), this->GetEpsilon(),
316 descr->HelperDescriptor);
327 Architecture_t::BatchNormLayerForwardInference(fNormAxis,
x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
328 y2, this->GetMuVector(), this->GetVarVector(),
329 this->GetEpsilon(), descr->HelperDescriptor);
336template <
typename Architecture_t>
338 const Tensor_t & activations_backward ) ->
void
344 if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
345 Tensor_t x =
Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
346 Tensor_t dx =
Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
347 Tensor_t dy =
Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
349 Architecture_t::BatchNormLayerBackward(fNormAxis,
x, dy, dx,
350 this->GetWeightsAt(0),
351 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
352 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
353 this->GetEpsilon(), descr->HelperDescriptor);
357 Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward,
358 this->GetActivationGradients(),
360 this->GetWeightsAt(0),
361 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
362 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
363 this->GetEpsilon(), descr->HelperDescriptor);
368template <
typename Architecture_t>
371 std::cout <<
" BATCH NORM Layer: \t";
372 std::cout <<
" Input/Output = ( " ;
373 auto &shape = this->GetOutput().GetShape();
374 for (
size_t i = 0; i < shape.size(); ++i) {
375 if (i > 0) std::cout <<
" , ";
376 std::cout << shape[i];
379 std::cout <<
"\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
380 std::cout <<
"\t axis = " << fNormAxis << std::endl;
381 std::cout << std::endl;
386template <
typename Architecture_t>
401 this->WriteMatrixToXML(layerxml,
"Training-mu", this->GetMuVector());
402 this->WriteMatrixToXML(layerxml,
"Training-variance", this->GetVarVector());
405 this->WriteMatrixToXML(layerxml,
"Gamma", this->GetWeightsAt(0));
406 this->WriteMatrixToXML(layerxml,
"Beta", this->GetWeightsAt(1));
411template <
typename Architecture_t>
419 this->ReadMatrixXML(parent,
"Training-mu", this->GetMuVector());
420 this->ReadMatrixXML(parent,
"Training-variance", this->GetVarVector());
422 this->ReadMatrixXML(parent,
"Gamma", this->GetWeightsAt(0));
423 this->ReadMatrixXML(parent,
"Beta", this->GetWeightsAt(1));
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char x2
Option_t Option_t TPoint TPoint const char y2
Layer implementing Batch Normalization.
static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
const Matrix_t & GetMuVector() const
int fNormAxis
Normalization axis. For each element of this axis we will compute mean and stddev.
typename Architecture_t::Matrix_t Matrix_t
Scalar_t GetMomentum() const
TDescriptors * fDescriptors
int & GetNTrainedBatches()
Scalar_t GetEpsilon() const
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
Scalar_t GetNormAxis() const
void SetExtraLayerParameters(const std::vector< Matrix_t > ¶ms)
void ResetTraining()
Reset some training flags after a loop on all batches Some layer (e.g.
std::vector< Matrix_t > GetExtraLayerParameters() const
typename Architecture_t::Tensor_t Tensor_t
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Scalar_t fMomentum
The weight decay.
Matrix_t & GetVarVector()
const Matrix_t & GetVariance() const
Matrix_t & GetBatchMean()
const Matrix_t & GetReshapedData() const
void Print() const
Printing the layer info.
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
const int & GetNTrainedBatches() const
const Matrix_t & GetIVariance() const
typename Architecture_t::Scalar_t Scalar_t
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Compute weight, bias and activation gradients.
Matrix_t & GetReshapedData()
void Forward(Tensor_t &input, bool inTraining=true)
Compute activation of the layer for the given input.
~TBatchNormLayer()
Destructor.
typename Architecture_t::TensorDescriptor_t HelperDescriptor_t
typename Architecture_t::BNormDescriptors_t BNormDescriptors_t
const Matrix_t & GetVarVector() const
Tensor_t fDerivatives
First fDerivatives of the activations of this layer.
const Matrix_t & GetBatchMean() const
TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth, const std::vector< size_t > &shape, int axis=-1, Scalar_t momentum=-1., Scalar_t epsilon=0.0001)
Constructor.
Matrix_t & GetIVariance()
Generic General Layer class.
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
create variable transformations