57inline bool isInteger(
const std::string &s)
59 if (s.empty() || ((!isdigit(s[0])) && (s[0] !=
'-') && (s[0] !=
'+')))
63 strtol(s.c_str(), &p, 10);
68template <
class NumericType>
69struct NumericAfterSubstrOutput {
70 explicit NumericAfterSubstrOutput()
82template <
class NumericType>
83inline NumericAfterSubstrOutput<NumericType> numericAfterSubstr(std::string
const &str, std::string
const &substr)
86 NumericAfterSubstrOutput<NumericType> output;
89 std::size_t found = str.find(substr);
90 if (found != std::string::npos) {
92 std::stringstream ss(str.substr(found + substr.size(), str.size() - found + substr.size()));
95 output.failed =
false;
96 output.rest = ss.str();
112 const std::size_t rows =
x.GetShape()[0];
113 const std::size_t cols =
x.GetShape()[1];
115 std::vector<Value_t> xRow(cols);
116 std::vector<Value_t> yRow(nOut);
117 for (std::size_t iRow = 0; iRow < rows; ++iRow) {
118 for (std::size_t iCol = 0; iCol < cols; ++iCol) {
119 xRow[iCol] =
x({iRow, iCol});
122 for (std::size_t iOut = 0; iOut < nOut; ++iOut) {
123 y({iRow, iOut}) = yRow[iOut];
194 for (
int &idx : indices) {
195 auto foundNode = nodeIndices.find(idx);
196 if (foundNode != nodeIndices.end()) {
197 idx = foundNode->second;
200 auto foundLeaf = leafIndices.find(idx);
201 if (foundLeaf != leafIndices.end()) {
202 idx = -foundLeaf->second;
205 std::stringstream errMsg;
206 errMsg <<
"RBDT: something is wrong in the node structure - node with index " << idx <<
" doesn't exist";
207 throw std::runtime_error(errMsg.str());
235 std::vector<std::string> &features,
int nClasses,
236 bool logistic,
Value_t baseScore)
238 const std::string info =
"constructing RBDT from " + txtpath +
": ";
240 if (
gSystem->AccessPathName(txtpath.c_str())) {
241 throw std::runtime_error(info +
"file does not exists");
244 std::ifstream file(txtpath.c_str());
245 return LoadText(file, features, nClasses, logistic, baseScore);
249 int nClasses,
bool logistic,
Value_t baseScore)
251 const std::string info =
"constructing RBDT from istream: ";
258 int treesSkipped = 0;
261 std::unordered_map<std::string, int> varIndices;
262 bool fixFeatures =
false;
264 if (!features.empty()) {
266 nVariables = features.size();
267 for (
int i = 0; i < nVariables; ++i) {
268 varIndices[features[i]] = i;
277 int nPreviousNodes = 0;
278 int nPreviousLeaves = 0;
280 while (std::getline(file,
line)) {
281 std::size_t foundBegin =
line.find(
"[");
282 std::size_t foundEnd =
line.find(
"]");
283 if (foundBegin != std::string::npos) {
284 std::string subline =
line.substr(foundBegin + 1, foundEnd - foundBegin - 1);
285 if (util::isInteger(subline) && !ff.
fResponses.empty()) {
286 terminateTree(ff, nPreviousNodes, nPreviousLeaves, nodeIndices, leafIndices, treesSkipped);
287 }
else if (!util::isInteger(subline)) {
288 std::stringstream ss(
line);
293 std::vector<std::string> splitstring =
ROOT::Split(subline,
"<");
294 std::string
const &varName = splitstring[0];
297 std::stringstream ss1(splitstring[1]);
300 if (!varIndices.count(varName)) {
302 throw std::runtime_error(info +
"feature " + varName +
" not in list of features");
304 varIndices[varName] = nVariables;
305 features.push_back(varName);
310 util::NumericAfterSubstrOutput<int> output = util::numericAfterSubstr<int>(
line,
"yes=");
311 if (!output.failed) {
314 throw std::runtime_error(info +
"problem while parsing the text dump");
316 output = util::numericAfterSubstr<int>(output.rest,
"no=");
317 if (!output.failed) {
320 throw std::runtime_error(info +
"problem while parsing the text dump");
327 std::size_t nNodeIndices = nodeIndices.size();
328 nodeIndices[index] = nNodeIndices + nPreviousNodes;
332 util::NumericAfterSubstrOutput<Value_t> output = util::numericAfterSubstr<Value_t>(
line,
"leaf=");
334 std::stringstream ss(
line);
340 std::size_t nLeafIndices = leafIndices.size();
341 leafIndices[index] = nLeafIndices + nPreviousLeaves;
345 terminateTree(ff, nPreviousNodes, nPreviousLeaves, nodeIndices, leafIndices, treesSkipped);
347 if (nClasses > 2 && (ff.
fRootIndices.size() + treesSkipped) % nClasses != 0) {
348 std::stringstream ss;
349 ss <<
"Error in RBDT construction : Forest has " << ff.
fRootIndices.size()
350 <<
" trees, which is not compatible with " << nClasses <<
"classes!";
351 throw std::runtime_error(ss.str());
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
static void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPreviousLeaves, IndexMap &nodeIndices, IndexMap &leafIndices, int &treesSkipped)
static void correctIndices(std::span< int > indices, IndexMap const &nodeIndices, IndexMap const &leafIndices)
RBDT uses a more efficient representation of the BDT in flat arrays.
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.