35template <
class Value_t>
42 for (
int i = 1; i <
nOut; ++i) {
45 for (
int i = 0; i <
nOut; ++i) {
50 for (
int i = 0; i <
nOut; ++i) {
57inline bool isInteger(
const std::string &s)
59 if (s.empty() || ((!
isdigit(s[0])) && (s[0] !=
'-') && (s[0] !=
'+')))
68template <
class NumericType>
69struct NumericAfterSubstrOutput {
70 explicit NumericAfterSubstrOutput()
82template <
class NumericType>
89 std::size_t found = str.find(
substr);
90 if (found != std::string::npos) {
92 std::stringstream
ss(str.substr(found +
substr.size(), str.size() - found +
substr.size()));
112 const std::size_t
rows =
x.GetShape()[0];
113 const std::size_t
cols =
x.GetShape()[1];
131 std::size_t
nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
133 throw std::runtime_error(
134 "Error in RBDT::softmax : binary classification models don't support softmax evaluation. Plase set "
135 "the number of classes in the RBDT-creating function if this is a multiclassification model.");
138 for (std::size_t i = 0; i <
nOut; ++i) {
139 out[i] = fBaseScore + fBaseResponses[i];
143 for (
int index : fRootIndices) {
145 int r = fRightIndices[
index];
146 int l = fLeftIndices[
index];
158 std::size_t
nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
162 out[0] = EvaluateBinary(array);
164 out[0] = 1.0 / (1.0 + std::exp(-out[0]));
171 Value_t out = fBaseScore + fBaseResponses[0];
177 int r = fRightIndices[
index];
178 int l = fLeftIndices[
index];
181 out += fResponses[-
index];
194 for (
int &idx : indices) {
206 errMsg <<
"RBDT: something is wrong in the node structure - node with index " << idx <<
" doesn't exist";
207 throw std::runtime_error(
errMsg.str());
224 ff.fBaseResponses[
treeNumbers %
ff.fBaseResponses.size()] +=
ff.fResponses.back();
225 ff.fResponses.pop_back();
238 const std::string
info =
"constructing RBDT from " +
txtpath +
": ";
241 throw std::runtime_error(
info +
"file does not exists");
244 std::ifstream file(
txtpath.c_str());
251 const std::string
info =
"constructing RBDT from istream: ";
261 std::unordered_map<std::string, int>
varIndices;
280 while (std::getline(file,
line)) {
285 if (util::isInteger(
subline) && !
ff.fResponses.empty()) {
287 }
else if (!util::isInteger(
subline)) {
288 std::stringstream
ss(
line);
302 throw std::runtime_error(
info +
"feature " +
varName +
" not in list of features");
310 util::NumericAfterSubstrOutput<int>
output = util::numericAfterSubstr<int>(
line,
"yes=");
314 throw std::runtime_error(
info +
"problem while parsing the text dump");
316 output = util::numericAfterSubstr<int>(
output.rest,
"no=");
320 throw std::runtime_error(
info +
"problem while parsing the text dump");
325 ff.fLeftIndices.push_back(
yes);
326 ff.fRightIndices.push_back(
no);
332 util::NumericAfterSubstrOutput<Value_t>
output = util::numericAfterSubstr<Value_t>(
line,
"leaf=");
334 std::stringstream
ss(
line);
339 ff.fResponses.push_back(
output.value);
348 std::stringstream
ss;
349 ss <<
"Error in RBDT construction : Forest has " <<
ff.fRootIndices.size()
350 <<
" trees, which is not compatible with " <<
nClasses <<
"classes!";
351 throw std::runtime_error(
ss.str());
360 if (!file || file->IsZombie()) {
361 throw std::runtime_error(
"Failed to open input file " +
filename);
365 throw std::runtime_error(
"No RBDT with name " + key);
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t wmax
R__EXTERN TSystem * gSystem
const_iterator begin() const
const_iterator end() const
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
static void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPreviousLeaves, IndexMap &nodeIndices, IndexMap &leafIndices, int &treesSkipped)
RBDT()=default
IO constructor (both for ROOT IO and LoadText()).
static void correctIndices(std::span< int > indices, IndexMap const &nodeIndices, IndexMap const &leafIndices)
RBDT uses a more efficient representation of the BDT in flat arrays.
std::unordered_map< int, int > IndexMap
Map from XGBoost to RBDT indices.
void Softmax(const Value_t *array, Value_t *out) const
Value_t EvaluateBinary(const Value_t *array) const
std::vector< Value_t > fBaseResponses
Vector Compute(const Vector &x) const
Compute model prediction on a single event.
void ComputeImpl(const Value_t *array, Value_t *out) const
static RBDT LoadText(std::string const &txtpath, std::vector< std::string > &features, int nClasses, bool logistic, Value_t baseScore)
RTensor is a container with contiguous memory and shape information.
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.