58 void Load(
const std::string &path, std::vector<std::vector<size_t>> inputShapes = {},
int verbose = 0)
61 enum EModelType {kONNX, kKeras, kPt, kROOT, kNotDef};
62 EModelType type = kNotDef;
64 size_t pos2 = std::string::npos;
65 if ( (pos2 = path.find(
".onnx")) != std::string::npos) {
66 if (verbose) std::cout <<
"input model type is ONNX" << std::endl;
68 }
else if ( (pos2 = path.find(
".h5")) != std::string::npos || (pos2 = path.find(
".keras")) != std::string::npos) {
69 if (verbose) std::cout <<
"input model type is Keras" << std::endl;
71 }
else if ( (pos2 = path.find(
".pt")) != std::string::npos) {
72 if (verbose) std::cout <<
"input model type is PyTorch" << std::endl;
74 }
else if ( (pos2 = path.find(
".root")) != std::string::npos) {
75 if (verbose) std::cout <<
"input model type is ROOT" << std::endl;
79 if (type == kNotDef) {
80 throw std::runtime_error(
"Input file is not an ONNX or Keras or PyTorch file");
82 auto pos1 = path.rfind(
"/");
83 if (pos1 == std::string::npos)
87 std::string modelName = path.substr(pos1,pos2-pos1);
88 std::string fileType = path.substr(pos2+1, path.length()-pos2-1);
89 if (verbose) std::cout <<
"Parsing SOFIE model " << modelName <<
" of type " << fileType << std::endl;
92 std::string modelHeader = modelName +
"_fromRSofieR.hxx";
93 std::string modelWeights = modelName +
"_fromRSofieR.dat";
97 std::string parserCode;
98 std::string parserPythonCode;
101 if (
gSystem->Load(
"libROOTTMVASofieParser") < 0) {
102 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with ONNX since libROOTTMVASofieParser is missing");
104 gInterpreter->Declare(
"#include \"TMVA/RModelParser_ONNX.hxx\"");
105 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModelParser_ONNX parser ; \n";
107 parserCode +=
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path +
"\",true); \n";
109 parserCode +=
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path +
"\"); \n";
111 else if (type == kKeras) {
113 parserPythonCode +=
"\"\"\"\n";
114 parserPythonCode +=
"import ROOT\n";
117 std::string batch_size =
"1";
118 if (!inputShapes.empty() && ! inputShapes[0].empty())
119 batch_size = std::to_string(inputShapes[0][0]);
120 parserPythonCode +=
"model = ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse('" + path +
"'," + batch_size +
")\n";
122 else if (type == kPt) {
124 if (
gSystem->Load(
"libROOTTMVASofiePyParsers") < 0) {
125 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since libROOTTMVASofiePyParsers is missing");
127 if (inputShapes.size() == 0) {
128 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since the input tensor shape is missing and is needed by the PyTorch parser");
130 std::string inputShapesStr =
"{";
131 for (
unsigned int i = 0; i < inputShapes.size(); i++) {
132 inputShapesStr +=
"{ ";
133 for (
unsigned int j = 0; j < inputShapes[i].size(); j++) {
135 if (j < inputShapes[i].
size()-1) inputShapesStr +=
", ";
137 inputShapesStr +=
"}";
138 if (i < inputShapes.size()-1) inputShapesStr +=
", ";
140 inputShapesStr +=
"}";
141 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyTorch::Parse(\"" + path +
"\", "
142 + inputShapesStr +
"); \n";
144 else if (type == kROOT) {
146 parserCode +=
"{\nauto fileRead = TFile::Open(\"" + path +
"\",\"READ\");\n";
147 parserCode +=
"TMVA::Experimental::SOFIE::RModel * modelPtr;\n";
148 parserCode +=
"auto keyList = fileRead->GetListOfKeys(); TString name;\n";
149 parserCode +=
"for (const auto&& k : *keyList) { \n";
150 parserCode +=
" TString cname = ((TKey*)k)->GetClassName(); if (cname==\"TMVA::Experimental::SOFIE::RModel\") name = k->GetName(); }\n";
151 parserCode +=
"fileRead->GetObject(name,modelPtr); fileRead->Close(); delete fileRead;\n";
152 parserCode +=
"TMVA::Experimental::SOFIE::RModel & model = *modelPtr;\n";
157 if (!parserPythonCode.empty())
158 throw std::runtime_error(
"Cannot use Custom operator with a Python parser (e.g. from a Keras model)");
161 parserCode +=
"{ auto p = new TMVA::Experimental::SOFIE::ROperator_Custom<float>(\""
162 + op.fOpName +
"\"," + op.fInputNames +
"," + op.fOutputNames +
"," + op.fOutputShapes +
",\"" + op.fFileName +
"\");\n";
163 parserCode +=
"std::unique_ptr<TMVA::Experimental::SOFIE::ROperator> op(p);\n";
164 parserCode +=
"model.AddOperator(std::move(op));\n}\n";
169 if (inputShapes.size() > 0 && inputShapes[0].size() > 0) {
170 batchSize = inputShapes[0][0];
171 if (batchSize < 1) batchSize = 1;
173 if (verbose) std::cout <<
"generating the code with batch size = " << batchSize <<
" ...\n";
175 if (parserPythonCode.empty()) {
176 parserCode +=
"model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
179 parserCode +=
"model.OutputGenerated(\"" + modelHeader +
"\");\n";
181 parserCode +=
"model.PrintRequiredInputTensors();\n";
182 parserCode +=
"model.PrintIntermediateTensors();\n";
183 parserCode +=
"model.PrintOutputTensors();\n";
185 parserCode +=
"model.PrintGenerated(); \n";
189 parserCode +=
"int nInputs = model.GetInputTensorNames().size();\n";
192 parserCode +=
"return nInputs;\n}\n";
195 parserPythonCode +=
"model.Generate(ROOT.TMVA.Experimental.SOFIE.Options.kDefault,"
198 parserPythonCode +=
"model.OutputGenerated('" + modelHeader +
"');\n";
200 parserPythonCode +=
"model.PrintRequiredInputTensors()\n";
201 parserPythonCode +=
"model.PrintIntermediateTensors()\n";
202 parserPythonCode +=
"model.PrintOutputTensors()\n";
204 parserPythonCode +=
"model.PrintGenerated()\n";
207 parserPythonCode +=
"\"\"\"";
211 if (parserPythonCode.empty()) {
213 std::cout <<
"...ParserCode being executed...:\n";
214 std::cout << parserCode << std::endl;
216 iret =
gROOT->ProcessLine(parserCode.c_str());
220 std::cout <<
"executing python3 -c ......" << std::endl;
221 std::cout << parserPythonCode << std::endl;
223 iret =
gSystem->Exec(TString(
"python3 -c ") + TString(parserPythonCode.c_str()));
226 if (!inputShapes.empty())
fNInputs = inputShapes.size();
230 std::string msg =
"RSofieReader: error processing the parser code: \n" + parserCode;
231 throw std::runtime_error(msg);
232 }
else if (verbose) {
233 std::cout <<
"Model Header file is generated!" << std::endl;
236 throw std::runtime_error(
"RSofieReader does not yet support model with > 3 inputs");
240 if (verbose) std::cout <<
"compile generated code from file " <<modelHeader << std::endl;
241 if (
gSystem->AccessPathName(modelHeader.c_str())) {
242 std::string msg =
"RSofieReader: input header file " + modelHeader +
" is not existing";
243 throw std::runtime_error(msg);
245 if (verbose) std::cout <<
"Creating Inference function for model " << modelName << std::endl;
246 std::string declCode;
247 declCode +=
"#pragma cling optimize(2)\n";
248 declCode +=
"#include \"" + modelHeader +
"\"\n";
250 std::string sessionClassName =
"TMVA_SOFIE_" + modelName +
"::Session";
252 std::string uidName = uuid.
AsString();
253 uidName.erase(std::remove_if(uidName.begin(), uidName.end(),
254 [](
char const&
c ) ->
bool { return !std::isalnum(c); } ), uidName.end());
256 std::string sessionName =
"session_" + uidName;
257 declCode += sessionClassName +
" " + sessionName +
"(\"" + modelWeights +
"\");";
259 if (verbose) std::cout <<
"//global session declaration\n" << declCode << std::endl;
262 iret =
gSystem->Load(
"libROOTTMVASofie");
264 throw std::runtime_error(
"Error loading libROOTTMVASofie library");
268 std::string msg =
"RSofieReader: error compiling inference code and creating session class\n" + declCode;
269 throw std::runtime_error(msg);
275 std::stringstream ifuncCode;
276 std::string funcName =
"SofieInference_" + uidName;
277 ifuncCode <<
"std::vector<float> " + funcName +
"( void * ptr";
279 ifuncCode <<
", float * data" << i;
280 ifuncCode <<
") {\n";
281 ifuncCode <<
" " << sessionClassName <<
" * s = " <<
"(" << sessionClassName <<
"*) (ptr);\n";
282 ifuncCode <<
" return s->infer(";
283 for (
int i = 0; i <
fNInputs; i++) {
284 if (i>0) ifuncCode <<
",";
285 ifuncCode <<
"data" << i;
290 if (verbose) std::cout <<
"//Inference function code using global session instance\n"
291 << ifuncCode.str() << std::endl;
295 std::string msg =
"RSofieReader: error compiling inference function\n" + ifuncCode.str();
296 throw std::runtime_error(msg);