Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooONNXFunction.h
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 04/2026
5 *
6 * Copyright (c) 2026, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#ifndef RooFit_RooONNXFunction_h
14#define RooFit_RooONNXFunction_h
15
16#include <RooAbsReal.h>
17#include <RooListProxy.h>
18
19#include <any>
20
22public:
23 RooONNXFunction() = default;
24
25 RooONNXFunction(const char *name, const char *title, const std::vector<RooArgList> &inputTensors,
26 const std::string &onnxFile, const std::vector<std::string> &inputNames = {},
27 const std::vector<std::vector<int>> &inputShapes = {});
28
29 RooONNXFunction(const RooONNXFunction &other, const char *newName = nullptr);
30
31 TObject *clone(const char *newName) const override { return new RooONNXFunction(*this, newName); }
32
33 std::size_t nInputTensors() const { return _inputTensors.size(); }
34 RooArgList const &inputTensorList(int iTensor) const { return *(_inputTensors[iTensor]); }
35
36 std::string funcName() const
37 {
38 initialize();
39 return _funcName;
40 }
41 std::string outerWrapperName() const { return "TMVA_SOFIE_" + funcName() + "::roo_outer_wrapper"; }
42
43protected:
44 double evaluate() const override;
45
46private:
47 /// Build transient runtime backend on first use.
48 void initialize() const;
49
50 /// Gather current RooFit inputs into a contiguous feature buffer.
51 void fillInputBuffer() const;
52
53 struct RuntimeCache;
54
55 std::vector<std::unique_ptr<RooListProxy>> _inputTensors; ///< Inputs mapping to flattened input tensors.
56 std::vector<std::uint8_t> _onnxBytes; ///< Persisted ONNX model bytes.
57 mutable std::shared_ptr<RuntimeCache> _runtime; ///<! Transient runtime information.
58 mutable std::vector<float> _inputBuffer; ///<!
59 mutable std::string _funcName; ///<!
60
62};
63
64namespace RooFit::Detail {
65
67 std::any any;
68 void *ptr = nullptr;
69
70 template <class T>
71 void emplace()
72 {
73 any = std::make_any<T>();
74 ptr = std::any_cast<T>(&any);
75 }
76
77 void emplace(std::string const &typeName);
78};
79
80template <class Session_t>
81void doInferWithSessionVoidPtr(void *session, float const *input, float *out)
82{
83 doInfer(*reinterpret_cast<Session_t *>(session), input, out);
84}
85
86} // namespace RooFit::Detail
87
88#endif
#define ClassDefOverride(name, id)
Definition Rtypes.h:348
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
char name[80]
Definition TGX11.cxx:145
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:63
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooONNXFunction wraps an ONNX model as a RooAbsReal, allowing it to be used as a building block in li...
std::shared_ptr< RuntimeCache > _runtime
! Transient runtime information.
TObject * clone(const char *newName) const override
std::string _funcName
!
RooArgList const & inputTensorList(int iTensor) const
std::size_t nInputTensors() const
void initialize() const
Build transient runtime backend on first use.
std::vector< std::unique_ptr< RooListProxy > > _inputTensors
Inputs mapping to flattened input tensors.
std::string outerWrapperName() const
RooONNXFunction()=default
std::string funcName() const
std::vector< std::uint8_t > _onnxBytes
Persisted ONNX model bytes.
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
std::vector< float > _inputBuffer
!
void fillInputBuffer() const
Gather current RooFit inputs into a contiguous feature buffer.
Mother of all ROOT objects.
Definition TObject.h:42
void doInferWithSessionVoidPtr(void *session, float const *input, float *out)