Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
TMVA_SOFIE_PyTorch.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_ml
3/// \notebook -nodraw
4/// This macro provides a simple example for the parsing of PyTorch .pt file
5/// into RModel object and further generating the .hxx header files for inference.
6///
7/// \macro_code
8/// \macro_output
9/// \author Sanjiban Sengupta
10
11using namespace TMVA::Experimental;
12
13TString pythonSrc = "\
14import torch\n\
15import torch.nn as nn\n\
16\n\
17model = nn.Sequential(\n\
18 nn.Linear(32,16),\n\
19 nn.ReLU(),\n\
20 nn.Linear(16,8),\n\
21 nn.ReLU()\n\
22 )\n\
23\n\
24criterion = nn.MSELoss()\n\
25optimizer = torch.optim.SGD(model.parameters(),lr=0.01)\n\
26\n\
27x=torch.randn(2,32)\n\
28y=torch.randn(2,8)\n\
29\n\
30for i in range(500):\n\
31 y_pred = model(x)\n\
32 loss = criterion(y_pred,y)\n\
33 optimizer.zero_grad()\n\
34 loss.backward()\n\
35 optimizer.step()\n\
36\n\
37model.eval()\n\
38m = torch.jit.script(model)\n\
39torch.jit.save(m,'PyTorchModel.pt')\n";
40
41
42void TMVA_SOFIE_PyTorch(){
43
44 // Running the Python script to generate PyTorch .pt file
45
46 TMacro m;
47 m.AddLine(pythonSrc);
48 m.SaveSource("make_pytorch_model.py");
49 gSystem->Exec("python3 make_pytorch_model.py");
50
51 // Parsing a PyTorch model requires the shape and data-type of input tensor
52 // Data-type of input tensor defaults to Float if not specified
53 std::vector<size_t> inputTensorShapeSequential{2, 32};
54 std::vector<std::vector<size_t>> inputShapesSequential{inputTensorShapeSequential};
55
56 // Parsing the saved PyTorch .pt file into RModel object
58
59 // Generating inference code
60 model.Generate();
61 model.OutputGenerated("PyTorchModel.hxx");
62
63 // Printing required input tensors
64 std::cout << "\n\n";
65 model.PrintRequiredInputTensors();
66
67 // Printing initialized tensors (weights)
68 std::cout << "\n\n";
69 model.PrintInitializedTensors();
70
71 // Printing intermediate tensors
72 std::cout << "\n\n";
73 model.PrintIntermediateTensors();
74
75 // Checking if tensor already exist in model
76 std::cout << "\n\nTensor \"0weight\" already exist: " << std::boolalpha << model.CheckIfTensorAlreadyExist("0weight")
77 << "\n\n";
78 std::vector<size_t> tensorShape = model.GetTensorShape("0weight");
79 std::cout << "Shape of tensor \"0weight\": ";
80 for (auto &it : tensorShape) {
81 std::cout << it << ",";
82 }
83 std::cout<<"\n\nData type of tensor \"0weight\": ";
84 SOFIE::ETensorType tensorType = model.GetTensorType("0weight");
85 std::cout<<SOFIE::ConvertTypeToString(tensorType);
86
87 //Printing generated inference code
88 std::cout<<"\n\n";
89 model.PrintGenerated();
90}
externTSystem * gSystem
Definition TSystem.h:582
Class supporting a collection of lines with C++ code.
Definition TMacro.h:31
Basic string class.
Definition TString.h:138
RModel Parse(std::string filepath, std::vector< std::vector< size_t > > inputShapes, std::vector< ETensorType > dtype)
Parser function for translating PyTorch .pt model into a RModel object.
std::string ConvertTypeToString(ETensorType type)
TMarker m
Definition textangle.C:8