96std::vector<std::uint8_t> fileToBytes(std::string
const &filePath)
99 std::ifstream file(filePath, std::ios::binary);
101 std::ostringstream os;
102 os <<
"failed to open file '" << filePath <<
"'";
103 throw std::runtime_error(os.str());
106 file.seekg(0, std::ios::end);
107 const std::streamsize
size = file.tellg();
108 file.seekg(0, std::ios::beg);
111 std::ostringstream os;
112 os <<
"file '" << filePath <<
"' is empty";
113 throw std::runtime_error(os.str());
116 std::vector<std::uint8_t> bytes(
static_cast<std::size_t
>(
size));
117 file.read(
reinterpret_cast<char *
>(bytes.data()),
size);
120 std::ostringstream os;
121 os <<
"error while reading file '" << filePath <<
"'";
122 throw std::runtime_error(os.str());
128template <
typename Fn>
129Fn resolveLazy(std::string
const &
name,
const char *code)
131 static Fn fn =
nullptr;
132 static std::once_flag flag;
134 std::call_once(flag, [&] {
137 throw std::runtime_error(std::string(
"ROOT JIT Declare failed for code defining ") +
name);
141 void *symbol =
reinterpret_cast<void *
>(
gInterpreter->ProcessLine((
name +
";").c_str()));
144 throw std::runtime_error(std::string(
"ROOT JIT failed to resolve symbol: ") +
name);
147 fn =
reinterpret_cast<Fn
>(symbol);
150 throw std::runtime_error(std::string(
"ROOT JIT produced null function pointer for: ") +
name);
158std::string toPtrString(T *ptr, std::string
const &castType)
160 return TString::Format(
"reinterpret_cast<%s>(0x%zx)", (castType +
"*").c_str(),
reinterpret_cast<std::size_t
>(ptr))
167std::string flatOffsetExpr(std::size_t i)
172 for (std::size_t j = 0; j < i; ++j) {
175 out +=
"inputTensorDims[" + std::to_string(j) +
"].total_size()";
184 auto anyPtrSession = toPtrString(
this,
"RooFit::Detail::AnyWithVoidPtr");
185 gInterpreter->ProcessLine((anyPtrSession +
"->emplace<" + typeName +
">();").c_str());
191 using Func = void (*)(
void *,
float *,
float const *);
216 const std::string &onnxFile,
const std::vector<std::string> & ,
217 const std::vector<std::vector<int>> & )
222 for (std::size_t i = 0; i < inputTensors.size(); ++i) {
223 std::string istr = std::to_string(i);
225 std::make_unique<RooListProxy>((
"!inputs_" + istr).c_str(), (
"Input tensor " + istr).c_str(),
this));
233 for (std::size_t i = 0; i < other.
_inputTensors.size(); ++i) {
245 _inputBuffer.push_back(
static_cast<float>(real->getVal(tensorList->nset())));
256 _runtime = std::make_unique<RuntimeCache>();
260 if (
gSystem->Load(
"libROOTTMVASofieParser") < 0) {
261 throw std::runtime_error(
"RooONNXFunc: cannot load ONNX file since SOFIE ONNX parser is missing."
262 " Please build ROOT with tmva-sofie=ON.");
264 using OnnxToCpp = std::string (*)(std::uint8_t
const *, std::size_t,
const char *);
265 auto onnxToCppWithSofie = resolveLazy<OnnxToCpp>(
"_RooONNXFunc_onnxToCppWithSofie",
267#include "TMVA/RModelParser_ONNX.hxx"
269std::string _RooONNXFunc_onnxToCppWithSofie(std::uint8_t const *onnxBytes, std::size_t onnxBytesSize, const char *outputName)
271 namespace SOFIE = TMVA::Experimental::SOFIE;
273 std::string buffer{reinterpret_cast<const char *>(onnxBytes), onnxBytesSize};
274 std::istringstream stream{buffer};
276 SOFIE::RModel rmodel = SOFIE::RModelParser_ONNX{}.Parse(stream, outputName);
277 rmodel.SetOptimizationLevel(SOFIE::OptimizationLevel::kBasic);
278 rmodel.Generate(SOFIE::Options::kNoWeightFile);
280 std::stringstream ss{};
281 rmodel.PrintGenerated(ss);
286 static int counter = 0;
287 _funcName =
"roo_onnx_func_" + std::to_string(counter);
288 std::string namespaceName =
"TMVA_SOFIE_" +
_funcName +
"";
295 gInterpreter->ProcessLine((
"std::size(" + namespaceName +
"::inputTensorDims);").c_str()));
298 std::string innerParams;
299 std::string innerArgs;
300 std::string outerDoubleParams;
301 std::string cladInputs;
303 std::string istr = std::to_string(i);
307 outerDoubleParams +=
", ";
310 innerParams +=
"float const *input" + istr;
311 innerArgs +=
"input" + istr;
312 outerDoubleParams +=
"double const *input" + istr;
313 cladInputs +=
"input" + istr;
318 std::ostringstream ss;
319 ss <<
"namespace " << namespaceName <<
" {\n\n"
320 <<
"float roo_inner_wrapper(Session const &session, " << innerParams <<
") {\n"
321 <<
" float out = 0.;\n"
322 <<
" doInfer(session, " << innerArgs <<
", &out);\n"
325 <<
"float roo_wrapper(Session const &session, " << innerParams <<
") {\n"
326 <<
" return roo_inner_wrapper(session, " << innerArgs <<
");\n"
328 <<
"} // namespace " << namespaceName <<
"\n";
336 std::ostringstream ss;
337 ss <<
"namespace " << namespaceName <<
" {\n"
338 <<
"void roo_eval_thunk(void *session_void, float *out, float const *flat_input) {\n"
339 <<
" auto *session = reinterpret_cast<Session *>(session_void);\n"
340 <<
" doInfer(*session";
342 ss <<
", flat_input + (" << flatOffsetExpr(i) <<
")";
346 <<
"} // namespace " << namespaceName <<
"\n";
350 std::string sessionName =
"::TMVA_SOFIE_" +
_funcName +
"::Session";
352 _runtime->_session.emplace(sessionName);
353 auto ptrSession = toPtrString(
_runtime->_session.ptr, sessionName);
356 (
"static_cast<void(*)(void *, float *, float const *)>(" + namespaceName +
"::roo_eval_thunk);").c_str()));
359 _runtime->_d_session.emplace(sessionName);
360 auto ptrDSession = toPtrString(
_runtime->_d_session.ptr, sessionName);
362 gInterpreter->Declare(
"#include <Math/CladDerivator.h>");
364 gInterpreter->ProcessLine((
"clad::gradient(" + namespaceName +
"::roo_wrapper, \"" + cladInputs +
"\");").c_str());
370 std::ostringstream ss;
371 ss <<
"namespace " << namespaceName <<
" {\n\n"
372 <<
"double roo_outer_wrapper(" << outerDoubleParams <<
") {\n"
373 <<
" auto &session = *" << ptrSession <<
";\n";
375 ss <<
" float inputFlt" << i <<
"[inputTensorDims[" << i <<
"].total_size()];\n"
376 <<
" for (std::size_t i = 0; i < std::size(inputFlt" << i <<
"); ++i) {\n"
377 <<
" inputFlt" << i <<
"[i] = input" << i <<
"[i];\n"
380 ss <<
" return roo_inner_wrapper(session";
382 ss <<
", inputFlt" << i;
386 <<
"} // namespace " << namespaceName <<
"\n\n"
387 <<
"namespace clad::custom_derivatives {\n"
388 <<
"namespace " << namespaceName <<
" {\n\n"
389 <<
"void roo_outer_wrapper_pullback(" << outerDoubleParams <<
", double d_y";
391 ss <<
", double *d_input" << i;
394 <<
" using namespace ::" << namespaceName <<
";\n";
396 ss <<
" float inputFlt" << i <<
"[inputTensorDims[" << i <<
"].total_size()];\n"
397 <<
" float d_inputFlt" << i <<
"[::std::size(inputFlt" << i <<
")];\n"
398 <<
" for (::std::size_t i = 0; i < ::std::size(inputFlt" << i <<
"); ++i) {\n"
399 <<
" inputFlt" << i <<
"[i] = input" << i <<
"[i];\n"
400 <<
" d_inputFlt" << i <<
"[i] = 0;\n"
403 ss <<
" auto *session = " << ptrSession <<
";\n"
404 <<
" auto *d_session = " << ptrDSession <<
";\n"
405 <<
" roo_inner_wrapper_pullback(*session";
407 ss <<
", inputFlt" << i;
409 ss <<
", d_y, d_session";
411 ss <<
", d_inputFlt" << i;
415 ss <<
" for (::std::size_t i = 0; i < ::std::size(inputFlt" << i <<
"); ++i) {\n"
416 <<
" d_input" << i <<
"[i] += d_inputFlt" << i <<
"[i];\n"
420 <<
"} // namespace " << namespaceName <<
"\n"
421 <<
"} // namespace clad::custom_derivatives\n";
432 return static_cast<double>(out);
ROOT::RRangeCast< T, false, Range_t > static_range_cast(Range_t &&coll)
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
RooAbsReal()
coverity[UNINIT_CTOR] Default constructor
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
std::vector< std::uint8_t > _onnxBytes
Persisted ONNX model bytes.
std::shared_ptr< RuntimeCache > _runtime
! Transient runtime information.
std::vector< float > _inputBuffer
!
std::size_t nInputTensors() const
std::vector< std::unique_ptr< RooListProxy > > _inputTensors
Inputs mapping to flattened input tensors.
void Streamer(TBuffer &) override
Stream an object of class RooAbsArg.
void initialize()
Build transient runtime backend on first use.
void fillInputBuffer() const
Gather current RooFit inputs into a contiguous feature buffer.
Buffer base class used for serializing objects.
virtual Int_t ReadClassBuffer(const TClass *cl, void *pointer, const TClass *onfile_class=nullptr)=0
virtual Int_t WriteClassBuffer(const TClass *cl, void *pointer)=0
const char * Data() const
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
RooFit::Detail::AnyWithVoidPtr _d_session
void(*)(void *, float *, float const *) Func
Uniform thunk signature regardless of input-tensor count.
RooFit::Detail::AnyWithVoidPtr _session