Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel_Base.cxx
Go to the documentation of this file.
1#include <limits>
2#include <algorithm>
3#include <cctype>
5
6namespace TMVA {
7namespace Experimental {
8namespace SOFIE {
9
10RModel_Base::RModel_Base(std::string name, std::string parsedtime):fFileName(name), fParseTime(parsedtime) {
11 fName = fFileName.substr(0, fFileName.rfind("."));
13}
14
15void RModel_Base::GenerateHeaderInfo(std::string& hgname) {
16 fGC += ("//Code generated automatically by TMVA for Inference of Model file [" + fFileName + "] at [" + fParseTime.substr(0, fParseTime.length()-1) +"] \n");
17 // add header guards
18 hgname = fName;
19 std::transform(hgname.begin(), hgname.end(), hgname.begin(), [](unsigned char c) {
20 return std::toupper(c);
21 } );
22 hgname = "ROOT_TMVA_SOFIE_" + hgname;
23 fGC += "\n#ifndef " + hgname + "\n";
24 fGC += "#define " + hgname + "\n\n";
25 for (auto& i: fNeededStdLib) {
26 fGC += "#include <" + i + ">\n";
27 }
28 for (auto& i: fCustomOpHeaders) {
29 fGC += "#include \"" + i + "\"\n";
30 }
31 // for the session we need to include SOFIE_Common functions
32 //needed for convolution operator (need to add a flag)
33 fGC += "#include \"TMVA/SOFIE_common.hxx\"\n";
35 fGC += "#include <fstream>\n";
36 // Include TFile when saving the weights in a binary ROOT file
38 fGC += "#include \"TFile.h\"\n";
39
40 fGC += "\nnamespace TMVA_SOFIE_" + fName + "{\n";
41 if (!fNeededBlasRoutines.empty()) {
42 fGC += ("namespace BLAS{\n");
43 for (auto &routine : fNeededBlasRoutines) {
44 if (routine == "Gemm") {
45 fGC += ("\textern \"C\" void sgemm_(const char * transa, const char * transb, const int * m, const int * n, const int * k,\n"
46 "\t const float * alpha, const float * A, const int * lda, const float * B, const int * ldb,\n"
47 "\t const float * beta, float * C, const int * ldc);\n");
48 } else if (routine == "Gemv") {
49 fGC += ("\textern \"C\" void sgemv_(const char * trans, const int * m, const int * n, const float * alpha, const float * A,\n"
50 "\t const int * lda, const float * X, const int * incx, const float * beta, const float * Y, const int * incy);\n");
51 } else if (routine == "Axpy") {
52 fGC += ("\textern \"C\" void saxpy_(const int * n, const float * alpha, const float * x,\n"
53 "\t const int * incx, float * y, const int * incy);\n");
54 } else if (routine == "Copy") {
55 fGC += ("\textern \"C\" void scopy_(const int *n, const float* x, const int *incx, float* y, const int* incy);\n");
56 }
57 }
58 fGC += ("}//BLAS\n");
59 }
60}
61
62void RModel_Base::OutputGenerated(std::string filename, bool append) {
63 // the model can be appended only if a file name is provided
64 if (filename.empty()) {
65 // if a file is pr
66 filename = fName + ".hxx";
67 append = false;
68 }
69 std::ofstream f;
70 if (append)
71 f.open(filename, std::ios_base::app);
72 else
73 f.open(filename);
74 if (!f.is_open()) {
75 throw std::runtime_error("tmva-sofie failed to open file for output generated inference code");
76 }
77 f << fGC;
78 f.close();
79}
80
81}//SOFIE
82}//Experimental
83}//TMVA
#define f(i)
Definition RSha256.hxx:104
#define c(i)
Definition RSha256.hxx:101
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
void GenerateHeaderInfo(std::string &hgname)
RModel_Base()=default
Default constructor.
std::unordered_set< std::string > fNeededBlasRoutines
std::unordered_set< std::string > fCustomOpHeaders
void OutputGenerated(std::string filename="", bool append=false)
std::unordered_set< std::string > fNeededStdLib
std::string Clean_name(std::string input_tensor_name)
create variable transformations