Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel_Base.cxx
Go to the documentation of this file.
1#include <limits>
2#include <algorithm>
3#include <cctype>
5
6namespace TMVA {
7namespace Experimental {
8namespace SOFIE {
9
10RModel_Base::RModel_Base(std::string name, std::string parsedtime):fFileName(name), fParseTime(parsedtime) {
11 fName = fFileName.substr(0, fFileName.rfind("."));
12}
13
14void RModel_Base::GenerateHeaderInfo(std::string& hgname) {
15 fGC += ("//Code generated automatically by TMVA for Inference of Model file [" + fFileName + "] at [" + fParseTime.substr(0, fParseTime.length()-1) +"] \n");
16 // add header guards
17 hgname = fName;
18 std::transform(hgname.begin(), hgname.end(), hgname.begin(), [](unsigned char c) {
19 return std::toupper(c);
20 } );
21 hgname = "ROOT_TMVA_SOFIE_" + hgname;
22 fGC += "\n#ifndef " + hgname + "\n";
23 fGC += "#define " + hgname + "\n\n";
24 for (auto& i: fNeededStdLib) {
25 fGC += "#include <" + i + ">\n";
26 }
27 for (auto& i: fCustomOpHeaders) {
28 fGC += "#include \"" + i + "\"\n";
29 }
30 // for the session we need to include SOFIE_Common functions
31 //needed for convolution operator (need to add a flag)
32 fGC += "#include \"TMVA/SOFIE_common.hxx\"\n";
34 fGC += "#include <fstream>\n";
35 // Include TFile when saving the weights in a binary ROOT file
37 fGC += "#include \"TFile.h\"\n";
38
39 fGC += "\nnamespace TMVA_SOFIE_" + fName + "{\n";
40 if (!fNeededBlasRoutines.empty()) {
41 fGC += ("namespace BLAS{\n");
42 for (auto &routine : fNeededBlasRoutines) {
43 if (routine == "Gemm") {
44 fGC += ("\textern \"C\" void sgemm_(const char * transa, const char * transb, const int * m, const int * n, const int * k,\n"
45 "\t const float * alpha, const float * A, const int * lda, const float * B, const int * ldb,\n"
46 "\t const float * beta, float * C, const int * ldc);\n");
47 } else if (routine == "Gemv") {
48 fGC += ("\textern \"C\" void sgemv_(const char * trans, const int * m, const int * n, const float * alpha, const float * A,\n"
49 "\t const int * lda, const float * X, const int * incx, const float * beta, const float * Y, const int * incy);\n");
50 } else if (routine == "Axpy") {
51 fGC += ("\textern \"C\" void saxpy_(const int * n, const float * alpha, const float * x,\n"
52 "\t const int * incx, float * y, const int * incy);\n");
53 } else if (routine == "Copy") {
54 fGC += ("\textern \"C\" void scopy_(const int *n, const float* x, const int *incx, float* y, const int* incy);\n");
55 }
56 }
57 fGC += ("}//BLAS\n");
58 }
59}
60
61void RModel_Base::OutputGenerated(std::string filename, bool append) {
62 // the model can be appended only if a file name is provided
63 if (filename.empty()) {
64 // if a file is pr
65 filename = fName + ".hxx";
66 append = false;
67 }
68 std::ofstream f;
69 if (append)
70 f.open(filename, std::ios_base::app);
71 else
72 f.open(filename);
73 if (!f.is_open()) {
74 throw std::runtime_error("tmva-sofie failed to open file for output generated inference code");
75 }
76 f << fGC;
77 f.close();
78}
79
80}//SOFIE
81}//Experimental
82}//TMVA
#define f(i)
Definition RSha256.hxx:104
#define c(i)
Definition RSha256.hxx:101
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
void GenerateHeaderInfo(std::string &hgname)
RModel_Base()=default
Default constructor.
std::unordered_set< std::string > fNeededBlasRoutines
std::unordered_set< std::string > fCustomOpHeaders
void OutputGenerated(std::string filename="", bool append=false)
std::unordered_set< std::string > fNeededStdLib
create variable transformations