26template <
typename AFloat>
31 const AFloat *dataOutput =
output.GetRawDataPointer();
37 auto f = [&dataY, &dataOutput, &dataWeights, &temp,
m](
UInt_t workerID) {
38 AFloat dy = dataY[workerID] - dataOutput[workerID];
39 temp[workerID] = dataWeights[workerID %
m] * dy * dy;
43 auto reduction = [](
const std::vector<AFloat> &
v )
45 return std::accumulate(
v.begin(),
v.end(),AFloat{});
53template <
typename AFloat>
60 const AFloat *dataOutput =
output.GetRawDataPointer();
66 auto f = [&dataDY, &dataY, &dataOutput, &dataWeights,
m, norm](
UInt_t workerID) {
67 dataDY[workerID] = -2.0 * norm * (dataY[workerID] - dataOutput[workerID]);
68 dataDY[workerID] *= dataWeights[workerID %
m];
76template <
typename AFloat>
81 const AFloat *dataOutput =
output.GetRawDataPointer();
88 auto f = [&dataY, &dataOutput, &dataWeights, &temp,
m](
UInt_t workerID) {
89 AFloat
y = dataY[workerID];
93 AFloat
x = dataOutput[workerID];
100 lr = std::log(1. + exp(-
x));
103 temp[workerID] =
y * lr + (1.0 -
y) * (
x +lr);
105 temp[workerID] *= dataWeights[workerID %
m];
109 auto reduction = [](
const std::vector<AFloat> &
v )
111 return std::accumulate(
v.begin(),
v.end(),AFloat{});
119template <
typename AFloat>
125 const AFloat *dataOutput =
output.GetRawDataPointer();
131 auto f = [&dataDY, &dataY, &dataOutput, &dataWeights,
m, norm](
UInt_t workerID) {
132 AFloat
y = dataY[workerID];
133 AFloat sig = 1.0 / (1.0 + exp(- dataOutput[workerID]));
134 dataDY[workerID] = norm * (sig -
y);
135 dataDY[workerID] *= dataWeights[workerID %
m];
143template <
typename AFloat>
148 const AFloat *dataOutput =
output.GetRawDataPointer();
151 std::vector<AFloat> temp(Y.
GetNrows());
154 AFloat norm = 1.0 / ((AFloat)
m);
156 auto f = [&dataY, &dataOutput, &dataWeights, &temp,
n,
m](
UInt_t workerID) {
158 for (
size_t j = 0; j <
n; j++) {
159 sum += exp(dataOutput[workerID + j *
m]);
161 for (
size_t j = 0; j <
n; j++) {
163 dataY[workerID + j *
m] * log(exp(dataOutput[workerID + j *
m]) /
sum);
165 temp[workerID] *= dataWeights[workerID];
169 auto reduction = [](
const std::vector<AFloat> &
v )
171 return std::accumulate(
v.begin(),
v.end(),AFloat{});
179template <
typename AFloat>
185 const AFloat *dataOutput =
output.GetRawDataPointer();
190 AFloat norm = 1.0 / ((AFloat)
m);
192 auto f = [&dataDY, &dataY, &dataOutput, &dataWeights, norm,
n,
m](
UInt_t workerID) {
195 AFloat weight = dataWeights[workerID];
196 for (
size_t j = 0; j <
n; j++) {
197 sum += exp(dataOutput[workerID + j *
m]);
198 sumY += dataY[workerID + j *
m];
200 for (
size_t j = 0; j <
n; j++) {
201 dataDY[workerID + j *
m] =
202 norm * (exp(dataOutput[workerID + j *
m]) /
sum * sumY - dataY[workerID + j *
m]);
203 dataDY[workerID + j *
m] *= weight;
A pseudo container class which is a generator of indices.
AFloat * GetRawDataPointer()
Return raw pointer to the elements stored contiguously in column-major order.
static Executor & GetThreadExecutor()
size_t GetNoElements() const
static void SoftmaxCrossEntropyGradients(Matrix_t &dY, const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
static void CrossEntropyGradients(Matrix_t &dY, const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
static void MeanSquaredErrorGradients(Matrix_t &dY, const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
static Scalar_t MeanSquaredError(const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
static Scalar_t CrossEntropy(const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
Sigmoid transformation is implicitly applied, thus output should hold the linear activations of the l...
static Scalar_t SoftmaxCrossEntropy(const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
Softmax transformation is implicitly applied, thus output should hold the linear activations of the l...
auto Map(F func, unsigned nTimes) -> std::vector< InvokeResult_t< F > >
Wrap TExecutor::Map functions.
auto Reduce(const std::vector< T > &objs, R redfunc) -> decltype(redfunc(objs))
Wrap Reduce function.
create variable transformations
static uint64_t sum(uint64_t i)