58 void Load(
const std::string &path, std::vector<std::vector<size_t>> inputShapes = {},
int verbose = 0)
61 enum EModelType {kONNX, kKeras, kPt, kROOT, kNotDef};
62 EModelType
type = kNotDef;
64 auto pos1 = path.rfind(
"/");
65 auto pos2 = path.find(
".onnx");
66 if (pos2 != std::string::npos) {
69 pos2 = path.find(
".h5");
70 if (pos2 != std::string::npos) {
73 pos2 = path.find(
".pt");
74 if (pos2 != std::string::npos) {
78 pos2 = path.find(
".root");
79 if (pos2 != std::string::npos) {
85 if (
type == kNotDef) {
86 throw std::runtime_error(
"Input file is not an ONNX or Keras or PyTorch file");
88 if (pos1 == std::string::npos)
92 std::string modelName = path.substr(pos1,pos2-pos1);
93 std::string fileType = path.substr(pos2+1, path.length()-pos2-1);
94 if (verbose) std::cout <<
"Parsing SOFIE model " << modelName <<
" of type " << fileType << std::endl;
98 std::string parserCode;
102 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with ONNX since libROOTTMVASofieParser is missing");
104 gInterpreter->Declare(
"#include \"TMVA/RModelParser_ONNX.hxx\"");
105 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModelParser_ONNX parser ; \n";
107 parserCode +=
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path +
"\",true); \n";
109 parserCode +=
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path +
"\"); \n";
111 else if (
type == kKeras) {
114 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with Keras since libPyMVA is missing");
118 if (!inputShapes.empty() && ! inputShapes[0].empty())
119 batch_size = std::to_string(inputShapes[0][0]);
120 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path +
123 else if (
type == kPt) {
126 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since libPyMVA is missing");
128 if (inputShapes.size() == 0) {
129 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since the input tensor shape is missing and is needed by the PyTorch parser");
131 std::string inputShapesStr =
"{";
132 for (
unsigned int i = 0;
i < inputShapes.size();
i++) {
133 inputShapesStr +=
"{ ";
134 for (
unsigned int j = 0; j < inputShapes[
i].size(); j++) {
136 if (j < inputShapes[
i].
size()-1) inputShapesStr +=
", ";
138 inputShapesStr +=
"}";
139 if (
i < inputShapes.size()-1) inputShapesStr +=
", ";
141 inputShapesStr +=
"}";
142 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyTorch::Parse(\"" + path +
"\", "
143 + inputShapesStr +
"); \n";
145 else if (
type == kROOT) {
147 parserCode +=
"{\nauto fileRead = TFile::Open(\"" + path +
"\",\"READ\");\n";
148 parserCode +=
"TMVA::Experimental::SOFIE::RModel * modelPtr;\n";
149 parserCode +=
"auto keyList = fileRead->GetListOfKeys(); TString name;\n";
150 parserCode +=
"for (const auto&& k : *keyList) { \n";
151 parserCode +=
" TString cname = ((TKey*)k)->GetClassName(); if (cname==\"TMVA::Experimental::SOFIE::RModel\") name = k->GetName(); }\n";
152 parserCode +=
"fileRead->GetObject(name,modelPtr); fileRead->Close(); delete fileRead;\n";
153 parserCode +=
"TMVA::Experimental::SOFIE::RModel & model = *modelPtr;\n";
157 if (inputShapes.size() > 0 && inputShapes[0].size() > 0) {
158 batchSize = inputShapes[0][0];
159 if (batchSize < 1) batchSize = 1;
161 if (verbose) std::cout <<
"generating the code with batch size = " << batchSize <<
" ...\n";
163 parserCode +=
"model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
169 parserCode +=
"model.PrintRequiredInputTensors();\n";
170 parserCode +=
"model.PrintIntermediateTensors();\n";
171 parserCode +=
"model.PrintOutputTensors();\n";
174 parserCode +=
"{ auto p = new TMVA::Experimental::SOFIE::ROperator_Custom<float>(\""
175 + op.fOpName +
"\"," + op.fInputNames +
"," + op.fOutputNames +
"," + op.fOutputShapes +
",\"" + op.fFileName +
"\");\n";
176 parserCode +=
"std::unique_ptr<TMVA::Experimental::SOFIE::ROperator> op(p);\n";
177 parserCode +=
"model.AddOperator(std::move(op));\n}\n";
179 parserCode +=
"model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
183 parserCode +=
"model.PrintGenerated(); \n";
184 parserCode +=
"model.OutputGenerated();\n";
186 parserCode +=
"int nInputs = model.GetInputTensorNames().size();\n";
191 parserCode +=
"return nInputs;\n}\n";
193 if (verbose) std::cout <<
"//ParserCode being executed:\n" << parserCode << std::endl;
195 auto iret =
gROOT->ProcessLine(parserCode.c_str());
197 std::string msg =
"RSofieReader: error processing the parser code: \n" + parserCode;
198 throw std::runtime_error(msg);
202 throw std::runtime_error(
"RSofieReader does not yet support model with > 3 inputs");
206 std::string modelHeader = modelName +
".hxx";
207 if (verbose) std::cout <<
"compile generated code from file " <<modelHeader << std::endl;
209 std::string msg =
"RSofieReader: input header file " + modelHeader +
" is not existing";
210 throw std::runtime_error(msg);
212 if (verbose) std::cout <<
"Creating Inference function for model " << modelName << std::endl;
213 std::string declCode;
214 declCode +=
"#pragma cling optimize(2)\n";
215 declCode +=
"#include \"" + modelHeader +
"\"\n";
217 std::string sessionClassName =
"TMVA_SOFIE_" + modelName +
"::Session";
219 std::string uidName = uuid.
AsString();
220 uidName.erase(std::remove_if(uidName.begin(), uidName.end(),
221 [](
char const&
c ) ->
bool { return !std::isalnum(c); } ), uidName.end());
223 std::string sessionName =
"session_" + uidName;
224 declCode += sessionClassName +
" " + sessionName +
";";
226 if (verbose) std::cout <<
"//global session declaration\n" << declCode << std::endl;
230 std::string msg =
"RSofieReader: error compiling inference code and creating session class\n" + declCode;
231 throw std::runtime_error(msg);
237 std::stringstream ifuncCode;
238 std::string funcName =
"SofieInference_" + uidName;
239 ifuncCode <<
"std::vector<float> " + funcName +
"( void * ptr";
241 ifuncCode <<
", float * data" <<
i;
242 ifuncCode <<
") {\n";
243 ifuncCode <<
" " << sessionClassName <<
" * s = " <<
"(" << sessionClassName <<
"*) (ptr);\n";
244 ifuncCode <<
" return s->infer(";
246 if (
i>0) ifuncCode <<
",";
247 ifuncCode <<
"data" <<
i;
252 if (verbose) std::cout <<
"//Inference function code using global session instance\n"
253 << ifuncCode.str() << std::endl;
257 std::string msg =
"RSofieReader: error compiling inference function\n" + ifuncCode.str();
258 throw std::runtime_error(msg);
266 void AddCustomOperator(
const std::string &opName,
const std::string &inputNames,
const std::string & outputNames,
267 const std::string & outputShapes,
const std::string & fileName) {
268 if (
fInitialized) std::cout <<
"WARNING: Model is already loaded and initialised. It must be done after adding the custom operators" << std::endl;
269 fCustomOperators.push_back( {fileName, opName,inputNames, outputNames,outputShapes});