Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_PyTorch.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides a simple example for the parsing of PyTorch .pt file
5/// into RModel object and further generating the .hxx header files for inference.
6///
7/// \macro_code
8/// \macro_output
9/// \author Sanjiban Sengupta
10
11using namespace TMVA::Experimental;
12
13TString pythonSrc = "\
14import torch\n\
15import torch.nn as nn\n\
16\n\
17model = nn.Sequential(\n\
18 nn.Linear(32,16),\n\
19 nn.ReLU(),\n\
20 nn.Linear(16,8),\n\
21 nn.ReLU()\n\
22 )\n\
23\n\
24criterion = nn.MSELoss()\n\
25optimizer = torch.optim.SGD(model.parameters(),lr=0.01)\n\
26\n\
27x=torch.randn(2,32)\n\
28y=torch.randn(2,8)\n\
29\n\
30for i in range(500):\n\
31 y_pred = model(x)\n\
32 loss = criterion(y_pred,y)\n\
33 optimizer.zero_grad()\n\
34 loss.backward()\n\
35 optimizer.step()\n\
36\n\
37model.eval()\n\
38m = torch.jit.script(model)\n\
39torch.jit.save(m,'PyTorchModel.pt')\n";
40
41
42void TMVA_SOFIE_PyTorch(){
43
44 //Running the Python script to generate PyTorch .pt file
46
47 TMacro m;
48 m.AddLine(pythonSrc);
49 m.SaveSource("make_pytorch_model.py");
50 gSystem->Exec(TMVA::Python_Executable() + " make_pytorch_model.py");
51
52 //Parsing a PyTorch model requires the shape and data-type of input tensor
53 //Data-type of input tensor defaults to Float if not specified
54 std::vector<size_t> inputTensorShapeSequential{2,32};
55 std::vector<std::vector<size_t>> inputShapesSequential{inputTensorShapeSequential};
56
57 //Parsing the saved PyTorch .pt file into RModel object
58 SOFIE::RModel model = SOFIE::PyTorch::Parse("PyTorchModel.pt",inputShapesSequential);
59
60 //Generating inference code
61 model.Generate();
62 model.OutputGenerated("PyTorchModel.hxx");
63
64 //Printing required input tensors
65 std::cout<<"\n\n";
67
68 //Printing initialized tensors (weights)
69 std::cout<<"\n\n";
71
72 //Printing intermediate tensors
73 std::cout<<"\n\n";
75
76 //Checking if tensor already exist in model
77 std::cout<<"\n\nTensor \"0weight\" already exist: "<<std::boolalpha<<model.CheckIfTensorAlreadyExist("0weight")<<"\n\n";
78 std::vector<size_t> tensorShape = model.GetTensorShape("0weight");
79 std::cout<<"Shape of tensor \"0weight\": ";
80 for(auto& it:tensorShape){
81 std::cout<<it<<",";
82 }
83 std::cout<<"\n\nData type of tensor \"0weight\": ";
84 SOFIE::ETensorType tensorType = model.GetTensorType("0weight");
85 std::cout<<SOFIE::ConvertTypeToString(tensorType);
86
87 //Printing generated inference code
88 std::cout<<"\n\n";
89 model.PrintGenerated();
90}
R__EXTERN TSystem * gSystem
Definition TSystem.h:559
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:70
void Generate(bool useSession=true, bool useWeightFile=true)
Definition RModel.cxx:175
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:91
void OutputGenerated(std::string filename="")
Definition RModel.cxx:525
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:49
static void PyInitialize()
Initialize Python interpreter.
Class supporting a collection of lines with C++ code.
Definition TMacro.h:31
virtual TObjString * AddLine(const char *text)
Add line with text in the list of lines of this macro.
Definition TMacro.cxx:141
Basic string class.
Definition TString.h:136
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition TSystem.cxx:656
TString Python_Executable()
Function to find current Python executable used by ROOT If Python2 is installed return "python" Inste...
auto * m
Definition textangle.C:8