Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CladDerivator.h
Go to the documentation of this file.
1/// \file CladDerivator.h
2///
3/// \brief The file is a bridge between ROOT and clad automatic differentiation
4/// plugin.
5///
6/// \author Vassil Vassilev <vvasilev@cern.ch>
7///
8/// \date July, 2018
9
10/*************************************************************************
11 * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers. *
12 * All rights reserved. *
13 * *
14 * For the licensing terms see $ROOTSYS/LICENSE. *
15 * For the list of contributors see $ROOTSYS/README/CREDITS. *
16 *************************************************************************/
17
18#ifndef CLAD_DERIVATOR
19#define CLAD_DERIVATOR
20
21#ifndef __CLING__
22#error "This file must not be included by compiled programs."
23#endif //__CLING__
24
25#include <plugins/include/clad/Differentiator/Differentiator.h>
26#include "TMath.h"
27
28#include <stdexcept>
29
30namespace clad {
31namespace custom_derivatives {
32namespace TMath {
33template <typename T>
35{
36 return {::TMath::Abs(x), ((x < 0) ? -1 : 1) * d_x};
37}
38
39template <typename T>
41{
42 return {::TMath::ACos(x), (-1. / ::TMath::Sqrt(1 - x * x)) * d_x};
43}
44
45template <typename T>
47{
48 return {::TMath::ACosH(x), (1. / ::TMath::Sqrt(x * x - 1)) * d_x};
49}
50
51template <typename T>
53{
54 return {::TMath::ASin(x), (1. / ::TMath::Sqrt(1 - x * x)) * d_x};
55}
56
57template <typename T>
59{
60 return {::TMath::ASinH(x), (1. / ::TMath::Sqrt(x * x + 1)) * d_x};
61}
62
63template <typename T>
65{
66 return {::TMath::ATan(x), (1. / (x * x + 1)) * d_x};
67}
68
69template <typename T>
71{
72 return {::TMath::ATanH(x), (1. / (1 - x * x)) * d_x};
73}
74
75template <typename T>
80
81template <typename T>
86
87template <typename T>
92
93template <typename T>
98
99template <typename T>
101{
102 return {::TMath::LnGamma(z), ::clad::custom_derivatives::std::clad_digamma(z) * d_z};
103}
104
105template <typename T>
110
111template <typename T>
116
117template <typename T, typename U>
118void Hypot_pullback(T x, T y, U p, T *d_x, T *d_y)
119{
120 T h = ::TMath::Hypot(x, y);
121 *d_x += x / h * p;
122 *d_y += y / h * p;
123}
124
125template <typename T>
127{
128 return {::TMath::Log(x), (1. / x) * d_x};
129}
130
131template <typename T>
133{
134 return {::TMath::Log10(x), (1.0 / (x * ::TMath::Ln10())) * d_x};
135}
136
137template <typename T>
139{
140 return {::TMath::Log2(x), (1.0 / (x * ::TMath::Log(2.0))) * d_x};
141}
142
143template <typename T, typename U>
145{
146 T pushforward = y * ::TMath::Power(x, y - 1) * d_x;
147 if (d_y) {
149 }
150 return {::TMath::Power(x, y), pushforward};
151}
152
153template <typename T, typename U, typename V>
154void Power_pullback(T x, U y, V p, T *d_x, U *d_y)
155{
156 auto t = pow_pushforward(x, y, 1, 0);
157 *d_x += t.pushforward * p;
158 t = pow_pushforward(x, y, 0, 1);
159 *d_y += t.pushforward * p;
160}
161
162template <typename T>
167
168template <typename T>
173
174template <typename T>
176{
177 return {::TMath::Sq(x), 2 * x * d_x};
178}
179
180template <typename T>
185
186template <typename T>
191
192template <typename T>
197
198} // namespace TMath
199
200namespace ROOT {
201namespace Math {
202
203inline void landau_pdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
204{
205 if (xi <= 0) {
206 return;
207 }
208 // clang-format off
209 static double p1[5] = {0.4259894875,-0.1249762550, 0.03984243700, -0.006298287635, 0.001511162253};
210 static double q1[5] = {1.0 ,-0.3388260629, 0.09594393323, -0.01608042283, 0.003778942063};
211
212 static double p2[5] = {0.1788541609, 0.1173957403, 0.01488850518, -0.001394989411, 0.0001283617211};
213 static double q2[5] = {1.0 , 0.7428795082, 0.3153932961, 0.06694219548, 0.008790609714};
214
215 static double p3[5] = {0.1788544503, 0.09359161662,0.006325387654, 0.00006611667319,-0.000002031049101};
216 static double q3[5] = {1.0 , 0.6097809921, 0.2560616665, 0.04746722384, 0.006957301675};
217
218 static double p4[5] = {0.9874054407, 118.6723273, 849.2794360, -743.7792444, 427.0262186};
219 static double q4[5] = {1.0 , 106.8615961, 337.6496214, 2016.712389, 1597.063511};
220
221 static double p5[5] = {1.003675074, 167.5702434, 4789.711289, 21217.86767, -22324.94910};
222 static double q5[5] = {1.0 , 156.9424537, 3745.310488, 9834.698876, 66924.28357};
223
224 static double p6[5] = {1.000827619, 664.9143136, 62972.92665, 475554.6998, -5743609.109};
225 static double q6[5] = {1.0 , 651.4101098, 56974.73333, 165917.4725, -2815759.939};
226
227 static double a1[3] = {0.04166666667,-0.01996527778, 0.02709538966};
228
229 static double a2[2] = {-1.845568670,-4.284640743};
230 // clang-format on
231 const double _const0 = 0.3989422803;
232 double v = (x - x0) / xi;
233 double _d_v = 0;
234 double _d_denlan = 0;
235 if (v < -5.5) {
236 double u = ::std::exp(v + 1.);
237 double _d_u = 0;
238 if (u >= 1.e-10) {
239 const double ue = ::std::exp(-1 / u);
240 const double us = ::std::sqrt(u);
241 double _t3;
242 double _d_ue = 0;
243 double _d_us = 0;
244 double denlan = _const0 * (ue / us) * (1 + (a1[0] + (a1[1] + a1[2] * u) * u) * u);
245 _d_denlan += d_out / xi;
246 *d_xi += d_out * -(denlan / (xi * xi));
247 denlan = _t3;
248 double _r_d3 = _d_denlan;
249 _d_denlan -= _r_d3;
250 _d_ue += _const0 * _r_d3 * (1 + (a1[0] + (a1[1] + a1[2] * u) * u) * u) / us;
251 double _r5 = _const0 * _r_d3 * (1 + (a1[0] + (a1[1] + a1[2] * u) * u) * u) * -(ue / (us * us));
252 _d_us += _r5;
253 _d_u += a1[2] * _const0 * (ue / us) * _r_d3 * u * u;
254 _d_u += (a1[1] + a1[2] * u) * _const0 * (ue / us) * _r_d3 * u;
255 _d_u += (a1[0] + (a1[1] + a1[2] * u) * u) * _const0 * (ue / us) * _r_d3;
256 double _r_d2 = _d_us;
257 _d_us -= _r_d2;
258 double _r4 = 0;
259 _r4 += _r_d2 * clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
260 _d_u += _r4;
261 double _r_d1 = _d_ue;
262 _d_ue -= _r_d1;
263 double _r2 = 0;
264 _r2 += _r_d1 * ::std::exp(-1 / u);
265 double _r3 = _r2 * -(-1 / (u * u));
266 _d_u += _r3;
267 }
268 double _r_d0 = _d_u;
269 _d_u -= _r_d0;
270 double _r1 = 0;
271 _r1 += _r_d0 * ::std::exp(v + 1.);
272 _d_v += _r1;
273 } else if (v < -1) {
274 double _t4;
275 double u = ::std::exp(-v - 1);
276 double _d_u = 0;
277 double _t5;
278 double _t8 = ::std::exp(-u);
279 double _t7 = ::std::sqrt(u);
280 double _t6 = (q1[0] + (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * v);
281 double denlan = _t8 * _t7 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / _t6;
282 _d_denlan += d_out / xi;
283 *d_xi += d_out * -(denlan / (xi * xi));
284 denlan = _t5;
285 double _r_d5 = _d_denlan;
286 _d_denlan -= _r_d5;
287 double _r7 = 0;
288 _r7 += _r_d5 / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) * _t7 * ::std::exp(-u);
289 _d_u += -_r7;
290 double _r8 = 0;
291 _r8 += _t8 * _r_d5 / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) *
292 clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
293 _d_u += _r8;
294 _d_v += p1[4] * _t8 * _t7 * _r_d5 / _t6 * v * v * v;
295 _d_v += (p1[3] + p1[4] * v) * _t8 * _t7 * _r_d5 / _t6 * v * v;
296 _d_v += (p1[2] + (p1[3] + p1[4] * v) * v) * _t8 * _t7 * _r_d5 / _t6 * v;
297 _d_v += (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * _t8 * _t7 * _r_d5 / _t6;
298 double _r9 = _r_d5 * -(_t8 * _t7 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / (_t6 * _t6));
299 _d_v += q1[4] * _r9 * v * v * v;
300 _d_v += (q1[3] + q1[4] * v) * _r9 * v * v;
301 _d_v += (q1[2] + (q1[3] + q1[4] * v) * v) * _r9 * v;
302 _d_v += (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * _r9;
303 u = _t4;
304 double _r_d4 = _d_u;
305 _d_u -= _r_d4;
306 double _r6 = 0;
307 _r6 += _r_d4 * ::std::exp(-v - 1);
308 _d_v += -_r6;
309 } else if (v < 1) {
310 double _t9;
311 double _t10 = (q2[0] + (q2[1] + (q2[2] + (q2[3] + q2[4] * v) * v) * v) * v);
312 double denlan = (p2[0] + (p2[1] + (p2[2] + (p2[3] + p2[4] * v) * v) * v) * v) / _t10;
313 _d_denlan += d_out / xi;
314 *d_xi += d_out * -(denlan / (xi * xi));
315 denlan = _t9;
316 double _r_d6 = _d_denlan;
317 _d_denlan -= _r_d6;
318 _d_v += p2[4] * _r_d6 / _t10 * v * v * v;
319 _d_v += (p2[3] + p2[4] * v) * _r_d6 / _t10 * v * v;
320 _d_v += (p2[2] + (p2[3] + p2[4] * v) * v) * _r_d6 / _t10 * v;
321 _d_v += (p2[1] + (p2[2] + (p2[3] + p2[4] * v) * v) * v) * _r_d6 / _t10;
322 double _r10 = _r_d6 * -((p2[0] + (p2[1] + (p2[2] + (p2[3] + p2[4] * v) * v) * v) * v) / (_t10 * _t10));
323 _d_v += q2[4] * _r10 * v * v * v;
324 _d_v += (q2[3] + q2[4] * v) * _r10 * v * v;
325 _d_v += (q2[2] + (q2[3] + q2[4] * v) * v) * _r10 * v;
326 _d_v += (q2[1] + (q2[2] + (q2[3] + q2[4] * v) * v) * v) * _r10;
327 } else if (v < 5) {
328 double _t11;
329 double _t12 = (q3[0] + (q3[1] + (q3[2] + (q3[3] + q3[4] * v) * v) * v) * v);
330 double denlan = (p3[0] + (p3[1] + (p3[2] + (p3[3] + p3[4] * v) * v) * v) * v) / _t12;
331 _d_denlan += d_out / xi;
332 *d_xi += d_out * -(denlan / (xi * xi));
333 denlan = _t11;
334 double _r_d7 = _d_denlan;
335 _d_denlan -= _r_d7;
336 _d_v += p3[4] * _r_d7 / _t12 * v * v * v;
337 _d_v += (p3[3] + p3[4] * v) * _r_d7 / _t12 * v * v;
338 _d_v += (p3[2] + (p3[3] + p3[4] * v) * v) * _r_d7 / _t12 * v;
339 _d_v += (p3[1] + (p3[2] + (p3[3] + p3[4] * v) * v) * v) * _r_d7 / _t12;
340 double _r11 = _r_d7 * -((p3[0] + (p3[1] + (p3[2] + (p3[3] + p3[4] * v) * v) * v) * v) / (_t12 * _t12));
341 _d_v += q3[4] * _r11 * v * v * v;
342 _d_v += (q3[3] + q3[4] * v) * _r11 * v * v;
343 _d_v += (q3[2] + (q3[3] + q3[4] * v) * v) * _r11 * v;
344 _d_v += (q3[1] + (q3[2] + (q3[3] + q3[4] * v) * v) * v) * _r11;
345 } else if (v < 12) {
346 double u = 1 / v;
347 double _d_u = 0;
348 double _t14;
349 double _t15 = (q4[0] + (q4[1] + (q4[2] + (q4[3] + q4[4] * u) * u) * u) * u);
350 double denlan = u * u * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u) / _t15;
351 _d_denlan += d_out / xi;
352 *d_xi += d_out * -(denlan / (xi * xi));
353 denlan = _t14;
354 double _r_d9 = _d_denlan;
355 _d_denlan -= _r_d9;
356 _d_u += _r_d9 / _t15 * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u) * u;
357 _d_u += u * _r_d9 / _t15 * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u);
358 _d_u += p4[4] * u * u * _r_d9 / _t15 * u * u * u;
359 _d_u += (p4[3] + p4[4] * u) * u * u * _r_d9 / _t15 * u * u;
360 _d_u += (p4[2] + (p4[3] + p4[4] * u) * u) * u * u * _r_d9 / _t15 * u;
361 _d_u += (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u * u * _r_d9 / _t15;
362 double _r13 = _r_d9 * -(u * u * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u) / (_t15 * _t15));
363 _d_u += q4[4] * _r13 * u * u * u;
364 _d_u += (q4[3] + q4[4] * u) * _r13 * u * u;
365 _d_u += (q4[2] + (q4[3] + q4[4] * u) * u) * _r13 * u;
366 _d_u += (q4[1] + (q4[2] + (q4[3] + q4[4] * u) * u) * u) * _r13;
367 double _r_d8 = _d_u;
368 _d_u -= _r_d8;
369 double _r12 = _r_d8 * -(1 / (v * v));
370 _d_v += _r12;
371 } else if (v < 50) {
372 double u = 1 / v;
373 double _d_u = 0;
374 double _t17;
375 double _t18 = (q5[0] + (q5[1] + (q5[2] + (q5[3] + q5[4] * u) * u) * u) * u);
376 double denlan = u * u * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u) / _t18;
377 _d_denlan += d_out / xi;
378 *d_xi += d_out * -(denlan / (xi * xi));
379 denlan = _t17;
380 double _r_d11 = _d_denlan;
381 _d_denlan -= _r_d11;
382 _d_u += _r_d11 / _t18 * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u) * u;
383 _d_u += u * _r_d11 / _t18 * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u);
384 _d_u += p5[4] * u * u * _r_d11 / _t18 * u * u * u;
385 _d_u += (p5[3] + p5[4] * u) * u * u * _r_d11 / _t18 * u * u;
386 _d_u += (p5[2] + (p5[3] + p5[4] * u) * u) * u * u * _r_d11 / _t18 * u;
387 _d_u += (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u * u * _r_d11 / _t18;
388 double _r15 = _r_d11 * -(u * u * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u) / (_t18 * _t18));
389 _d_u += q5[4] * _r15 * u * u * u;
390 _d_u += (q5[3] + q5[4] * u) * _r15 * u * u;
391 _d_u += (q5[2] + (q5[3] + q5[4] * u) * u) * _r15 * u;
392 _d_u += (q5[1] + (q5[2] + (q5[3] + q5[4] * u) * u) * u) * _r15;
393 double _r_d10 = _d_u;
394 _d_u -= _r_d10;
395 double _r14 = _r_d10 * -(1 / (v * v));
396 _d_v += _r14;
397 } else if (v < 300) {
398 double _t19;
399 double u = 1 / v;
400 double _d_u = 0;
401 double _t20;
402 double _t21 = (q6[0] + (q6[1] + (q6[2] + (q6[3] + q6[4] * u) * u) * u) * u);
403 double denlan = u * u * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u) / _t21;
404 _d_denlan += d_out / xi;
405 *d_xi += d_out * -(denlan / (xi * xi));
406 denlan = _t20;
407 double _r_d13 = _d_denlan;
408 _d_denlan -= _r_d13;
409 _d_u += _r_d13 / _t21 * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u) * u;
410 _d_u += u * _r_d13 / _t21 * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u);
411 _d_u += p6[4] * u * u * _r_d13 / _t21 * u * u * u;
412 _d_u += (p6[3] + p6[4] * u) * u * u * _r_d13 / _t21 * u * u;
413 _d_u += (p6[2] + (p6[3] + p6[4] * u) * u) * u * u * _r_d13 / _t21 * u;
414 _d_u += (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u * u * _r_d13 / _t21;
415 double _r17 = _r_d13 * -(u * u * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u) / (_t21 * _t21));
416 _d_u += q6[4] * _r17 * u * u * u;
417 _d_u += (q6[3] + q6[4] * u) * _r17 * u * u;
418 _d_u += (q6[2] + (q6[3] + q6[4] * u) * u) * _r17 * u;
419 _d_u += (q6[1] + (q6[2] + (q6[3] + q6[4] * u) * u) * u) * _r17;
420 u = _t19;
421 double _r_d12 = _d_u;
422 _d_u -= _r_d12;
423 double _r16 = _r_d12 * -(1 / (v * v));
424 _d_v += _r16;
425 } else {
426 double _t22;
427 double _t25 = ::std::log(v);
428 double _t24 = (v + 1);
429 double _t23 = (v - v * _t25 / _t24);
430 double u = 1 / _t23;
431 double _d_u = 0;
432 double _t26;
433 double denlan = u * u * (1 + (a2[0] + a2[1] * u) * u);
434 _d_denlan += d_out / xi;
435 *d_xi += d_out * -(denlan / (xi * xi));
436 denlan = _t26;
437 double _r_d15 = _d_denlan;
438 _d_denlan -= _r_d15;
439 _d_u += _r_d15 * (1 + (a2[0] + a2[1] * u) * u) * u;
440 _d_u += u * _r_d15 * (1 + (a2[0] + a2[1] * u) * u);
441 _d_u += a2[1] * u * u * _r_d15 * u;
442 _d_u += (a2[0] + a2[1] * u) * u * u * _r_d15;
443 u = _t22;
444 double _r_d14 = _d_u;
445 _d_u -= _r_d14;
446 double _r18 = _r_d14 * -(1 / (_t23 * _t23));
447 _d_v += _r18;
448 _d_v += -_r18 / _t24 * _t25;
449 double _r19 = 0;
450 _r19 += v * -_r18 / _t24 / v;
451 _d_v += _r19;
452 double _r20 = -_r18 * -(v * _t25 / (_t24 * _t24));
453 _d_v += _r20;
454 }
455 *d_x += _d_v / xi;
456 *d_x0 += -_d_v / xi;
457 double _r0 = _d_v * -((x - x0) / (xi * xi));
458 *d_xi += _r0;
459}
460
461inline void landau_cdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
462{
463 // clang-format off
464 static double p1[5] = {0.2514091491e+0,-0.6250580444e-1, 0.1458381230e-1,-0.2108817737e-2, 0.7411247290e-3};
465 static double q1[5] = {1.0 ,-0.5571175625e-2, 0.6225310236e-1,-0.3137378427e-2, 0.1931496439e-2};
466
467 static double p2[4] = {0.2868328584e+0, 0.3564363231e+0, 0.1523518695e+0, 0.2251304883e-1};
468 static double q2[4] = {1.0 , 0.6191136137e+0, 0.1720721448e+0, 0.2278594771e-1};
469
470 static double p3[4] = {0.2868329066e+0, 0.3003828436e+0, 0.9950951941e-1, 0.8733827185e-2};
471 static double q3[4] = {1.0 , 0.4237190502e+0, 0.1095631512e+0, 0.8693851567e-2};
472
473 static double p4[4] = {0.1000351630e+1, 0.4503592498e+1, 0.1085883880e+2, 0.7536052269e+1};
474 static double q4[4] = {1.0 , 0.5539969678e+1, 0.1933581111e+2, 0.2721321508e+2};
475
476 static double p5[4] = {0.1000006517e+1, 0.4909414111e+2, 0.8505544753e+2, 0.1532153455e+3};
477 static double q5[4] = {1.0 , 0.5009928881e+2, 0.1399819104e+3, 0.4200002909e+3};
478
479 static double p6[4] = {0.1000000983e+1, 0.1329868456e+3, 0.9162149244e+3,-0.9605054274e+3};
480 static double q6[4] = {1.0 , 0.1339887843e+3, 0.1055990413e+4, 0.5532224619e+3};
481
482 static double a1[4] = {0 ,-0.4583333333e+0, 0.6675347222e+0,-0.1641741416e+1};
483 static double a2[4] = {0 , 1.0 ,-0.4227843351e+0,-0.2043403138e+1};
484 // clang-format on
485
486 const double v = (x - x0) / xi;
487 double _d_v = 0;
488 if (v < -5.5) {
489 double _d_u = 0;
490 const double _const0 = 0.3989422803;
491 double u = ::std::exp(v + 1);
492 double _t3 = ::std::exp(-1. / u);
493 double _t2 = ::std::sqrt(u);
494 double _r2 = 0;
495 _r2 += _const0 * d_out * (1 + (a1[1] + (a1[2] + a1[3] * u) * u) * u) * _t2 * ::std::exp(-1. / u);
496 double _r3 = _r2 * -(-1. / (u * u));
497 _d_u += _r3;
498 double _r4 = 0;
499 _r4 += _const0 * _t3 * d_out * (1 + (a1[1] + (a1[2] + a1[3] * u) * u) * u) *
500 clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
501 _d_u += _r4;
502 _d_u += a1[3] * _const0 * _t3 * _t2 * d_out * u * u;
503 _d_u += (a1[2] + a1[3] * u) * _const0 * _t3 * _t2 * d_out * u;
504 _d_u += (a1[1] + (a1[2] + a1[3] * u) * u) * _const0 * _t3 * _t2 * d_out;
505 _d_v += _d_u * ::std::exp(v + 1);
506 } else if (v < -1) {
507 double _d_u = 0;
508 double u = ::std::exp(-v - 1);
509 double _t8 = ::std::exp(-u);
510 double _t7 = ::std::sqrt(u);
511 double _t6 = (q1[0] + (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * v);
512 double _r6 = 0;
513 _r6 += d_out / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / _t7 * ::std::exp(-u);
514 _d_u += -_r6;
515 double _r7 = d_out / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) * -(_t8 / (_t7 * _t7));
516 double _r8 = 0;
517 _r8 += _r7 * clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
518 _d_u += _r8;
519 _d_v += p1[4] * (_t8 / _t7) * d_out / _t6 * v * v * v;
520 _d_v += (p1[3] + p1[4] * v) * (_t8 / _t7) * d_out / _t6 * v * v;
521 _d_v += (p1[2] + (p1[3] + p1[4] * v) * v) * (_t8 / _t7) * d_out / _t6 * v;
522 _d_v += (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * (_t8 / _t7) * d_out / _t6;
523 double _r9 = d_out * -((_t8 / _t7) * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / (_t6 * _t6));
524 _d_v += q1[4] * _r9 * v * v * v;
525 _d_v += (q1[3] + q1[4] * v) * _r9 * v * v;
526 _d_v += (q1[2] + (q1[3] + q1[4] * v) * v) * _r9 * v;
527 _d_v += (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * _r9;
528 _d_v += -_d_u * ::std::exp(-v - 1);
529 } else if (v < 1) {
530 double _t10 = (q2[0] + (q2[1] + (q2[2] + q2[3] * v) * v) * v);
531 _d_v += p2[3] * d_out / _t10 * v * v;
532 _d_v += (p2[2] + p2[3] * v) * d_out / _t10 * v;
533 _d_v += (p2[1] + (p2[2] + p2[3] * v) * v) * d_out / _t10;
534 double _r10 = d_out * -((p2[0] + (p2[1] + (p2[2] + p2[3] * v) * v) * v) / (_t10 * _t10));
535 _d_v += q2[3] * _r10 * v * v;
536 _d_v += (q2[2] + q2[3] * v) * _r10 * v;
537 _d_v += (q2[1] + (q2[2] + q2[3] * v) * v) * _r10;
538 } else if (v < 4) {
539 double _t12 = (q3[0] + (q3[1] + (q3[2] + q3[3] * v) * v) * v);
540 _d_v += p3[3] * d_out / _t12 * v * v;
541 _d_v += (p3[2] + p3[3] * v) * d_out / _t12 * v;
542 _d_v += (p3[1] + (p3[2] + p3[3] * v) * v) * d_out / _t12;
543 double _r11 = d_out * -((p3[0] + (p3[1] + (p3[2] + p3[3] * v) * v) * v) / (_t12 * _t12));
544 _d_v += q3[3] * _r11 * v * v;
545 _d_v += (q3[2] + q3[3] * v) * _r11 * v;
546 _d_v += (q3[1] + (q3[2] + q3[3] * v) * v) * _r11;
547 } else if (v < 12) {
548 double _d_u = 0;
549 double u = 1. / v;
550 double _t15 = (q4[0] + (q4[1] + (q4[2] + q4[3] * u) * u) * u);
551 _d_u += p4[3] * d_out / _t15 * u * u;
552 _d_u += (p4[2] + p4[3] * u) * d_out / _t15 * u;
553 _d_u += (p4[1] + (p4[2] + p4[3] * u) * u) * d_out / _t15;
554 double _r13 = d_out * -((p4[0] + (p4[1] + (p4[2] + p4[3] * u) * u) * u) / (_t15 * _t15));
555 _d_u += q4[3] * _r13 * u * u;
556 _d_u += (q4[2] + q4[3] * u) * _r13 * u;
557 _d_u += (q4[1] + (q4[2] + q4[3] * u) * u) * _r13;
558 double _r12 = _d_u * -(1. / (v * v));
559 _d_v += _r12;
560 } else if (v < 50) {
561 double _d_u = 0;
562 double u = 1. / v;
563 double _t18 = (q5[0] + (q5[1] + (q5[2] + q5[3] * u) * u) * u);
564 _d_u += p5[3] * d_out / _t18 * u * u;
565 _d_u += (p5[2] + p5[3] * u) * d_out / _t18 * u;
566 _d_u += (p5[1] + (p5[2] + p5[3] * u) * u) * d_out / _t18;
567 double _r15 = d_out * -((p5[0] + (p5[1] + (p5[2] + p5[3] * u) * u) * u) / (_t18 * _t18));
568 _d_u += q5[3] * _r15 * u * u;
569 _d_u += (q5[2] + q5[3] * u) * _r15 * u;
570 _d_u += (q5[1] + (q5[2] + q5[3] * u) * u) * _r15;
571 double _r14 = _d_u * -(1. / (v * v));
572 _d_v += _r14;
573 } else if (v < 300) {
574 double _d_u = 0;
575 double u = 1. / v;
576 double _t21 = (q6[0] + (q6[1] + (q6[2] + q6[3] * u) * u) * u);
577 _d_u += p6[3] * d_out / _t21 * u * u;
578 _d_u += (p6[2] + p6[3] * u) * d_out / _t21 * u;
579 _d_u += (p6[1] + (p6[2] + p6[3] * u) * u) * d_out / _t21;
580 double _r17 = d_out * -((p6[0] + (p6[1] + (p6[2] + p6[3] * u) * u) * u) / (_t21 * _t21));
581 _d_u += q6[3] * _r17 * u * u;
582 _d_u += (q6[2] + q6[3] * u) * _r17 * u;
583 _d_u += (q6[1] + (q6[2] + q6[3] * u) * u) * _r17;
584 double _r16 = _d_u * -(1. / (v * v));
585 _d_v += _r16;
586 } else {
587 double _d_u = 0;
588 double _t25 = ::std::log(v);
589 double _t24 = (v + 1);
590 double _t23 = (v - v * _t25 / _t24);
591 double u = 1. / _t23;
592 double _t26;
593 _d_u += a2[3] * -d_out * u * u;
594 _d_u += (a2[2] + a2[3] * u) * -d_out * u;
595 _d_u += (a2[1] + (a2[2] + a2[3] * u) * u) * -d_out;
596 double _r18 = _d_u * -(1. / (_t23 * _t23));
597 _d_v += _r18;
598 _d_v += -_r18 / _t24 * _t25;
599 double _r19 = 0;
600 _r19 += v * -_r18 / _t24 / v;
601 _d_v += _r19;
602 double _r20 = -_r18 * -(v * _t25 / (_t24 * _t24));
603 _d_v += _r20;
604 }
605
606 *d_x += _d_v / xi;
607 *d_x0 += -_d_v / xi;
608 *d_xi += _d_v * -((x - x0) / (xi * xi));
609}
610
611inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x);
612
613inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x)
614{
615 // Synced with SpecFuncCephes.h
616 constexpr double kMACHEP = 1.11022302462515654042363166809e-16;
617 constexpr double kMAXLOG = 709.782712893383973096206318587;
618 constexpr double kMINLOG = -708.396418532264078748994506896;
619 constexpr double kMAXSTIR = 108.116855767857671821730036754;
620 constexpr double kMAXLGM = 2.556348e305;
621 constexpr double kBig = 4.503599627370496e15;
622 constexpr double kBiginv = 2.22044604925031308085e-16;
623
624 double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_r = 0;
625 double _t1;
626 double _t2;
627 double _t3;
628 double _t4;
629 double _t5;
630 clad::tape<double> _t7 = {};
631 clad::tape<double> _t8 = {};
632 clad::tape<double> _t9 = {};
633 double ans, ax, c, r;
634 if (a <= 0)
635 return;
636 if (x <= 0)
637 return;
638 if ((x > 1.) && (x > a)) {
639 double _r0 = 0;
640 double _r1 = 0;
642 *_d_a += _r0;
643 *_d_x += _r1;
644 return;
645 }
646 _t1 = ::std::log(x);
647 ax = a * _t1 - x - ::std::lgamma(a);
648 if (ax < -kMAXLOG) {
649 *_d_x += (a * _d_ax / x) - _d_ax;
650 *_d_a += _d_ax * (_t1 - ::clad::custom_derivatives::std::clad_digamma(
651 a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
652 _d_ax = 0.;
653 return;
654 }
655 _t2 = ax;
656 ax = ::std::exp(ax);
657 _t3 = r;
658 r = a;
659 _t4 = c;
660 c = 1.;
661 _t5 = ans;
662 ans = 1.;
663 unsigned long _t6 = 0;
664 do {
665 _t6++;
666 clad::push(_t7, r);
667 r += 1.;
668 clad::push(_t8, c);
669 c *= x / r;
670 clad::push(_t9, ans);
671 ans += c;
672 } while (c / ans > kMACHEP);
673 {
674 _d_ans += _d_y / a * ax;
675 _d_ax += ans * _d_y / a;
676 double _r6 = _d_y * -(ans * ax / (a * a));
677 *_d_a += _r6;
678 }
679 do {
680 {
681 {
682 ans = clad::pop(_t9);
683 double _r_d7 = _d_ans;
684 _d_c += _r_d7;
685 }
686 {
687 c = clad::pop(_t8);
688 double _r_d6 = _d_c;
689 _d_c -= _r_d6;
690 _d_c += _r_d6 * x / r;
691 *_d_x += c * _r_d6 / r;
692 double _r5 = c * _r_d6 * -(x / (r * r));
693 _d_r += _r5;
694 }
695 {
696 r = clad::pop(_t7);
697 double _r_d5 = _d_r;
698 }
699 }
700 _t6--;
701 } while (_t6);
702 {
703 ans = _t5;
704 double _r_d4 = _d_ans;
705 _d_ans -= _r_d4;
706 }
707 {
708 c = _t4;
709 double _r_d3 = _d_c;
710 _d_c -= _r_d3;
711 }
712 {
713 r = _t3;
714 double _r_d2 = _d_r;
715 _d_r -= _r_d2;
716 *_d_a += _r_d2;
717 }
718 {
719 ax = _t2;
720 double _r_d1 = _d_ax;
721 _d_ax -= _r_d1;
722 double _r4 = 0;
723 _r4 += _r_d1 * ::std::exp(ax);
724 _d_ax += _r4;
725 }
726 {
727 *_d_x += (a * _d_ax / x) - _d_ax;
728 *_d_a += _d_ax * (_t1 - ::clad::custom_derivatives::std::clad_digamma(
729 a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
730 _d_ax = 0.;
731 }
732}
733
734inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x)
735{
736 // Synced with SpecFuncCephes.h
737 constexpr double kMACHEP = 1.11022302462515654042363166809e-16;
738 constexpr double kMAXLOG = 709.782712893383973096206318587;
739 constexpr double kMINLOG = -708.396418532264078748994506896;
740 constexpr double kMAXSTIR = 108.116855767857671821730036754;
741 constexpr double kMAXLGM = 2.556348e305;
742 constexpr double kBig = 4.503599627370496e15;
743 constexpr double kBiginv = 2.22044604925031308085e-16;
744
745 double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_yc = 0, _d_r = 0, _d_t = 0, _d_y0 = 0, _d_z = 0;
746 double _d_pk = 0, _d_pkm1 = 0, _d_pkm2 = 0, _d_qk = 0, _d_qkm1 = 0, _d_qkm2 = 0;
747 double _t1;
748 double _t2;
749 double _t3;
750 double _t4;
751 double _t5;
752 double _t6;
753 double _t7;
754 double _t8;
755 double _t9;
756 double _t10;
757 unsigned long _t11;
758 clad::tape<double> _t12 = {};
759 clad::tape<double> _t13 = {};
760 clad::tape<double> _t14 = {};
761 clad::tape<double> _t15 = {};
762 clad::tape<double> _t16 = {};
763 clad::tape<double> _t17 = {};
764 clad::tape<double> _t19 = {};
765 clad::tape<double> _t20 = {};
766 clad::tape<double> _t21 = {};
767 clad::tape<double> _t22 = {};
768 clad::tape<double> _t23 = {};
769 clad::tape<double> _t24 = {};
770 clad::tape<double> _t25 = {};
771 clad::tape<double> _t26 = {};
772 clad::tape<double> _t27 = {};
773 clad::tape<bool> _t29 = {};
774 clad::tape<double> _t30 = {};
775 clad::tape<double> _t31 = {};
776 clad::tape<double> _t32 = {};
777 clad::tape<double> _t33 = {};
778 double ans, ax, c, yc, r, t, y, z;
779 double pk, pkm1, pkm2, qk, qkm1, qkm2;
780 if (a <= 0)
781 return;
782 if (x <= 0)
783 return;
784 if ((x < 1.) || (x < a)) {
785 double _r0 = 0;
786 double _r1 = 0;
788 *_d_a += _r0;
789 *_d_x += _r1;
790 return;
791 }
792 _t1 = ::std::log(x);
793 ax = a * _t1 - x - ::std::lgamma(a);
794 if (ax < -kMAXLOG) {
795 *_d_x += a * _d_ax / x - _d_ax;
796 *_d_a += _d_ax * (_t1 - ::clad::custom_derivatives::std::clad_digamma(
797 a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
798 _d_ax = 0.;
799 return;
800 }
801 _t2 = ax;
802 ax = ::std::exp(ax);
803 _t3 = y;
804 y = 1. - a;
805 _t4 = z;
806 z = x + y + 1.;
807 _t5 = c;
808 c = 0.;
809 _t6 = pkm2;
810 pkm2 = 1.;
811 _t7 = qkm2;
812 qkm2 = x;
813 _t8 = pkm1;
814 pkm1 = x + 1.;
815 _t9 = qkm1;
816 qkm1 = z * x;
817 _t10 = ans;
818 ans = pkm1 / qkm1;
819 _t11 = 0;
820 do {
821 _t11++;
822 clad::push(_t12, c);
823 c += 1.;
824 clad::push(_t13, y);
825 y += 1.;
826 clad::push(_t14, z);
827 z += 2.;
828 clad::push(_t15, yc);
829 yc = y * c;
830 clad::push(_t16, pk);
831 pk = pkm1 * z - pkm2 * yc;
832 clad::push(_t17, qk);
833 qk = qkm1 * z - qkm2 * yc;
834 double _t18 = qk;
835 {
836 if (_t18) {
837 clad::push(_t20, r);
838 r = pk / qk;
839 clad::push(_t21, t);
840 t = ::std::abs((ans - r) / r);
841 clad::push(_t22, ans);
842 ans = r;
843 } else {
844 clad::push(_t23, t);
845 t = 1.;
846 }
847 clad::push(_t19, _t18);
848 }
849 clad::push(_t24, pkm2);
850 pkm2 = pkm1;
851 clad::push(_t25, pkm1);
852 pkm1 = pk;
853 clad::push(_t26, qkm2);
854 qkm2 = qkm1;
855 clad::push(_t27, qkm1);
856 qkm1 = qk;
857 bool _t28 = ::std::abs(pk) > kBig;
858 {
859 if (_t28) {
860 clad::push(_t30, pkm2);
861 pkm2 *= kBiginv;
862 clad::push(_t31, pkm1);
863 pkm1 *= kBiginv;
864 clad::push(_t32, qkm2);
865 qkm2 *= kBiginv;
866 clad::push(_t33, qkm1);
867 qkm1 *= kBiginv;
868 }
869 clad::push(_t29, _t28);
870 }
871 } while (t > kMACHEP);
872 {
873 _d_ans += _d_y * ax;
874 _d_ax += ans * _d_y;
875 }
876 do {
877 {
878 if (clad::pop(_t29)) {
879 {
880 qkm1 = clad::pop(_t33);
881 double _r_d27 = _d_qkm1;
882 _d_qkm1 -= _r_d27;
883 _d_qkm1 += _r_d27 * kBiginv;
884 }
885 {
886 qkm2 = clad::pop(_t32);
887 double _r_d26 = _d_qkm2;
888 _d_qkm2 -= _r_d26;
889 _d_qkm2 += _r_d26 * kBiginv;
890 }
891 {
892 pkm1 = clad::pop(_t31);
893 double _r_d25 = _d_pkm1;
894 _d_pkm1 -= _r_d25;
895 _d_pkm1 += _r_d25 * kBiginv;
896 }
897 {
898 pkm2 = clad::pop(_t30);
899 double _r_d24 = _d_pkm2;
900 _d_pkm2 -= _r_d24;
901 _d_pkm2 += _r_d24 * kBiginv;
902 }
903 }
904 {
905 qkm1 = clad::pop(_t27);
906 double _r_d23 = _d_qkm1;
907 _d_qkm1 -= _r_d23;
908 _d_qk += _r_d23;
909 }
910 {
911 qkm2 = clad::pop(_t26);
912 double _r_d22 = _d_qkm2;
913 _d_qkm2 -= _r_d22;
914 _d_qkm1 += _r_d22;
915 }
916 {
917 pkm1 = clad::pop(_t25);
918 double _r_d21 = _d_pkm1;
919 _d_pkm1 -= _r_d21;
920 _d_pk += _r_d21;
921 }
922 {
923 pkm2 = clad::pop(_t24);
924 double _r_d20 = _d_pkm2;
925 _d_pkm2 -= _r_d20;
926 _d_pkm1 += _r_d20;
927 }
928 if (clad::pop(_t19)) {
929 {
930 ans = clad::pop(_t22);
931 double _r_d18 = _d_ans;
932 _d_ans -= _r_d18;
933 _d_r += _r_d18;
934 }
935 {
936 t = clad::pop(_t21);
937 double _r_d17 = _d_t;
938 _d_t -= _r_d17;
939 double _r7 = 0;
940 _r7 += _r_d17 * clad::custom_derivatives::std::abs_pushforward((ans - r) / r, 1.).pushforward;
941 _d_ans += _r7 / r;
942 _d_r += -_r7 / r;
943 double _r8 = _r7 * -((ans - r) / (r * r));
944 _d_r += _r8;
945 }
946 {
947 r = clad::pop(_t20);
948 double _r_d16 = _d_r;
949 _d_r -= _r_d16;
950 _d_pk += _r_d16 / qk;
951 double _r6 = _r_d16 * -(pk / (qk * qk));
952 _d_qk += _r6;
953 }
954 } else {
955 t = clad::pop(_t23);
956 double _r_d19 = _d_t;
957 _d_t -= _r_d19;
958 }
959 {
960 qk = clad::pop(_t17);
961 double _r_d15 = _d_qk;
962 _d_qk -= _r_d15;
963 _d_qkm1 += _r_d15 * z;
964 _d_z += qkm1 * _r_d15;
965 _d_qkm2 += -_r_d15 * yc;
966 _d_yc += qkm2 * -_r_d15;
967 }
968 {
969 pk = clad::pop(_t16);
970 double _r_d14 = _d_pk;
971 _d_pk -= _r_d14;
972 _d_pkm1 += _r_d14 * z;
973 _d_z += pkm1 * _r_d14;
974 _d_pkm2 += -_r_d14 * yc;
975 _d_yc += pkm2 * -_r_d14;
976 }
977 {
978 yc = clad::pop(_t15);
979 double _r_d13 = _d_yc;
980 _d_yc -= _r_d13;
981 _d_y0 += _r_d13 * c;
982 _d_c += y * _r_d13;
983 }
984 {
985 z = clad::pop(_t14);
986 double _r_d12 = _d_z;
987 }
988 {
989 y = clad::pop(_t13);
990 double _r_d11 = _d_y0;
991 }
992 {
993 c = clad::pop(_t12);
994 double _r_d10 = _d_c;
995 }
996 }
997 _t11--;
998 } while (_t11);
999 {
1000 ans = _t10;
1001 double _r_d9 = _d_ans;
1002 _d_ans -= _r_d9;
1003 _d_pkm1 += _r_d9 / qkm1;
1004 double _r5 = _r_d9 * -(pkm1 / (qkm1 * qkm1));
1005 _d_qkm1 += _r5;
1006 }
1007 {
1008 qkm1 = _t9;
1009 double _r_d8 = _d_qkm1;
1010 _d_qkm1 -= _r_d8;
1011 _d_z += _r_d8 * x;
1012 *_d_x += z * _r_d8;
1013 }
1014 {
1015 pkm1 = _t8;
1016 double _r_d7 = _d_pkm1;
1017 _d_pkm1 -= _r_d7;
1018 *_d_x += _r_d7;
1019 }
1020 {
1021 qkm2 = _t7;
1022 double _r_d6 = _d_qkm2;
1023 _d_qkm2 -= _r_d6;
1024 *_d_x += _r_d6;
1025 }
1026 {
1027 pkm2 = _t6;
1028 double _r_d5 = _d_pkm2;
1029 _d_pkm2 -= _r_d5;
1030 }
1031 {
1032 c = _t5;
1033 double _r_d4 = _d_c;
1034 _d_c -= _r_d4;
1035 }
1036 {
1037 z = _t4;
1038 double _r_d3 = _d_z;
1039 _d_z -= _r_d3;
1040 *_d_x += _r_d3;
1041 _d_y0 += _r_d3;
1042 }
1043 {
1044 y = _t3;
1045 double _r_d2 = _d_y0;
1046 _d_y0 -= _r_d2;
1047 *_d_a += -_r_d2;
1048 }
1049 {
1050 ax = _t2;
1051 double _r_d1 = _d_ax;
1052 _d_ax -= _r_d1;
1053 double _r4 = _r_d1 * ::std::exp(ax);
1054 _d_ax += _r4;
1055 }
1056 {
1057 *_d_x += a * _d_ax / x - _d_ax;
1058 *_d_a += _d_ax * (_t1 - ::clad::custom_derivatives::std::clad_digamma(
1059 a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
1060 _d_ax = 0.;
1061 }
1062}
1063
1064} // namespace Math
1065} // namespace ROOT
1066
1067} // namespace custom_derivatives
1068} // namespace clad
1069
1070// Forward declare BLAS functions.
1071extern "C" void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
1072 const float *alpha, const float *A, const int *lda, const float *B, const int *ldb,
1073 const float *beta, float *C, const int *ldc);
1074
1075namespace clad::custom_derivatives {
1076
1078
1079inline void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, int n, int k, float alpha,
1080 const float *A, const float *B, float beta, const float *C, float *_d_output, bool *,
1081 bool *, int *, int *, int *, float *_d_alpha, float *_d_A, float *_d_B, float *_d_beta,
1082 float *_d_C)
1083{
1084 using ::TMVA::Experimental::SOFIE::Gemm_Call;
1085
1086 // TODO:
1087 // - fix and test the implementation for alpha != 1.0
1088 if (alpha != 1.0f) {
1089 return;
1090 }
1091
1092 // beta needs to be one because we want to add to _d_A and _d_B instead of
1093 // overwriting it.
1094 float one = 1.;
1095
1096 // ---- dA ----
1097 if (!transa) {
1098 // dA += dY * op(B)^T
1099 Gemm_Call(_d_A, false, !transb, m, k, n, one, _d_output, B, one, _d_A);
1100 } else {
1101 // dA += op(B) * dY^T
1102 Gemm_Call(_d_A, transb, true, k, m, n, one, B, _d_output, one, _d_A);
1103 }
1104
1105 // ---- dB ----
1106 if (!transb) {
1107 // dB += op(A)^T * dY
1108 Gemm_Call(_d_B, !transa, false, k, n, m, one, A, _d_output, one, _d_B);
1109 } else {
1110 // dB += dY^T * op(A)
1111 Gemm_Call(_d_B, true, transa, n, k, m, one, _d_output, A, one, _d_B);
1112 }
1113
1114 int sizeC = n * m;
1115
1116 for (int i = 0; i < sizeC; ++i) {
1117 if (C) {
1118 *_d_alpha += _d_output[i] * (output[i] - beta * C[i]);
1119 *_d_beta += _d_output[i] * C[i];
1120 } else {
1121 *_d_alpha += _d_output[i] * output[i];
1122 }
1123 if (_d_C)
1124 _d_C[i] += _d_output[i] * beta;
1125 }
1126}
1127
1128inline void Copy_pullback(float *output, const float *input, int size, float *_d_output, float *_d_input, int *)
1129{
1130 for (int i = 0; i < size; i++) {
1131 output[i] = input[i];
1132 _d_input[i] += _d_output[i];
1133 _d_output[i] = 0.F;
1134 }
1135}
1136
1137inline void Fill_pullback(float *output, float value, int size, float *_d_output, float *_d_value, int *)
1138{
1139 for (int i = 0; i < size; i++) {
1140 output[i] = value;
1141 *_d_value += _d_output[i];
1142 _d_output[i] = 0.F;
1143 }
1144}
1145
1146inline void Relu_pullback(float *output, const float *input, int size, float *_d_output, float *_d_input, int *)
1147{
1148 for (int i = 0; i < size; i++) {
1149 output[i] = input[i] > 0.F ? input[i] : 0.F;
1150 float _r_d0 = _d_output[i];
1151 _d_output[i] = 0.F;
1152 if (input[i] > 0.F)
1153 _d_input[i] += _r_d0;
1154 }
1155}
1156
1157} // namespace TMVA::Experimental::SOFIE
1158
1159} // namespace clad::custom_derivatives
1160
1161#endif // CLAD_DERIVATOR
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
#define c(i)
Definition RSha256.hxx:101
#define a(i)
Definition RSha256.hxx:99
#define h(i)
Definition RSha256.hxx:106
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
#define kMACHEP
#define kMAXLOG
#define kMAXLGM
#define kMAXSTIR
#define kMINLOG
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
Namespace for new Math classes and functions.
TMath.
Definition TMathBase.h:35
Double_t CosH(Double_t)
Returns the hyperbolic cosine of x.
Definition TMath.h:623
Double_t ACos(Double_t)
Returns the principal value of the arc cosine of x, expressed in radians.
Definition TMath.h:643
Double_t ASin(Double_t)
Returns the principal value of the arc sine of x, expressed in radians.
Definition TMath.h:635
Double_t Log2(Double_t x)
Returns the binary (base-2) logarithm of x.
Definition TMath.cxx:107
Double_t Exp(Double_t x)
Returns the base-e exponential function of x, which is e raised to the power x.
Definition TMath.h:720
Double_t Erf(Double_t x)
Computation of the error function erf(x).
Definition TMath.cxx:190
Double_t ATan(Double_t)
Returns the principal value of the arc tangent of x, expressed in radians.
Definition TMath.h:651
Double_t ASinH(Double_t)
Returns the area hyperbolic sine of x.
Definition TMath.cxx:67
Double_t TanH(Double_t)
Returns the hyperbolic tangent of x.
Definition TMath.h:629
Double_t ACosH(Double_t)
Returns the nonnegative area hyperbolic cosine of x.
Definition TMath.cxx:81
Double_t Log(Double_t x)
Returns the natural logarithm of x.
Definition TMath.h:767
Double_t Erfc(Double_t x)
Computes the complementary error function erfc(x).
Definition TMath.cxx:199
Double_t Sq(Double_t x)
Returns x*x.
Definition TMath.h:667
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:673
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Returns x raised to the power y.
Definition TMath.h:732
constexpr Double_t Ln10()
Natural log of 10 (to convert log to ln)
Definition TMath.h:103
Double_t Hypot(Double_t x, Double_t y)
Returns sqrt(x*x + y*y)
Definition TMath.cxx:59
Double_t Cos(Double_t)
Returns the cosine of an angle of x radians.
Definition TMath.h:605
constexpr Double_t Pi()
Definition TMath.h:40
Double_t LnGamma(Double_t z)
Computation of ln[gamma(z)] for all z.
Definition TMath.cxx:509
Double_t Sin(Double_t)
Returns the sine of an angle of x radians.
Definition TMath.h:599
Double_t Tan(Double_t)
Returns the tangent of an angle of x radians.
Definition TMath.h:611
Double_t ATanH(Double_t)
Returns the area hyperbolic tangent of x.
Definition TMath.cxx:95
Double_t Log10(Double_t x)
Returns the common (base-10) logarithm of x.
Definition TMath.h:773
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:122
Double_t SinH(Double_t)
Returns the hyperbolic sine of `x.
Definition TMath.h:617
void landau_pdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x)
void landau_cdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x)
void Relu_pullback(float *output, const float *input, int size, float *_d_output, float *_d_input, int *)
void Fill_pullback(float *output, float value, int size, float *_d_output, float *_d_value, int *)
void Copy_pullback(float *output, const float *input, int size, float *_d_output, float *_d_input, int *)
void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, const float *B, float beta, const float *C, float *_d_output, bool *, bool *, int *, int *, int *, float *_d_alpha, float *_d_A, float *_d_B, float *_d_beta, float *_d_C)
ValueAndPushforward< T, T > CosH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Abs_pushforward(T x, T d_x)
void Power_pullback(T x, U y, V p, T *d_x, U *d_y)
ValueAndPushforward< T, T > Sq_pushforward(T x, T d_x)
void Hypot_pullback(T x, T y, U p, T *d_x, T *d_y)
ValueAndPushforward< T, T > Erf_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Erfc_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Sin_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Hypot_pushforward(T x, T y, T d_x, T d_y)
ValueAndPushforward< T, T > ASinH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > LnGamma_pushforward(T z, T d_z)
ValueAndPushforward< T, T > ACosH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ASin_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Power_pushforward(T x, U y, T d_x, U d_y)
ValueAndPushforward< T, T > Cos_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Sqrt_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Tan_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Log_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Log10_pushforward(T x, T d_x)
ValueAndPushforward< T, T > TanH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ACos_pushforward(T x, T d_x)
ValueAndPushforward< T, T > SinH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Exp_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Log2_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ATanH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ATan_pushforward(T x, T d_x)
TMarker m
Definition textangle.C:8
static void output()