Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseConcat.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuncSignature ParseConcat = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
11 std::vector<std::string> inputs;
12 size_t size = nodeproto.input_size();
13 inputs.reserve(size);
14 for (int i = 0; i < nodeproto.input_size(); ++i) {
15 auto input_name = nodeproto.input(i);
16 if (parser.IsRegisteredTensorType(input_name)) {
17 if (i == 0)
18 input_type = parser.GetTensorType(input_name);
19 else
20 assert(parser.GetTensorType(input_name) == input_type);
21 } else {
22 throw std::runtime_error("TMVA::SOFIE ONNX Parser Concat op has input tensor" + input_name +
23 " but its type is not yet registered");
24 }
25 inputs.emplace_back(input_name);
26 }
27
28 std::unique_ptr<ROperator> op;
29 std::string output_name = nodeproto.output(0);
30
31 int attr_axis = 0;
32 int attr_new_axis = 0;
33 for (int_t i = 0; i < nodeproto.attribute_size(); i++) {
34 std::string attribute_name = nodeproto.attribute(i).name();
35 if (attribute_name == "axis")
36 attr_axis = nodeproto.attribute(i).i();
37 else if (attribute_name == "new_axis") // this is for ConcatFromSequence (that is equivalent to np.stack)
38 attr_new_axis = nodeproto.attribute(i).i();
39 }
40 switch (input_type) {
41 case ETensorType::FLOAT: op.reset(new ROperator_Concat<float>(inputs, attr_axis, attr_new_axis, output_name)); break;
42 default:
43 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Concat does not yet support input type " +
44 std::to_string(static_cast<int>(input_type)));
45 }
46
47 if (!parser.IsRegisteredTensorType(output_name)) {
48 parser.RegisterTensorType(output_name, input_type);
49 }
50 return op;
51};
52
53} // namespace SOFIE
54} // namespace Experimental
55} // namespace TMVA
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
ParserFuncSignature ParseConcat
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
create variable transformations