ROOT   Reference Guide
mulmod.h
Go to the documentation of this file.
1// @(#)root/mathcore:$Id$
2// Author: Jonas Hahnfeld 11/2020
3
4/*************************************************************************
5 * Copyright (C) 1995-2020, Rene Brun and Fons Rademakers. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. * 9 * For the list of contributors see$ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#ifndef RANLUXPP_MULMOD_H
13#define RANLUXPP_MULMOD_H
14
15#include "helpers.h"
16
17#include <cstdint>
18
19/// Multiply two 576 bit numbers, stored as 9 numbers of 64 bits each
20///
21/// \param[in] in1 first factor as 9 numbers of 64 bits each
22/// \param[in] in2 second factor as 9 numbers of 64 bits each
23/// \param[out] out result with 18 numbers of 64 bits each
24static void multiply9x9(const uint64_t *in1, const uint64_t *in2, uint64_t *out)
25{
26 uint64_t next = 0;
27 unsigned nextCarry = 0;
28
29#if defined(__clang__) || defined(__INTEL_COMPILER) || defined(__CUDACC__)
30#pragma unroll
31#elif defined(__GNUC__) && __GNUC__ >= 8
32// This pragma was introduced in GCC version 8.
33#pragma GCC unroll 18
34#endif
35 for (int i = 0; i < 18; i++) {
36 uint64_t current = next;
37 unsigned carry = nextCarry;
38
39 next = 0;
40 nextCarry = 0;
41
42#if defined(__clang__) || defined(__INTEL_COMPILER) || defined(__CUDACC__)
43#pragma unroll
44#elif defined(__GNUC__) && __GNUC__ >= 8
45// This pragma was introduced in GCC version 8.
46#pragma GCC unroll 9
47#endif
48 for (int j = 0; j < 9; j++) {
49 int k = i - j;
50 if (k < 0 || k >= 9)
51 continue;
52
53 uint64_t fac1 = in1[j];
54 uint64_t fac2 = in2[k];
55#if defined(__SIZEOF_INT128__) && !defined(ROOT_NO_INT128)
56 unsigned __int128 prod = fac1;
57 prod = prod * fac2;
58
59 uint64_t upper = prod >> 64;
60 uint64_t lower = static_cast<uint64_t>(prod);
61#else
62 uint64_t upper1 = fac1 >> 32;
63 uint64_t lower1 = static_cast<uint32_t>(fac1);
64
65 uint64_t upper2 = fac2 >> 32;
66 uint64_t lower2 = static_cast<uint32_t>(fac2);
67
68 // Multiply 32-bit parts, each product has a maximum value of
69 // (2 ** 32 - 1) ** 2 = 2 ** 64 - 2 * 2 ** 32 + 1.
70 uint64_t upper = upper1 * upper2;
71 uint64_t middle1 = upper1 * lower2;
72 uint64_t middle2 = lower1 * upper2;
73 uint64_t lower = lower1 * lower2;
74
75 // When adding the two products, the maximum value for middle is
76 // 2 * 2 ** 64 - 4 * 2 ** 32 + 2, which exceeds a uint64_t.
77 unsigned overflow;
78 uint64_t middle = add_overflow(middle1, middle2, overflow);
79 // Handling the overflow by a multiplication with 0 or 1 is cheaper
80 // than branching with an if statement, which the compiler does not
81 // optimize to this equivalent code. Note that we could do entirely
82 // without this overflow handling when summing up the intermediate
83 // products differently as described in the following SO answer:
84 // https://stackoverflow.com/a/51587262
85 // However, this approach takes at least the same amount of thinking
86 // why a) the code gives the same results without b) overflowing due
87 // to the mixture of 32 bit arithmetic. Moreover, my tests show that
88 // the scheme implemented here is actually slightly more performant.
89 uint64_t overflow_add = overflow * (uint64_t(1) << 32);
90 // This addition can never overflow because the maximum value of upper
91 // is 2 ** 64 - 2 * 2 ** 32 + 1 (see above). When now adding another
92 // 2 ** 32, the result is 2 ** 64 - 2 ** 32 + 1 and still smaller than
93 // the maximum 2 ** 64 - 1 that can be stored in a uint64_t.
94 upper += overflow_add;
95
96 uint64_t middle_upper = middle >> 32;
97 uint64_t middle_lower = middle << 32;
98
99 lower = add_overflow(lower, middle_lower, overflow);
100 upper += overflow;
101
102 // This still can't overflow since the maximum of middle_upper is
103 // - 2 ** 32 - 4 if there was an overflow for middle above, bringing
104 // the maximum value of upper to 2 ** 64 - 2.
105 // - otherwise upper still has the initial maximum value given above
106 // and the addition of a value smaller than 2 ** 32 brings it to
107 // a maximum value of 2 ** 64 - 2 ** 32 + 2.
108 // (Both cases include the increment to handle the overflow in lower.)
109 //
110 // All the reasoning makes perfect sense given that the product of two
111 // 64 bit numbers is smaller than or equal to
112 // (2 ** 64 - 1) ** 2 = 2 ** 128 - 2 * 2 ** 64 + 1
113 // with the upper bits matching the 2 ** 64 - 2 of the first case.
114 upper += middle_upper;
115#endif
116
117 // Add to current, remember carry.
118 current = add_carry(current, lower, carry);
119
120 // Add to next, remember nextCarry.
121 next = add_carry(next, upper, nextCarry);
122 }
123
124 next = add_carry(next, carry, nextCarry);
125
126 out[i] = current;
127 }
128}
129
130/// Compute a value congruent to mul modulo m less than 2 ** 576
131///
132/// \param[in] mul product from multiply9x9 with 18 numbers of 64 bits each
133/// \param[out] out result with 9 numbers of 64 bits each
134///
135/// \f$m = 2^{576} - 2^{240} + 1 \f$
136///
137/// The result in out is guaranteed to be smaller than the modulus.
138static void mod_m(const uint64_t *mul, uint64_t *out)
139{
140 uint64_t r[9];
141 // Assign r = t0
142 for (int i = 0; i < 9; i++) {
143 r[i] = mul[i];
144 }
145
146 int64_t c = compute_r(mul + 9, r);
147
148 // To update r = r - c * m, it suffices to know c * (-2 ** 240 + 1)
149 // because the 2 ** 576 will cancel out. Also note that c may be zero, but
150 // the operation is still performed to avoid branching.
151
152 // c * (-2 ** 240 + 1) in 576 bits looks as follows, depending on c:
153 // - if c = 0, the number is zero.
154 // - if c = 1: bits 576 to 240 are set,
155 // bits 239 to 1 are zero, and
156 // the last one is set
157 // - if c = -1, which corresponds to all bits set (signed int64_t):
158 // bits 576 to 240 are zero and the rest is set.
159 // Note that all bits except the last are exactly complimentary (unless c = 0)
160 // and the last byte is conveniently represented by c already.
161 // Now construct the three bit patterns from c, their names correspond to the
162 // assembly implementation by Alexei Sibidanov.
163
164 // c = 0 -> t0 = 0; c = 1 -> t0 = 0; c = -1 -> all bits set (sign extension)
165 // (The assembly implementation shifts by 63, which gives the same result.)
166 int64_t t0 = c >> 1;
167
168 // Left shifting negative values is undefined behavior until C++20, cast to
169 // unsigned.
170 uint64_t c_unsigned = static_cast<uint64_t>(c);
171
172 // c = 0 -> t2 = 0; c = 1 -> upper 16 bits set; c = -1 -> lower 48 bits set
173 int64_t t2 = t0 - (c_unsigned << 48);
174
175 // c = 0 -> t1 = 0; c = 1 -> all bits set (sign extension); c = -1 -> t1 = 0
176 // (The assembly implementation shifts by 63, which gives the same result.)
177 int64_t t1 = t2 >> 48;
178
179 unsigned carry = 0;
180 {
181 uint64_t r_0 = r[0];
182
183 uint64_t out_0 = sub_carry(r_0, c, carry);
184 out[0] = out_0;
185 }
186 for (int i = 1; i < 3; i++) {
187 uint64_t r_i = r[i];
188 r_i = sub_overflow(r_i, carry, carry);
189
190 uint64_t out_i = sub_carry(r_i, t0, carry);
191 out[i] = out_i;
192 }
193 {
194 uint64_t r_3 = r[3];
195 r_3 = sub_overflow(r_3, carry, carry);
196
197 uint64_t out_3 = sub_carry(r_3, t2, carry);
198 out[3] = out_3;
199 }
200 for (int i = 4; i < 9; i++) {
201 uint64_t r_i = r[i];
202 r_i = sub_overflow(r_i, carry, carry);
203
204 uint64_t out_i = sub_carry(r_i, t1, carry);
205 out[i] = out_i;
206 }
207}
208
209/// Combine multiply9x9 and mod_m with internal temporary storage
210///
211/// \param[in] in1 first factor with 9 numbers of 64 bits each
212/// \param[inout] inout second factor and also the output of the same size
213///
214/// The result in inout is guaranteed to be smaller than the modulus.
215static void mulmod(const uint64_t *in1, uint64_t *inout)
216{
217 uint64_t mul[2 * 9] = {0};
218 multiply9x9(in1, inout, mul);
219 mod_m(mul, inout);
220}
221
222/// Compute base to the n modulo m
223///
224/// \param[in] base with 9 numbers of 64 bits each
225/// \param[out] res output with 9 numbers of 64 bits each
226/// \param[in] n exponent
227///
228/// The arguments base and res may point to the same location.
229static void powermod(const uint64_t *base, uint64_t *res, uint64_t n)
230{
231 uint64_t fac[9] = {0};
232 fac[0] = base[0];
233 res[0] = 1;
234 for (int i = 1; i < 9; i++) {
235 fac[i] = base[i];
236 res[i] = 0;
237 }
238
239 uint64_t mul[18] = {0};
240 while (n) {
241 if (n & 1) {
242 multiply9x9(res, fac, mul);
243 mod_m(mul, res);
244 }
245 n >>= 1;
246 if (!n)
247 break;
248 multiply9x9(fac, fac, mul);
249 mod_m(mul, fac);
250 }
251}
252
253#endif
#define c(i)
Definition: RSha256.hxx:101
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
static uint64_t sub_overflow(uint64_t a, uint64_t b, unsigned &overflow)
Compute a - b and set overflow accordingly.
Definition: helpers.h:37
static uint64_t sub_carry(uint64_t a, uint64_t b, unsigned &carry)
Compute a - b and increment carry if there was an overflow.
Definition: helpers.h:45
static uint64_t add_carry(uint64_t a, uint64_t b, unsigned &carry)
Compute a + b and increment carry if there was an overflow.
Definition: helpers.h:26
static int64_t compute_r(const uint64_t *upper, uint64_t *r)
Update r = r - (t1 + t2) + (t3 + t2) * b ** 10.
Definition: helpers.h:62
static uint64_t add_overflow(uint64_t a, uint64_t b, unsigned &overflow)
Compute a + b and set overflow accordingly.
Definition: helpers.h:18
const Int_t n
Definition: legend1.C:16
static void mulmod(const uint64_t *in1, uint64_t *inout)
Combine multiply9x9 and mod_m with internal temporary storage.
Definition: mulmod.h:215
static void powermod(const uint64_t *base, uint64_t *res, uint64_t n)
Compute base to the n modulo m.
Definition: mulmod.h:229
static void mod_m(const uint64_t *mul, uint64_t *out)
Compute a value congruent to mul modulo m less than 2 ** 576.
Definition: mulmod.h:138
static void multiply9x9(const uint64_t *in1, const uint64_t *in2, uint64_t *out)
Multiply two 576 bit numbers, stored as 9 numbers of 64 bits each.
Definition: mulmod.h:24
auto * t1
Definition: textangle.C:20