16#ifndef TMVA_RSOFIEREADER 
   17#define TMVA_RSOFIEREADER 
   34namespace Experimental {
 
   51    RSofieReader(
const std::string &path, std::vector<std::vector<size_t>> inputShapes = {}, 
int verbose = 0)
 
   54      enum EModelType {kONNX, kKeras, kPt, kROOT, kNotDef}; 
 
   55      EModelType 
type = kNotDef;
 
   57      auto pos1 = path.rfind(
"/");
 
   58      auto pos2 = path.find(
".onnx");
 
   59      if (pos2 != std::string::npos) {
 
   62         pos2 = path.find(
".h5");
 
   63         if (pos2 != std::string::npos) {
 
   66            pos2 = path.find(
".pt");
 
   67            if (pos2 != std::string::npos) {
 
   71               pos2 = path.find(
".root");
 
   72               if (pos2 != std::string::npos) {
 
   78      if (
type == kNotDef) {
 
   79         throw std::runtime_error(
"Input file is not an ONNX or Keras or PyTorch file");
 
   81      if (pos1 == std::string::npos)
 
   85      std::string modelName = path.substr(pos1,pos2-pos1);
 
   86      std::string fileType = path.substr(pos2+1, path.length()-pos2-1);
 
   87      if (verbose) std::cout << 
"Parsing SOFIE model " << modelName << 
" of type " << fileType << std::endl;
 
   91      std::string parserCode;
 
   95            throw std::runtime_error(
"RSofieReader: cannot use SOFIE with ONNX since libROOTTMVASofieParser is missing");
 
   97         gInterpreter->Declare(
"#include \"TMVA/RModelParser_ONNX.hxx\"");
 
   98         parserCode += 
"{\nTMVA::Experimental::SOFIE::RModelParser_ONNX parser ; \n";
 
  100            parserCode += 
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + 
"\",true); \n";
 
  102            parserCode += 
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + 
"\"); \n";
 
  104      else if (
type == kKeras) {
 
  107            throw std::runtime_error(
"RSofieReader: cannot use SOFIE with Keras since libPyMVA is missing");
 
  109         parserCode += 
"{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path + 
"\"); \n";
 
  111      else if (
type == kPt) {
 
  114            throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since libPyMVA is missing");
 
  116         if (inputShapes.size() == 0) {
 
  117            throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since the input tensor shape is missing and is needed by the PyTorch parser");
 
  119         std::string inputShapesStr = 
"{";
 
  120         for (
unsigned int i = 0; i < inputShapes.size(); i++) {
 
  121            inputShapesStr += 
"{ ";
 
  122            for (
unsigned int j = 0; j < inputShapes[i].size(); j++) {
 
  124               if (j < inputShapes[i].
size()-1) inputShapesStr += 
", ";
 
  126            inputShapesStr += 
"}";
 
  127            if (i < inputShapes.size()-1) inputShapesStr += 
", ";
 
  129         inputShapesStr += 
"}";
 
  130         parserCode += 
"{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyTorch::Parse(\"" + path + 
"\", " 
  131                    + inputShapesStr + 
"); \n";
 
  133      else if (
type == kROOT) {
 
  135         parserCode += 
"{\nauto fileRead = TFile::Open(\"" + path + 
"\",\"READ\");\n";
 
  136         parserCode += 
"TMVA::Experimental::SOFIE::RModel * modelPtr;\n";
 
  137         parserCode += 
"auto keyList = fileRead->GetListOfKeys(); TString name;\n";
 
  138         parserCode += 
"for (const auto&& k : *keyList)  { \n";
 
  139         parserCode += 
"   TString cname =  ((TKey*)k)->GetClassName();  if (cname==\"TMVA::Experimental::SOFIE::RModel\") name = k->GetName(); }\n";
 
  140         parserCode += 
"fileRead->GetObject(name,modelPtr); fileRead->Close(); delete fileRead;\n";
 
  141         parserCode += 
"TMVA::Experimental::SOFIE::RModel & model = *modelPtr;\n";
 
  145      if (inputShapes.size() > 0 && inputShapes[0].size() > 0) {
 
  146         batchSize = inputShapes[0][0];
 
  147         if (batchSize < 1) batchSize = 1;
 
  149      if (verbose) std::cout << 
"generating the code with batch size = " << batchSize << 
" ...\n";
 
  150      parserCode += 
"model.Generate(TMVA::Experimental::SOFIE::Options::kDefault," 
  153         parserCode += 
"model.PrintGenerated(); \n";
 
  154      parserCode += 
"model.OutputGenerated();\n";
 
  157      parserCode += 
"return 1;\n }\n";
 
  159      if (verbose) std::cout << 
"//ParserCode being executed:\n" << parserCode << std::endl;
 
  161      auto iret = 
gROOT->ProcessLine(parserCode.c_str());
 
  163         std::string msg = 
"RSofieReader: error processing the parser code: \n" + parserCode;
 
  164         throw std::runtime_error(msg);
 
  168      std::string modelHeader = modelName + 
".hxx";
 
  169      if (verbose) std::cout << 
"compile generated code from file " <<modelHeader << std::endl;
 
  171         std::string msg = 
"RSofieReader: input header file " + modelHeader + 
" is not existing";
 
  172         throw std::runtime_error(msg);
 
  174      if (verbose) std::cout << 
"Creating Inference function for model " << modelName << std::endl;
 
  175      std::string declCode;
 
  176      declCode += 
"#pragma cling optimize(2)\n";
 
  177      declCode += 
"#include \"" + modelHeader + 
"\"\n";
 
  179      std::string sessionClassName = 
"TMVA_SOFIE_" + modelName + 
"::Session";
 
  181      std::string uidName = uuid.
AsString();
 
  182      uidName.erase(std::remove_if(uidName.begin(), uidName.end(),
 
  183         []( 
char const& 
c ) -> 
bool { return !std::isalnum(c); } ), uidName.end());
 
  185      std::string sessionName = 
"session_" + uidName;
 
  186      declCode += sessionClassName + 
" " + sessionName + 
";";
 
  188      if (verbose) std::cout << 
"//global session declaration\n" << declCode << std::endl;
 
  192         std::string msg = 
"RSofieReader: error compiling inference code and creating session class\n" + declCode;
 
  193         throw std::runtime_error(msg);
 
  199      std::stringstream ifuncCode;
 
  200      std::string funcName = 
"SofieInference_" + uidName;
 
  201      ifuncCode << 
"std::vector<float> " + funcName + 
"( void * ptr, float * data) {\n";
 
  202      ifuncCode << 
"   " << sessionClassName << 
" * s = " << 
"(" << sessionClassName << 
"*) (ptr);\n";
 
  203      ifuncCode << 
"   return s->infer(data);\n";
 
  206      if (verbose) std::cout << 
"//Inference function code using global session instance\n" 
  207                              << ifuncCode.str() << std::endl;
 
  211         std::string msg = 
"RSofieReader: error compiling inference function\n" + ifuncCode.str();
 
  212         throw std::runtime_error(msg);
 
  215      fFuncPtr = 
reinterpret_cast<std::vector<float> (*)(
void *, 
const float *)
>(fptr);
 
  220   std::vector<float> 
Compute(
const std::vector<float> &
x)
 
  223         return std::vector<float>();
 
  243      const auto rowsize = 
x.GetStrides()[0];
 
  250      for (
size_t i = 1; i < nrows; i++) {
 
  261   std::function<std::vector<float> (
void *, 
const float *)> 
fFuncPtr;
 
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
 
R__EXTERN TSystem * gSystem
 
#define R__WRITE_LOCKGUARD(mutex)
 
TMVA::RSofieReader class for reading external Machine Learning models in ONNX files,...
 
RSofieReader(const std::string &path, std::vector< std::vector< size_t > > inputShapes={}, int verbose=0)
Create TMVA model from ONNX file print level can be 0 (minimal) 1 with info , 2 with all ONNX parsing...
 
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor The shape of the input tensor should be {nevents,...
 
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
 
std::function< std::vector< float >(void *, const float *)> fFuncPtr
 
RTensor is a container with contiguous memory and shape information.
 
const Shape_t & GetShape() const
 
virtual int Load(const char *module, const char *entry="", Bool_t system=kFALSE)
Load a shared library.
 
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
 
This class defines a UUID (Universally Unique IDentifier), also known as GUIDs (Globally Unique IDent...
 
const char * AsString() const
Return UUID as string. Copy string immediately since it will be reused.
 
std::string ToString(const T &val)
Utility function for conversion to strings.
 
R__EXTERN TVirtualRWMutex * gCoreMutex
 
create variable transformations