Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_PyTorch.h
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Sanjiban Sengupta, 2021
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * *
8 * *
9 * Description: *
10 * Functionality for parsing a saved PyTorch .PT model into RModel object *
11 * *
12 * Authors (alphabetical): *
13 * Sanjiban Sengupta <sanjiban.sg@gmail.com> *
14 * *
15 * Copyright (c) 2021: *
16 * CERN, Switzerland *
17 * *
18 * *
19 * Redistribution and use in source and binary forms, with or without *
20 * modification, are permitted according to the terms listed in LICENSE *
21 * (see tmva/doc/LICENSE) *
22 **********************************************************************************/
23
24
25#ifndef TMVA_SOFIE_RMODELPARSER_PYTORCH
26#define TMVA_SOFIE_RMODELPARSER_PYTORCH
27
28#include "TMVA/RModel.hxx"
29#include "TMVA/SOFIE_common.hxx"
30#include "TMVA/Types.h"
31#include "TMVA/OperatorList.hxx"
32
33#include "TMVA/PyMethodBase.h"
34
35#include "Rtypes.h"
36#include "TString.h"
37
38
39namespace TMVA{
40namespace Experimental{
41namespace SOFIE{
42namespace PyTorch{
43
44/// Parser function for translating PyTorch .pt model into a RModel object.
45/// Accepts the file location of a PyTorch model, shapes and data-types of input tensors
46/// and returns the equivalent RModel object.
47RModel Parse(std::string filepath,std::vector<std::vector<size_t>> inputShapes, std::vector<ETensorType> dtype);
48
49/// Overloaded Parser function for translating PyTorch .pt model into a RModel object.
50/// Accepts the file location of a PyTorch model and the shapes of input tensors.
51/// Builds the vector of data-types for input tensors and calls the `Parse()` function to
52/// return the equivalent RModel object.
53RModel Parse(std::string filepath,std::vector<std::vector<size_t>> inputShapes);
54
55}//PyTorch
56}//SOFIE
57}//Experimental
58}//TMVA
59
60#endif //TMVA_PYMVA_RMODELPARSER_PYTORCH
RModel Parse(std::string filepath, std::vector< std::vector< size_t > > inputShapes, std::vector< ETensorType > dtype)
Parser function for translating PyTorch .pt model into a RModel object.
create variable transformations