1#ifndef TMVA_SOFIE_ROPERATOR_GEMM
2#define TMVA_SOFIE_ROPERATOR_GEMM
17namespace Experimental{
48 fNB(UTILITY::Clean_name(nameB)),
fNY(UTILITY::Clean_name(nameY)) {
50 if (std::is_same<T, float>::value) {
53 throw std::runtime_error(
"TMVA SOFIE Encountered unsupported type parsing a gemm operator");
57 ROperator_Gemm(
float alpha,
float beta,
int_t transA,
int_t transB, std::string nameA, std::string nameB, std::string nameC, std::string nameY):
59 fNB(UTILITY::Clean_name(nameB)),
fNC(UTILITY::Clean_name(nameC)),
fNY(UTILITY::Clean_name(nameY)) {
61 if (std::is_same<T, float>::value) {
64 throw std::runtime_error(
"TMVA SOFIE Encountered unsupported type parsing a gemm operator");
74 if (
input.size() > 3)
throw std::runtime_error(
"TMVA SOFIE Gemm Op Shape Inference only need 2 or 3 input tensor");
77 throw std::runtime_error(
"TMVA SOFIE Gemm Op Shape Inference only accept input tensor with 2 dimensions");
80 std::vector<std::vector<size_t>> ret;
81 if (
input.size() == 3){
82 ret.push_back(
input[2]);
85 std::vector<size_t> s_a(
input[0]);
86 std::vector<size_t> s_b(
input[1]);
88 std::reverse(s_a.begin(), s_a.end());
91 std::reverse(s_b.begin(), s_b.end());
93 std::vector<size_t> s_y(2);
106 throw std::runtime_error(
"TMVA SOFIE Gemm Op Input Tensor " +
fNA +
" or " +
fNB +
" is not found in model");
110 throw std::runtime_error(
"TMVA SOFIE Gemm Op Input Tensor" +
fNC +
" is not found in model");
118 throw std::runtime_error(
"TMVA SOFIE Gemm Op Input Tensor" +
fNA +
139 if (broadcast_needed) {
143 if (
fType ==
"float") {
144 std::shared_ptr<void> new_data_ptr(UTILITY::UnidirectionalBroadcast<float>(
145 static_cast<float *
>(original_data.get()),
fShapeC, targetShape),
146 std::default_delete<
float[]>());
170 std::stringstream out;
176 out <<
" float * data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_"
179 out <<
SP <<
SP <<
"std::copy(data, data + " <<
length <<
", tensor_" <<
fNC2 <<
");\n";
180 out <<
SP <<
SP <<
"delete [] data;\n";
187 OpName =
"op_" + OpName;
190 throw std::runtime_error(
"TMVA SOFIE Gemm Op called to Generate without being initialized first");
192 std::stringstream out;
193 out <<
"\n//--------- Gemm\n";
194 out <<
SP <<
"char " << OpName <<
"_transA = " << (
fAttrTransA ?
"\'t\'" :
"\'n\'") <<
";\n";
195 out <<
SP <<
"char " << OpName <<
"_transB = " << (
fAttrTransB ?
"\'t\'" :
"\'n\'") <<
";\n";
199 out <<
SP <<
"int " << OpName <<
"_m = " <<
m <<
";\n";
200 out <<
SP <<
"int " << OpName <<
"_n = " <<
n <<
";\n";
201 out <<
SP <<
"int " << OpName <<
"_k = " << k <<
";\n";
202 out <<
SP <<
"float " << OpName <<
"_alpha = " << std::setprecision(std::numeric_limits<float>::max_digits10) <<
fAttrAlpha <<
";\n";
203 out <<
SP <<
"float " << OpName <<
"_beta = " << std::setprecision(std::numeric_limits<float>::max_digits10) <<
fAttrBeta <<
";\n";
204 out <<
SP <<
"int " << OpName <<
"_lda = " << (
fAttrTransA ?
m : k) <<
";\n";
205 out <<
SP <<
"int " << OpName <<
"_ldb = " << (
fAttrTransB ? k :
n) <<
";\n";
211 out <<
SP <<
"std::copy(" <<
"tensor_" <<
fNC2 <<
", " <<
"tensor_" <<
fNC2 <<
" + " <<
length <<
", " <<
"tensor_" <<
fNY <<
");\n";
216 throw std::runtime_error(
"TMVA SOFIE Gemm Op : Bias tensor is not present but beta value in Gemm is not zero");
219 if (
fType ==
"float"){
220 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName
221 <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, " <<
"tensor_" <<
fNB
222 <<
", &" << OpName <<
"_ldb, " <<
"tensor_" <<
fNA <<
", &" << OpName <<
"_lda, &" << OpName <<
"_beta, " <<
"tensor_" <<
fNY <<
", &"
223 << OpName <<
"_n);\n";
230 std::vector<std::string>
GetBlasRoutines() {
return { std::string(
"Gemm"), std::string(
"Gemv") }; }
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
void AddNeededStdLib(std::string libname)
const ETensorType & GetTensorType(std::string name)
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape)
bool CheckIfTensorAlreadyExist(std::string tensor_name)
const std::vector< size_t > & GetTensorShape(std::string name)
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
ROperator_Gemm(float alpha, float beta, int_t transA, int_t transB, std::string nameA, std::string nameB, std::string nameC, std::string nameY)
std::vector< size_t > fShapeY
std::string GenerateInitCode()
std::vector< size_t > fShapeC
std::string Generate(std::string OpName)
std::vector< std::string > GetBlasRoutines()
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
void Initialize(RModel &model)
std::vector< size_t > fShapeA
ROperator_Gemm(float alpha, float beta, int_t transA, int_t transB, std::string nameA, std::string nameB, std::string nameY)
std::vector< size_t > fShapeB
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
const std::string SP
space used to correctly indent the generated C++ code
bool AreSameShape(const std::vector< size_t > &, const std::vector< size_t > &)
std::vector< size_t > UnidirectionalBroadcastShape(std::vector< size_t >, std::vector< size_t >)
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations