98 std::ifstream file(
filePath, std::ios::binary);
100 std::ostringstream os;
101 os <<
"failed to open file '" <<
filePath <<
"'";
102 throw std::runtime_error(os.str());
105 file.seekg(0, std::ios::end);
106 const std::streamsize
size = file.tellg();
107 file.seekg(0, std::ios::beg);
110 std::ostringstream os;
111 os <<
"file '" <<
filePath <<
"' is empty";
112 throw std::runtime_error(os.str());
115 std::vector<std::uint8_t>
bytes(
static_cast<std::size_t
>(
size));
116 file.read(
reinterpret_cast<char *
>(
bytes.data()),
size);
119 std::ostringstream os;
120 os <<
"error while reading file '" <<
filePath <<
"'";
121 throw std::runtime_error(os.str());
127template <
typename Fn>
130 static Fn fn =
nullptr;
131 static std::once_flag
flag;
133 std::call_once(
flag, [&] {
136 throw std::runtime_error(std::string(
"ROOT JIT Declare failed for code defining ") + name);
143 throw std::runtime_error(std::string(
"ROOT JIT failed to resolve symbol: ") +
name);
149 throw std::runtime_error(std::string(
"ROOT JIT produced null function pointer for: ") +
name);
159 return TString::Format(
"reinterpret_cast<%s>(0x%zx)", (
castType +
"*").c_str(),
reinterpret_cast<std::size_t
>(ptr))
172 using Func = void (*)(
void *,
float const *,
float *);
197 const std::string &
onnxFile,
const std::vector<std::string> & ,
198 const std::vector<std::vector<int>> & )
202 std::string
istr = std::to_string(i);
204 std::make_unique<RooListProxy>((
"!inputs_" +
istr).c_str(), (
"Input tensor " +
istr).c_str(),
this));
212 for (std::size_t i = 0; i <
other._inputTensors.size(); ++i) {
213 _inputTensors.emplace_back(std::make_unique<RooListProxy>(
"!inputs",
this, *
other._inputTensors[i]));
235 _runtime = std::make_unique<RuntimeCache>();
240 throw std::runtime_error(
"RooONNXFunction: cannot load ONNX file since SOFIE ONNX parser is missing."
241 " Please build ROOT with tmva-sofie=ON.");
243 using OnnxToCpp = std::string (*)(std::uint8_t
const *, std::size_t,
const char *);
246#include "TMVA/RModelParser_ONNX.hxx"
248std::string _RooONNXFunction_onnxToCppWithSofie(std::uint8_t const *onnxBytes, std::size_t onnxBytesSize, const char *outputName)
250 namespace SOFIE = TMVA::Experimental::SOFIE;
252 std::string buffer{reinterpret_cast<const char *>(onnxBytes), onnxBytesSize};
253 std::istringstream stream{buffer};
255 SOFIE::RModel rmodel = SOFIE::RModelParser_ONNX{}.Parse(stream, outputName);
256 rmodel.SetOptimizationLevel(SOFIE::OptimizationLevel::kBasic);
257 rmodel.Generate(SOFIE::Options::kNoWeightFile);
259 std::stringstream ss{};
260 rmodel.PrintGenerated(ss);
265 static int counter = 0;
266 _funcName =
"roo_onnx_func_" + std::to_string(counter);
289namespace %%NAMESPACE%% {
291float roo_inner_wrapper(Session const &session, float const *input)
294 doInfer(session, input, &out);
298float roo_wrapper(Session const &session, float const *input)
300 return roo_inner_wrapper(session, input);
303} // namespace %%NAMESPACE%%
312 std::stringstream
ss2;
313 ss2 <<
"static_cast<void (*)(void *, float const *, float *)>(RooFit::Detail::doInferWithSessionVoidPtr<"
321 gInterpreter->Declare(
"#include <Math/CladDerivator.h>");
326namespace %%NAMESPACE%% {
328double roo_outer_wrapper(double const *input) {
329 auto &session = *)" +
331 float inputFlt[inputTensorDims[0].total_size()];
332 for (std::size_t i = 0; i < std::size(inputFlt); ++i) {
333 inputFlt[i] = input[i];
335 return roo_inner_wrapper(session, inputFlt);
338} // namespace %%NAMESPACE%%
340namespace clad::custom_derivatives {
342namespace %%NAMESPACE%% {
344void roo_outer_wrapper_pullback(double const *input, double d_y, double *d_input) {
346 using namespace ::%%NAMESPACE%%;
348 float inputFlt[inputTensorDims[0].total_size()];
349 float d_inputFlt[::std::size(inputFlt)];
350 for (::std::size_t i = 0; i < ::std::size(inputFlt); ++i) {
351 inputFlt[i] = input[i];
352 d_inputFlt[i] = d_input[i];
356 auto *d_session = )" +
358 roo_inner_wrapper_pullback(*session, inputFlt, d_y, d_session, d_inputFlt);
359 for (::std::size_t i = 0; i < ::std::size(inputFlt); ++i) {
360 d_input[i] += d_inputFlt[i];
364} // namespace %%NAMESPACE%%
366} // namespace clad::custom_derivatives
378 return static_cast<double>(out);
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t bytes
R__EXTERN TSystem * gSystem
Abstract base class for objects that represent a real value and implements functionality common to al...
RooONNXFunction wraps an ONNX model as a RooAbsReal, allowing it to be used as a building block in li...
std::shared_ptr< RuntimeCache > _runtime
! Transient runtime information.
void initialize() const
Build transient runtime backend on first use.
std::vector< std::unique_ptr< RooListProxy > > _inputTensors
Inputs mapping to flattened input tensors.
RooONNXFunction()=default
std::vector< std::uint8_t > _onnxBytes
Persisted ONNX model bytes.
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
std::vector< float > _inputBuffer
!
void fillInputBuffer() const
Gather current RooFit inputs into a contiguous feature buffer.
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
virtual int Load(const char *module, const char *entry="", Bool_t system=kFALSE)
Load a shared library.
RooFit::Detail::AnyWithVoidPtr _session
RooFit::Detail::AnyWithVoidPtr _d_session