27#ifndef TMVA_DNN_OPTIMIZER
28#define TMVA_DNN_OPTIMIZER
42template <
typename Architecture_t,
typename Layer_t = VGeneralLayer<Architecture_t>,
43 typename DeepNet_t = TDeepNet<Architecture_t, Layer_t>>
46 using Matrix_t =
typename Architecture_t::Matrix_t;
47 using Scalar_t =
typename Architecture_t::Scalar_t;
56 UpdateWeights(
size_t layerIndex, std::vector<Matrix_t> &weights,
const std::vector<Matrix_t> &weightGradients) = 0;
60 UpdateBiases(
size_t layerIndex, std::vector<Matrix_t> &biases,
const std::vector<Matrix_t> &biasGradients) = 0;
89template <
typename Architecture_t,
typename Layer_t,
typename DeepNet_t>
91 : fLearningRate(learningRate), fGlobalStep(0), fDeepNet(deepNet)
96template <
typename Architecture_t,
typename Layer_t,
typename DeepNet_t>
99 for (
size_t i = 0; i < this->GetLayers().size(); i++) {
100 this->UpdateWeights(i, this->GetLayerAt(i)->GetWeights(), this->GetLayerAt(i)->GetWeightGradients());
101 this->UpdateBiases(i, this->GetLayerAt(i)->GetBiases(), this->GetLayerAt(i)->GetBiasGradients());
size_t fGlobalStep
The current global step count during training.
Layer_t * GetLayerAt(size_t i)
std::vector< Layer_t * > & GetLayers()
void IncrementGlobalStep()
Increments the global step.
virtual ~VOptimizer()=default
Virtual Destructor.
virtual void UpdateBiases(size_t layerIndex, std::vector< Matrix_t > &biases, const std::vector< Matrix_t > &biasGradients)=0
Update the biases, given the current bias gradients.
virtual void UpdateWeights(size_t layerIndex, std::vector< Matrix_t > &weights, const std::vector< Matrix_t > &weightGradients)=0
Update the weights, given the current weight gradients.
void SetLearningRate(size_t learningRate)
Setters.
Scalar_t GetLearningRate() const
Getters.
void Step()
Performs one step of optimization.
Scalar_t fLearningRate
The learning rate used for training.
typename Architecture_t::Scalar_t Scalar_t
size_t GetGlobalStep() const
DeepNet_t & fDeepNet
The reference to the deep net.
VOptimizer(Scalar_t learningRate, DeepNet_t &deepNet)
Constructor.
typename Architecture_t::Matrix_t Matrix_t
Abstract ClassifierFactory template that handles arbitrary types.