1#ifndef TMVA_SOFIE_ROPERATOR_BatchNormalization
2#define TMVA_SOFIE_ROPERATOR_BatchNormalization
4#include "SOFIE_common.hxx"
5#include "ROperator.hxx"
6#include "RModel.hxx"
9#include <cmath>
10#include <sstream>
12namespace TMVA{
13namespace Experimental{
14namespace SOFIE{
16template <typename T>
22 /* Attributes */
23 float fepsilon = 1e-05;
24 float fmomentum = 0.9;
25 std::size_t ftraining_mode = 0;
27 std::string fNX;
28 std::string fNScale;
29 std::string fNB;
30 std::string fNMean;
31 std::string fNVar;
32 std::string fNY;
34 std::vector<size_t> fShapeX;
35 std::vector<size_t> fShapeScale;
36 std::vector<size_t> fShapeB;
37 std::vector<size_t> fShapeMean;
38 std::vector<size_t> fShapeVar;
39 std::vector<size_t> fShapeY;
41 std::string fType;
46 /* Constructor */
47 ROperator_BatchNormalization( float epsilon, float momentum, std::size_t training_mode,
48 std::string nameX, std::string nameScale, std::string nameB,
49 std::string nameMean, std::string nameVar, std::string nameY):
50 fepsilon(epsilon), fmomentum(momentum), ftraining_mode(training_mode),
51 fNX(UTILITY::Clean_name(nameX)), fNScale(UTILITY::Clean_name(nameScale)),
52 fNB(UTILITY::Clean_name(nameB)), fNMean(UTILITY::Clean_name(nameMean)),
53 fNVar(UTILITY::Clean_name(nameVar)), fNY(UTILITY::Clean_name(nameY))
54 {
55 if(std::is_same<T, float>::value){
56 fType = "float";
57 }
58 else{
59 throw
60 std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a BatchNormalization operator");
61 }
62 }
65 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) {
66 ETensorType out = input[0];
67 return {out};
68 }
70 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) {
71 if (input.size() != 5 ) {
72 throw
73 std::runtime_error("TMVA SOFIE BatchNormalization Op Shape inference need 5 input tensors");
74 }
75 for(size_t i = 0; i < input.size(); i++) {
76 if (input[i].size() != 4) {
77 throw
78 std::runtime_error("TMVA SOFIE BatchNormalization Op Shape inference only accept tensor with 4 dimensions");
79 }
80 }
82 auto ret = input;
83 return ret;
84 }
86 void Initialize(RModel& model){
87 if (!model.CheckIfTensorAlreadyExist(fNX)) {
88 throw
89 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNX + " fnx is not found in model");
90 }
92 throw
93 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNScale + " fns is not found in model");
94 }
95 if (!model.CheckIfTensorAlreadyExist(fNB)) {
96 throw
97 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNB + " fnb is not found in model");
98 }
100 throw
101 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNMean + " fnm is not found in model");
102 }
103 if (!model.CheckIfTensorAlreadyExist(fNVar)) {
104 throw
105 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNVar + " fnv is not found in model");
106 }
108 fShapeX = model.GetTensorShape(fNX);
110 if (fShapeX.size() < 2 || fShapeX.size() > 4) {
111 throw
112 std::runtime_error("TMVA SOFIE BatchNormalization Op input tensor " + fNX + " fnx has wrong shape : " + ConvertShapeToString(fShapeX));
113 }
116 fShapeB = model.GetTensorShape(fNB);
122 if (fShapeB.size() == 1) {
123 // Broadcast scale, bias, input_mean and input_var to shape_X
124 auto original_B = model.GetInitializedTensorData(fNB);
125 auto original_S = model.GetInitializedTensorData(fNScale);
126 auto original_M = model.GetInitializedTensorData(fNMean);
127 auto original_V = model.GetInitializedTensorData(fNVar);
128 size_t batchSize = fShapeX[0];
129 size_t channels = fShapeX[1];
130 size_t height = (fShapeX.size() > 2) ? fShapeX[2] : 1;
131 size_t width = (fShapeX.size() > 3) ? fShapeX[3] : 1;
132 size_t n = batchSize * channels * height * width;
133 if (fType == "float") {
134 float *original_bias = static_cast<float *>(original_B.get());
135 float *original_scale = static_cast<float *>(original_S.get());
136 float *original_mean = static_cast<float *>(original_M.get());
137 float *original_var = static_cast<float *>(original_V.get());
138 float *new_bias = new float[n];
139 float *new_scale = new float[n];
140 float *new_mean = new float[n];
141 float *new_var = new float[n];
142 size_t bs = 0, ch = 0, h = 0, w = 0;
143 for (ch = 0; ch < channels; ch++) {
144 for (h = 0; h < height; h++) {
145 for (w = 0; w < width; w++) {
146 new_bias[bs * channels * height * width + ch * height * width + h * width + w] = original_bias[ch];
147 new_scale[bs * channels * height * width + ch * height * width + h * width + w] =
148 original_scale[ch];
149 new_mean[bs * channels * height * width + ch * height * width + h * width + w] = original_mean[ch];
150 new_var[bs * channels * height * width + ch * height * width + h * width + w] = original_var[ch];
151 }
152 }
153 }
154 size_t Batchoffset = channels * height * width;
155 for (bs = 1; bs < batchSize; bs++) {
156 std::copy(new_bias, new_bias + Batchoffset, new_bias + (bs * Batchoffset));
157 std::copy(new_scale, new_scale + Batchoffset, new_scale + (bs * Batchoffset));
158 std::copy(new_mean, new_mean + Batchoffset, new_mean + (bs * Batchoffset));
159 std::copy(new_var, new_var + Batchoffset, new_var + (bs * Batchoffset));
160 }
161 //// new_var =1. / sqrt(input_var + fepsilon)
162 for (size_t i = 0; i < n; i++) {
163 new_var[i] = 1. / sqrt(new_var[i] + fepsilon);
164 new_scale[i] *= new_var[i]; // include var in new scale
165 }
166 std::vector<size_t> new_bias_shape = {batchSize, channels, height, width};
167 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
168 std::shared_ptr<void> new_scale_ptr(new_scale, std::default_delete<float[]>());
169 std::shared_ptr<void> new_mean_ptr(new_mean, std::default_delete<float[]>());
170 std::shared_ptr<void> new_var_ptr(new_var, std::default_delete<float[]>());
171 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
172 model.UpdateInitializedTensor(fNScale, model.GetTensorType(fNScale), new_bias_shape, new_scale_ptr);
173 model.UpdateInitializedTensor(fNMean, model.GetTensorType(fNMean), new_bias_shape, new_mean_ptr);
174 model.UpdateInitializedTensor(fNVar, model.GetTensorType(fNVar), new_bias_shape, new_var_ptr);
175 fShapeB = model.GetTensorShape(fNB);
179 }
180 }
181 }
183 std::string Generate(std::string OpName){
184 OpName = "op_" + OpName;
185 if (fShapeX.empty()){
186 throw std::runtime_error("TMVA SOFIE Batch Normalization called to Generate without being initialized first");
187 }
189 std::stringstream out;
190 //// Batch Norm op
191 size_t batchSize = fShapeX[0];
192 size_t channels = fShapeX[1];
193 size_t height = (fShapeX.size() > 2) ? fShapeX[2] : 1;
194 size_t width = (fShapeX.size() > 3) ? fShapeX[3] : 1;
195 size_t n = batchSize * channels * height * width;
197 //// copy X into Y
198 out << SP << "constexpr int " << OpName << "_N =" << batchSize * channels * height * width << ";\n";
199 out << SP << "constexpr int "<<OpName<< "_incx = 1;\n";
200 out << SP << "constexpr int "<<OpName<< "_incy = 1;\n";
201 out << SP << "BLAS::scopy_(&" << OpName << "_N, " << "tensor_" << fNX << ", &" << OpName << "_incx," << "tensor_" << fNY << ", &" << OpName << "_incy);\n\n";
203 //// blas saxpy (Y = -Bmean + Y)
204 out << SP << "float "<<OpName<< "_alpha = -1;\n";
205 out << SP << "BLAS::saxpy_(&" << OpName << "_N, &" << OpName << "_alpha, " << "tensor_" << fNMean << ", &" << OpName << "_incx,"
206 << "tensor_" << fNY <<", &" << OpName << "_incy);\n\n ";
208 //// Y *= scale*var
209 out << SP << "for (size_t i = 0; i < " << n << "; i++) {\n";
210 // scale tensor contains already the var
211 out << SP << SP << "tensor_" << fNY << "[i] *= tensor_" << fNScale << "[i]; \n";
212 out << SP << "}\n";
214 //// blas saxpy (Y = Bbias + Y)
215 out << SP <<OpName<< "_alpha = 1;\n";
216 out << SP << "BLAS::saxpy_(&" << OpName << "_N, &" << OpName << "_alpha, " << "tensor_" << fNB << ", &" << OpName << "_incx, "
217 << "tensor_" << fNY << ", &" << OpName << "_incy);\n\n";
219 return out.str();
220 }
222 std::vector<std::string> GetBlasRoutines() { return { std::string("Copy"), std::string("Axpy") }; }
230#endif //TMVA_SOFIE_ROPERATOR_BatchNormalization
