Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_BatchNormalization.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_BatchNormalization
2#define TMVA_SOFIE_ROPERATOR_BatchNormalization
3
4#include "SOFIE_common.hxx"
5#include "ROperator.hxx"
6#include "RModel.hxx"
7
8
9#include <cmath>
10#include <sstream>
11
12namespace TMVA{
13namespace Experimental{
14namespace SOFIE{
15
16template <typename T>
18{
19
20private:
21
22 /* Attributes */
23 float fepsilon = 1e-05;
24 float fmomentum = 0.9;
25 std::size_t ftraining_mode = 0;
26
27 std::string fNX;
28 std::string fNScale;
29 std::string fNB;
30 std::string fNMean;
31 std::string fNVar;
32 std::string fNY;
33
34 std::vector<size_t> fShapeX;
35 std::vector<size_t> fShapeScale;
36 std::vector<size_t> fShapeB;
37 std::vector<size_t> fShapeMean;
38 std::vector<size_t> fShapeVar;
39 std::vector<size_t> fShapeY;
40
41 std::string fType;
42
43public:
45
46 /* Constructor */
47 ROperator_BatchNormalization( float epsilon, float momentum, std::size_t training_mode,
48 std::string nameX, std::string nameScale, std::string nameB,
49 std::string nameMean, std::string nameVar, std::string nameY):
50 fepsilon(epsilon), fmomentum(momentum), ftraining_mode(training_mode),
51 fNX(UTILITY::Clean_name(nameX)), fNScale(UTILITY::Clean_name(nameScale)),
52 fNB(UTILITY::Clean_name(nameB)), fNMean(UTILITY::Clean_name(nameMean)),
53 fNVar(UTILITY::Clean_name(nameVar)), fNY(UTILITY::Clean_name(nameY))
54 {
55 if(std::is_same<T, float>::value){
56 fType = "float";
57 }
58 else{
59 throw
60 std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a BatchNormalization operator");
61 }
62 }
63
64
65 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) {
66 ETensorType out = input[0];
67 return {out};
68 }
69
70 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) {
71 if (input.size() != 5 ) {
72 throw
73 std::runtime_error("TMVA SOFIE BatchNormalization Op Shape inference need 5 input tensors");
74 }
75 for(size_t i = 0; i < input.size(); i++) {
76 if (input[i].size() != 4) {
77 throw
78 std::runtime_error("TMVA SOFIE BatchNormalization Op Shape inference only accept tensor with 4 dimensions");
79 }
80 }
81
82 auto ret = input;
83 return ret;
84 }
85
86 void Initialize(RModel& model){
87 if (!model.CheckIfTensorAlreadyExist(fNX)) {
88 throw
89 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNX + " fnx is not found in model");
90 }
92 throw
93 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNScale + " fns is not found in model");
94 }
95 if (!model.CheckIfTensorAlreadyExist(fNB)) {
96 throw
97 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNB + " fnb is not found in model");
98 }
100 throw
101 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNMean + " fnm is not found in model");
102 }
103 if (!model.CheckIfTensorAlreadyExist(fNVar)) {
104 throw
105 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNVar + " fnv is not found in model");
106 }
107
108 fShapeX = model.GetTensorShape(fNX);
109
110 if (fShapeX.size() < 2 || fShapeX.size() > 4) {
111 throw
112 std::runtime_error("TMVA SOFIE BatchNormalization Op input tensor " + fNX + " fnx has wrong shape : " + ConvertShapeToString(fShapeX));
113 }
114
116 fShapeB = model.GetTensorShape(fNB);
121
122 if (fShapeB.size() == 1) {
123 // Broadcast scale, bias, input_mean and input_var to shape_X
124 auto original_B = model.GetInitializedTensorData(fNB);
125 auto original_S = model.GetInitializedTensorData(fNScale);
126 auto original_M = model.GetInitializedTensorData(fNMean);
127 auto original_V = model.GetInitializedTensorData(fNVar);
128 size_t batchSize = fShapeX[0];
129 size_t channels = fShapeX[1];
130 size_t height = (fShapeX.size() > 2) ? fShapeX[2] : 1;
131 size_t width = (fShapeX.size() > 3) ? fShapeX[3] : 1;
132 size_t n = batchSize * channels * height * width;
133 if (fType == "float") {
134 float *original_bias = static_cast<float*>(original_B.get());
135 float *original_scale = static_cast<float*>(original_S.get());
136 float *original_mean = static_cast<float*>(original_M.get());
137 float *original_var = static_cast<float*>(original_V.get());
138 float *new_bias = new float[n];
139 float *new_scale = new float[n];
140 float *new_mean = new float[n];
141 float *new_var = new float[n];
142 size_t bs = 0, ch = 0, h = 0, w = 0;
143 for(ch=0; ch<channels; ch++){
144 for(h=0; h<height; h++){
145 for(w=0; w<width; w++){
146 new_bias[bs*channels*height*width + ch*height*width + h*width + w] = original_bias[ch];
147 new_scale[bs*channels*height*width + ch*height*width + h*width + w] = original_scale[ch];
148 new_mean[bs*channels*height*width + ch*height*width + h*width + w] = original_mean[ch];
149 new_var[bs*channels*height*width + ch*height*width + h*width + w] = original_var[ch];
150 }
151 }
152 }
153 size_t Batchoffset = channels*height*width;
154 for(bs = 1; bs<batchSize; bs++){
155 std::copy(new_bias, new_bias+Batchoffset, new_bias+(bs*Batchoffset));
156 std::copy(new_scale, new_scale+Batchoffset, new_scale+(bs*Batchoffset));
157 std::copy(new_mean, new_mean+Batchoffset, new_mean+(bs*Batchoffset));
158 std::copy(new_var, new_var+Batchoffset, new_var+(bs*Batchoffset));
159 }
160 //// new_var =1. / sqrt(input_var + fepsilon)
161 for(size_t i=0; i<n; i++){
162 new_var[i] = 1./sqrt(new_var[i] + fepsilon);
163 }
164 std::vector<size_t> new_bias_shape = {batchSize,channels,height,width};
165 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
166 std::shared_ptr<void> new_scale_ptr(new_scale, std::default_delete<float[]>());
167 std::shared_ptr<void> new_mean_ptr(new_mean, std::default_delete<float[]>());
168 std::shared_ptr<void> new_var_ptr(new_var, std::default_delete<float[]>());
169 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
170 model.UpdateInitializedTensor(fNScale, model.GetTensorType(fNScale), new_bias_shape, new_scale_ptr);
171 model.UpdateInitializedTensor(fNMean, model.GetTensorType(fNMean), new_bias_shape, new_mean_ptr);
172 model.UpdateInitializedTensor(fNVar, model.GetTensorType(fNVar), new_bias_shape, new_var_ptr);
173 fShapeB = model.GetTensorShape(fNB);
177 }
178 }
179 }
180
181
182 std::string Generate(std::string OpName){
183 OpName = "op_" + OpName;
184 if (fShapeX.empty()){
185 throw std::runtime_error("TMVA SOFIE Batch Normalization called to Generate without being initialized first");
186 }
187
188 std::stringstream out;
189 //// Batch Norm op
190 size_t batchSize = fShapeX[0];
191 size_t channels = fShapeX[1];
192 size_t height = (fShapeX.size() > 2) ? fShapeX[2] : 1;
193 size_t width = (fShapeX.size() > 3) ? fShapeX[3] : 1;
194 size_t n = batchSize * channels * height * width;
195
196 //// copy X into Y
197 out << SP << "constexpr int " << OpName << "_N =" << batchSize * channels * height * width << ";\n";
198 out << SP << "constexpr int "<<OpName<< "_incx = 1;\n";
199 out << SP << "constexpr int "<<OpName<< "_incy = 1;\n";
200 out << SP << "BLAS::scopy_(&" << OpName << "_N, " << "tensor_" << fNX << ", &" << OpName << "_incx," << "tensor_" << fNY << ", &" << OpName << "_incy);\n\n";
201
202 //// blas saxpy (Y = -Bmean + Y)
203 out << SP << "float "<<OpName<< "_alpha = -1;\n";
204 out << SP << "BLAS::saxpy_(&" << OpName << "_N, &" << OpName << "_alpha, " << "tensor_" << fNMean << ", &" << OpName << "_incx,"
205 << "tensor_" << fNY <<", &" << OpName << "_incy);\n\n ";
206
207 //// Y *= scale*var
208 out << SP << "for (size_t i = 0; i < " << n << "; i++) {\n";
209 out << SP << SP << "tensor_" << fNY << "[i] *= tensor_" << fNScale << "[i] * tensor_" << fNVar << "[i]; \n";
210 out << SP << "}\n";
211
212 //// blas saxpy (Y = Bbias + Y)
213 out << SP <<OpName<< "_alpha = 1;\n";
214 out << SP << "BLAS::saxpy_(&" << OpName << "_N, &" << OpName << "_alpha, " << "tensor_" << fNB << ", &" << OpName << "_incx, "
215 << "tensor_" << fNY << ", &" << OpName << "_incy);\n\n";
216
217 return out.str();
218 }
219
220};
221
222}//SOFIE
223}//Experimental
224}//TMVA
225
226
227#endif //TMVA_SOFIE_ROPERATOR_BatchNormalization
#define h(i)
Definition RSha256.hxx:106
#define e(i)
Definition RSha256.hxx:103
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
include TDocParser_001 C image html pict1_TDocParser_001 png width
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:70
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape)
Definition RModel.cxx:136
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:91
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:49
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:160
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:151
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
ROperator_BatchNormalization(float epsilon, float momentum, std::size_t training_mode, std::string nameX, std::string nameScale, std::string nameB, std::string nameMean, std::string nameVar, std::string nameY)
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:39
const Int_t n
Definition legend1.C:16
std::string ConvertShapeToString(std::vector< size_t > shape)
create variable transformations
REAL epsilon
Definition triangle.c:618