Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Gemm.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GEMM
2#define TMVA_SOFIE_ROPERATOR_GEMM
3
4
6#include "TMVA/ROperator.hxx"
7#include "TMVA/RModel.hxx"
8
9#include <sstream>
10#include <algorithm>
11#include <iterator>
12#include <iomanip>
13#include <limits>
14#include <cassert>
15
16namespace TMVA{
17namespace Experimental{
18namespace SOFIE{
19
20
21 template <typename T>
22 class ROperator_Gemm final : public ROperator
23 {
24
25 private:
26 float fAttrAlpha = 1.0;
27 float fAttrBeta = 1.0;
30
31 std::string fNA;
32 std::string fNB;
33 std::string fNC = "";
34 std::string fNC2; // bias tensor name after broadcasting
35 std::string fNY;
36 std::vector<size_t> fShapeA;
37 std::vector<size_t> fShapeB;
38 std::vector<size_t> fShapeC;
39 std::vector<size_t> fShapeY;
40
41 std::string fType;
42
43 public:
44
46 ROperator_Gemm(float alpha, float beta, int_t transA, int_t transB, std::string nameA, std::string nameB, std::string nameY):
47 fAttrAlpha(alpha), fAttrBeta(beta), fAttrTransA(transA), fAttrTransB(transB), fNA(UTILITY::Clean_name(nameA)),
48 fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY)) {
49
50 if (std::is_same<T, float>::value) {
51 fType = "float";
52 }else{
53 throw std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a gemm operator");
54 }
55 }
56
57 ROperator_Gemm(float alpha, float beta, int_t transA, int_t transB, std::string nameA, std::string nameB, std::string nameC, std::string nameY):
58 fAttrAlpha(alpha), fAttrBeta(beta), fAttrTransA(transA), fAttrTransB(transB), fNA(UTILITY::Clean_name(nameA)),
59 fNB(UTILITY::Clean_name(nameB)), fNC(UTILITY::Clean_name(nameC)), fNY(UTILITY::Clean_name(nameY)) {
60
61 if (std::is_same<T, float>::value) {
62 fType = "float";
63 }else{
64 throw std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a gemm operator");
65 }
66 }
67
68 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
69 ETensorType out = input[0];
70 return {out};
71 }
72
73 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
74 if (input.size() > 3) throw std::runtime_error("TMVA SOFIE Gemm Op Shape Inference only need 2 or 3 input tensor");
75 for (auto& i: input){
76 if (i.size() > 2){
77 throw std::runtime_error("TMVA SOFIE Gemm Op Shape Inference only accept input tensor with 2 dimensions");
78 }
79 }
80 std::vector<std::vector<size_t>> ret;
81 if (input.size() == 3){
82 ret.push_back(input[2]); //shape of C is shape of Y
83 return ret;
84 }
85 std::vector<size_t> s_a(input[0]);
86 std::vector<size_t> s_b(input[1]);
87 if (fAttrTransA){
88 std::reverse(s_a.begin(), s_a.end());
89 }
90 if (fAttrTransB){
91 std::reverse(s_b.begin(), s_b.end());
92 }
93 std::vector<size_t> s_y(2);
94 s_y[0] = s_a[0];
95 s_y[1] = s_b[1];
96 ret.push_back(s_y);
97 return ret;
98 }
99
100
101
102 void Initialize(RModel& model){
103 //TODO: propagate A or B as specified by ONNX standard
104
105 if ((model.CheckIfTensorAlreadyExist(fNA) == false) || (model.CheckIfTensorAlreadyExist(fNB) == false) ){ //input must be a graph input, or already initialized intermediate tensor
106 throw std::runtime_error("TMVA SOFIE Gemm Op Input Tensor " + fNA + " or " + fNB + " is not found in model");
107 }
108 if (fNC != ""){
109 if (model.CheckIfTensorAlreadyExist(fNC) == false){ //input must be a graph input, or already initialized intermediate tensor
110 throw std::runtime_error("TMVA SOFIE Gemm Op Input Tensor" + fNC + " is not found in model");
111 }
112 }
113 fShapeA = model.GetTensorShape(fNA);
114 if (fShapeA.size() != 2){
115 if (fShapeA.size() == 1)
116 fShapeA = {1,fShapeA[0]};
117 else
118 throw std::runtime_error("TMVA SOFIE Gemm Op Input Tensor" + fNA +
119 " is not of 2 dimensions: A " + ConvertShapeToString(fShapeA));
120 }
121 fShapeB = model.GetTensorShape(fNB);
122 if (fShapeB.size() != 2){
123 throw std::runtime_error("TMVA SOFIE Gemm Op Input Tensor" + fNB + " is not of 2 dimensions: B " + ConvertShapeToString(fShapeB));
124 }
126 if (fNC != ""){
127 fShapeC = model.GetTensorShape(fNC);
128 fNC2 = fNC;
129 bool broadcast_needed = !UTILITY::AreSameShape(fShapeC, fShapeY);
130 // For Gemm broadcasting is not needed if fShapeY[0] == 1 i.e. C and Y have same length
131 //if (fShapeY[0] == 1 && ConvertShapeToLength(fShapeC) != ConvertShapeToLength(fShapeY)) {
132 // broadcast_needed = false;
133 //}
134
135 // std::cout << "doing broadcast " << broadcast_needed << " use session " << model.UseSession() <<
136 // " shape C " << ConvertShapeToString(fShapeC) << " shape Y " << ConvertShapeToString(fShapeY)
137 // << std::endl;
138
139 if (broadcast_needed) {
140 if (!model.UseSession()) {
141 auto original_data = model.GetInitializedTensorData(fNC);
143 if (fType == "float") {
144 std::shared_ptr<void> new_data_ptr(UTILITY::UnidirectionalBroadcast<float>(
145 static_cast<float *>(original_data.get()), fShapeC, targetShape),
146 std::default_delete<float[]>());
147
148 model.UpdateInitializedTensor(fNC, model.GetTensorType(fNC), fShapeY, new_data_ptr);
150 }
151 } else {
152 // In case of session add broadcasting code in Session constructor and in GenerateInitCode
153 // we need to add a new intermediate tensor for broadcasted bias tensor
154 fNC2 = fNC + "bcast";
156 }
157 }
158 }
159
160
161
162
164 model.AddNeededStdLib("algorithm");
165
166 }
167
168 std::string GenerateInitCode()
169 {
170 std::stringstream out;
171 // generate initialization code for broadcasting of bias tensor
172 if (fShapeC.size() != fShapeY.size() && fNC != fNC2) {
174 // include a separate scope to avoid defining unique operator temp variables
175 out << SP << "{\n";
176 out << " float * data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_"
177 << fNC << "," << ConvertShapeToString(fShapeC) << ", " << ConvertShapeToString(targetShape) << ");\n";
179 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNC2 << ");\n";
180 out << SP << SP << "delete [] data;\n";
181 out << SP << "}\n";
182 }
183 return out.str();
184 }
185
186 std::string Generate(std::string OpName){
187 OpName = "op_" + OpName;
188
189 if (fShapeA.empty() || fShapeB.empty() || fShapeY.empty() || (fNC != "" && fShapeC.empty())) {
190 throw std::runtime_error("TMVA SOFIE Gemm Op called to Generate without being initialized first");
191 }
192 std::stringstream out;
193 out << "\n//--------- Gemm\n";
194 out << SP << "char " << OpName << "_transA = " << (fAttrTransA ? "\'t\'" : "\'n\'") << ";\n";
195 out << SP << "char " << OpName << "_transB = " << (fAttrTransB ? "\'t\'" : "\'n\'") << ";\n";
196 int m = (fAttrTransA ? fShapeA[1] : fShapeA[0]);
197 int n = (fAttrTransB ? fShapeB[0] : fShapeB[1]);
198 int k = (fAttrTransA ? fShapeA[0] : fShapeA[1]);
199 out << SP << "int " << OpName << "_m = " << m << ";\n";
200 out << SP << "int " << OpName << "_n = " << n << ";\n";
201 out << SP << "int " << OpName << "_k = " << k << ";\n";
202 out << SP << "float " << OpName << "_alpha = " << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrAlpha << ";\n";
203 out << SP << "float " << OpName << "_beta = " << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrBeta << ";\n";
204 out << SP << "int " << OpName << "_lda = " << (fAttrTransA ? m : k) << ";\n";
205 out << SP << "int " << OpName << "_ldb = " << (fAttrTransB ? k : n) << ";\n";
206 if (fNC != ""){
208 if (fNC2 == fNC)
209 // case broadcasting was not needed or done outside of session
211 out << SP << "std::copy(" << "tensor_" << fNC2 << ", " << "tensor_" << fNC2 << " + " << length << ", " << "tensor_" << fNY << ");\n";
212 } else {
213 //in this case fAttrBeta needs to be equal to zero otherwise second time we run we will use
214 // the previous result
215 if (fAttrBeta != 0) {
216 throw std::runtime_error("TMVA SOFIE Gemm Op : Bias tensor is not present but beta value in Gemm is not zero");
217 }
218 }
219 if (fType == "float"){
220 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
221 << "_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, " << "tensor_" << fNB
222 << ", &" << OpName << "_ldb, " << "tensor_" << fNA << ", &" << OpName << "_lda, &" << OpName << "_beta, " << "tensor_" << fNY << ", &"
223 << OpName << "_n);\n";
224 }
225
226 return out.str();
227
228 }
229
230 std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Gemv") }; }
231
232 };
233
234
235}//SOFIE
236}//Experimental
237}//TMVA
238
239
240#endif //TMVA_SOFIE_ROPERATOR_GEMM
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
void AddNeededStdLib(std::string libname)
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:76
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape)
Definition RModel.cxx:156
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:97
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:55
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:188
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:179
ROperator_Gemm(float alpha, float beta, int_t transA, int_t transB, std::string nameA, std::string nameB, std::string nameC, std::string nameY)
std::string Generate(std::string OpName)
std::vector< std::string > GetBlasRoutines()
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
ROperator_Gemm(float alpha, float beta, int_t transA, int_t transB, std::string nameA, std::string nameB, std::string nameY)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:41
const Int_t n
Definition legend1.C:16
bool AreSameShape(const std::vector< size_t > &, const std::vector< size_t > &)
std::vector< size_t > UnidirectionalBroadcastShape(std::vector< size_t >, std::vector< size_t >)
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations
TMarker m
Definition textangle.C:8