Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA::Experimental::SOFIE::PyTorch::INTERNAL Namespace Reference

Typedefs

using PyTorchMethodMap = std::unordered_map< std::string, std::unique_ptr< ROperator >(*)(PyObject *fNode)>
 

Functions

std::unique_ptr< ROperatorMakePyTorchConv (PyObject *fNode)
 Prepares a ROperator_Conv object.
 
std::unique_ptr< ROperatorMakePyTorchGemm (PyObject *fNode)
 Prepares a ROperator_Gemm object.
 
std::unique_ptr< ROperatorMakePyTorchNode (PyObject *fNode)
 Prepares equivalent ROperator with respect to PyTorch ONNX node.
 
std::unique_ptr< ROperatorMakePyTorchRelu (PyObject *fNode)
 Prepares a ROperator_Relu object.
 
std::unique_ptr< ROperatorMakePyTorchSelu (PyObject *fNode)
 Prepares a ROperator_Selu object.
 
std::unique_ptr< ROperatorMakePyTorchSigmoid (PyObject *fNode)
 Prepares a ROperator_Sigmoid object.
 
std::unique_ptr< ROperatorMakePyTorchTranspose (PyObject *fNode)
 Prepares a ROperator_Transpose object.
 

Variables

const PyTorchMethodMap mapPyTorchNode
 

Typedef Documentation

◆ PyTorchMethodMap

using TMVA::Experimental::SOFIE::PyTorch::INTERNAL::PyTorchMethodMap = typedef std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(PyObject* fNode)>

Definition at line 55 of file RModelParser_PyTorch.cxx.

Function Documentation

◆ MakePyTorchConv()

std::unique_ptr< ROperator > TMVA::Experimental::SOFIE::PyTorch::INTERNAL::MakePyTorchConv ( PyObject fNode)

Prepares a ROperator_Conv object.

Parameters
[in]fNodePython PyTorch ONNX Graph node
Returns
Unique pointer to ROperator object

For Conv Operator of PyTorch's ONNX Graph, attributes like dilations, group, kernel shape, pads and strides are found, and are passed in instantiating the ROperator object with autopad default to NOTSET.

Definition at line 276 of file RModelParser_PyTorch.cxx.

◆ MakePyTorchGemm()

std::unique_ptr< ROperator > TMVA::Experimental::SOFIE::PyTorch::INTERNAL::MakePyTorchGemm ( PyObject fNode)

Prepares a ROperator_Gemm object.

Parameters
[in]fNodePython PyTorch ONNX Graph node
Returns
Unique pointer to ROperator object

For PyTorch's Linear layer having Gemm operation in its ONNX graph, the names of the input tensor, output tensor are extracted, and then are passed to instantiate a ROperator_Gemm object using the required attributes. fInputs is a list of tensor names, which includes the names of the input tensor and the weight tensors.

Definition at line 114 of file RModelParser_PyTorch.cxx.

◆ MakePyTorchNode()

std::unique_ptr< ROperator > TMVA::Experimental::SOFIE::PyTorch::INTERNAL::MakePyTorchNode ( PyObject fNode)

Prepares equivalent ROperator with respect to PyTorch ONNX node.

Parameters
[in]fNodePython PyTorch ONNX Graph node
Returns
unique pointer to ROperator object

Function searches for the passed PyTorch ONNX Graph node in the map, and calls the specific preparatory function, subsequently returning the ROperator object.

For developing new preparatory functions for supporting PyTorch ONNX Graph nodes in future, all one needs is to extract the required properties and attributes from the fNode dictionary which contains all the information about any PyTorch ONNX the ROperator object.

The fNode dictionary which holds all the information about a PyTorch ONNX Graph's node has following structure:-

dict fNode {  'nodeType'        : Type of node (operator)
              'nodeAttributes'  : Attributes of the node
              'nodeInputs'      : List of names of input tensors
              'nodeOutputs'     : List of names of output tensors
              'nodeDType'       : Data-type of the operator node
           }

Definition at line 93 of file RModelParser_PyTorch.cxx.

◆ MakePyTorchRelu()

std::unique_ptr< ROperator > TMVA::Experimental::SOFIE::PyTorch::INTERNAL::MakePyTorchRelu ( PyObject fNode)

Prepares a ROperator_Relu object.

Parameters
[in]fNodePython PyTorch ONNX Graph node
Returns
Unique pointer to ROperator object

For instantiating a ROperator_Relu object, the names of input & output tensors and the data-type of the Graph node are extracted.

Definition at line 160 of file RModelParser_PyTorch.cxx.

◆ MakePyTorchSelu()

std::unique_ptr< ROperator > TMVA::Experimental::SOFIE::PyTorch::INTERNAL::MakePyTorchSelu ( PyObject fNode)

Prepares a ROperator_Selu object.

Parameters
[in]fNodePython PyTorch ONNX Graph node
Returns
Unique pointer to ROperator object

For instantiating a ROperator_Selu object, the names of input & output tensors and the data-type of the Graph node are extracted.

Definition at line 187 of file RModelParser_PyTorch.cxx.

◆ MakePyTorchSigmoid()

std::unique_ptr< ROperator > TMVA::Experimental::SOFIE::PyTorch::INTERNAL::MakePyTorchSigmoid ( PyObject fNode)

Prepares a ROperator_Sigmoid object.

Parameters
[in]fNodePython PyTorch ONNX Graph node
Returns
Unique pointer to ROperator object

For instantiating a ROperator_Sigmoid object, the names of input & output tensors and the data-type of the Graph node are extracted.

Definition at line 213 of file RModelParser_PyTorch.cxx.

◆ MakePyTorchTranspose()

std::unique_ptr< ROperator > TMVA::Experimental::SOFIE::PyTorch::INTERNAL::MakePyTorchTranspose ( PyObject fNode)

Prepares a ROperator_Transpose object.

Parameters
[in]fNodePython PyTorch ONNX Graph node
Returns
Unique pointer to ROperator object

For Transpose Operator of PyTorch's ONNX Graph, the permute dimensions are found, and are passed in instantiating the ROperator object.

Definition at line 239 of file RModelParser_PyTorch.cxx.

Variable Documentation

◆ mapPyTorchNode

const PyTorchMethodMap TMVA::Experimental::SOFIE::PyTorch::INTERNAL::mapPyTorchNode
Initial value:
=
{
{"onnx::Gemm", &MakePyTorchGemm},
{"onnx::Conv", &MakePyTorchConv},
{"onnx::Relu", &MakePyTorchRelu},
{"onnx::Selu", &MakePyTorchSelu},
{"onnx::Sigmoid", &MakePyTorchSigmoid},
{"onnx::Transpose", &MakePyTorchTranspose}
}
std::unique_ptr< ROperator > MakePyTorchGemm(PyObject *fNode)
Prepares a ROperator_Gemm object.
std::unique_ptr< ROperator > MakePyTorchConv(PyObject *fNode)
Prepares a ROperator_Conv object.

Definition at line 57 of file RModelParser_PyTorch.cxx.