關於高效大數除法的那些事

contents

  1. 1. 前情提要
  2. 2. 大數除法
    1. 2.1. 牛頓迭代法
    2. 2.2. 倒數計算
    3. 2.3. 運行範例
    4. 2.4. 實作細節
      1. 2.4.1. 精準度
      2. 2.4.2. 牛頓迭代
      3. 2.4.3. 加速優化
      4. 2.4.4. 誤差修正
  3. 3. 參考題目/資料
  4. 4. b960 的亂碼

前情提要

從國小開始學起加減乘除,腦海裡計算加法比減法簡單,減法可以換成加法,乘法則可以換成數個加法。即使到了中國最強的利器——珠心算,我們學習這古老的快速運算時,也都採取這類方法,而除法最為複雜,需要去估算商,嘗試去扣完,再做細微調整。然而,當整數位數為 $N$ 時,加減法效能為 $N$ 基礎運算,而乘除法為 $N^2$ 次基礎運算,一次基礎運算通常定義在查表,例如背誦的九九乘法表,使得我們在借位和乘積運算達到非常快速。

這些方法在計算機發展後,硬體實作大致使用這些算法,直到了快速傅立葉 (FFT) 算法能在 $O(N \log N)$ 時間內完成旋積 (卷積) 計算,順利地解決了多項式乘法的問題,使得大數乘法能在 $O(N \log N)$ 時間內完成。

一開始我們知道乘法和除法都可以在 $O(N^2)$ 時間內完成,有了 FFT 之後,除法是不是也能跟乘法一樣在 $O(N \log N)$ 內完成?

大數除法

就目前研究來看,除法可以轉換成乘法運算,在數論的整數域中,若存在乘法反元素,除一個數相當於成一個數,而我們發展的計算機,有效運算為 32/64-bit,其超過的部分視為溢位 (overflow),溢位則可視為模數。因此,早期 CPU 設計中,除法需要的運行次數比乘法多一些,編譯器會將除法轉換乘法運算 (企圖找到反元素),來加速運算。現在,由於 intel 的黑魔法,導致幾乎一模一樣快而沒人注意到這些瑣事。

回過頭來,我們介紹一個藉由 FFT 達到的除法運算 $O(N \log N \log N)$ 的算法。而那多的 $O(\log N)$ 來自於牛頓迭代法。目標算出 $C = \lfloor A/B \rfloor$,轉換成 $C = \lfloor A \ast (1/B) \rfloor$,快速算出 $1/B$ 的值,採用牛頓迭代法。

額外要求整數長度 $n = \text{length}(A)$$m = \text{length}(B)$,當計算 $1/B$ 時,要求精準度到小數點後 $n$ 位才行。

牛頓迭代法

目標求 $f(x) = 0$ 的一組解。

$$\begin{align} x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)} \end{align}$$

不斷地迭代逼近找解。若有多組解,則依賴初始值 $x_0$ 決定優先找到哪一組解。

這部分也應用在快速開根號——反平方根快速演算法 Fast inverse square root

倒數計算

藉由下述定義,找到 $x = 1/B$

$$\begin{align} f(x) = \frac{1}{x} - B, \; f'(x) = -\frac{1}{x^2} \end{align}$$

接著推導牛頓迭代法的公式

