Logo ROOT  
Reference Guide
DataLoader.cxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 06/06/17
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/////////////////////////////////////////////////////////////
13// Specializations of Copy functions for the DataLoader //
14// specialized for the reference architecture. //
15/////////////////////////////////////////////////////////////
16
18#include "TMVA/DataSetInfo.h"
19
20namespace TMVA {
21namespace DNN {
22
23//______________________________________________________________________________
24template <>
26{
27 const TMatrixT<Real_t> &input = std::get<0>(fData);
28 Int_t m = matrix.GetNrows();
29 Int_t n = input.GetNcols();
30
31 for (Int_t i = 0; i < m; i++) {
32 Int_t sampleIndex = *sampleIterator;
33 for (Int_t j = 0; j < n; j++) {
34 matrix(i, j) = static_cast<Real_t>(input(sampleIndex, j));
35 }
36 sampleIterator++;
37 }
38}
39
40//______________________________________________________________________________
41template <>
43 IndexIterator_t sampleIterator)
44{
45 const TMatrixT<Real_t> &output = std::get<1>(fData);
46 Int_t m = matrix.GetNrows();
47 Int_t n = output.GetNcols();
48
49 for (Int_t i = 0; i < m; i++) {
50 Int_t sampleIndex = *sampleIterator;
51 for (Int_t j = 0; j < n; j++) {
52 matrix(i, j) = static_cast<Real_t>(output(sampleIndex, j));
53 }
54 sampleIterator++;
55 }
56}
57
58//______________________________________________________________________________
59template <>
61 IndexIterator_t sampleIterator)
62{
63 const TMatrixT<Real_t> &weights = std::get<2>(fData);
64 Int_t m = matrix.GetNrows();
65
66 for (Int_t i = 0; i < m; i++) {
67 Int_t sampleIndex = *sampleIterator;
68 matrix(i, 0) = static_cast<Real_t>(weights(sampleIndex, 0));
69 sampleIterator++;
70 }
71}
72
73//______________________________________________________________________________
74template <>
76 IndexIterator_t sampleIterator)
77{
78 const TMatrixT<Double_t> &input = std::get<0>(fData);
79 Int_t m = matrix.GetNrows();
80 Int_t n = input.GetNcols();
81
82 for (Int_t i = 0; i < m; i++) {
83 Int_t sampleIndex = *sampleIterator;
84 for (Int_t j = 0; j < n; j++) {
85 matrix(i, j) = static_cast<Double_t>(input(sampleIndex, j));
86 }
87 sampleIterator++;
88 }
89}
90
91//______________________________________________________________________________
92template <>
94 IndexIterator_t sampleIterator)
95{
96 const TMatrixT<Double_t> &output = std::get<1>(fData);
97 Int_t m = matrix.GetNrows();
98 Int_t n = output.GetNcols();
99
100 for (Int_t i = 0; i < m; i++) {
101 Int_t sampleIndex = *sampleIterator;
102 for (Int_t j = 0; j < n; j++) {
103 matrix(i, j) = static_cast<Double_t>(output(sampleIndex, j));
104 }
105 sampleIterator++;
106 }
107}
108
109//______________________________________________________________________________
110template <>
112 IndexIterator_t sampleIterator)
113{
114 const TMatrixT<Double_t> &output = std::get<2>(fData);
115 Int_t m = matrix.GetNrows();
116
117 for (Int_t i = 0; i < m; i++) {
118 Int_t sampleIndex = *sampleIterator;
119 matrix(i, 0) = static_cast<Double_t>(output(sampleIndex, 0));
120 sampleIterator++;
121 }
122}
123
124//______________________________________________________________________________
125template <>
127{
128 // short-circuit on empty
129 if (std::get<0>(fData).empty())
130 return;
131 Event *event = nullptr;
132
133 Int_t m = matrix.GetNrows();
134 Int_t n = matrix.GetNcols();
135
136 // Copy input variables.
137
138 for (Int_t i = 0; i < m; i++) {
139 Int_t sampleIndex = *sampleIterator++;
140 event = std::get<0>(fData)[sampleIndex];
141 for (Int_t j = 0; j < n; j++) {
142 matrix(i, j) = event->GetValue(j);
143 }
144 }
145}
146
147//______________________________________________________________________________
148template <>
150{
151 // short-circuit on empty
152 if (std::get<0>(fData).empty())
153 return;
154 Event *event = nullptr;
155 const DataSetInfo &info = std::get<1>(fData);
156 Int_t m = matrix.GetNrows();
157 Int_t n = matrix.GetNcols();
158
159 for (Int_t i = 0; i < m; i++) {
160 Int_t sampleIndex = *sampleIterator++;
161 event = std::get<0>(fData)[sampleIndex];
162 for (Int_t j = 0; j < n; j++) {
163 // Classification
164 if (event->GetNTargets() == 0) {
165 if (n == 1) {
166 // Binary.
167 matrix(i, j) = (info.IsSignal(event)) ? 1.0 : 0.0;
168 } else {
169 // Multiclass.
170 matrix(i, j) = 0.0;
171 if (j == static_cast<Int_t>(event->GetClass())) {
172 matrix(i, j) = 1.0;
173 }
174 }
175 } else {
176 matrix(i, j) = static_cast<Real_t>(event->GetTarget(j));
177 }
178 }
179 }
180}
181
182//______________________________________________________________________________
183template <>
185{
186 // short-circuit on empty
187 if (std::get<0>(fData).empty())
188 return;
189 Event *event = nullptr;
190 for (Int_t i = 0; i < matrix.GetNrows(); i++) {
191 Int_t sampleIndex = *sampleIterator++;
192 event = std::get<0>(fData)[sampleIndex];
193 matrix(i, 0) = event->GetWeight();
194 }
195}
196
197//______________________________________________________________________________
198template <>
200 IndexIterator_t sampleIterator)
201{
202 // short-circuit on empty
203 if (std::get<0>(fData).empty())
204 return;
205 Event *event = nullptr;
206 Int_t m = matrix.GetNrows();
207
208 // Copy input variables.
209
210 for (Int_t i = 0; i < m; i++) {
211 Int_t sampleIndex = *sampleIterator++;
212 event = std::get<0>(fData)[sampleIndex];
213 for (Int_t j = 0; j < static_cast<Int_t>(event ? event->GetNVariables() : 0); j++) {
214 matrix(i, j) = event->GetValue(j);
215 }
216 }
217}
218
219//______________________________________________________________________________
220template <>
222 IndexIterator_t sampleIterator)
223{
224 Event *event = nullptr;
225 const DataSetInfo &info = std::get<1>(fData);
226 Int_t m = matrix.GetNrows();
227 Int_t n = matrix.GetNcols();
228
229 for (Int_t i = 0; i < m; i++) {
230 Int_t sampleIndex = *sampleIterator++;
231 event = std::get<0>(fData)[sampleIndex];
232 for (Int_t j = 0; j < n; j++) {
233 // Classification
234 if (event->GetNTargets() == 0) {
235 if (n == 1) {
236 // Binary.
237 matrix(i, j) = (info.IsSignal(event)) ? 1.0 : 0.0;
238 } else {
239 // Multiclass.
240 matrix(i, j) = 0.0;
241 if (j == static_cast<Int_t>(event->GetClass())) {
242 matrix(i, j) = 1.0;
243 }
244 }
245 } else {
246 matrix(i, j) = static_cast<Real_t>(event->GetTarget(j));
247 }
248 }
249 }
250}
251
252//______________________________________________________________________________
253template <>
255 IndexIterator_t sampleIterator)
256{
257 Event *event = nullptr;
258
259 for (Int_t i = 0; i < matrix.GetNrows(); i++) {
260 Int_t sampleIndex = *sampleIterator++;
261 event = std::get<0>(fData)[sampleIndex];
262 matrix(i, 0) = event->GetWeight();
263 }
264}
265
266// Explicit instantiations.
271
272} // namespace DNN
273} // namespace TMVA
float Real_t
Definition: RtypesCore.h:68
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Class that contains all the data information.
Definition: DataSetInfo.h:62
Bool_t IsSignal(const Event *ev) const
Int_t GetNrows() const
Definition: TMatrixTBase.h:123
Int_t GetNcols() const
Definition: TMatrixTBase.h:126
TMatrixT.
Definition: TMatrixT.h:39
const Int_t n
Definition: legend1.C:16
typename std::vector< size_t >::iterator IndexIterator_t
Definition: DataLoader.h:42
create variable transformations
TMarker m
Definition: textangle.C:8
static void output()