Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseBasicNary.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4#include <memory>
5
6namespace TMVA {
7namespace Experimental {
8namespace SOFIE {
9
10template<EBasicNaryOperator Op>
11std::unique_ptr<ROperator> ParseBasicNary(RModelParser_ONNX& parser, const onnx::NodeProto& nodeproto) {
13 std::vector<std::string> inputs;
14 size_t size = nodeproto.input_size();
15 inputs.reserve(size);
16 for (int i = 0; i < nodeproto.input_size(); ++i) {
17 auto input_name = nodeproto.input(i);
18 if (parser.IsRegisteredTensorType(input_name)) {
19 if (i == 0)
20 input_type = parser.GetTensorType(input_name);
21 else
22 assert(parser.GetTensorType(input_name) == input_type);
23 } else {
24 throw std::runtime_error("TMVA::SOFIE ONNX Parser Max op has input tensor" + input_name +
25 " but its type is not yet registered");
26 }
27 inputs.emplace_back(input_name);
28 }
29
30 std::unique_ptr<ROperator> op;
31 std::string output_name = nodeproto.output(0);
32
33 switch (input_type) {
34 case ETensorType::FLOAT: op.reset(new ROperator_BasicNary<float, Op>(inputs, output_name)); break;
35 default:
36 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Max does not yet support input type " + ConvertTypeToString(input_type));
37 }
38
39 if (!parser.IsRegisteredTensorType(output_name)) {
40 parser.RegisterTensorType(output_name, input_type);
41 }
42
43 return op;
44}
45
46
47ParserFuncSignature ParseMax = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
48 return ParseBasicNary<EBasicNaryOperator::Max>(parser, nodeproto);
49};
50
51ParserFuncSignature ParseMin= [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
52 return ParseBasicNary<EBasicNaryOperator::Min>(parser, nodeproto);
53};
54
55ParserFuncSignature ParseMean = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
56 return ParseBasicNary<EBasicNaryOperator::Mean>(parser, nodeproto);
57};
58
59ParserFuncSignature ParseSum = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
60 return ParseBasicNary<EBasicNaryOperator::Sum>(parser, nodeproto);
61};
62
63} // namespace SOFIE
64} // namespace Experimental
65} // namespace TMVA
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
ParserFuncSignature ParseMax
std::unique_ptr< ROperator > ParseBasicNary(RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
ParserFuncSignature ParseMean
ParserFuncSignature ParseSum
ParserFuncSignature ParseMin
create variable transformations