Namespaces | |
namespace | INTERNAL |
Functions | |
RModel | Parse (std::string filepath, std::vector< std::vector< size_t > > inputShapes) |
Overloaded Parser function for translating PyTorch .pt model into a RModel object. | |
RModel | Parse (std::string filepath, std::vector< std::vector< size_t > > inputShapes, std::vector< ETensorType > dtype) |
Parser function for translating PyTorch .pt model into a RModel object. | |
Variables | |
static void(&) | PyRunString (TString, PyObject *, PyObject *) = PyMethodBase::PyRunString |
static const char *(&) | PyStringAsString (PyObject *) = PyMethodBase::PyStringAsString |
RModel TMVA::Experimental::SOFIE::PyTorch::Parse | ( | std::string | filepath, |
std::vector< std::vector< size_t > > | inputShapes | ||
) |
Overloaded Parser function for translating PyTorch .pt model into a RModel object.
Accepts the file location of a PyTorch model and the shapes of input tensors. Builds the vector of data-types for input tensors and calls the Parse()
function to return the equivalent RModel object.
[in] | filepath | file location of PyTorch .pt model |
[in] | inputShapes | vector of input shape vectors |
Overloaded Parser function for translating PyTorch .pt model to RModel object. Function only requires the inputShapes vector as a parameter. Function builds the vector of Data-types for the input tensors using Float as default, Function calls the Parse()
function with the vector of data-types included, subsequently returning the parsed RModel object.
Definition at line 557 of file RModelParser_PyTorch.cxx.
RModel TMVA::Experimental::SOFIE::PyTorch::Parse | ( | std::string | filename, |
std::vector< std::vector< size_t > > | inputShapes, | ||
std::vector< ETensorType > | inputDTypes | ||
) |
Parser function for translating PyTorch .pt model into a RModel object.
Accepts the file location of a PyTorch model, shapes and data-types of input tensors and returns the equivalent RModel object.
[in] | filename | file location of PyTorch .pt model |
[in] | inputShapes | vector of input shape vectors |
[in] | inputDTypes | vector of ETensorType for data-types of Input tensors |
The Parse()
function defined in TMVA::Experimental::SOFIE::PyTorch
will parse a trained PyTorch .pt model into a RModel Object. The parser uses internal functions of PyTorch to convert any PyTorch model into its equivalent ONNX Graph. For this conversion, dummy inputs are built which are passed through the model and the applied operators are recorded for populating the ONNX graph. The Parse()
function requires the shapes and data-types of the input tensors which are used for building the dummy inputs. After the said conversion, the nodes of the ONNX graph are then traversed to extract properties like Node type, Attributes, input & output tensor names. Function AddOperator()
is then called on the extracted nodes to add the operator into the RModel object. The nodes are also checked for adding any required routines for executing the generated Inference code.
The internal function used to convert the model to graph object returns a list which contains a Graph object and a dictionary of weights. This dictionary is used to extract the Initialized tensors for the model. The names and data-types of the Initialized tensors are extracted along with their values in NumPy array, and after approapriate type-conversions, they are added into the RModel object.
For adding the Input tensor infos, the names of the input tensors are extracted from the PyTorch ONNX graph object. The vector of shapes & data-types passed into the Parse()
function are used to extract the data-type and the shape of the input tensors. Extracted input tensor infos are then added into the RModel object by calling the AddInputTensorInfo()
function.
For the output tensor infos, names of the output tensors are also extracted from the Graph object and are then added into the RModel object by calling the AddOutputTensorNameList() function.
Example Usage:
Definition at line 357 of file RModelParser_PyTorch.cxx.
|
static |
Definition at line 37 of file RModelParser_PyTorch.cxx.
|
static |
Definition at line 38 of file RModelParser_PyTorch.cxx.