Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CladDerivator.h
Go to the documentation of this file.
1/// \file CladDerivator.h
2///
3/// \brief The file is a bridge between ROOT and clad automatic differentiation
4/// plugin.
5///
6/// \author Vassil Vassilev <vvasilev@cern.ch>
7///
8/// \date July, 2018
9
10/*************************************************************************
11 * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers. *
12 * All rights reserved. *
13 * *
14 * For the licensing terms see $ROOTSYS/LICENSE. *
15 * For the list of contributors see $ROOTSYS/README/CREDITS. *
16 *************************************************************************/
17
18#ifndef CLAD_DERIVATOR
19#define CLAD_DERIVATOR
20
21#ifndef __CLING__
22#error "This file must not be included by compiled programs."
23#endif //__CLING__
24
25#include <plugins/include/clad/Differentiator/Differentiator.h>
26#include "TMath.h"
27
28// For the digamma function, that is the derivative of lgamma. We get it via
29// mathmore from the GSL, so the pullbacks that use digamma are only available
30// with mathmore=ON.
31#ifdef R__HAS_MATHMORE
33#endif
34
35#include <stdexcept>
36
37namespace clad {
38namespace custom_derivatives {
39namespace TMath {
40template <typename T>
42{
43 return {::TMath::Abs(x), ((x < 0) ? -1 : 1) * d_x};
44}
45
46template <typename T>
48{
49 return {::TMath::ACos(x), (-1. / ::TMath::Sqrt(1 - x * x)) * d_x};
50}
51
52template <typename T>
54{
55 return {::TMath::ACosH(x), (1. / ::TMath::Sqrt(x * x - 1)) * d_x};
56}
57
58template <typename T>
60{
61 return {::TMath::ASin(x), (1. / ::TMath::Sqrt(1 - x * x)) * d_x};
62}
63
64template <typename T>
66{
67 return {::TMath::ASinH(x), (1. / ::TMath::Sqrt(x * x + 1)) * d_x};
68}
69
70template <typename T>
72{
73 return {::TMath::ATan(x), (1. / (x * x + 1)) * d_x};
74}
75
76template <typename T>
78{
79 return {::TMath::ATanH(x), (1. / (1 - x * x)) * d_x};
80}
81
82template <typename T>
87
88template <typename T>
93
94template <typename T>
99
100template <typename T>
102{
103 return {::TMath::Erfc(x), -Erf_pushforward(x, d_x).pushforward};
104}
105
106#ifdef R__HAS_MATHMORE
107
108template <typename T>
110{
112}
113
114#endif
115
116template <typename T>
121
122template <typename T>
127
128template <typename T, typename U>
129void Hypot_pullback(T x, T y, U p, clad::array_ref<T> d_x, clad::array_ref<T> d_y)
130{
131 T h = ::TMath::Hypot(x, y);
132 *d_x += x / h * p;
133 *d_y += y / h * p;
134}
135
136template <typename T>
138{
139 return {::TMath::Log(x), (1. / x) * d_x};
140}
141
142template <typename T>
144{
145 return {::TMath::Log10(x), (1.0 / (x * ::TMath::Ln10())) * d_x};
146}
147
148template <typename T>
150{
151 return {::TMath::Log2(x), (1.0 / (x * ::TMath::Log(2.0))) * d_x};
152}
153
154template <typename T>
156{
157 T pushforward = y * ::TMath::Power(x, y - 1) * d_x;
158 if (d_y) {
160 }
161 return {::TMath::Power(x, y), pushforward};
162}
163
164template <typename T, typename U>
165void Power_pullback(T x, T y, U p, clad::array_ref<T> d_x, clad::array_ref<T> d_y)
166{
167 auto t = pow_pushforward(x, y, 1, 0);
168 *d_x += t.pushforward * p;
169 t = pow_pushforward(x, y, 0, 1);
170 *d_y += t.pushforward * p;
171}
172
173template <typename T>
178
179template <typename T>
184
185template <typename T>
187{
188 return {::TMath::Sq(x), 2 * x * d_x};
189}
190
191template <typename T>
196
197template <typename T>
202
203template <typename T>
208
209#ifdef WIN32
210// Additional custom derivatives that can be removed
211// after Issue #12108 in ROOT is resolved
212// constexpr is removed
214{
215 return {3.1415926535897931, 0.};
216}
217// constexpr is removed
219{
220 return {2.3025850929940459, 0.};
221}
222#endif
223} // namespace TMath
224
225namespace ROOT {
226namespace Math {
227
228inline void landau_pdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
229{
230 if (xi <= 0) {
231 return;
232 }
233 // clang-format off
234 static double p1[5] = {0.4259894875,-0.1249762550, 0.03984243700, -0.006298287635, 0.001511162253};
235 static double q1[5] = {1.0 ,-0.3388260629, 0.09594393323, -0.01608042283, 0.003778942063};
236
237 static double p2[5] = {0.1788541609, 0.1173957403, 0.01488850518, -0.001394989411, 0.0001283617211};
238 static double q2[5] = {1.0 , 0.7428795082, 0.3153932961, 0.06694219548, 0.008790609714};
239
240 static double p3[5] = {0.1788544503, 0.09359161662,0.006325387654, 0.00006611667319,-0.000002031049101};
241 static double q3[5] = {1.0 , 0.6097809921, 0.2560616665, 0.04746722384, 0.006957301675};
242
243 static double p4[5] = {0.9874054407, 118.6723273, 849.2794360, -743.7792444, 427.0262186};
244 static double q4[5] = {1.0 , 106.8615961, 337.6496214, 2016.712389, 1597.063511};
245
246 static double p5[5] = {1.003675074, 167.5702434, 4789.711289, 21217.86767, -22324.94910};
247 static double q5[5] = {1.0 , 156.9424537, 3745.310488, 9834.698876, 66924.28357};
248
249 static double p6[5] = {1.000827619, 664.9143136, 62972.92665, 475554.6998, -5743609.109};
250 static double q6[5] = {1.0 , 651.4101098, 56974.73333, 165917.4725, -2815759.939};
251
252 static double a1[3] = {0.04166666667,-0.01996527778, 0.02709538966};
253
254 static double a2[2] = {-1.845568670,-4.284640743};
255 // clang-format on
256 const double _const0 = 0.3989422803;
257 double v = (x - x0) / xi;
258 double _d_v = 0;
259 double _d_denlan = 0;
260 if (v < -5.5) {
261 double u = ::std::exp(v + 1.);
262 double _d_u = 0;
263 if (u >= 1.e-10) {
264 const double ue = ::std::exp(-1 / u);
265 const double us = ::std::sqrt(u);
266 double _t3;
267 double _d_ue = 0;
268 double _d_us = 0;
269 double denlan = _const0 * (ue / us) * (1 + (a1[0] + (a1[1] + a1[2] * u) * u) * u);
270 _d_denlan += d_out / xi;
271 *d_xi += d_out * -(denlan / (xi * xi));
272 denlan = _t3;
273 double _r_d3 = _d_denlan;
274 _d_denlan -= _r_d3;
275 _d_ue += _const0 * _r_d3 * (1 + (a1[0] + (a1[1] + a1[2] * u) * u) * u) / us;
276 double _r5 = _const0 * _r_d3 * (1 + (a1[0] + (a1[1] + a1[2] * u) * u) * u) * -(ue / (us * us));
277 _d_us += _r5;
278 _d_u += a1[2] * _const0 * (ue / us) * _r_d3 * u * u;
279 _d_u += (a1[1] + a1[2] * u) * _const0 * (ue / us) * _r_d3 * u;
280 _d_u += (a1[0] + (a1[1] + a1[2] * u) * u) * _const0 * (ue / us) * _r_d3;
281 double _r_d2 = _d_us;
282 _d_us -= _r_d2;
283 double _r4 = 0;
284 _r4 += _r_d2 * clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
285 _d_u += _r4;
286 double _r_d1 = _d_ue;
287 _d_ue -= _r_d1;
288 double _r2 = 0;
289 _r2 += _r_d1 * ::std::exp(-1 / u);
290 double _r3 = _r2 * -(-1 / (u * u));
291 _d_u += _r3;
292 }
293 double _r_d0 = _d_u;
294 _d_u -= _r_d0;
295 double _r1 = 0;
296 _r1 += _r_d0 * ::std::exp(v + 1.);
297 _d_v += _r1;
298 } else if (v < -1) {
299 double _t4;
300 double u = ::std::exp(-v - 1);
301 double _d_u = 0;
302 double _t5;
303 double _t8 = ::std::exp(-u);
304 double _t7 = ::std::sqrt(u);
305 double _t6 = (q1[0] + (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * v);
306 double denlan = _t8 * _t7 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / _t6;
307 _d_denlan += d_out / xi;
308 *d_xi += d_out * -(denlan / (xi * xi));
309 denlan = _t5;
310 double _r_d5 = _d_denlan;
311 _d_denlan -= _r_d5;
312 double _r7 = 0;
313 _r7 += _r_d5 / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) * _t7 * ::std::exp(-u);
314 _d_u += -_r7;
315 double _r8 = 0;
316 _r8 += _t8 * _r_d5 / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) *
317 clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
318 _d_u += _r8;
319 _d_v += p1[4] * _t8 * _t7 * _r_d5 / _t6 * v * v * v;
320 _d_v += (p1[3] + p1[4] * v) * _t8 * _t7 * _r_d5 / _t6 * v * v;
321 _d_v += (p1[2] + (p1[3] + p1[4] * v) * v) * _t8 * _t7 * _r_d5 / _t6 * v;
322 _d_v += (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * _t8 * _t7 * _r_d5 / _t6;
323 double _r9 = _r_d5 * -(_t8 * _t7 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / (_t6 * _t6));
324 _d_v += q1[4] * _r9 * v * v * v;
325 _d_v += (q1[3] + q1[4] * v) * _r9 * v * v;
326 _d_v += (q1[2] + (q1[3] + q1[4] * v) * v) * _r9 * v;
327 _d_v += (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * _r9;
328 u = _t4;
329 double _r_d4 = _d_u;
330 _d_u -= _r_d4;
331 double _r6 = 0;
332 _r6 += _r_d4 * ::std::exp(-v - 1);
333 _d_v += -_r6;
334 } else if (v < 1) {
335 double _t9;
336 double _t10 = (q2[0] + (q2[1] + (q2[2] + (q2[3] + q2[4] * v) * v) * v) * v);
337 double denlan = (p2[0] + (p2[1] + (p2[2] + (p2[3] + p2[4] * v) * v) * v) * v) / _t10;
338 _d_denlan += d_out / xi;
339 *d_xi += d_out * -(denlan / (xi * xi));
340 denlan = _t9;
341 double _r_d6 = _d_denlan;
342 _d_denlan -= _r_d6;
343 _d_v += p2[4] * _r_d6 / _t10 * v * v * v;
344 _d_v += (p2[3] + p2[4] * v) * _r_d6 / _t10 * v * v;
345 _d_v += (p2[2] + (p2[3] + p2[4] * v) * v) * _r_d6 / _t10 * v;
346 _d_v += (p2[1] + (p2[2] + (p2[3] + p2[4] * v) * v) * v) * _r_d6 / _t10;
347 double _r10 = _r_d6 * -((p2[0] + (p2[1] + (p2[2] + (p2[3] + p2[4] * v) * v) * v) * v) / (_t10 * _t10));
348 _d_v += q2[4] * _r10 * v * v * v;
349 _d_v += (q2[3] + q2[4] * v) * _r10 * v * v;
350 _d_v += (q2[2] + (q2[3] + q2[4] * v) * v) * _r10 * v;
351 _d_v += (q2[1] + (q2[2] + (q2[3] + q2[4] * v) * v) * v) * _r10;
352 } else if (v < 5) {
353 double _t11;
354 double _t12 = (q3[0] + (q3[1] + (q3[2] + (q3[3] + q3[4] * v) * v) * v) * v);
355 double denlan = (p3[0] + (p3[1] + (p3[2] + (p3[3] + p3[4] * v) * v) * v) * v) / _t12;
356 _d_denlan += d_out / xi;
357 *d_xi += d_out * -(denlan / (xi * xi));
358 denlan = _t11;
359 double _r_d7 = _d_denlan;
360 _d_denlan -= _r_d7;
361 _d_v += p3[4] * _r_d7 / _t12 * v * v * v;
362 _d_v += (p3[3] + p3[4] * v) * _r_d7 / _t12 * v * v;
363 _d_v += (p3[2] + (p3[3] + p3[4] * v) * v) * _r_d7 / _t12 * v;
364 _d_v += (p3[1] + (p3[2] + (p3[3] + p3[4] * v) * v) * v) * _r_d7 / _t12;
365 double _r11 = _r_d7 * -((p3[0] + (p3[1] + (p3[2] + (p3[3] + p3[4] * v) * v) * v) * v) / (_t12 * _t12));
366 _d_v += q3[4] * _r11 * v * v * v;
367 _d_v += (q3[3] + q3[4] * v) * _r11 * v * v;
368 _d_v += (q3[2] + (q3[3] + q3[4] * v) * v) * _r11 * v;
369 _d_v += (q3[1] + (q3[2] + (q3[3] + q3[4] * v) * v) * v) * _r11;
370 } else if (v < 12) {
371 double u = 1 / v;
372 double _d_u = 0;
373 double _t14;
374 double _t15 = (q4[0] + (q4[1] + (q4[2] + (q4[3] + q4[4] * u) * u) * u) * u);
375 double denlan = u * u * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u) / _t15;
376 _d_denlan += d_out / xi;
377 *d_xi += d_out * -(denlan / (xi * xi));
378 denlan = _t14;
379 double _r_d9 = _d_denlan;
380 _d_denlan -= _r_d9;
381 _d_u += _r_d9 / _t15 * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u) * u;
382 _d_u += u * _r_d9 / _t15 * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u);
383 _d_u += p4[4] * u * u * _r_d9 / _t15 * u * u * u;
384 _d_u += (p4[3] + p4[4] * u) * u * u * _r_d9 / _t15 * u * u;
385 _d_u += (p4[2] + (p4[3] + p4[4] * u) * u) * u * u * _r_d9 / _t15 * u;
386 _d_u += (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u * u * _r_d9 / _t15;
387 double _r13 = _r_d9 * -(u * u * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u) / (_t15 * _t15));
388 _d_u += q4[4] * _r13 * u * u * u;
389 _d_u += (q4[3] + q4[4] * u) * _r13 * u * u;
390 _d_u += (q4[2] + (q4[3] + q4[4] * u) * u) * _r13 * u;
391 _d_u += (q4[1] + (q4[2] + (q4[3] + q4[4] * u) * u) * u) * _r13;
392 double _r_d8 = _d_u;
393 _d_u -= _r_d8;
394 double _r12 = _r_d8 * -(1 / (v * v));
395 _d_v += _r12;
396 } else if (v < 50) {
397 double u = 1 / v;
398 double _d_u = 0;
399 double _t17;
400 double _t18 = (q5[0] + (q5[1] + (q5[2] + (q5[3] + q5[4] * u) * u) * u) * u);
401 double denlan = u * u * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u) / _t18;
402 _d_denlan += d_out / xi;
403 *d_xi += d_out * -(denlan / (xi * xi));
404 denlan = _t17;
405 double _r_d11 = _d_denlan;
406 _d_denlan -= _r_d11;
407 _d_u += _r_d11 / _t18 * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u) * u;
408 _d_u += u * _r_d11 / _t18 * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u);
409 _d_u += p5[4] * u * u * _r_d11 / _t18 * u * u * u;
410 _d_u += (p5[3] + p5[4] * u) * u * u * _r_d11 / _t18 * u * u;
411 _d_u += (p5[2] + (p5[3] + p5[4] * u) * u) * u * u * _r_d11 / _t18 * u;
412 _d_u += (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u * u * _r_d11 / _t18;
413 double _r15 = _r_d11 * -(u * u * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u) / (_t18 * _t18));
414 _d_u += q5[4] * _r15 * u * u * u;
415 _d_u += (q5[3] + q5[4] * u) * _r15 * u * u;
416 _d_u += (q5[2] + (q5[3] + q5[4] * u) * u) * _r15 * u;
417 _d_u += (q5[1] + (q5[2] + (q5[3] + q5[4] * u) * u) * u) * _r15;
418 double _r_d10 = _d_u;
419 _d_u -= _r_d10;
420 double _r14 = _r_d10 * -(1 / (v * v));
421 _d_v += _r14;
422 } else if (v < 300) {
423 double _t19;
424 double u = 1 / v;
425 double _d_u = 0;
426 double _t20;
427 double _t21 = (q6[0] + (q6[1] + (q6[2] + (q6[3] + q6[4] * u) * u) * u) * u);
428 double denlan = u * u * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u) / _t21;
429 _d_denlan += d_out / xi;
430 *d_xi += d_out * -(denlan / (xi * xi));
431 denlan = _t20;
432 double _r_d13 = _d_denlan;
433 _d_denlan -= _r_d13;
434 _d_u += _r_d13 / _t21 * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u) * u;
435 _d_u += u * _r_d13 / _t21 * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u);
436 _d_u += p6[4] * u * u * _r_d13 / _t21 * u * u * u;
437 _d_u += (p6[3] + p6[4] * u) * u * u * _r_d13 / _t21 * u * u;
438 _d_u += (p6[2] + (p6[3] + p6[4] * u) * u) * u * u * _r_d13 / _t21 * u;
439 _d_u += (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u * u * _r_d13 / _t21;
440 double _r17 = _r_d13 * -(u * u * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u) / (_t21 * _t21));
441 _d_u += q6[4] * _r17 * u * u * u;
442 _d_u += (q6[3] + q6[4] * u) * _r17 * u * u;
443 _d_u += (q6[2] + (q6[3] + q6[4] * u) * u) * _r17 * u;
444 _d_u += (q6[1] + (q6[2] + (q6[3] + q6[4] * u) * u) * u) * _r17;
445 u = _t19;
446 double _r_d12 = _d_u;
447 _d_u -= _r_d12;
448 double _r16 = _r_d12 * -(1 / (v * v));
449 _d_v += _r16;
450 } else {
451 double _t22;
452 double _t25 = ::std::log(v);
453 double _t24 = (v + 1);
454 double _t23 = (v - v * _t25 / _t24);
455 double u = 1 / _t23;
456 double _d_u = 0;
457 double _t26;
458 double denlan = u * u * (1 + (a2[0] + a2[1] * u) * u);
459 _d_denlan += d_out / xi;
460 *d_xi += d_out * -(denlan / (xi * xi));
461 denlan = _t26;
462 double _r_d15 = _d_denlan;
463 _d_denlan -= _r_d15;
464 _d_u += _r_d15 * (1 + (a2[0] + a2[1] * u) * u) * u;
465 _d_u += u * _r_d15 * (1 + (a2[0] + a2[1] * u) * u);
466 _d_u += a2[1] * u * u * _r_d15 * u;
467 _d_u += (a2[0] + a2[1] * u) * u * u * _r_d15;
468 u = _t22;
469 double _r_d14 = _d_u;
470 _d_u -= _r_d14;
471 double _r18 = _r_d14 * -(1 / (_t23 * _t23));
472 _d_v += _r18;
473 _d_v += -_r18 / _t24 * _t25;
474 double _r19 = 0;
475 _r19 += v * -_r18 / _t24 / v;
476 _d_v += _r19;
477 double _r20 = -_r18 * -(v * _t25 / (_t24 * _t24));
478 _d_v += _r20;
479 }
480 *d_x += _d_v / xi;
481 *d_x0 += -_d_v / xi;
482 double _r0 = _d_v * -((x - x0) / (xi * xi));
483 *d_xi += _r0;
484}
485
486inline void landau_cdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
487{
488 // clang-format off
489 static double p1[5] = {0.2514091491e+0,-0.6250580444e-1, 0.1458381230e-1,-0.2108817737e-2, 0.7411247290e-3};
490 static double q1[5] = {1.0 ,-0.5571175625e-2, 0.6225310236e-1,-0.3137378427e-2, 0.1931496439e-2};
491
492 static double p2[4] = {0.2868328584e+0, 0.3564363231e+0, 0.1523518695e+0, 0.2251304883e-1};
493 static double q2[4] = {1.0 , 0.6191136137e+0, 0.1720721448e+0, 0.2278594771e-1};
494
495 static double p3[4] = {0.2868329066e+0, 0.3003828436e+0, 0.9950951941e-1, 0.8733827185e-2};
496 static double q3[4] = {1.0 , 0.4237190502e+0, 0.1095631512e+0, 0.8693851567e-2};
497
498 static double p4[4] = {0.1000351630e+1, 0.4503592498e+1, 0.1085883880e+2, 0.7536052269e+1};
499 static double q4[4] = {1.0 , 0.5539969678e+1, 0.1933581111e+2, 0.2721321508e+2};
500
501 static double p5[4] = {0.1000006517e+1, 0.4909414111e+2, 0.8505544753e+2, 0.1532153455e+3};
502 static double q5[4] = {1.0 , 0.5009928881e+2, 0.1399819104e+3, 0.4200002909e+3};
503
504 static double p6[4] = {0.1000000983e+1, 0.1329868456e+3, 0.9162149244e+3,-0.9605054274e+3};
505 static double q6[4] = {1.0 , 0.1339887843e+3, 0.1055990413e+4, 0.5532224619e+3};
506
507 static double a1[4] = {0 ,-0.4583333333e+0, 0.6675347222e+0,-0.1641741416e+1};
508 static double a2[4] = {0 , 1.0 ,-0.4227843351e+0,-0.2043403138e+1};
509 // clang-format on
510
511 const double v = (x - x0) / xi;
512 double _d_v = 0;
513 if (v < -5.5) {
514 double _d_u = 0;
515 const double _const0 = 0.3989422803;
516 double u = ::std::exp(v + 1);
517 double _t3 = ::std::exp(-1. / u);
518 double _t2 = ::std::sqrt(u);
519 double _r2 = 0;
520 _r2 += _const0 * d_out * (1 + (a1[1] + (a1[2] + a1[3] * u) * u) * u) * _t2 * ::std::exp(-1. / u);
521 double _r3 = _r2 * -(-1. / (u * u));
522 _d_u += _r3;
523 double _r4 = 0;
524 _r4 += _const0 * _t3 * d_out * (1 + (a1[1] + (a1[2] + a1[3] * u) * u) * u) *
525 clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
526 _d_u += _r4;
527 _d_u += a1[3] * _const0 * _t3 * _t2 * d_out * u * u;
528 _d_u += (a1[2] + a1[3] * u) * _const0 * _t3 * _t2 * d_out * u;
529 _d_u += (a1[1] + (a1[2] + a1[3] * u) * u) * _const0 * _t3 * _t2 * d_out;
530 _d_v += _d_u * ::std::exp(v + 1);
531 } else if (v < -1) {
532 double _d_u = 0;
533 double u = ::std::exp(-v - 1);
534 double _t8 = ::std::exp(-u);
535 double _t7 = ::std::sqrt(u);
536 double _t6 = (q1[0] + (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * v);
537 double _r6 = 0;
538 _r6 += d_out / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / _t7 * ::std::exp(-u);
539 _d_u += -_r6;
540 double _r7 = d_out / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) * -(_t8 / (_t7 * _t7));
541 double _r8 = 0;
542 _r8 += _r7 * clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
543 _d_u += _r8;
544 _d_v += p1[4] * (_t8 / _t7) * d_out / _t6 * v * v * v;
545 _d_v += (p1[3] + p1[4] * v) * (_t8 / _t7) * d_out / _t6 * v * v;
546 _d_v += (p1[2] + (p1[3] + p1[4] * v) * v) * (_t8 / _t7) * d_out / _t6 * v;
547 _d_v += (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * (_t8 / _t7) * d_out / _t6;
548 double _r9 = d_out * -((_t8 / _t7) * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / (_t6 * _t6));
549 _d_v += q1[4] * _r9 * v * v * v;
550 _d_v += (q1[3] + q1[4] * v) * _r9 * v * v;
551 _d_v += (q1[2] + (q1[3] + q1[4] * v) * v) * _r9 * v;
552 _d_v += (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * _r9;
553 _d_v += -_d_u * ::std::exp(-v - 1);
554 } else if (v < 1) {
555 double _t10 = (q2[0] + (q2[1] + (q2[2] + q2[3] * v) * v) * v);
556 _d_v += p2[3] * d_out / _t10 * v * v;
557 _d_v += (p2[2] + p2[3] * v) * d_out / _t10 * v;
558 _d_v += (p2[1] + (p2[2] + p2[3] * v) * v) * d_out / _t10;
559 double _r10 = d_out * -((p2[0] + (p2[1] + (p2[2] + p2[3] * v) * v) * v) / (_t10 * _t10));
560 _d_v += q2[3] * _r10 * v * v;
561 _d_v += (q2[2] + q2[3] * v) * _r10 * v;
562 _d_v += (q2[1] + (q2[2] + q2[3] * v) * v) * _r10;
563 } else if (v < 4) {
564 double _t12 = (q3[0] + (q3[1] + (q3[2] + q3[3] * v) * v) * v);
565 _d_v += p3[3] * d_out / _t12 * v * v;
566 _d_v += (p3[2] + p3[3] * v) * d_out / _t12 * v;
567 _d_v += (p3[1] + (p3[2] + p3[3] * v) * v) * d_out / _t12;
568 double _r11 = d_out * -((p3[0] + (p3[1] + (p3[2] + p3[3] * v) * v) * v) / (_t12 * _t12));
569 _d_v += q3[3] * _r11 * v * v;
570 _d_v += (q3[2] + q3[3] * v) * _r11 * v;
571 _d_v += (q3[1] + (q3[2] + q3[3] * v) * v) * _r11;
572 } else if (v < 12) {
573 double _d_u = 0;
574 double u = 1. / v;
575 double _t15 = (q4[0] + (q4[1] + (q4[2] + q4[3] * u) * u) * u);
576 _d_u += p4[3] * d_out / _t15 * u * u;
577 _d_u += (p4[2] + p4[3] * u) * d_out / _t15 * u;
578 _d_u += (p4[1] + (p4[2] + p4[3] * u) * u) * d_out / _t15;
579 double _r13 = d_out * -((p4[0] + (p4[1] + (p4[2] + p4[3] * u) * u) * u) / (_t15 * _t15));
580 _d_u += q4[3] * _r13 * u * u;
581 _d_u += (q4[2] + q4[3] * u) * _r13 * u;
582 _d_u += (q4[1] + (q4[2] + q4[3] * u) * u) * _r13;
583 double _r12 = _d_u * -(1. / (v * v));
584 _d_v += _r12;
585 } else if (v < 50) {
586 double _d_u = 0;
587 double u = 1. / v;
588 double _t18 = (q5[0] + (q5[1] + (q5[2] + q5[3] * u) * u) * u);
589 _d_u += p5[3] * d_out / _t18 * u * u;
590 _d_u += (p5[2] + p5[3] * u) * d_out / _t18 * u;
591 _d_u += (p5[1] + (p5[2] + p5[3] * u) * u) * d_out / _t18;
592 double _r15 = d_out * -((p5[0] + (p5[1] + (p5[2] + p5[3] * u) * u) * u) / (_t18 * _t18));
593 _d_u += q5[3] * _r15 * u * u;
594 _d_u += (q5[2] + q5[3] * u) * _r15 * u;
595 _d_u += (q5[1] + (q5[2] + q5[3] * u) * u) * _r15;
596 double _r14 = _d_u * -(1. / (v * v));
597 _d_v += _r14;
598 } else if (v < 300) {
599 double _d_u = 0;
600 double u = 1. / v;
601 double _t21 = (q6[0] + (q6[1] + (q6[2] + q6[3] * u) * u) * u);
602 _d_u += p6[3] * d_out / _t21 * u * u;
603 _d_u += (p6[2] + p6[3] * u) * d_out / _t21 * u;
604 _d_u += (p6[1] + (p6[2] + p6[3] * u) * u) * d_out / _t21;
605 double _r17 = d_out * -((p6[0] + (p6[1] + (p6[2] + p6[3] * u) * u) * u) / (_t21 * _t21));
606 _d_u += q6[3] * _r17 * u * u;
607 _d_u += (q6[2] + q6[3] * u) * _r17 * u;
608 _d_u += (q6[1] + (q6[2] + q6[3] * u) * u) * _r17;
609 double _r16 = _d_u * -(1. / (v * v));
610 _d_v += _r16;
611 } else {
612 double _d_u = 0;
613 double _t25 = ::std::log(v);
614 double _t24 = (v + 1);
615 double _t23 = (v - v * _t25 / _t24);
616 double u = 1. / _t23;
617 double _t26;
618 _d_u += a2[3] * -d_out * u * u;
619 _d_u += (a2[2] + a2[3] * u) * -d_out * u;
620 _d_u += (a2[1] + (a2[2] + a2[3] * u) * u) * -d_out;
621 double _r18 = _d_u * -(1. / (_t23 * _t23));
622 _d_v += _r18;
623 _d_v += -_r18 / _t24 * _t25;
624 double _r19 = 0;
625 _r19 += v * -_r18 / _t24 / v;
626 _d_v += _r19;
627 double _r20 = -_r18 * -(v * _t25 / (_t24 * _t24));
628 _d_v += _r20;
629 }
630
631 *d_x += _d_v / xi;
632 *d_x0 += -_d_v / xi;
633 *d_xi += _d_v * -((x - x0) / (xi * xi));
634}
635
636#ifdef R__HAS_MATHMORE
637
638inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x);
639
640inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x)
641{
642 // Synced with SpecFuncCephes.h
643 constexpr double kMACHEP = 1.11022302462515654042363166809e-16;
644 constexpr double kMAXLOG = 709.782712893383973096206318587;
645 constexpr double kMINLOG = -708.396418532264078748994506896;
646 constexpr double kMAXSTIR = 108.116855767857671821730036754;
647 constexpr double kMAXLGM = 2.556348e305;
648 constexpr double kBig = 4.503599627370496e15;
649 constexpr double kBiginv = 2.22044604925031308085e-16;
650
651 double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_r = 0;
652 double _t1;
653 double _t2;
654 double _t3;
655 double _t4;
656 double _t5;
657 clad::tape<double> _t7 = {};
658 clad::tape<double> _t8 = {};
659 clad::tape<double> _t9 = {};
660 double ans, ax, c, r;
661 if (a <= 0)
662 return;
663 if (x <= 0)
664 return;
665 if ((x > 1.) && (x > a)) {
666 double _r0 = 0;
667 double _r1 = 0;
669 *_d_a += _r0;
670 *_d_x += _r1;
671 return;
672 }
673 _t1 = ::std::log(x);
674 ax = a * _t1 - x - ::std::lgamma(a);
675 if (ax < -kMAXLOG) {
676 *_d_x += (a * _d_ax / x) - _d_ax;
677 *_d_a +=
678 _d_ax *
679 (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
680 _d_ax = 0.;
681 return;
682 }
683 _t2 = ax;
684 ax = ::std::exp(ax);
685 _t3 = r;
686 r = a;
687 _t4 = c;
688 c = 1.;
689 _t5 = ans;
690 ans = 1.;
691 unsigned long _t6 = 0;
692 do {
693 _t6++;
694 clad::push(_t7, r);
695 r += 1.;
696 clad::push(_t8, c);
697 c *= x / r;
698 clad::push(_t9, ans);
699 ans += c;
700 } while (c / ans > kMACHEP);
701 {
702 _d_ans += _d_y / a * ax;
703 _d_ax += ans * _d_y / a;
704 double _r6 = _d_y * -(ans * ax / (a * a));
705 *_d_a += _r6;
706 }
707 do {
708 {
709 {
710 ans = clad::pop(_t9);
711 double _r_d7 = _d_ans;
712 _d_c += _r_d7;
713 }
714 {
715 c = clad::pop(_t8);
716 double _r_d6 = _d_c;
717 _d_c -= _r_d6;
718 _d_c += _r_d6 * x / r;
719 *_d_x += c * _r_d6 / r;
720 double _r5 = c * _r_d6 * -(x / (r * r));
721 _d_r += _r5;
722 }
723 {
724 r = clad::pop(_t7);
725 double _r_d5 = _d_r;
726 }
727 }
728 _t6--;
729 } while (_t6);
730 {
731 ans = _t5;
732 double _r_d4 = _d_ans;
733 _d_ans -= _r_d4;
734 }
735 {
736 c = _t4;
737 double _r_d3 = _d_c;
738 _d_c -= _r_d3;
739 }
740 {
741 r = _t3;
742 double _r_d2 = _d_r;
743 _d_r -= _r_d2;
744 *_d_a += _r_d2;
745 }
746 {
747 ax = _t2;
748 double _r_d1 = _d_ax;
749 _d_ax -= _r_d1;
750 double _r4 = 0;
751 _r4 += _r_d1 * ::std::exp(ax);
752 _d_ax += _r4;
753 }
754 {
755 *_d_x += (a * _d_ax / x) - _d_ax;
756 *_d_a +=
757 _d_ax *
758 (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
759 _d_ax = 0.;
760 }
761}
762
763inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x)
764{
765 // Synced with SpecFuncCephes.h
766 constexpr double kMACHEP = 1.11022302462515654042363166809e-16;
767 constexpr double kMAXLOG = 709.782712893383973096206318587;
768 constexpr double kMINLOG = -708.396418532264078748994506896;
769 constexpr double kMAXSTIR = 108.116855767857671821730036754;
770 constexpr double kMAXLGM = 2.556348e305;
771 constexpr double kBig = 4.503599627370496e15;
772 constexpr double kBiginv = 2.22044604925031308085e-16;
773
774 double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_yc = 0, _d_r = 0, _d_t = 0, _d_y0 = 0, _d_z = 0;
775 double _d_pk = 0, _d_pkm1 = 0, _d_pkm2 = 0, _d_qk = 0, _d_qkm1 = 0, _d_qkm2 = 0;
776 double _t1;
777 double _t2;
778 double _t3;
779 double _t4;
780 double _t5;
781 double _t6;
782 double _t7;
783 double _t8;
784 double _t9;
785 double _t10;
786 unsigned long _t11;
787 clad::tape<double> _t12 = {};
788 clad::tape<double> _t13 = {};
789 clad::tape<double> _t14 = {};
790 clad::tape<double> _t15 = {};
791 clad::tape<double> _t16 = {};
792 clad::tape<double> _t17 = {};
793 clad::tape<double> _t19 = {};
794 clad::tape<double> _t20 = {};
795 clad::tape<double> _t21 = {};
796 clad::tape<double> _t22 = {};
797 clad::tape<double> _t23 = {};
798 clad::tape<double> _t24 = {};
799 clad::tape<double> _t25 = {};
800 clad::tape<double> _t26 = {};
801 clad::tape<double> _t27 = {};
802 clad::tape<bool> _t29 = {};
803 clad::tape<double> _t30 = {};
804 clad::tape<double> _t31 = {};
805 clad::tape<double> _t32 = {};
806 clad::tape<double> _t33 = {};
807 double ans, ax, c, yc, r, t, y, z;
808 double pk, pkm1, pkm2, qk, qkm1, qkm2;
809 if (a <= 0)
810 return;
811 if (x <= 0)
812 return;
813 if ((x < 1.) || (x < a)) {
814 double _r0 = 0;
815 double _r1 = 0;
817 *_d_a += _r0;
818 *_d_x += _r1;
819 return;
820 }
821 _t1 = ::std::log(x);
822 ax = a * _t1 - x - ::std::lgamma(a);
823 if (ax < -kMAXLOG) {
824 *_d_x += a * _d_ax / x - _d_ax;
825 *_d_a +=
826 _d_ax *
827 (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
828 _d_ax = 0.;
829 return;
830 }
831 _t2 = ax;
832 ax = ::std::exp(ax);
833 _t3 = y;
834 y = 1. - a;
835 _t4 = z;
836 z = x + y + 1.;
837 _t5 = c;
838 c = 0.;
839 _t6 = pkm2;
840 pkm2 = 1.;
841 _t7 = qkm2;
842 qkm2 = x;
843 _t8 = pkm1;
844 pkm1 = x + 1.;
845 _t9 = qkm1;
846 qkm1 = z * x;
847 _t10 = ans;
848 ans = pkm1 / qkm1;
849 _t11 = 0;
850 do {
851 _t11++;
852 clad::push(_t12, c);
853 c += 1.;
854 clad::push(_t13, y);
855 y += 1.;
856 clad::push(_t14, z);
857 z += 2.;
858 clad::push(_t15, yc);
859 yc = y * c;
860 clad::push(_t16, pk);
861 pk = pkm1 * z - pkm2 * yc;
862 clad::push(_t17, qk);
863 qk = qkm1 * z - qkm2 * yc;
864 double _t18 = qk;
865 {
866 if (_t18) {
867 clad::push(_t20, r);
868 r = pk / qk;
869 clad::push(_t21, t);
870 t = ::std::abs((ans - r) / r);
871 clad::push(_t22, ans);
872 ans = r;
873 } else {
874 clad::push(_t23, t);
875 t = 1.;
876 }
877 clad::push(_t19, _t18);
878 }
879 clad::push(_t24, pkm2);
880 pkm2 = pkm1;
881 clad::push(_t25, pkm1);
882 pkm1 = pk;
883 clad::push(_t26, qkm2);
884 qkm2 = qkm1;
885 clad::push(_t27, qkm1);
886 qkm1 = qk;
887 bool _t28 = ::std::abs(pk) > kBig;
888 {
889 if (_t28) {
890 clad::push(_t30, pkm2);
891 pkm2 *= kBiginv;
892 clad::push(_t31, pkm1);
893 pkm1 *= kBiginv;
894 clad::push(_t32, qkm2);
895 qkm2 *= kBiginv;
896 clad::push(_t33, qkm1);
897 qkm1 *= kBiginv;
898 }
899 clad::push(_t29, _t28);
900 }
901 } while (t > kMACHEP);
902 {
903 _d_ans += _d_y * ax;
904 _d_ax += ans * _d_y;
905 }
906 do {
907 {
908 if (clad::pop(_t29)) {
909 {
910 qkm1 = clad::pop(_t33);
911 double _r_d27 = _d_qkm1;
912 _d_qkm1 -= _r_d27;
913 _d_qkm1 += _r_d27 * kBiginv;
914 }
915 {
916 qkm2 = clad::pop(_t32);
917 double _r_d26 = _d_qkm2;
918 _d_qkm2 -= _r_d26;
919 _d_qkm2 += _r_d26 * kBiginv;
920 }
921 {
922 pkm1 = clad::pop(_t31);
923 double _r_d25 = _d_pkm1;
924 _d_pkm1 -= _r_d25;
925 _d_pkm1 += _r_d25 * kBiginv;
926 }
927 {
928 pkm2 = clad::pop(_t30);
929 double _r_d24 = _d_pkm2;
930 _d_pkm2 -= _r_d24;
931 _d_pkm2 += _r_d24 * kBiginv;
932 }
933 }
934 {
935 qkm1 = clad::pop(_t27);
936 double _r_d23 = _d_qkm1;
937 _d_qkm1 -= _r_d23;
938 _d_qk += _r_d23;
939 }
940 {
941 qkm2 = clad::pop(_t26);
942 double _r_d22 = _d_qkm2;
943 _d_qkm2 -= _r_d22;
944 _d_qkm1 += _r_d22;
945 }
946 {
947 pkm1 = clad::pop(_t25);
948 double _r_d21 = _d_pkm1;
949 _d_pkm1 -= _r_d21;
950 _d_pk += _r_d21;
951 }
952 {
953 pkm2 = clad::pop(_t24);
954 double _r_d20 = _d_pkm2;
955 _d_pkm2 -= _r_d20;
956 _d_pkm1 += _r_d20;
957 }
958 if (clad::pop(_t19)) {
959 {
960 ans = clad::pop(_t22);
961 double _r_d18 = _d_ans;
962 _d_ans -= _r_d18;
963 _d_r += _r_d18;
964 }
965 {
966 t = clad::pop(_t21);
967 double _r_d17 = _d_t;
968 _d_t -= _r_d17;
969 double _r7 = 0;
970 _r7 += _r_d17 * clad::custom_derivatives::std::abs_pushforward((ans - r) / r, 1.).pushforward;
971 _d_ans += _r7 / r;
972 _d_r += -_r7 / r;
973 double _r8 = _r7 * -((ans - r) / (r * r));
974 _d_r += _r8;
975 }
976 {
977 r = clad::pop(_t20);
978 double _r_d16 = _d_r;
979 _d_r -= _r_d16;
980 _d_pk += _r_d16 / qk;
981 double _r6 = _r_d16 * -(pk / (qk * qk));
982 _d_qk += _r6;
983 }
984 } else {
985 t = clad::pop(_t23);
986 double _r_d19 = _d_t;
987 _d_t -= _r_d19;
988 }
989 {
990 qk = clad::pop(_t17);
991 double _r_d15 = _d_qk;
992 _d_qk -= _r_d15;
993 _d_qkm1 += _r_d15 * z;
994 _d_z += qkm1 * _r_d15;
995 _d_qkm2 += -_r_d15 * yc;
996 _d_yc += qkm2 * -_r_d15;
997 }
998 {
999 pk = clad::pop(_t16);
1000 double _r_d14 = _d_pk;
1001 _d_pk -= _r_d14;
1002 _d_pkm1 += _r_d14 * z;
1003 _d_z += pkm1 * _r_d14;
1004 _d_pkm2 += -_r_d14 * yc;
1005 _d_yc += pkm2 * -_r_d14;
1006 }
1007 {
1008 yc = clad::pop(_t15);
1009 double _r_d13 = _d_yc;
1010 _d_yc -= _r_d13;
1011 _d_y0 += _r_d13 * c;
1012 _d_c += y * _r_d13;
1013 }
1014 {
1015 z = clad::pop(_t14);
1016 double _r_d12 = _d_z;
1017 }
1018 {
1019 y = clad::pop(_t13);
1020 double _r_d11 = _d_y0;
1021 }
1022 {
1023 c = clad::pop(_t12);
1024 double _r_d10 = _d_c;
1025 }
1026 }
1027 _t11--;
1028 } while (_t11);
1029 {
1030 ans = _t10;
1031 double _r_d9 = _d_ans;
1032 _d_ans -= _r_d9;
1033 _d_pkm1 += _r_d9 / qkm1;
1034 double _r5 = _r_d9 * -(pkm1 / (qkm1 * qkm1));
1035 _d_qkm1 += _r5;
1036 }
1037 {
1038 qkm1 = _t9;
1039 double _r_d8 = _d_qkm1;
1040 _d_qkm1 -= _r_d8;
1041 _d_z += _r_d8 * x;
1042 *_d_x += z * _r_d8;
1043 }
1044 {
1045 pkm1 = _t8;
1046 double _r_d7 = _d_pkm1;
1047 _d_pkm1 -= _r_d7;
1048 *_d_x += _r_d7;
1049 }
1050 {
1051 qkm2 = _t7;
1052 double _r_d6 = _d_qkm2;
1053 _d_qkm2 -= _r_d6;
1054 *_d_x += _r_d6;
1055 }
1056 {
1057 pkm2 = _t6;
1058 double _r_d5 = _d_pkm2;
1059 _d_pkm2 -= _r_d5;
1060 }
1061 {
1062 c = _t5;
1063 double _r_d4 = _d_c;
1064 _d_c -= _r_d4;
1065 }
1066 {
1067 z = _t4;
1068 double _r_d3 = _d_z;
1069 _d_z -= _r_d3;
1070 *_d_x += _r_d3;
1071 _d_y0 += _r_d3;
1072 }
1073 {
1074 y = _t3;
1075 double _r_d2 = _d_y0;
1076 _d_y0 -= _r_d2;
1077 *_d_a += -_r_d2;
1078 }
1079 {
1080 ax = _t2;
1081 double _r_d1 = _d_ax;
1082 _d_ax -= _r_d1;
1083 double _r4 = _r_d1 * ::std::exp(ax);
1084 _d_ax += _r4;
1085 }
1086 {
1087 *_d_x += a * _d_ax / x - _d_ax;
1088 *_d_a +=
1089 _d_ax *
1090 (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
1091 _d_ax = 0.;
1092 }
1093}
1094
1095#endif // R__HAS_MATHMORE
1096
1097} // namespace Math
1098} // namespace ROOT
1099
1100} // namespace custom_derivatives
1101} // namespace clad
1102
1103// Forward declare BLAS functions.
1104extern "C" void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
1105 const float *alpha, const float *A, const int *lda, const float *B, const int *ldb,
1106 const float *beta, float *C, const int *ldc);
1107
1108namespace clad::custom_derivatives {
1109
1111
1112inline void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, int n, int k, float alpha,
1113 const float *A, const float *B, float beta, const float *C, float *_d_output, bool *,
1114 bool *, int *, int *, int *, float *_d_alpha, float *_d_A, float *_d_B, float *_d_beta,
1115 float *_d_C)
1116{
1117 // TODO:
1118 // - handle transa and transb cases correctly
1119 if (transa || transb) {
1120 return;
1121 }
1122
1123 char ct = 't';
1124 char cn = 'n';
1125
1126 // beta needs to be one because we want to add to _d_A and _d_B instead of
1127 // overwriting it.
1128 float one = 1.;
1129
1130 // _d_A, _d_B
1131 // note: beta needs to be one because we want to add to _d_A and _d_B instead of overwriting it.
1132 ::sgemm_(&cn, &ct, &m, &k, &n, &alpha, _d_output, &m, B, &k, &one, _d_A, &m);
1133 ::sgemm_(&ct, &cn, &k, &n, &m, &alpha, A, &m, _d_output, &m, &one, _d_B, &k);
1134
1135 // _d_alpha, _d_beta, _d_C
1136 int sizeC = n * m;
1137 for (int i = 0; i < sizeC; ++i) {
1138 *_d_alpha += _d_output[i] * (output[i] - beta * C[i]);
1139 *_d_beta += _d_output[i] * C[i];
1140 _d_C[i] += _d_output[i] * beta;
1141 }
1142}
1143
1144} // namespace TMVA::Experimental::SOFIE
1145
1146} // namespace clad::custom_derivatives
1147
1148#endif // CLAD_DERIVATOR
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
#define c(i)
Definition RSha256.hxx:101
#define a(i)
Definition RSha256.hxx:99
#define h(i)
Definition RSha256.hxx:106
#define kMACHEP
#define kMAXLOG
#define kMAXLGM
#define kMAXSTIR
#define kMINLOG
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
double digamma(double x)
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
Namespace for new Math classes and functions.
TMath.
Definition TMathBase.h:35
Double_t CosH(Double_t)
Returns the hyperbolic cosine of x.
Definition TMath.h:623
Double_t ACos(Double_t)
Returns the principal value of the arc cosine of x, expressed in radians.
Definition TMath.h:643
Double_t ASin(Double_t)
Returns the principal value of the arc sine of x, expressed in radians.
Definition TMath.h:635
Double_t Log2(Double_t x)
Returns the binary (base-2) logarithm of x.
Definition TMath.cxx:107
Double_t Exp(Double_t x)
Returns the base-e exponential function of x, which is e raised to the power x.
Definition TMath.h:720
Double_t Erf(Double_t x)
Computation of the error function erf(x).
Definition TMath.cxx:190
Double_t ATan(Double_t)
Returns the principal value of the arc tangent of x, expressed in radians.
Definition TMath.h:651
Double_t ASinH(Double_t)
Returns the area hyperbolic sine of x.
Definition TMath.cxx:67
Double_t TanH(Double_t)
Returns the hyperbolic tangent of x.
Definition TMath.h:629
Double_t ACosH(Double_t)
Returns the nonnegative area hyperbolic cosine of x.
Definition TMath.cxx:81
Double_t Log(Double_t x)
Returns the natural logarithm of x.
Definition TMath.h:767
Double_t Erfc(Double_t x)
Computes the complementary error function erfc(x).
Definition TMath.cxx:199
Double_t Sq(Double_t x)
Returns x*x.
Definition TMath.h:667
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:673
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Returns x raised to the power y.
Definition TMath.h:732
constexpr Double_t Ln10()
Natural log of 10 (to convert log to ln)
Definition TMath.h:103
Double_t Hypot(Double_t x, Double_t y)
Returns sqrt(x*x + y*y)
Definition TMath.cxx:59
Double_t Cos(Double_t)
Returns the cosine of an angle of x radians.
Definition TMath.h:605
constexpr Double_t Pi()
Definition TMath.h:40
Double_t LnGamma(Double_t z)
Computation of ln[gamma(z)] for all z.
Definition TMath.cxx:509
Double_t Sin(Double_t)
Returns the sine of an angle of x radians.
Definition TMath.h:599
Double_t Tan(Double_t)
Returns the tangent of an angle of x radians.
Definition TMath.h:611
Double_t ATanH(Double_t)
Returns the area hyperbolic tangent of x.
Definition TMath.cxx:95
Double_t Log10(Double_t x)
Returns the common (base-10) logarithm of x.
Definition TMath.h:773
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:124
Double_t SinH(Double_t)
Returns the hyperbolic sine of `x.
Definition TMath.h:617
void landau_pdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
void landau_cdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, const float *B, float beta, const float *C, float *_d_output, bool *, bool *, int *, int *, int *, float *_d_alpha, float *_d_A, float *_d_B, float *_d_beta, float *_d_C)
ValueAndPushforward< T, T > CosH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Abs_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Sq_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Erf_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Erfc_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Sin_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Hypot_pushforward(T x, T y, T d_x, T d_y)
ValueAndPushforward< T, T > ASinH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ACosH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ASin_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Cos_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Sqrt_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Tan_pushforward(T x, T d_x)
void Hypot_pullback(T x, T y, U p, clad::array_ref< T > d_x, clad::array_ref< T > d_y)
ValueAndPushforward< T, T > Power_pushforward(T x, T y, T d_x, T d_y)
void Power_pullback(T x, T y, U p, clad::array_ref< T > d_x, clad::array_ref< T > d_y)
ValueAndPushforward< T, T > Log_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Log10_pushforward(T x, T d_x)
ValueAndPushforward< T, T > TanH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ACos_pushforward(T x, T d_x)
ValueAndPushforward< T, T > SinH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Exp_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Log2_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ATanH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ATan_pushforward(T x, T d_x)
TMarker m
Definition textangle.C:8
static void output()