Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA::Experimental::SOFIE::PyTorch Namespace Reference

Namespaces

namespace  INTERNAL
 

Functions

RModel Parse (std::string filepath, std::vector< std::vector< size_t > > inputShapes)
 Overloaded Parser function for translating PyTorch .pt model into a RModel object.
 
RModel Parse (std::string filepath, std::vector< std::vector< size_t > > inputShapes, std::vector< ETensorType > dtype)
 Parser function for translating PyTorch .pt model into a RModel object.
 

Variables

static void(&) PyRunString (TString, PyObject *, PyObject *) = PyMethodBase::PyRunString
 
static const char *(&) PyStringAsString (PyObject *) = PyMethodBase::PyStringAsString
 

Function Documentation

◆ Parse() [1/2]

RModel TMVA::Experimental::SOFIE::PyTorch::Parse ( std::string  filepath,
std::vector< std::vector< size_t > >  inputShapes 
)

Overloaded Parser function for translating PyTorch .pt model into a RModel object.

Accepts the file location of a PyTorch model and the shapes of input tensors. Builds the vector of data-types for input tensors and calls the Parse() function to return the equivalent RModel object.

Parameters
[in]filenamefile location of PyTorch .pt model
[in]inputShapesvector of input shape vectors
Returns
Parsed RModel object

Overloaded Parser function for translating PyTorch .pt model to RModel object. Function only requires the inputShapes vector as a parameter. Function builds the vector of Data-types for the input tensors using Float as default, Function calls the Parse() function with the vector of data-types included, subsequently returning the parsed RModel object.

Definition at line 548 of file RModelParser_PyTorch.cxx.

◆ Parse() [2/2]

RModel TMVA::Experimental::SOFIE::PyTorch::Parse ( std::string  filename,
std::vector< std::vector< size_t > >  inputShapes,
std::vector< ETensorType inputDTypes 
)

Parser function for translating PyTorch .pt model into a RModel object.

Accepts the file location of a PyTorch model, shapes and data-types of input tensors and returns the equivalent RModel object.

Parameters
[in]filenamefile location of PyTorch .pt model
[in]inputShapesvector of input shape vectors
[in]inputDTypesvector of ETensorType for data-types of Input tensors
Returns
Parsed RModel object

The Parse() function defined in TMVA::Experimental::SOFIE::PyTorch will parse a trained PyTorch .pt model into a RModel Object. The parser uses internal functions of PyTorch to convert any PyTorch model into its equivalent ONNX Graph. For this conversion, dummy inputs are built which are passed through the model and the applied operators are recorded for populating the ONNX graph. The Parse() function requires the shapes and data-types of the input tensors which are used for building the dummy inputs. After the said conversion, the nodes of the ONNX graph are then traversed to extract properties like Node type, Attributes, input & output tensor names. Function AddOperator() is then called on the extracted nodes to add the operator into the RModel object. The nodes are also checked for adding any required routines for executing the generated Inference code.

The internal function used to convert the model to graph object returns a list which contains a Graph object and a dictionary of weights. This dictionary is used to extract the Initialized tensors for the model. The names and data-types of the Initialized tensors are extracted along with their values in NumPy array, and after approapriate type-conversions, they are added into the RModel object.

For adding the Input tensor infos, the names of the input tensors are extracted from the PyTorch ONNX graph object. The vector of shapes & data-types passed into the Parse() function are used to extract the data-type and the shape of the input tensors. Extracted input tensor infos are then added into the RModel object by calling the AddInputTensorInfo() function.

For the output tensor infos, names of the output tensors are also extracted from the Graph object and are then added into the RModel object by calling the AddOutputTensorNameList() function.

Example Usage:

//Building the vector of input tensor shapes
std::vector<size_t> s1{120,1};
std::vector<std::vector<size_t>> inputShape{s1};
RModel model = PyTorch::Parse("trained_model_dense.pt",inputShape);
#define s1(x)
Definition RSha256.hxx:91
RModel Parse(std::string filepath, std::vector< std::vector< size_t > > inputShapes, std::vector< ETensorType > dtype)
Parser function for translating PyTorch .pt model into a RModel object.

Definition at line 357 of file RModelParser_PyTorch.cxx.

Variable Documentation

◆ PyRunString

void(&) TMVA::Experimental::SOFIE::PyTorch::PyRunString(TString, PyObject *, PyObject *) ( TString  ,
PyObject ,
PyObject  
) = PyMethodBase::PyRunString
static

Definition at line 37 of file RModelParser_PyTorch.cxx.

◆ PyStringAsString

const char *(&) TMVA::Experimental::SOFIE::PyTorch::PyStringAsString(PyObject *) ( PyObject ) = PyMethodBase::PyStringAsString
static

Definition at line 38 of file RModelParser_PyTorch.cxx.