Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBDT.cxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * *
5 * *
6 * Description: *
7 * *
8 * Authors: *
9 * Jonas Rembser (jonas.rembser@cern.ch) *
10 * *
11 * Copyright (c) 2024: *
12 * CERN, Switzerland *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (see tmva/doc/LICENSE) *
17 **********************************************************************************/
18
19#include <TMVA/RBDT.hxx>
20
21#include <ROOT/StringUtils.hxx>
22
23#include <TFile.h>
24#include <TSystem.h>
25
26#include <cmath>
27#include <fstream>
28#include <iostream>
29#include <sstream>
30#include <stdexcept>
31#include <stdlib.h>
32
33namespace {
34
35template <class Value_t>
36void softmaxTransformInplace(Value_t *out, int nOut)
37{
38 // Do softmax transformation inplace, mimicing exactly the Softmax function
39 // in the src/common/math.h source file of xgboost.
40 double norm = 0.;
41 Value_t wmax = *out;
42 for (int i = 1; i < nOut; ++i) {
43 wmax = std::max(out[i], wmax);
44 }
45 for (int i = 0; i < nOut; ++i) {
46 Value_t &x = out[i];
47 x = std::exp(x - wmax);
48 norm += x;
49 }
50 for (int i = 0; i < nOut; ++i) {
51 out[i] /= static_cast<float>(norm);
52 }
53}
54
55namespace util {
56
57inline bool isInteger(const std::string &s)
58{
59 if (s.empty() || ((!isdigit(s[0])) && (s[0] != '-') && (s[0] != '+')))
60 return false;
61
62 char *p;
63 strtol(s.c_str(), &p, 10);
64
65 return (*p == 0);
66}
67
68template <class NumericType>
69struct NumericAfterSubstrOutput {
70 explicit NumericAfterSubstrOutput()
71 {
72 value = 0;
73 found = false;
74 failed = true;
75 }
76 NumericType value;
77 bool found;
78 bool failed;
79 std::string rest;
80};
81
82template <class NumericType>
83inline NumericAfterSubstrOutput<NumericType> numericAfterSubstr(std::string const &str, std::string const &substr)
84{
85 std::string rest;
86 NumericAfterSubstrOutput<NumericType> output;
87 output.rest = str;
88
89 std::size_t found = str.find(substr);
90 if (found != std::string::npos) {
91 output.found = true;
92 std::stringstream ss(str.substr(found + substr.size(), str.size() - found + substr.size()));
93 ss >> output.value;
94 if (!ss.fail()) {
95 output.failed = false;
96 output.rest = ss.str();
97 }
98 }
99 return output;
100}
101
102} // namespace util
103
104} // namespace
105
107
108/// Compute model prediction on input RTensor
110{
111 std::size_t nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
112 const std::size_t rows = x.GetShape()[0];
113 const std::size_t cols = x.GetShape()[1];
114 RTensor<Value_t> y({rows, nOut}, MemoryLayout::ColumnMajor);
115 std::vector<Value_t> xRow(cols);
116 std::vector<Value_t> yRow(nOut);
117 for (std::size_t iRow = 0; iRow < rows; ++iRow) {
118 for (std::size_t iCol = 0; iCol < cols; ++iCol) {
119 xRow[iCol] = x({iRow, iCol});
120 }
121 ComputeImpl(xRow.data(), yRow.data());
122 for (std::size_t iOut = 0; iOut < nOut; ++iOut) {
123 y({iRow, iOut}) = yRow[iOut];
124 }
125 }
126 return y;
127}
128
130{
131 std::size_t nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
132 if (nOut == 1) {
133 throw std::runtime_error(
134 "Error in RBDT::softmax : binary classification models don't support softmax evaluation. Plase set "
135 "the number of classes in the RBDT-creating function if this is a multiclassification model.");
136 }
137
138 for (std::size_t i = 0; i < nOut; ++i) {
139 out[i] = fBaseScore + fBaseResponses[i];
140 }
141
142 int iRootIndex = 0;
143 for (int index : fRootIndices) {
144 do {
145 int r = fRightIndices[index];
146 int l = fLeftIndices[index];
147 index = array[fCutIndices[index]] < fCutValues[index] ? l : r;
148 } while (index > 0);
149 out[fTreeNumbers[iRootIndex] % nOut] += fResponses[-index];
150 ++iRootIndex;
151 }
152
153 softmaxTransformInplace(out, nOut);
154}
155
157{
158 std::size_t nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
159 if (nOut > 1) {
160 Softmax(array, out);
161 } else {
162 out[0] = EvaluateBinary(array);
163 if (fLogistic) {
164 out[0] = 1.0 / (1.0 + std::exp(-out[0]));
165 }
166 }
167}
168
170{
171 Value_t out = fBaseScore + fBaseResponses[0];
172
173 for (std::vector<int>::const_iterator indexIter = fRootIndices.begin(); indexIter != fRootIndices.end();
174 ++indexIter) {
175 int index = *indexIter;
176 do {
177 int r = fRightIndices[index];
178 int l = fLeftIndices[index];
179 index = array[fCutIndices[index]] < fCutValues[index] ? l : r;
180 } while (index > 0);
181 out += fResponses[-index];
182 }
183
184 return out;
185}
186
187/// RBDT uses a more efficient representation of the BDT in flat arrays. This
188/// function translates the indices to the RBDT indices. In RBDT, leaf nodes
189/// are stored in separate arrays. To encode this, the sign of the index is
190/// flipped.
191void TMVA::Experimental::RBDT::correctIndices(std::span<int> indices, IndexMap const &nodeIndices,
192 IndexMap const &leafIndices)
193{
194 for (int &idx : indices) {
195 auto foundNode = nodeIndices.find(idx);
196 if (foundNode != nodeIndices.end()) {
197 idx = foundNode->second;
198 continue;
199 }
200 auto foundLeaf = leafIndices.find(idx);
201 if (foundLeaf != leafIndices.end()) {
202 idx = -foundLeaf->second;
203 continue;
204 } else {
205 std::stringstream errMsg;
206 errMsg << "RBDT: something is wrong in the node structure - node with index " << idx << " doesn't exist";
207 throw std::runtime_error(errMsg.str());
208 }
209 }
210}
211
212void TMVA::Experimental::RBDT::terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPreviousLeaves,
213 IndexMap &nodeIndices, IndexMap &leafIndices, int &treesSkipped)
214{
215 correctIndices({ff.fRightIndices.begin() + nPreviousNodes, ff.fRightIndices.end()}, nodeIndices, leafIndices);
216 correctIndices({ff.fLeftIndices.begin() + nPreviousNodes, ff.fLeftIndices.end()}, nodeIndices, leafIndices);
217
218 if (nPreviousNodes != static_cast<int>(ff.fCutValues.size())) {
219 ff.fTreeNumbers.push_back(ff.fRootIndices.size() + treesSkipped);
220 ff.fRootIndices.push_back(nPreviousNodes);
221 } else {
222 int treeNumbers = ff.fRootIndices.size() + treesSkipped;
223 ++treesSkipped;
224 ff.fBaseResponses[treeNumbers % ff.fBaseResponses.size()] += ff.fResponses.back();
225 ff.fResponses.pop_back();
226 }
227
228 nodeIndices.clear();
229 leafIndices.clear();
230 nPreviousNodes = ff.fCutValues.size();
231 nPreviousLeaves = ff.fResponses.size();
232}
233
235 std::vector<std::string> &features, int nClasses,
236 bool logistic, Value_t baseScore)
237{
238 const std::string info = "constructing RBDT from " + txtpath + ": ";
239
240 if (gSystem->AccessPathName(txtpath.c_str())) {
241 throw std::runtime_error(info + "file does not exists");
242 }
243
244 std::ifstream file(txtpath.c_str());
245 return LoadText(file, features, nClasses, logistic, baseScore);
246}
247
248TMVA::Experimental::RBDT TMVA::Experimental::RBDT::LoadText(std::istream &file, std::vector<std::string> &features,
249 int nClasses, bool logistic, Value_t baseScore)
250{
251 const std::string info = "constructing RBDT from istream: ";
252
253 RBDT ff;
254 ff.fLogistic = logistic;
255 ff.fBaseScore = baseScore;
256 ff.fBaseResponses.resize(nClasses <= 2 ? 1 : nClasses);
257
258 int treesSkipped = 0;
259
260 int nVariables = 0;
261 std::unordered_map<std::string, int> varIndices;
262 bool fixFeatures = false;
263
264 if (!features.empty()) {
265 fixFeatures = true;
266 nVariables = features.size();
267 for (int i = 0; i < nVariables; ++i) {
268 varIndices[features[i]] = i;
269 }
270 }
271
272 std::string line;
273
274 IndexMap nodeIndices;
275 IndexMap leafIndices;
276
277 int nPreviousNodes = 0;
278 int nPreviousLeaves = 0;
279
280 while (std::getline(file, line)) {
281 std::size_t foundBegin = line.find("[");
282 std::size_t foundEnd = line.find("]");
283 if (foundBegin != std::string::npos) {
284 std::string subline = line.substr(foundBegin + 1, foundEnd - foundBegin - 1);
285 if (util::isInteger(subline) && !ff.fResponses.empty()) {
286 terminateTree(ff, nPreviousNodes, nPreviousLeaves, nodeIndices, leafIndices, treesSkipped);
287 } else if (!util::isInteger(subline)) {
288 std::stringstream ss(line);
289 int index;
290 ss >> index;
291 line = ss.str();
292
293 std::vector<std::string> splitstring = ROOT::Split(subline, "<");
294 std::string const &varName = splitstring[0];
295 Value_t cutValue;
296 {
297 std::stringstream ss1(splitstring[1]);
298 ss1 >> cutValue;
299 }
300 if (!varIndices.count(varName)) {
301 if (fixFeatures) {
302 throw std::runtime_error(info + "feature " + varName + " not in list of features");
303 }
304 varIndices[varName] = nVariables;
305 features.push_back(varName);
306 ++nVariables;
307 }
308 int yes;
309 int no;
310 util::NumericAfterSubstrOutput<int> output = util::numericAfterSubstr<int>(line, "yes=");
311 if (!output.failed) {
312 yes = output.value;
313 } else {
314 throw std::runtime_error(info + "problem while parsing the text dump");
315 }
316 output = util::numericAfterSubstr<int>(output.rest, "no=");
317 if (!output.failed) {
318 no = output.value;
319 } else {
320 throw std::runtime_error(info + "problem while parsing the text dump");
321 }
322
323 ff.fCutValues.push_back(cutValue);
324 ff.fCutIndices.push_back(varIndices[varName]);
325 ff.fLeftIndices.push_back(yes);
326 ff.fRightIndices.push_back(no);
327 std::size_t nNodeIndices = nodeIndices.size();
328 nodeIndices[index] = nNodeIndices + nPreviousNodes;
329 }
330
331 } else {
332 util::NumericAfterSubstrOutput<Value_t> output = util::numericAfterSubstr<Value_t>(line, "leaf=");
333 if (output.found) {
334 std::stringstream ss(line);
335 int index;
336 ss >> index;
337 line = ss.str();
338
339 ff.fResponses.push_back(output.value);
340 std::size_t nLeafIndices = leafIndices.size();
341 leafIndices[index] = nLeafIndices + nPreviousLeaves;
342 }
343 }
344 }
345 terminateTree(ff, nPreviousNodes, nPreviousLeaves, nodeIndices, leafIndices, treesSkipped);
346
347 if (nClasses > 2 && (ff.fRootIndices.size() + treesSkipped) % nClasses != 0) {
348 std::stringstream ss;
349 ss << "Error in RBDT construction : Forest has " << ff.fRootIndices.size()
350 << " trees, which is not compatible with " << nClasses << "classes!";
351 throw std::runtime_error(ss.str());
352 }
353
354 return ff;
355}
356
357TMVA::Experimental::RBDT::RBDT(const std::string &key, const std::string &filename)
358{
359 std::unique_ptr<TFile> file{TFile::Open(filename.c_str(), "READ")};
360 if (!file || file->IsZombie()) {
361 throw std::runtime_error("Failed to open input file " + filename);
362 }
363 auto *fromFile = file->Get<TMVA::Experimental::RBDT>(key.c_str());
364 if (!fromFile) {
365 throw std::runtime_error("No RBDT with name " + key);
366 }
367 *this = *fromFile;
368}
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t wmax
R__EXTERN TSystem * gSystem
Definition TSystem.h:555
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4089
std::vector< Value_t > fCutValues
Definition RBDT.hxx:83
static void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPreviousLeaves, IndexMap &nodeIndices, IndexMap &leafIndices, int &treesSkipped)
Definition RBDT.cxx:212
RBDT()=default
IO constructor (both for ROOT IO and LoadText()).
static void correctIndices(std::span< int > indices, IndexMap const &nodeIndices, IndexMap const &leafIndices)
RBDT uses a more efficient representation of the BDT in flat arrays.
Definition RBDT.cxx:191
std::vector< int > fRightIndices
Definition RBDT.hxx:85
std::unordered_map< int, int > IndexMap
Map from XGBoost to RBDT indices.
Definition RBDT.hxx:70
void Softmax(const Value_t *array, Value_t *out) const
Definition RBDT.cxx:129
std::vector< int > fTreeNumbers
Definition RBDT.hxx:87
Value_t EvaluateBinary(const Value_t *array) const
Definition RBDT.cxx:169
std::vector< Value_t > fResponses
Definition RBDT.hxx:86
std::vector< Value_t > fBaseResponses
Definition RBDT.hxx:88
Vector Compute(const Vector &x) const
Compute model prediction on a single event.
Definition RBDT.hxx:52
std::vector< unsigned int > fCutIndices
Definition RBDT.hxx:82
void ComputeImpl(const Value_t *array, Value_t *out) const
Definition RBDT.cxx:156
static RBDT LoadText(std::string const &txtpath, std::vector< std::string > &features, int nClasses, bool logistic, Value_t baseScore)
Definition RBDT.cxx:234
std::vector< int > fRootIndices
Definition RBDT.hxx:81
std::vector< int > fLeftIndices
Definition RBDT.hxx:84
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1296
TLine * line
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.
Definition RBDT.cxx:55
TLine l
Definition textangle.C:4
static void output()