b444. 期望試驗 快速冪次

contents

  1. 1. Problem
    1. 1.1. 背景
    2. 1.2. 問題描述
  2. 2. Sample Input
  3. 3. Sample Output
  4. 4. Solution
    1. 4.1. R-to-L
    2. 4.2. L-to-R-2bits
    3. 4.3. L-to-R-sliding
    4. 4.4. L-to-R-sliding-cheat

Problem

背景

曾經某 M 被期望值坑,就只是在計算 $x^y \mod z$ 時偷偷替換成 $x^{y-1} \times x \mod z$,結果得到 Time Limit Exceeded。

根據分析 $y = 16$ 時,用二進制表示為$(10000)_{2}$,若變成 $y = 15$,就會變成$(01111)_{2}$,通常快速求冪的乘法次數與二進制的 1 個數成正比,所以速度就慢非常多。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpow(UINT64 x, UINT64 y, UINT64 mod) {
UINT64 ret = 1;
while (y) {
if (y&1)
ret = mul(ret, x, mod);
y >>= 1, x = mul(x, x, mod);
}
return ret;
}

問題描述

讓我們來一場 $x^y \mod z$ 的期望值試驗吧,基礎目標是減少乘法次數。

其中 $1 \le x, z \le 10^{18}$, $0 \le y \le 2^{2^{20}}$,在這場試驗中展現你的優化吧。

Sample Input

1
2
3
4
5 01 7
5 10 7
5 0101 7
3 0110 7

Sample Output

1
2
3
4
5
4
3
1

Solution

詳細可以參考《資訊安全 - 近代加密 快速冪次計算》那一篇。

Algorithm Table Size #squaring Average #Multiplication
Right-To-Left 1: $x^{2^i}$ $n$ $n/2$
Left-To-Right 1: $x$ $n$ $n/2$
Left-To-Right(2-bits) 3: $x$, $x^2$, $x^3$ $n$ $3n/8$
Left-To-Right(sliding) 2: $x$, $x^3$ $n$ $n/3$

減少乘法次數,但以上期望乘法次數是跟 1 的個數有關,雖然最好是從 $n/2$ 降到 $n/3$,並不表示速度會真的快上 $1.5$ 倍左右,畢竟還有所謂的基礎乘法次數需求,根據實驗下來大約能快個 10% 到 20% 之間,加上 -Ofast 編譯此時的差異又會再少一點,看起來實作方法影響很嚴重。

例如在不加編譯優化參數下

1
2
3
4
if a[i] == 0 && a[i+1] == 0
else if a[i] == 0 && a[i+1] == 1
else if a[i] == 1 && a[i+1] == 0
else

上述做法會比下述來得快上許多

1
2
3
4
5
6
if a[i] == 0
if a[i+1] == 0
else
else
if a[i+1] == 0
else

最後,產出一個 cheat 版本,使用 L-to-R-2bits 的概念下去擴充,使用 loop unrolling 進行加速,由於會發生不被整除的問題,小測資就靠 L-to-R-sliding 的方案去解決。

在 zerojudge 主機上平台上,隨機測資下的運作情況如下:

Algorithm Time
Right-To-Left 5.4s
Left-To-Right(2-bits) 4.9s
Left-To-Right(sliding) 4.8s
Left-To-Right-sliding-cheat 4.2s

R-to-L

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowR2L(UINT64 x, char y[], UINT64 z) {
int n;
for (n = 1; y[n]; n <<= 1);
UINT64 ret = 1;
for (int i = n-1; i >= 0; i--) {
if (y[i] == '1')
ret = mul(ret, x, z);
x = mul(x, x, z);
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowR2L(x, y, z));
}
return 0;
}

L-to-R-2bits

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowL2R2(UINT64 x, char y[], UINT64 z) {
UINT64 x2 = mul(x, x, z);
UINT64 x3 = mul(x2, x, z);
UINT64 ret = 1;
for (int i = 0; y[i]; i += 2) {
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
if (y[i] == '1' && y[i+1] == '1') {
ret = mul(ret, x3, z);
} else if (y[i] == '1' && y[i+1] == '0') {
ret = mul(ret, x2, z);
} else if (y[i+1] == '1') {
ret = mul(ret, x, z);
}
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowL2R2(x, y, z));
}
return 0;
}

L-to-R-sliding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowL2RS(UINT64 x, char y[], UINT64 z) {
UINT64 x3 = mul(mul(x, x, z), x, z);
UINT64 ret = 1;
for (int i = 0; y[i]; ) {
ret = mul(ret, ret, z);
if (y[i] == '1' && y[i+1] == '1') {
ret = mul(ret, ret, z);
ret = mul(ret, x3, z);
i += 2;
} else if (y[i] == '1') {
ret = mul(ret, x, z);
i ++;
} else if (y[i+1] == '0') {
ret = mul(ret, ret, z);
i += 2;
} else {
i++;
}
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowL2RS(x, y, z));
}
return 0;
}

L-to-R-sliding-cheat

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowL2RS(UINT64 x, char y[], UINT64 z) {
UINT64 x3 = mul(mul(x, x, z), x, z);
UINT64 ret = 1;
for (int i = 0; y[i]; ) {
ret = mul(ret, ret, z);
if (y[i] == '1' && y[i+1] == '1') {
ret = mul(ret, ret, z);
ret = mul(ret, x3, z);
i += 2;
} else if (y[i] == '1') {
ret = mul(ret, x, z);
i ++;
} else if (y[i+1] == '0') {
ret = mul(ret, ret, z);
i += 2;
} else {
i++;
}
}
return ret;
}
#define PREPROC 8
UINT64 mpowCHEAT(UINT64 x, char y[], UINT64 z) {
int n;
for (n = 1; y[n]; n <<= 1);
if (n < 1<<PREPROC)
return mpowL2RS(x, y, z);
UINT64 X[1<<PREPROC] = {1};
for (int i = 1; i < (1<<PREPROC); i++)
X[i] = mul(X[i-1], x, z);
UINT64 ret = 1;
for (int i = 0, v; y[i]; i += PREPROC) {
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
v = (y[i]-'0')<<7|(y[i+1]-'0')<<6|(y[i+2]-'0')<<5|(y[i+3]-'0')<<4|(y[i+4]-'0')<<3|(y[i+5]-'0')<<2|(y[i+6]-'0')<<1|(y[i+7]-'0');
ret = mul(ret, X[v], z);
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowCHEAT(x, y, z));
}
return 0;
}