46 std::string nameX, std::string nameScale, std::string nameB,
57 if(std::is_same<T, float>::value){
62 std::runtime_error(
"TMVA SOFIE Encountered unsupported type parsing a BatchNormalization operator");
89 if (!model.CheckIfTensorAlreadyExist(
fNX)) {
91 std::runtime_error(
"TMVA SOFIE BatchNormalization op Input Tensor " +
fNX +
" fnx is not found in model");
93 if (!model.CheckIfTensorAlreadyExist(
fNScale)) {
95 std::runtime_error(
"TMVA SOFIE BatchNormalization op Input Tensor " +
fNScale +
" fns is not found in model");
97 if (!model.CheckIfTensorAlreadyExist(
fNB)) {
99 std::runtime_error(
"TMVA SOFIE BatchNormalization op Input Tensor " +
fNB +
" fnb is not found in model");
101 if (!model.CheckIfTensorAlreadyExist(
fNMean)) {
103 std::runtime_error(
"TMVA SOFIE BatchNormalization op Input Tensor " +
fNMean +
" fnm is not found in model");
105 if (!model.CheckIfTensorAlreadyExist(
fNVar)) {
107 std::runtime_error(
"TMVA SOFIE BatchNormalization op Input Tensor " +
fNVar +
" fnv is not found in model");
118 model.AddIntermediateTensor(
fNY, model.GetTensorType(
fNX),
fShapeY);
120 auto original_S = model.GetInitializedTensorData(
fNScale);
121 auto original_V = model.GetInitializedTensorData(
fNVar);
123 auto shape_S = model.GetTensorShape(
fNScale);
124 if (shape_S.size() != 1) {
125 throw std::runtime_error(
"TMVA SOFIE BatchNormalization 'scale' tensor must be 1D (per-channel).");
127 size_t channels = shape_S[0];
129 if (
fType ==
"float") {
130 float *original_scale_ptr =
static_cast<float *
>(original_S.get());
131 float *original_var_ptr =
static_cast<float *
>(original_V.get());
132 float *fused_scale_data =
new float[channels];
134 for (
size_t i = 0; i < channels; i++) {
136 fused_scale_data[i] = original_scale_ptr[i] / std::sqrt(original_var_ptr[i] +
fepsilon);
139 std::shared_ptr<void> fused_scale_ptr(fused_scale_data, std::default_delete<
float[]>());
140 model.AddInitializedTensor(
fNFusedScale, model.GetTensorType(
fNScale), {channels}, fused_scale_ptr);