Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RReader.hxx
Go to the documentation of this file.
1#ifndef TMVA_RREADER
2#define TMVA_RREADER
3
4#include "TString.h"
5#include "TXMLEngine.h"
6
7#include "TMVA/RTensor.hxx"
8#include "TMVA/Reader.h"
9
10#include <memory> // std::unique_ptr
11#include <sstream> // std::stringstream
12
13namespace TMVA {
14namespace Experimental {
15
16namespace Internal {
17
18/// Internal definition of analysis types
20
21/// Container for information extracted from TMVA XML config
22struct XMLConfig {
23 unsigned int numVariables;
24 std::vector<std::string> variables;
25 std::vector<std::string> expressions;
26 unsigned int numClasses;
27 std::vector<std::string> classes;
30 : numVariables(0), variables(std::vector<std::string>(0)), numClasses(0), classes(std::vector<std::string>(0)),
32 {
33 }
34};
35
36/// Parse TMVA XML config
37inline XMLConfig ParseXMLConfig(const std::string &filename)
38{
40
41 // Parse XML file and find root node
42 TXMLEngine xml;
43 auto xmldoc = xml.ParseFile(filename.c_str());
44 if (xmldoc == 0) {
45 std::stringstream ss;
46 ss << "Failed to open TMVA XML file "
47 << filename << ".";
48 throw std::runtime_error(ss.str());
49 }
50 auto mainNode = xml.DocGetRootElement(xmldoc);
51 for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
52 const auto nodeName = std::string(xml.GetNodeName(node));
53 // Read out input variables
54 if (nodeName.compare("Variables") == 0) {
55 c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
56 c.variables = std::vector<std::string>(c.numVariables);
57 c.expressions = std::vector<std::string>(c.numVariables);
58 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
59 const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
60 c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
61 c.expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
62 }
63 }
64 // Read out output classes
65 else if (nodeName.compare("Classes") == 0) {
66 c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
67 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
68 c.classes.push_back(xml.GetAttr(thisNode, "Name"));
69 }
70 }
71 // Read out analysis type
72 else if (nodeName.compare("GeneralInfo") == 0) {
73 std::string analysisType = "";
74 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
75 if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
76 analysisType = xml.GetAttr(thisNode, "value");
77 }
78 }
79 if (analysisType.compare("Classification") == 0) {
81 } else if (analysisType.compare("Regression") == 0) {
83 } else if (analysisType.compare("Multiclass") == 0) {
85 }
86 }
87 }
88 xml.FreeDoc(xmldoc);
89
90 // Error-handling
91 if (c.numVariables != c.variables.size() || c.numVariables == 0) {
92 std::stringstream ss;
93 ss << "Failed to parse input variables from TMVA config " << filename << ".";
94 throw std::runtime_error(ss.str());
95 }
96 if (c.numClasses != c.classes.size() || c.numClasses == 0) {
97 std::stringstream ss;
98 ss << "Failed to parse output classes from TMVA config " << filename << ".";
99 throw std::runtime_error(ss.str());
100 }
101 if (c.analysisType == Internal::AnalysisType::Undefined) {
102 std::stringstream ss;
103 ss << "Failed to parse analysis type from TMVA config " << filename << ".";
104 throw std::runtime_error(ss.str());
105 }
106
107 return c;
108}
109
110} // namespace Internal
111
112/// TMVA::Reader legacy interface
113class RReader {
114private:
115 std::unique_ptr<Reader> fReader;
116 std::vector<float> fValues;
117 std::vector<std::string> fVariables;
118 std::vector<std::string> fExpressions;
119 unsigned int fNumClasses;
120 const char *name = "RReader";
122
123public:
124 /// Create TMVA model from XML file
125 RReader(const std::string &path)
126 {
127 // Load config
128 auto c = Internal::ParseXMLConfig(path);
129 fVariables = c.variables;
130 fExpressions = c.expressions;
131 fAnalysisType = c.analysisType;
132 fNumClasses = c.numClasses;
133
134 // Setup reader
135 fReader = std::make_unique<Reader>("Silent");
136 const auto numVars = fVariables.size();
137 fValues = std::vector<float>(numVars);
138 for (std::size_t i = 0; i < numVars; i++) {
139 fReader->AddVariable(TString(fExpressions[i]), &fValues[i]);
140 }
141 fReader->BookMVA(name, path.c_str());
142 }
143
144 /// Compute model prediction on vector
145 std::vector<float> Compute(const std::vector<float> &x)
146 {
147 if (x.size() != fVariables.size())
148 throw std::runtime_error("Size of input vector is not equal to number of variables.");
149
150 // Copy over inputs to memory used by TMVA reader
151 for (std::size_t i = 0; i < x.size(); i++) {
152 fValues[i] = x[i];
153 }
154
155 // Take lock to protect model evaluation
157
158 // Evaluate TMVA model
159 // Classification
161 return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
162 }
163 // Regression
165 return fReader->EvaluateRegression(name);
166 }
167 // Multiclass
169 return fReader->EvaluateMulticlass(name);
170 }
171 // Throw error
172 else {
173 throw std::runtime_error("RReader has undefined analysis type.");
174 return std::vector<float>();
175 }
176 }
177
178 /// Compute model prediction on input RTensor
180 {
181 // Error-handling for input tensor
182 const auto shape = x.GetShape();
183 if (shape.size() != 2)
184 throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
185
186 const auto numEntries = shape[0];
187 const auto numVars = shape[1];
188 if (numVars != fVariables.size())
189 throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
190
191 // Define shape of output tensor based on analysis type
192 unsigned int numClasses = 1;
194 numClasses = fNumClasses;
195 RTensor<float> y({numEntries * numClasses});
197 y = y.Reshape({numEntries, numClasses});
198
199 // Fill output tensor
200 for (std::size_t i = 0; i < numEntries; i++) {
201 for (std::size_t j = 0; j < numVars; j++) {
202 fValues[j] = x(i, j);
203 }
205 // Classification
207 y(i) = fReader->EvaluateMVA(name);
208 }
209 // Regression
211 y(i) = fReader->EvaluateRegression(name)[0];
212 }
213 // Multiclass
215 const auto p = fReader->EvaluateMulticlass(name);
216 for (std::size_t k = 0; k < numClasses; k++)
217 y(i, k) = p[k];
218 }
219 }
220
221 return y;
222 }
223
224 std::vector<std::string> GetVariableNames() { return fVariables; }
225};
226
227} // namespace Experimental
228} // namespace TMVA
229
230#endif // TMVA_RREADER
#define c(i)
Definition RSha256.hxx:101
#define R__WRITE_LOCKGUARD(mutex)
TMVA::Reader legacy interface.
Definition RReader.hxx:113
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor.
Definition RReader.hxx:179
std::vector< float > fValues
Definition RReader.hxx:116
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
Definition RReader.hxx:145
Internal::AnalysisType fAnalysisType
Definition RReader.hxx:121
std::vector< std::string > fExpressions
Definition RReader.hxx:118
std::vector< std::string > GetVariableNames()
Definition RReader.hxx:224
std::vector< std::string > fVariables
Definition RReader.hxx:117
RReader(const std::string &path)
Create TMVA model from XML file.
Definition RReader.hxx:125
std::unique_ptr< Reader > fReader
Definition RReader.hxx:115
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
RTensor< Value_t, Container_t > Reshape(const Shape_t &shape) const
Reshape tensor.
Definition RTensor.hxx:481
Basic string class.
Definition TString.h:136
XMLNodePointer_t GetChild(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
returns first child of xmlnode
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
const char * GetNodeName(XMLNodePointer_t xmlnode)
returns name of xmlnode
const char * GetAttr(XMLNodePointer_t xmlnode, const char *name)
returns value of attribute for xmlnode
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLNodePointer_t GetNext(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
return next to xmlnode node if realnode==kTRUE, any special nodes in between will be skipped
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
R__EXTERN TVirtualRWMutex * gCoreMutex
XMLConfig ParseXMLConfig(const std::string &filename)
Parse TMVA XML config.
Definition RReader.hxx:37
AnalysisType
Internal definition of analysis types.
Definition RReader.hxx:19
create variable transformations
Container for information extracted from TMVA XML config.
Definition RReader.hxx:22
std::vector< std::string > classes
Definition RReader.hxx:27
std::vector< std::string > expressions
Definition RReader.hxx:25
std::vector< std::string > variables
Definition RReader.hxx:24