Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_ONNX.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides a simple example for the parsing of ONNX files into
5/// RModel object and further generating the .hxx header files for inference.
6///
7/// \macro_code
8/// \macro_output
9/// \author Sanjiban Sengupta
10
11using namespace TMVA::Experimental;
12
13void TMVA_SOFIE_ONNX(std::string inputFile = ""){
14 if (inputFile.empty() )
15 inputFile = std::string(gROOT->GetTutorialsDir()) + "/tmva/Linear_16.onnx";
16
17 //Creating parser object to parse ONNX files
19 SOFIE::RModel model = parser.Parse(inputFile);
20
21 //Generating inference code
22 model.Generate();
23 // write the code in a file (by default Linear_16.hxx and Linear_16.dat
24 model.OutputGenerated();
25
26 //Printing required input tensors
28
29 //Printing initialized tensors (weights)
30 std::cout<<"\n\n";
32
33 //Printing intermediate tensors
34 std::cout<<"\n\n";
36
37 //Checking if tensor already exist in model
38 std::cout<<"\n\nTensor \"16weight\" already exist: "<<std::boolalpha<<model.CheckIfTensorAlreadyExist("16weight")<<"\n\n";
39 std::vector<size_t> tensorShape = model.GetTensorShape("16weight");
40 std::cout<<"Shape of tensor \"16weight\": ";
41 for(auto& it:tensorShape){
42 std::cout<<it<<",";
43 }
44 std::cout<<"\n\nData type of tensor \"16weight\": ";
45 SOFIE::ETensorType tensorType = model.GetTensorType("16weight");
46 std::cout<<SOFIE::ConvertTypeToString(tensorType);
47
48 //Printing generated inference code
49 std::cout<<"\n\n";
50 model.PrintGenerated();
51}
#define gROOT
Definition TROOT.h:404
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:70
void Generate(bool useSession=true, bool useWeightFile=true)
Definition RModel.cxx:175
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:91
void OutputGenerated(std::string filename="")
Definition RModel.cxx:525
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:49