$$\begin{align} x_{n+1} &= x_n - \frac{f(x_n)}{f'(x_n)} \\ &= x_n - \frac{\frac{1}{x} - B}{-\frac{1}{x^2}} = (2 - x_n B) x_n \end{align}$$

運行範例

$A = 7123456, \; B = 123$,計算倒數 $1/B = 1/123$,由於 $A$$n=7$ 位數,小數點後精準 7 位,意即迭代比較小數點後 7 位不再變動時停止。

以下描述皆以 十進制 (在程式運行時我們可能會偏向萬進制,甚至億進制來充分利用 32-bit 暫存器)

  • 決定 $x_0$ 是很重要的一步,這嚴重影響著迭代次數。
  • $x_0$$B$ 的最高位數 1 進行大數除小數,得到 $x_0 = 0.01$
  • $x_1 = (2 - x_0 \ast B) \ast x_0 = 0.0077$
  • $x_2 = (2 - x_1 \ast B) \ast x_1 = 0.0081073$
  • $x_3 = (2 - x_2 \ast B) \ast x_2 = 0.00813001$
  • $x_4 = (2 - x_3 \ast B) \ast x_3 = 0.00813008$

接下來進行移位到整數部分,進行乘法後再移回去即可。

實作細節

精準度

使用浮點數實作的 FFT,需要小心精準度,網路上有些代碼利用合角公式取代建表。這類型的優化,在精準度不需要很高的時候提供較好的效能,卻無法提供精準的值,誤差修正成了一道難題。

牛頓迭代

由於有 FFT 造成的誤差,檢查收斂必須更加地保守,我們可能需要保留小數點下 $n+10$ 位,在收斂時檢查 $n+5$ 位確保收斂。通常收斂可以在 20 次左右完成,若發現迭代次數過多,有可能是第一個精準度造成的影響,盡可能使用內建函數得到的三角函數值。

加速優化

一般使用 IEEE-754 的浮點數格式,根據 FFT 的長度,要避開超過的震級,因此,在 Zerojudge b960 中,最多使用十萬進制進行加速。

誤差修正

儘管算出的商 $C=A \ast (1/B)$,得到的 $C$ 可能會差個 1,需要在後續乘法中檢驗。

參考題目/資料

b960 的亂碼

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
#include <bits/stdc++.h>
using namespace std;
template<typename T> class TOOL_FFT {
public:
struct complex {
T x, y;
complex(T x = 0, T y = 0):
x(x), y(y) {}
complex operator+(const complex &A) {
return complex(x+A.x,y+A.y);
}
complex operator-(const complex &A) {
return complex(x-A.x,y-A.y);
}
complex operator*(const complex &A) {
return complex(x*A.x-y*A.y,x*A.y+y*A.x);
}
};
T PI;
static const int MAXN = 1<<17;
complex p[2][MAXN];
int reversePos[MAXN];
TOOL_FFT() {
PI = acos(-1);
preprocessing();
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline uint32_t FastReverseBits(uint32_t a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void FFT(bool InverseTransform, complex In[], complex Out[], int n) {
// simultaneous data copy and bit-reversal ordering into outputs
int NumSamples = n;
int NumBits = NumberOfBitsNeeded(NumSamples);
for (int i = 0; i < NumSamples; ++i)
Out[reversePos[i]] = In[i];
// the FFT process
for (int i = 1; i <= NumBits; i++) {
int BlockSize = 1<<i, BlockEnd = BlockSize>>1, BlockCnt = NumSamples/BlockSize;
for (int j = 0; j < NumSamples; j += BlockSize) {
complex *t = p[InverseTransform];
int k = 0;
#define UNLOOP_SIZE 8
for (; k+UNLOOP_SIZE < BlockEnd; ) {
#define UNLOOP { \
complex a = (*t) * Out[k+j+BlockEnd]; \
Out[k+j+BlockEnd] = Out[k+j] - a; \
Out[k+j] = Out[k+j] + a;\
k++, t += BlockCnt;\
}
#define UNLOOP4 {UNLOOP UNLOOP UNLOOP UNLOOP;}
#define UNLOOP8 {UNLOOP4 UNLOOP4;}
UNLOOP8;
}
for (; k < BlockEnd;)
UNLOOP;
}
}
// normalize if inverse transform
if (InverseTransform) {
for (int i = 0; i < NumSamples; ++i) {
Out[i] = Out[i].x / NumSamples;
}
}
}
void convolution(T *a, T *b, int n, T *c) {
static complex s[MAXN], d1[MAXN], d2[MAXN], y[MAXN];
n = MAXN;
for (int i = 0; i < n; ++i)
s[i] = complex(a[i], 0);
FFT(false, s, d1, n);
s[0] = complex(b[0], 0);
for (int i = 1; i < n; ++i)
s[i] = complex(b[n - i], 0);
FFT(false, s, d2, n);
for (int i = 0; i < n; ++i)
y[i] = d1[i] * d2[i];
FFT(true, y, s, n);
for (int i = 0; i < n; ++i)
c[i] = s[i].x;
}
void preprocessing() {
int n = MAXN;
for (int i = 0; i < n; i ++) {
p[0][i] = complex(cos(2*i*PI / n), sin(2*i*PI / n));
p[1][i] = complex(p[0][i].x, -p[0][i].y);
}
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; i++)
reversePos[i] = FastReverseBits(i, NumBits);
}
};
TOOL_FFT<double> tool;
struct BigInt {
long long *v;
int size;
static const int DIGITS = 5;
static const int MAXN = 1<<17;
static int compare(const BigInt a, const BigInt b) {
for (int i = MAXN-1; i >= 0; i--) {
if (a.v[i] < b.v[i])
return -1;
if (a.v[i] > b.v[i])
return 1;
}
return 0;
}
void str2int(char *s, long long buf[]) {
int n = strlen(s);
size = (n+DIGITS-1) / DIGITS;
int cnt = n%DIGITS == 0 ? DIGITS : n%DIGITS;
int x = 0, pos = size-1;
v = buf;
for (int i = 0; i < n; i++) {
x = x*10 + s[i] - '0';
if (--cnt == 0) {
cnt = DIGITS;
v[pos--] = x, x = 0;
}
}
}
void println() {
printf("%lld", v[size-1]);
for (int pos = size-2; pos >= 0; pos--)
printf("%05lld", v[pos]);
puts("");
}
BigInt multiply(const BigInt &other, long long buf[]) const {
int m = MAXN;
static double na[MAXN], nb[MAXN];
static double tmp[MAXN];
memset(na+size, 0, sizeof(v[0])*(m-size));
for (int i = 0; i < size; i++)
na[i] = v[i];
memset(nb+1, 0, sizeof(v[0])*(m-other.size));
for (int i = 1, j = m-1; i < other.size; i++, j--)
nb[j] = other.v[i];
nb[0] = other.v[0];
tool.convolution(na, nb, m, tmp);
BigInt ret;
ret.size = m;
ret.v = buf;
for (int i = 0; i < m; i++)
buf[i] = (long long) (tmp[i] + 1.5e-1);
for (int i = 0; i < m; i++) {
if (buf[i] >= 100000)
buf[i+1] += buf[i]/100000, buf[i] %= 100000;
}
ret.reduce();
return ret;
}
void reduce() {
while (size > 1 && fabs(v[size-1]) < 5e-1)
size--;
}
BigInt divide(const BigInt &other, long long buf[]) const {
{
int cmp = compare(*this, other);
BigInt ret;
ret.size = MAXN-1, ret.v = buf;
if (cmp == 0) {
memset(buf, 0, sizeof(buf[0])*MAXN);
buf[0] = 1;
ret.reduce();
return ret;
} else if (cmp < 0) {
memset(buf, 0, sizeof(buf[0])*MAXN);
buf[0] = 0;
ret.reduce();
return ret;
}
}
// A / B = A * (1/B)
// x' = (2 - x * B) * x
int m = MAXN;
static double na[MAXN], nb[MAXN];
static double invB[MAXN], netB[MAXN], tmpB[MAXN];
static long long bufB[MAXN];
int PRECISION = size+10;
memset(nb+1, 0, sizeof(v[0])*(m-other.size));
for (int i = 1, j = m-1; i < other.size; i++, j--)
nb[j] = other.v[i];
nb[0] = other.v[0];
memset(invB, 0, sizeof(invB[0])*m);
{
long long t = 100000, a = other.v[other.size-1];
if (other.size-2 >= 0)
t = t*100000, a = a*100000+other.v[other.size-2];
for (int i = PRECISION-other.size; i >= 0; i--) {
invB[i] = t/a;
t = (t%a)*100000;
}
}
for (int it = 0; it < 100; it++) {
// netB = xi * B
tool.convolution(invB, nb, m, netB);
long long carry = 0;
for (int i = 0; i <= PRECISION*2; i++) {
bufB[i] = carry + (long long) (netB[i] + 1.5e-1);
if (bufB[i] >= 100000)
carry = bufB[i]/100000, bufB[i] %= 100000;
else
carry = 0;
bufB[i] = -bufB[i];
}
// tmpB = 2 - xi * B
bufB[PRECISION] += 2;
memset(tmpB, 0, sizeof(tmpB[0])*m);
for (int i = 0; i <= PRECISION*2; i++) {
if (bufB[i] < 0)
bufB[i] += 100000, bufB[i+1]--;
if (i != 0)
tmpB[m-i] = bufB[i];
else
tmpB[i] = bufB[i];
}
// netB = tmpB * invB
tool.convolution(invB, tmpB, m, netB);
{
long long carry = 0;
memset(bufB, 0, sizeof(bufB[0])*m);
for (int i = 0; i <= PRECISION*2; i++) {
bufB[i] = carry + (long long) (netB[i] + 1.5e-1);
if (bufB[i] >= 100000)
carry = bufB[i]/100000, bufB[i] %= 100000;
else
carry = 0;
}
}
{
int same = 1;
for (int i = PRECISION; same && i >= 5; i--)
same &= ((long long) (invB[i]) == bufB[i+PRECISION]);
if (same)
break;
}
memset(invB, 0, sizeof(invB[0])*m);
for (int i = 0; i+PRECISION <= PRECISION*2; i++)
invB[i] = bufB[i+PRECISION];
}
memset(na+1, 0, sizeof(v[0])*(m-size));
for (int i = 1, j = m-1; i < size; i++, j--)
na[j] = v[i];
na[0] = v[0];
tool.convolution(invB, na, m, netB);
BigInt ret;
ret.size = m-1;
ret.v = buf;
long long carry = 0;
for (int i = 0; i < m; i++) {
buf[i] = carry + (long long) (netB[i] + 1.5e-1);
if (buf[i] >= 100000)
carry = buf[i]/100000, buf[i] %= 100000;
else
carry = 0;
}
for (int i = 0; i+PRECISION < m; i++)
buf[i] = buf[i+PRECISION];
memset(buf+PRECISION, 0, sizeof(buf[0])*(m-PRECISION));
{
memset(na, 0, sizeof(na[0])*m);
for (int i = 1, j = m-1; i < m-1; i++, j--)
na[j] = buf[i];
na[0] = buf[0];
for (int i = 0; i < m; i++)
nb[i] = other.v[i];
tool.convolution(nb, na, m, netB);
long long carry = 0;
for (int i = 0; i < m; i++) {
bufB[i] = (carry + (long long) (netB[i] + 1e-1));
if (bufB[i] >= 100000)
carry = bufB[i]/100000, bufB[i] %= 100000;
else
carry = 0;
}
carry = 0;
for (int i = 0; i < m; i++) {
bufB[i] = v[i] - bufB[i] + carry;
if (bufB[i] < 0)
bufB[i] += 100000, carry = -1;
else
carry = 0;
}
BigInt R;
R.size = m-1, R.v = bufB;
if (compare(other, R) <= 0) {
buf[0]++;
for (int i = 0; buf[i] >= 100000; i++)
buf[i+1]++, buf[i] -= 100000;
}
}
ret.reduce();
return ret;
}
};
int main() {
static char sa[1<<20], sb[1<<20];
while (scanf("%s %s", sa, sb) == 2) {
static long long da[1<<19], db[1<<19], dc[1<<19];
BigInt a, b, c;
a.str2int(sa, da);
b.str2int(sb, db);
c = a.divide(b, dc);
c.println();
}
return 0;
}