ROOT
git-r3/HEAD
Reference Guide
Loading...
Searching...
No Matches
RModel_Base.cxx
Go to the documentation of this file.
1
#include <limits>
2
#include <algorithm>
3
#include <cctype>
4
#include "
TMVA/RModel_Base.hxx
"
5
6
namespace
TMVA
{
7
namespace
Experimental
{
8
namespace
SOFIE
{
9
10
RModel_Base::RModel_Base
(std::string
name
, std::string parsedtime):
fFileName
(
name
),
fParseTime
(parsedtime) {
11
fName
=
fFileName
.substr(0,
fFileName
.rfind(
"."
));
12
fName
=
UTILITY::Clean_name
(
fName
);
13
}
14
15
void
RModel_Base::GenerateHeaderInfo
(std::string& hgname) {
16
fGC
+= (
"//Code generated automatically by TMVA for Inference of Model file ["
+
fFileName
+
"] at ["
+
fParseTime
.substr(0,
fParseTime
.length()-1) +
"] \n"
);
17
// add header guards
18
hgname =
fName
;
19
std::transform(hgname.begin(), hgname.end(), hgname.begin(), [](
unsigned
char
c
) {
20
return std::toupper(c);
21
} );
22
hgname =
"ROOT_TMVA_SOFIE_"
+ hgname;
23
fGC
+=
"\n#ifndef "
+ hgname +
"\n"
;
24
fGC
+=
"#define "
+ hgname +
"\n\n"
;
25
for
(
auto
& i:
fNeededStdLib
) {
26
fGC
+=
"#include <"
+ i +
">\n"
;
27
}
28
for
(
auto
& i:
fCustomOpHeaders
) {
29
fGC
+=
"#include \""
+ i +
"\"\n"
;
30
}
31
// for the session we need to include SOFIE_Common functions
32
//needed for convolution operator (need to add a flag)
33
fGC
+=
"#include \"TMVA/SOFIE_common.hxx\"\n"
;
34
if
(
fUseWeightFile
)
35
fGC
+=
"#include <fstream>\n"
;
36
// Include TFile when saving the weights in a binary ROOT file
37
if
(
fWeightFile
==
WeightFileType::RootBinary
)
38
fGC
+=
"#include \"TFile.h\"\n"
;
39
40
fGC
+=
"\nnamespace TMVA_SOFIE_"
+
fName
+
"{\n"
;
41
if
(!
fNeededBlasRoutines
.empty()) {
42
fGC
+= (
"namespace BLAS{\n"
);
43
for
(
auto
&routine :
fNeededBlasRoutines
) {
44
if
(routine ==
"Gemm"
) {
45
fGC
+= (
"\textern \"C\" void sgemm_(const char * transa, const char * transb, const int * m, const int * n, const int * k,\n"
46
"\t const float * alpha, const float * A, const int * lda, const float * B, const int * ldb,\n"
47
"\t const float * beta, float * C, const int * ldc);\n"
);
48
}
else
if
(routine ==
"Gemv"
) {
49
fGC
+= (
"\textern \"C\" void sgemv_(const char * trans, const int * m, const int * n, const float * alpha, const float * A,\n"
50
"\t const int * lda, const float * X, const int * incx, const float * beta, const float * Y, const int * incy);\n"
);
51
}
else
if
(routine ==
"Axpy"
) {
52
fGC
+= (
"\textern \"C\" void saxpy_(const int * n, const float * alpha, const float * x,\n"
53
"\t const int * incx, float * y, const int * incy);\n"
);
54
}
else
if
(routine ==
"Copy"
) {
55
fGC
+= (
"\textern \"C\" void scopy_(const int *n, const float* x, const int *incx, float* y, const int* incy);\n"
);
56
}
57
}
58
fGC
+= (
"}//BLAS\n"
);
59
}
60
}
61
62
void
RModel_Base::OutputGenerated
(std::string filename,
bool
append) {
63
// the model can be appended only if a file name is provided
64
if
(filename.empty()) {
65
// if a file is pr
66
filename =
fName
+
".hxx"
;
67
append =
false
;
68
}
69
std::ofstream
f
;
70
if
(append)
71
f
.open(filename, std::ios_base::app);
72
else
73
f
.open(filename);
74
if
(!
f
.is_open()) {
75
throw
std::runtime_error(
"tmva-sofie failed to open file for output generated inference code"
);
76
}
77
f
<<
fGC
;
78
f
.close();
79
}
80
81
}
//SOFIE
82
}
//Experimental
83
}
//TMVA
RModel_Base.hxx
f
#define f(i)
Definition
RSha256.hxx:104
c
#define c(i)
Definition
RSha256.hxx:101
name
char name[80]
Definition
TGX11.cxx:148
TMVA::Experimental::SOFIE::RModel_Base::GenerateHeaderInfo
void GenerateHeaderInfo(std::string &hgname)
Definition
RModel_Base.cxx:15
TMVA::Experimental::SOFIE::RModel_Base::RModel_Base
RModel_Base()=default
Default constructor.
TMVA::Experimental::SOFIE::RModel_Base::fNeededBlasRoutines
std::unordered_set< std::string > fNeededBlasRoutines
Definition
RModel_Base.hxx:53
TMVA::Experimental::SOFIE::RModel_Base::fGC
std::string fGC
Definition
RModel_Base.hxx:59
TMVA::Experimental::SOFIE::RModel_Base::fCustomOpHeaders
std::unordered_set< std::string > fCustomOpHeaders
Definition
RModel_Base.hxx:56
TMVA::Experimental::SOFIE::RModel_Base::fFileName
std::string fFileName
Definition
RModel_Base.hxx:48
TMVA::Experimental::SOFIE::RModel_Base::fParseTime
std::string fParseTime
Definition
RModel_Base.hxx:49
TMVA::Experimental::SOFIE::RModel_Base::OutputGenerated
void OutputGenerated(std::string filename="", bool append=false)
Definition
RModel_Base.cxx:62
TMVA::Experimental::SOFIE::RModel_Base::fNeededStdLib
std::unordered_set< std::string > fNeededStdLib
Definition
RModel_Base.hxx:55
TMVA::Experimental::SOFIE::RModel_Base::fUseWeightFile
bool fUseWeightFile
Definition
RModel_Base.hxx:60
TMVA::Experimental::SOFIE::RModel_Base::fWeightFile
WeightFileType fWeightFile
Definition
RModel_Base.hxx:51
TMVA::Experimental::SOFIE::RModel_Base::fName
std::string fName
Definition
RModel_Base.hxx:58
TMVA::Experimental::SOFIE::UTILITY::Clean_name
std::string Clean_name(std::string input_tensor_name)
Definition
SOFIE_common.cxx:512
TMVA::Experimental::SOFIE
Definition
RFunction.hxx:12
TMVA::Experimental::SOFIE::WeightFileType::RootBinary
@ RootBinary
Definition
RModel_Base.hxx:40
TMVA::Experimental
Definition
RFunction.hxx:11
TMVA
create variable transformations
Definition
GeneticMinimizer.h:22
tmva
sofie
src
RModel_Base.cxx
ROOTgit-r3/HEAD - Reference Guide Generated on
(GVA Time) using Doxygen 1.16.1