Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
RModelParser_ONNX.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RMODELPARSER_ONNX
2#define TMVA_SOFIE_RMODELPARSER_ONNX
3
4#include "TMVA/RModel.hxx"
5
6#include <memory>
7#include <functional>
8#include <unordered_map>
9
10// forward declaration
11namespace onnx {
12class NodeProto;
13class GraphProto;
14class ModelProto;
15} // namespace onnx
16
17namespace TMVA {
18namespace Experimental {
19namespace SOFIE {
20
21class RModelParser_ONNX;
22
24 std::function<std::unique_ptr<ROperator>(RModelParser_ONNX & /*parser*/, const onnx::NodeProto & /*nodeproto*/)>;
26 std::function<std::unique_ptr<ROperator> (RModelParser_ONNX& /*parser*/, const onnx::NodeProto& /*firstnode*/, const onnx::NodeProto& /*secondnode*/)>;
27
29public:
30 struct OperatorsMapImpl;
31
32private:
33
34 bool fVerbose = false;
35 // Registered operators
36 std::unique_ptr<OperatorsMapImpl> fOperatorsMapImpl;
37 // Type of the tensors
38 std::unordered_map<std::string, ETensorType> fTensorTypeMap;
39 // flag list of fused operators
40 std::vector<bool> fFusedOperators;
41
42
43public:
44 // Register an ONNX operator
45 void RegisterOperator(const std::string &name, ParserFuncSignature func);
46
47 // Check if the operator is registered
48 bool IsRegisteredOperator(const std::string &name);
49
50 // List of registered operators (in alphabetical order)
51 std::vector<std::string> GetRegisteredOperators();
52
53 // Set the type of the tensor
54 void RegisterTensorType(const std::string & /*name*/, ETensorType /*type*/);
55
56 // Check if the type of the tensor is registered
57 bool IsRegisteredTensorType(const std::string & /*name*/);
58
59 // check verbosity
60 bool Verbose() const {
61 return fVerbose;
62 }
63
64 // Get the type of the tensor
65 ETensorType GetTensorType(const std::string &name);
66
67 // Parse the index'th node from the ONNX graph
68 std::unique_ptr<ROperator> ParseOperator(const size_t /*index*/, const onnx::GraphProto & /*graphproto*/,
69 const std::vector<size_t> & /*nodes*/, const std::vector<int> & /* children */);
70
71 // check a graph for missing operators
72 void CheckGraph(const onnx::GraphProto & g, int & level, std::map<std::string, int> & missingOperators);
73
74 // parse the ONNX graph
75 void ParseONNXGraph(RModel & model, const onnx::GraphProto & g, std::string name = "");
76
77 std::unique_ptr<onnx::ModelProto> LoadModel(std::string filename);
78
79public:
80
82
83 RModel Parse(std::string filename, bool verbose = false);
84
85 // check the model for missing operators - return false in case some operator implementation is missing
86 bool CheckModel(std::string filename, bool verbose = false);
87
89};
90
91} // namespace SOFIE
92} // namespace Experimental
93} // namespace TMVA
94
95#endif // TMVA_SOFIE_RMODELPARSER_ONNX
#define g(i)
Definition RSha256.hxx:105
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
void RegisterOperator(const std::string &name, ParserFuncSignature func)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &, const std::vector< int > &)
bool IsRegisteredOperator(const std::string &name)
void CheckGraph(const onnx::GraphProto &g, int &level, std::map< std::string, int > &missingOperators)
void ParseONNXGraph(RModel &model, const onnx::GraphProto &g, std::string name="")
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string filename, bool verbose=false)
void RegisterTensorType(const std::string &, ETensorType)
std::unique_ptr< onnx::ModelProto > LoadModel(std::string filename)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
bool CheckModel(std::string filename, bool verbose=false)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
create variable transformations