Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Split.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_Swish
2#define TMVA_SOFIE_ROPERATOR_Swish
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9
10namespace TMVA{
11namespace Experimental{
12namespace SOFIE{
13
14template <typename T>
15class ROperator_Split final : public ROperator
16{
17
18private:
19
20 std::string fNX;
21 std::string fNS;
22 std::vector<std::string> fNYs;
23 std::vector<std::vector<size_t>> fOutputShapes;
24
25
26public:
28 ROperator_Split(const std::string & nameX, const std::string & nameS, const std::vector<std::string> & namesY):
29 fNX(UTILITY::Clean_name(nameX)), fNS(UTILITY::Clean_name(nameS)){
30 fNYs.reserve(namesY.size());
31 for (auto & name : namesY)
32 fNYs.push_back(UTILITY::Clean_name(name));
33 }
34
35 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
36 return input;
37 }
38
39 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
40 auto ret = input; //suggest copy to compiler
41 return ret;
42 }
43
44 void Initialize(RModel& model){
45 if (model.CheckIfTensorAlreadyExist(fNX) == false){ //input must be a graph input, or already initialized intermediate tensor
46 throw std::runtime_error("TMVA SOFIE Split Op Input Tensor is not found in model");
47 }
48 auto inputShape = model.GetTensorShape(fNX);
49
50 // support now splitting only of 1D tensors and assuming tensor can be split in equal parts
51 //int splitAxis = 0; // assume split with zero axis
52 int nsplit = fNYs.size();
53 // support now only 1D tensor
54 if (inputShape.size() > 1)
55 throw std::runtime_error("TMVA SOFIE Split Op supports now only 1D tensors");
56 // support only equal splits
57 if (inputShape[0] % nsplit != 0)
58 throw std::runtime_error("TMVA SOFIE Split Op does not support splitting of " + ConvertShapeToString(inputShape)
59 + " into " + std::to_string(nsplit));
60
61 for (size_t i = 0; i < fNYs.size(); i++) {
62 std::vector<size_t> outputShape = { inputShape[0]/nsplit };
63 model.AddIntermediateTensor(fNYs[i], model.GetTensorType(fNX), outputShape);
64 fOutputShapes.push_back(outputShape); // need for generating code
65 }
66 if (model.Verbose()) {
67 std::cout << "Split - input shape " << ConvertShapeToString(inputShape) << " --> ";
68 for (auto & s : fOutputShapes)
69 std::cout << ConvertShapeToString(s) << " ";
70 std::cout << std::endl;
71 }
72 }
73
74
75 std::string Generate(std::string OpName){
76 OpName = "op_" + OpName;
77 if (fOutputShapes.empty()){
78 throw std::runtime_error("TMVA SOFIE Operator Split called to Generate without being initialized first");
79 }
80 std::stringstream out;
81 out << "\n//------ Split\n";
82 out << "size_t offset = 0;\n";
83 for (size_t i = 0; i < fNYs.size(); i++) {
85 out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
86 out << SP << SP << "tensor_" << fNYs[i] << "[id] = tensor_" << fNX <<"[offset+id];\n";
87 out << SP << "}\n";
88 if (i < fNYs.size()-1) out << SP << "offset += " << length << ";\n";
89 }
90 return out.str();
91 }
92
93};
94
95}//SOFIE
96}//Experimental
97}//TMVA
98
99
100#endif //TMVA_SOFIE_ROPERATOR_Swish
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
char name[80]
Definition TGX11.cxx:110
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:94
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:203
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::vector< std::vector< size_t > > fOutputShapes
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
ROperator_Split(const std::string &nameX, const std::string &nameS, const std::vector< std::string > &namesY)
std::string Generate(std::string OpName)
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:41
std::string Clean_name(std::string input_tensor_name)
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations