Logo ROOT   6.12/07
Reference Guide
DataLoader.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 06/06/17
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /////////////////////////////////////////////////////////////
13 // Specializations of Copy functions for the DataLoader //
14 // specialized for the reference architecture. //
15 /////////////////////////////////////////////////////////////
16 
18 #include "TMVA/DataSetInfo.h"
19 
20 namespace TMVA {
21 namespace DNN {
22 
23 //______________________________________________________________________________
24 template <>
26 {
27  const TMatrixT<Real_t> &input = std::get<0>(fData);
28  Int_t m = matrix.GetNrows();
29  Int_t n = input.GetNcols();
30 
31  for (Int_t i = 0; i < m; i++) {
32  Int_t sampleIndex = *sampleIterator;
33  for (Int_t j = 0; j < n; j++) {
34  matrix(i, j) = static_cast<Real_t>(input(sampleIndex, j));
35  }
36  sampleIterator++;
37  }
38 }
39 
40 //______________________________________________________________________________
41 template <>
43  IndexIterator_t sampleIterator)
44 {
45  const TMatrixT<Real_t> &output = std::get<1>(fData);
46  Int_t m = matrix.GetNrows();
47  Int_t n = output.GetNcols();
48 
49  for (Int_t i = 0; i < m; i++) {
50  Int_t sampleIndex = *sampleIterator;
51  for (Int_t j = 0; j < n; j++) {
52  matrix(i, j) = static_cast<Real_t>(output(sampleIndex, j));
53  }
54  sampleIterator++;
55  }
56 }
57 
58 //______________________________________________________________________________
59 template <>
61  IndexIterator_t sampleIterator)
62 {
63  const TMatrixT<Real_t> &weights = std::get<2>(fData);
64  Int_t m = matrix.GetNrows();
65 
66  for (Int_t i = 0; i < m; i++) {
67  Int_t sampleIndex = *sampleIterator;
68  matrix(i, 0) = static_cast<Real_t>(weights(sampleIndex, 0));
69  sampleIterator++;
70  }
71 }
72 
73 //______________________________________________________________________________
74 template <>
76  IndexIterator_t sampleIterator)
77 {
78  const TMatrixT<Double_t> &input = std::get<0>(fData);
79  Int_t m = matrix.GetNrows();
80  Int_t n = input.GetNcols();
81 
82  for (Int_t i = 0; i < m; i++) {
83  Int_t sampleIndex = *sampleIterator;
84  for (Int_t j = 0; j < n; j++) {
85  matrix(i, j) = static_cast<Double_t>(input(sampleIndex, j));
86  }
87  sampleIterator++;
88  }
89 }
90 
91 //______________________________________________________________________________
92 template <>
94  IndexIterator_t sampleIterator)
95 {
96  const TMatrixT<Double_t> &output = std::get<1>(fData);
97  Int_t m = matrix.GetNrows();
98  Int_t n = output.GetNcols();
99 
100  for (Int_t i = 0; i < m; i++) {
101  Int_t sampleIndex = *sampleIterator;
102  for (Int_t j = 0; j < n; j++) {
103  matrix(i, j) = static_cast<Double_t>(output(sampleIndex, j));
104  }
105  sampleIterator++;
106  }
107 }
108 
109 //______________________________________________________________________________
110 template <>
112  IndexIterator_t sampleIterator)
113 {
114  const TMatrixT<Double_t> &output = std::get<2>(fData);
115  Int_t m = matrix.GetNrows();
116 
117  for (Int_t i = 0; i < m; i++) {
118  Int_t sampleIndex = *sampleIterator;
119  matrix(i, 0) = static_cast<Double_t>(output(sampleIndex, 0));
120  sampleIterator++;
121  }
122 }
123 
124 //______________________________________________________________________________
125 template <>
127 {
128  Event *event = nullptr;
129 
130  Int_t m = matrix.GetNrows();
131  Int_t n = event->GetNVariables();
132 
133  // Copy input variables.
134 
135  for (Int_t i = 0; i < m; i++) {
136  Int_t sampleIndex = *sampleIterator++;
137  event = std::get<0>(fData)[sampleIndex];
138  for (Int_t j = 0; j < n; j++) {
139  matrix(i, j) = event->GetValue(j);
140  }
141  }
142 }
143 
144 //______________________________________________________________________________
145 template <>
147 {
148  Event *event = std::get<0>(fData).front();
149  const DataSetInfo &info = std::get<1>(fData);
150  Int_t m = matrix.GetNrows();
151  Int_t n = matrix.GetNcols();
152 
153  for (Int_t i = 0; i < m; i++) {
154  Int_t sampleIndex = *sampleIterator++;
155  event = std::get<0>(fData)[sampleIndex];
156  for (Int_t j = 0; j < n; j++) {
157  // Classification
158  if (event->GetNTargets() == 0) {
159  if (n == 1) {
160  // Binary.
161  matrix(i, j) = (info.IsSignal(event)) ? 1.0 : 0.0;
162  } else {
163  // Multiclass.
164  matrix(i, j) = 0.0;
165  if (j == static_cast<Int_t>(event->GetClass())) {
166  matrix(i, j) = 1.0;
167  }
168  }
169  } else {
170  matrix(i, j) = static_cast<Real_t>(event->GetTarget(j));
171  }
172  }
173  }
174 }
175 
176 //______________________________________________________________________________
177 template <>
179 {
180  Event *event = std::get<0>(fData).front();
181  for (Int_t i = 0; i < matrix.GetNrows(); i++) {
182  Int_t sampleIndex = *sampleIterator++;
183  event = std::get<0>(fData)[sampleIndex];
184  matrix(i, 0) = event->GetWeight();
185  }
186 }
187 
188 //______________________________________________________________________________
189 template <>
191  IndexIterator_t sampleIterator)
192 {
193  Event *event = std::get<0>(fData).front();
194  Int_t m = matrix.GetNrows();
195  Int_t n = event->GetNVariables();
196 
197  // Copy input variables.
198 
199  for (Int_t i = 0; i < m; i++) {
200  Int_t sampleIndex = *sampleIterator++;
201  event = std::get<0>(fData)[sampleIndex];
202  for (Int_t j = 0; j < n; j++) {
203  matrix(i, j) = event->GetValue(j);
204  }
205  }
206 }
207 
208 //______________________________________________________________________________
209 template <>
211  IndexIterator_t sampleIterator)
212 {
213  Event *event = std::get<0>(fData).front();
214  const DataSetInfo &info = std::get<1>(fData);
215  Int_t m = matrix.GetNrows();
216  Int_t n = matrix.GetNcols();
217 
218  for (Int_t i = 0; i < m; i++) {
219  Int_t sampleIndex = *sampleIterator++;
220  event = std::get<0>(fData)[sampleIndex];
221  for (Int_t j = 0; j < n; j++) {
222  // Classification
223  if (event->GetNTargets() == 0) {
224  if (n == 1) {
225  // Binary.
226  matrix(i, j) = (info.IsSignal(event)) ? 1.0 : 0.0;
227  } else {
228  // Multiclass.
229  matrix(i, j) = 0.0;
230  if (j == static_cast<Int_t>(event->GetClass())) {
231  matrix(i, j) = 1.0;
232  }
233  }
234  } else {
235  matrix(i, j) = static_cast<Real_t>(event->GetTarget(j));
236  }
237  }
238  }
239 }
240 
241 //______________________________________________________________________________
242 template <>
244  IndexIterator_t sampleIterator)
245 {
246  Event *event = nullptr;
247 
248  for (Int_t i = 0; i < matrix.GetNrows(); i++) {
249  Int_t sampleIndex = *sampleIterator++;
250  event = std::get<0>(fData)[sampleIndex];
251  matrix(i, 0) = event->GetWeight();
252  }
253 }
254 
255 // Explicit instantiations.
260 
261 } // namespace DNN
262 } // namespace TMVA
auto * m
Definition: textangle.C:8
typename std::vector< size_t >::iterator IndexIterator_t
Definition: DataLoader.h:42
Int_t GetNcols() const
Definition: TMatrixTBase.h:125
int Int_t
Definition: RtypesCore.h:41
TMatrixT.
Definition: TMatrixDfwd.h:22
Class that contains all the data information.
Definition: DataSetInfo.h:60
TDataLoader.
Definition: DataLoader.h:79
Int_t GetNrows() const
Definition: TMatrixTBase.h:122
double Double_t
Definition: RtypesCore.h:55
float Real_t
Definition: RtypesCore.h:64
Abstract ClassifierFactory template that handles arbitrary types.
Bool_t IsSignal(const Event *ev) const
const Int_t n
Definition: legend1.C:16