16#ifndef TMVA_RSOFIEREADER
17#define TMVA_RSOFIEREADER
34namespace Experimental {
58 void Load(
const std::string &path, std::vector<std::vector<size_t>>
inputShapes = {},
int verbose = 0)
64 size_t pos2 = std::string::npos;
65 if ( (
pos2 = path.find(
".onnx")) != std::string::npos) {
66 if (verbose) std::cout <<
"input model type is ONNX" << std::endl;
68 }
else if ( (
pos2 = path.find(
".h5")) != std::string::npos || (
pos2 = path.find(
".keras")) != std::string::npos) {
69 if (verbose) std::cout <<
"input model type is Keras" << std::endl;
71 }
else if ( (
pos2 = path.find(
".pt")) != std::string::npos) {
72 if (verbose) std::cout <<
"input model type is PyTorch" << std::endl;
74 }
else if ( (
pos2 = path.find(
".root")) != std::string::npos) {
75 if (verbose) std::cout <<
"input model type is ROOT" << std::endl;
80 throw std::runtime_error(
"Input file is not an ONNX or Keras or PyTorch file");
82 auto pos1 = path.rfind(
"/");
83 if (
pos1 == std::string::npos)
89 if (verbose) std::cout <<
"Parsing SOFIE model " <<
modelName <<
" of type " <<
fileType << std::endl;
97 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with ONNX since libROOTTMVASofieParser is missing");
99 gInterpreter->Declare(
"#include \"TMVA/RModelParser_ONNX.hxx\"");
100 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModelParser_ONNX parser ; \n";
102 parserCode +=
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path +
"\",true); \n";
104 parserCode +=
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path +
"\"); \n";
108 if (
gSystem->
Load(
"libROOTTMVASofiePyParsers") < 0) {
109 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with Keras since libROOTTMVASofiePyParsers is missing");
115 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path +
120 if (
gSystem->
Load(
"libROOTTMVASofiePyParsers") < 0) {
121 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since libROOTTMVASofiePyParsers is missing");
124 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since the input tensor shape is missing and is needed by the PyTorch parser");
127 for (
unsigned int i = 0; i <
inputShapes.size(); i++) {
137 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyTorch::Parse(\"" + path +
"\", "
140 else if (
type == kROOT) {
142 parserCode +=
"{\nauto fileRead = TFile::Open(\"" + path +
"\",\"READ\");\n";
143 parserCode +=
"TMVA::Experimental::SOFIE::RModel * modelPtr;\n";
144 parserCode +=
"auto keyList = fileRead->GetListOfKeys(); TString name;\n";
145 parserCode +=
"for (const auto&& k : *keyList) { \n";
146 parserCode +=
" TString cname = ((TKey*)k)->GetClassName(); if (cname==\"TMVA::Experimental::SOFIE::RModel\") name = k->GetName(); }\n";
147 parserCode +=
"fileRead->GetObject(name,modelPtr); fileRead->Close(); delete fileRead;\n";
148 parserCode +=
"TMVA::Experimental::SOFIE::RModel & model = *modelPtr;\n";
155 parserCode +=
"{ auto p = new TMVA::Experimental::SOFIE::ROperator_Custom<float>(\""
156 +
op.fOpName +
"\"," +
op.fInputNames +
"," +
op.fOutputNames +
"," +
op.fOutputShapes +
",\"" +
op.fFileName +
"\");\n";
157 parserCode +=
"std::unique_ptr<TMVA::Experimental::SOFIE::ROperator> op(p);\n";
158 parserCode +=
"model.AddOperator(std::move(op));\n}\n";
165 if (batchSize < 1) batchSize = 1;
167 if (verbose) std::cout <<
"generating the code with batch size = " << batchSize <<
" ...\n";
169 parserCode +=
"model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
173 parserCode +=
"model.PrintRequiredInputTensors();\n";
174 parserCode +=
"model.PrintIntermediateTensors();\n";
175 parserCode +=
"model.PrintOutputTensors();\n";
182 parserCode +=
"model.PrintRequiredInputTensors();\n";
183 parserCode +=
"model.PrintIntermediateTensors();\n";
184 parserCode +=
"model.PrintOutputTensors();\n";
187 parserCode +=
"{ auto p = new TMVA::Experimental::SOFIE::ROperator_Custom<float>(\""
188 +
op.fOpName +
"\"," +
op.fInputNames +
"," +
op.fOutputNames +
"," +
op.fOutputShapes +
",\"" +
op.fFileName +
"\");\n";
189 parserCode +=
"std::unique_ptr<TMVA::Experimental::SOFIE::ROperator> op(p);\n";
190 parserCode +=
"model.AddOperator(std::move(op));\n}\n";
192 parserCode +=
"model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
200 parserCode +=
"int nInputs = model.GetInputTensorNames().size();\n";
207 if (verbose) std::cout <<
"//ParserCode being executed:\n" <<
parserCode << std::endl;
211 std::string
msg =
"RSofieReader: error processing the parser code: \n" +
parserCode;
212 throw std::runtime_error(
msg);
216 throw std::runtime_error(
"RSofieReader does not yet support model with > 3 inputs");
221 if (verbose) std::cout <<
"compile generated code from file " <<
modelHeader << std::endl;
223 std::string
msg =
"RSofieReader: input header file " +
modelHeader +
" is not existing";
224 throw std::runtime_error(
msg);
226 if (verbose) std::cout <<
"Creating Inference function for model " <<
modelName << std::endl;
228 declCode +=
"#pragma cling optimize(2)\n";
235 [](
char const&
c ) ->
bool { return !std::isalnum(c); } ),
uidName.
end());
240 if (verbose) std::cout <<
"//global session declaration\n" <<
declCode << std::endl;
244 std::string
msg =
"RSofieReader: error compiling inference code and creating session class\n" +
declCode;
245 throw std::runtime_error(
msg);
259 for (
int i = 0; i <
fNInputs; i++) {
266 if (verbose) std::cout <<
"//Inference function code using global session instance\n"
271 std::string
msg =
"RSofieReader: error compiling inference function\n" +
ifuncCode.str();
272 throw std::runtime_error(
msg);
281 const std::string &
outputShapes,
const std::string & fileName) {
282 if (
fInitialized) std::cout <<
"WARNING: Model is already loaded and initialised. It must be done after adding the custom operators" << std::endl;
289 std::string
msg =
"Wrong number of inputs - model requires " + std::to_string(
fNInputs);
290 throw std::runtime_error(
msg);
292 auto fptr =
reinterpret_cast<std::vector<float> (*)(
void *,
const float *)
>(
fFuncPtr);
295 std::vector<float>
DoCompute(
const std::vector<float> &
x1,
const std::vector<float> &
x2) {
297 std::string
msg =
"Wrong number of inputs - model requires " + std::to_string(
fNInputs);
298 throw std::runtime_error(
msg);
300 auto fptr =
reinterpret_cast<std::vector<float> (*)(
void *,
const float *,
const float *)
>(
fFuncPtr);
303 std::vector<float>
DoCompute(
const std::vector<float> &
x1,
const std::vector<float> &
x2,
const std::vector<float> &
x3) {
305 std::string
msg =
"Wrong number of inputs - model requires " + std::to_string(
fNInputs);
306 throw std::runtime_error(
msg);
308 auto fptr =
reinterpret_cast<std::vector<float> (*)(
void *,
const float *,
const float *,
const float *)
>(
fFuncPtr);
313 template<
typename... T>
317 return std::vector<float>();
327 std::vector<float>
Compute(
const std::vector<float> &
x) {
329 return std::vector<float>();
347 const auto nrows =
x.GetShape()[0];
348 const auto rowsize =
x.GetStrides()[0];
349 auto fptr =
reinterpret_cast<std::vector<float> (*)(
void *,
const float *)
>(
fFuncPtr);
356 for (
size_t i = 1; i <
nrows; i++) {
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char x2
Option_t Option_t TPoint TPoint const char x1
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
R__EXTERN TSystem * gSystem
#define R__WRITE_LOCKGUARD(mutex)
const_iterator begin() const
const_iterator end() const
TMVA::RSofieReader class for reading external Machine Learning models in ONNX files,...
RSofieReader(const std::string &path, std::vector< std::vector< size_t > > inputShapes={}, int verbose=0)
Create TMVA model from ONNX file print level can be 0 (minimal) 1 with info , 2 with all ONNX parsing...
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor The shape of the input tensor should be {nevents,...
std::vector< float > Compute(const std::vector< float > &x)
std::vector< float > Compute(T... x)
Compute model prediction on vector.
void Load(const std::string &path, std::vector< std::vector< size_t > > inputShapes={}, int verbose=0)
std::vector< float > DoCompute(const std::vector< float > &x1, const std::vector< float > &x2, const std::vector< float > &x3)
std::vector< CustomOperatorData > fCustomOperators
std::vector< float > DoCompute(const std::vector< float > &x1)
void AddCustomOperator(const std::string &opName, const std::string &inputNames, const std::string &outputNames, const std::string &outputShapes, const std::string &fileName)
std::vector< float > DoCompute(const std::vector< float > &x1, const std::vector< float > &x2)
RSofieReader()
Dummy constructor which needs model loading afterwards.
virtual int Load(const char *module, const char *entry="", Bool_t system=kFALSE)
Load a shared library.
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
This class defines a UUID (Universally Unique IDentifier), also known as GUIDs (Globally Unique IDent...
const char * AsString() const
Return UUID as string. Copy string immediately since it will be reused.
std::string ToString(const T &val)
Utility function for conversion to strings.
R__EXTERN TVirtualRWMutex * gCoreMutex
create variable transformations
std::string fOutputShapes