Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RFunction_MLP.cxx
Go to the documentation of this file.
2
3
9
10namespace TMVA {
11namespace Experimental {
12namespace SOFIE {
13
14RFunction_MLP::RFunction_MLP(FunctionTarget target, Int_t numLayers, Activation activation_function, bool activate_final, GraphType gType):
15 RFunction_Update(target, gType), fNumLayers(numLayers), fActivationFunction(activation_function), fActivateFinal(activate_final) {
17 throw std::runtime_error("TMVA SOFIE GNN doesn't currently supports the provided activation function for " + fFuncName + " update.");
18 }
19
20 // assuming all the linear layers has a kernel and a bias initialized tensors
21 if(fActivateFinal) {
22 function_block->AddOutputTensorNameList({fFuncName+"Relu"+std::to_string(fNumLayers)});
23 } else {
24 function_block->AddOutputTensorNameList({fFuncName+"Gemm"+std::to_string(fNumLayers)});
25 }
26}
27
29
30 std::string fGemmInput;
32 std::unique_ptr<ROperator> op_concat;
33 op_concat.reset(new ROperator_Concat<float>(fInputTensors,1,0,fFuncName+"InputConcat"));
34 function_block->AddOperator(std::move(op_concat));
35 fGemmInput = fFuncName+"InputConcat";
36
38 fGemmInput = fInputTensors[0];
39 }
40
41 std::unique_ptr<ROperator> op_gemm;
42 for(int i=0; i<fNumLayers-1; ++i) {
43 op_gemm.reset(new ROperator_Gemm<float>(1.0,1.0,0,0,fGemmInput,UTILITY::Clean_name(fKernelTensors[i]),UTILITY::Clean_name(fBiasTensors[i]),fFuncName+"Gemm"+std::to_string(i)));
44 function_block->AddOperator(std::move(op_gemm));
45 fGemmInput = fFuncName+"Gemm"+i;
47 std::unique_ptr<ROperator> op_relu;
48 op_relu.reset(new ROperator_Relu<float>(fFuncName+"Gemm"+std::to_string(i), fFuncName+"Relu"+std::to_string(i)));
49 function_block->AddOperator(std::move(op_relu));
50 fGemmInput = fFuncName+"Relu"+i;
51
52 }
53 }
54
55 op_gemm.reset(new ROperator_Gemm<float>(1.0,1.0,0,0,fGemmInput,UTILITY::Clean_name(fKernelTensors.back()),UTILITY::Clean_name(fBiasTensors.back()),fFuncName+"Gemm"+std::to_string(fNumLayers)));
56 function_block->AddOperator(std::move(op_gemm));
57 if(fActivateFinal) {
59 std::unique_ptr<ROperator> op_relu;
60 op_relu.reset(new ROperator_Relu<float>(fFuncName+"Gemm"+std::to_string(fNumLayers), fFuncName+"Relu"+std::to_string(fNumLayers)));
61 function_block->AddOperator(std::move(op_relu));
62 }
63 }
64
65
66 if(fAddlOp.size()) {
67 for(auto &i:fAddlOp) {
68 std::unique_ptr<ROperator> tmp(i);
69 function_block->AddOperator(std::move(tmp));
70 }
71 }
72}
73
74void RFunction_MLP::AddLayerNormalization(int axis, float epsilon, size_t stashType, const std::string &nameX,
75 const std::string &nameScale, const std::string &nameB, const std::string &nameY) {
76 auto op_layerNorm = new ROperator_LayerNormalization<float>(axis, epsilon, stashType, nameX,
77 nameScale, nameB, nameY, "", "");
78 fAddlOp.push_back((op_layerNorm));
79}
80
81}
82}
83}
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t target
std::vector< std::string > fKernelTensors
RFunction_MLP(FunctionTarget target, Int_t numLayers, Activation activation_function=Activation::RELU, bool activate_final=false, GraphType gType=GraphType::GNN)
void AddLayerNormalization(int axis, float epsilon, size_t stashType, const std::string &nameX, const std::string &nameScale, const std::string &nameB, const std::string &nameY)
std::vector< std::string > fBiasTensors
std::shared_ptr< RModel > function_block
Definition RFunction.hxx:35
std::vector< std::string > fInputTensors
Definition RFunction.hxx:38
std::string Clean_name(std::string input_tensor_name)
create variable transformations
static int gType
double epsilon
Definition triangle.c:618