b449. 加速策略 圈出角點

Problem

背景

影像處理中,給定一張圖,準確地找到點、線、邊、角都是相當困難的,由於圖片會受到干擾、顏色屬性的差異,使得擷取特徵相當困難。

問題描述

對於 $N \times M$ 的像素圖片,方便起見只由黑白影像構成,0 表示暗、1 表示亮,對於每一個像素位置判斷是否可能是角點。

在角點偵測的算法中,有一個由 Rosten and Drummond 提出的 FAST (Features from Accelerated Segment Test) 方法。概念由一個 $7 \times 7$ 的遮罩,待測點 $p$ 位於遮罩中心,由遮罩內圈上的 16 個像素的灰階判斷 $p$ 是否為角點。遮罩樣子如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
+--------------------+
| | |16| 1| 2| | |
+--------------------+
| |15| | | | 3| |
+--------------------+
|14| | | | | | 4|
+--------------------+
|13| | | p| | | 5|
+--------------------+
|12| | | | | | 6|
+--------------------+
| |11| | | | 7| |
+--------------------+
| | |10| 9| 8| | |
+---------------------

只要這個圈上出現連續大於等於 12 個相同的暗像素或者是亮像素,則 $p$ 就被視為一個角點。

不幸地,這會造成在一個角上出現很多角點,通常會根據掃描的順序找到角點,當找到一個角點後,會抑制鄰近區域不可以是角點。此題不考慮抑制情況,對於每一個角點必須在 16 個像素在圖片上才進行判斷,圖片邊界不進行偵測。

輸出一個 $N \times M$ 的矩陣,按照原圖片位置,若該點是角點則為 1,反之為 0。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
7 7
0011100
0100010
1000001
1000001
1000001
0100010
0011100
7 7
0011100
0100010
1000001
0000001
1000001
0100010
0010100

Sample Output

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Case #1:
0000000
0000000
0000000
0001000
0000000
0000000
0000000
Case #2:
0000000
0000000
0000000
0000000
0000000
0000000
0000000

Solution

這一題的技巧不是演算法,而是實作的加速細節。

同時,來實驗老師上課說的優化策略的用處,由於要連續 12 個,根據鴿籠原理的方式,挑選位置 1, 5, 9, 13 出來,若連續三個狀態相同再進行 $O(16)$ 的判斷。然而挑出這四個位置,可以加速 50%,相較於只有 $O(16)$ 的判斷效能,可以參考下方的樸素解。

樸素寫法並不是最快的,因為 branch 太多,導致速度至少為 800ms,去處理一張 $1920 \times 1080$ 的影像,更好的方案是使用 bitmask,預處理在 16bits 下,連續 9 個相同狀態的位元情況,搭配 loop unrolling 的方式去撰寫,直接 $O(16)$ 判斷,為了減少代碼量,採用巨集的前處理展開。請參考 bitmask 版本。速度來到 140ms,加速幾乎 8 倍。

單純的 bitmask 還不是最快,直接建表 $O(2^{16} \times 16)$ 得到 16bits 是否是角點,建表消耗時間,但單一判斷變成 $O(1)$。請參考 bitmask2 版本。速度來到 76ms,直接翻了快兩倍。

最終 bitmask3 版本 56ms,採用以下的方案:

  1. 減少型別轉換 movz 的出現,用補數來抽換判斷。
    意即 (a&mask) == 0 || (a&mask) == mask) 將只會有 (a&mask) == mask,需要 (a&mask) == 0 的判斷,則先 a = ~a 再進行 (a&mask) == mask
  2. 利用編譯的常數展開,減少二維陣列取址時的一次乘法。
    意即 g[x][y] 取址使用時,會動用到一次乘法和一次加法,對於每一個角點偵測,動用到 16 次的乘法運算。若矩陣大小事先已知,那麼對於某一行的角點,g[x] 可以用一次乘法計算,接著該行所有角點偵測,只會剩下 16 次的加法。
  3. 建表太慢,用 __builtin_popcount() 提供剪枝。

當然,以上變態至極的作法,倒不如樸素解直接開 g++ -O3 或者是 g++ -Ofast 來的省事,速度慢一點也是沒問題的對吧。

1
#pragma GCC optimize ("O3")

樸素解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <bits/stdc++.h>
using namespace std;
const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
char g[2048][2048], ret[2048][2048];
int FAST(int x, int y) {
char c[4] = {g[x+dx[0]][y+dy[0]], g[x+dx[4]][y+dy[4]],
g[x+dx[8]][y+dy[8]], g[x+dx[12]][y+dy[12]]};
if (c[0] == c[1] && c[1] == c[2] ||
c[1] == c[2] && c[2] == c[3] ||
c[2] == c[3] && c[3] == c[0] ||
c[3] == c[0] && c[0] == c[1]) {
int cc = -1, p = 1;
for (int it = 1, i = 1, j = 0; it < 16; it++, i++, i = i >= 16 ? 0 : i) {
if (g[x+dx[i]][y+dy[i]] == g[x+dx[j]][y+dy[j]])
j ++, j = j >= 16 ? 0 : j, p++;
else {
if (cc == -1)
cc = p;
j = i, p = 1;
}
if (p >= 12)
return 1;
}
if (g[x+dx[0]][y+dy[0]] == g[x+dx[15]][y+dy[15]] && p+cc >= 12)
return 1;
}
return 0;
}
int main() {
int N, M, cases = 0;
while (scanf("%d %d", &N, &M) == 2) {
while (getchar() != '\n');
for (int i = 0; i < N; i++)
fgets(g[i], 2000, stdin);
for (int i = 0; i < N; i++) {
for (int j = 0; j < M; j++) {
ret[i][j] = '0';
if (i-3 >= 0 && j-3 >= 0 && i+3 < N && j+3 < M)
ret[i][j] = FAST(i, j) + '0';
}
}
printf("Case #%d:\n", ++cases);
for (int i = 0; i < N; i++) {
ret[i][M] = '\0';
puts(ret[i]);
}
}
return 0;
}

bitmask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>
using namespace std;
const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
const int MAXH = 2048, MAXW = 2048;
const int MAXN = MAXH * MAXW;
char g[MAXH][MAXW], ret[MAXH*MAXW];
int n, m;
#define T(x, y, z) ((g[x+dx[z]][y+dy[z]])<<z)
#define UNLOOPX(i) (val&corner[i]) == 0 || (val&corner[i]) == corner[i] || \
(val&corner[i+1]) == 0 || (val&corner[i+1]) == corner[i+1] || \
(val&corner[i+2]) == 0 || (val&corner[i+2]) == corner[i+2] || \
(val&corner[i+3]) == 0 || (val&corner[i+3]) == corner[i+3]
#define UNLOOPXALL UNLOOPX(0) || UNLOOPX(4) || UNLOOPX(8) || UNLOOPX(12)
#define UNLOOPY(i) T(x, y, i) | T(x, y, i+1) | T(x, y, i+2) | T(x, y, i+3)
#define UNLOOPYALL UNLOOPY(0) | UNLOOPY(4) | UNLOOPY(8) | UNLOOPY(12)
typedef unsigned short int UINT16;
UINT16 corner[16] = {};
UINT16 rotate_left(UINT16 x, UINT16 n) {
return (x << n) | (x >> (16-n));
}
inline int FAST(int x, int y) {
UINT16 val = UNLOOPYALL;
return UNLOOPXALL;
}
int main() {
for (int i = 0, j = (1<<12)-1; i < 16; i++, j = rotate_left(j, 1))
corner[i] = j;
int cases = 0;
while (scanf("%d %d", &n, &m) == 2) {
while (getchar() != '\n');
for (int i = 0; i < n; i++) {
fgets(g[i], MAXW, stdin);
for (int j = 0; j < m; j++)
g[i][j] -= '0';
}
int bn = n-3, bm = m-3;
char *p = ret;
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
for (int i = 3; i < bn; i++) {
*p = '0', p++;
*p = '0', p++;
*p = '0', p++;
for (int j = 3; j < bm; j++)
*p = FAST(i, j) + '0', p++;
*p = '0', p++;
*p = '0', p++;
*p = '0', p++;
*p = '\n', p++;
}
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
*p = '\0', p++;
printf("Case #%d:\n", ++cases);
puts(ret);
}
return 0;
}

bitmask2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
using namespace std;
//const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
//const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
const int MAXH = 2048, MAXW = 2048;
const int MAXN = MAXH * MAXW;
char g[MAXH][MAXW], ret[MAXH*MAXW];
int n, m;
#define dx_00 -3
#define dx_01 -3
#define dx_02 -2
#define dx_03 -1
#define dx_40 0
#define dx_41 1
#define dx_42 2
#define dx_43 3
#define dx_80 3
#define dx_81 3
#define dx_82 2
#define dx_83 1
#define dx_120 0
#define dx_121 -1
#define dx_122 -2
#define dx_123 -3
#define dy_00 0
#define dy_01 1
#define dy_02 2
#define dy_03 3
#define dy_40 3
#define dy_41 3
#define dy_42 2
#define dy_43 1
#define dy_80 0
#define dy_81 -1
#define dy_82 -2
#define dy_83 -3
#define dy_120 -3
#define dy_121 -3
#define dy_122 -2
#define dy_123 -1
#define T(x, y, z, w) ((g[x + dx_##z##w][y + dy_##z##w])<<(z+w))
#define UNLOOPX(i) ((val&cor[i]) == cor[i] || (val&cor[i+1]) == cor[i+1] || \
(val&cor[i+2]) == cor[i+2] || (val&cor[i+3]) == cor[i+3])
#define UNLOOPXALL UNLOOPX(0) || UNLOOPX(4) || UNLOOPX(8) || UNLOOPX(12)
#define UNLOOPY(i) T(x, y, i, 0) | T(x, y, i, 1) | T(x, y, i, 2) | T(x, y, i, 3)
#define UNLOOPYALL UNLOOPY(0) | UNLOOPY(4) | UNLOOPY(8) | UNLOOPY(12)
typedef unsigned short int UINT16;
UINT16 cor[16] = {};
UINT16 rotate_left(UINT16 x, UINT16 n) {
return (x << n) | (x >> (16-n));
}
int f[1<<16];
int main() {
UINT16 val;
for (int i = 0, j = (1<<12)-1; i < 16; i++, j = rotate_left(j, 1))
cor[i] = j;
for (int i = 0; i < 1<<16; i++) {
val = i;
f[i] = UNLOOPXALL ? 1 : (val = ~val, UNLOOPXALL);
}
int cases = 0;
char c;
while (scanf("%d %d", &n, &m) == 2) {
while (getchar() != '\n');
for (int i = 0; i < n; i++) {
fgets(g[i], 2000, stdin);
for (int j = 0; j < m; j++)
g[i][j] -= '0';
}
int bn = n-3, bm = m-3;
char *p = ret;
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
for (int x = 3; x < bn; x++) {
*p = '0', p++, *p = '0', p++, *p = '0', p++;
#define UNLOOP { \
val = UNLOOPYALL; \
*p = f[val] | '0'; \
p++, y++; \
}
#define UNLOOP4 UNLOOP UNLOOP UNLOOP UNLOOP
#define UNLOOP8 {UNLOOP4 UNLOOP4}
int y = 3;
for (; y+8 < bm; )
UNLOOP8;
for (; y < bm; )
UNLOOP;
*p = '0', p++, *p = '0', p++, *p = '0', p++;
*p = '\n', p++;
}
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
*p = '\0', p++;
printf("Case #%d:\n", ++cases);
puts(ret);
}
return 0;
}

bitmask3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
using namespace std;
//const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
//const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
#define MAXH 2048
#define MAXW 2048
char g[MAXH][MAXW], ret[MAXH*MAXW];
char *ptr_g;
int n, m;
#define AB00 -3*MAXW
#define AB01 -3*MAXW+1
#define AB02 -2*MAXW+2
#define AB03 -MAXW+3
#define AB40 3
#define AB41 MAXW+3
#define AB42 2*MAXW+2
#define AB43 3*MAXW+1
#define AB80 3*MAXW
#define AB81 3*MAXW-1
#define AB82 2*MAXW-2
#define AB83 MAXW-3
#define AB120 -3
#define AB121 -MAXW-3
#define AB122 -2*MAXW-2
#define AB123 -3*MAXW-1
#define T(x, y, z, w) (*(ptr_g + AB##z##w)<<(z+w))
#define UNLOOPX(i) ((val&cor[i]) == cor[i] || (val&cor[i+1]) == cor[i+1] || \
(val&cor[i+2]) == cor[i+2] || (val&cor[i+3]) == cor[i+3])
#define UNLOOPXALL UNLOOPX(0) || UNLOOPX(4) || UNLOOPX(8) || UNLOOPX(12)
#define UNLOOPY(i) T(x, y, i, 0) | T(x, y, i, 1) | T(x, y, i, 2) | T(x, y, i, 3)
#define UNLOOPYALL UNLOOPY(0) | UNLOOPY(4) | UNLOOPY(8) | UNLOOPY(12)
typedef unsigned short int UINT16;
UINT16 cor[16] = {};
UINT16 rotate_left(UINT16 x, UINT16 n) {
return (x << n) | (x >> (16-n));
}
char f[1<<16];
int main() {
UINT16 val;
for (int i = 0, j = (1<<12)-1; i < 16; i++, j = rotate_left(j, 1))
cor[i] = j;
for (int i = 0, one; i < 1<<16; i++) {
val = i, one = __builtin_popcount(val);
if (one < 12 && one > 4)
f[i] = 0;
else
f[i] = UNLOOPXALL ? 1 : (val = ~val, UNLOOPXALL);
f[i] |= '0';
}
int cases = 0;
char c;
while (scanf("%d %d", &n, &m) == 2) {
while (getchar() != '\n');
for (int i = 0; i < n; i++) {
fgets(g[i], 2000, stdin);
for (int j = 0; j < m; j++)
g[i][j] -= '0';
}
int bn = n-3, bm = m-3;
char *p = ret;
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
for (int x = 3, y; x < bn; x++) {
*p = '0', p++, *p = '0', p++, *p = '0', p++;
ptr_g = g[x]+3;
#define UNLOOP { \
val = UNLOOPYALL; \
*p = f[val]; \
p++, y++, ptr_g++; \
}
#define UNLOOP4 UNLOOP UNLOOP UNLOOP UNLOOP
#define UNLOOP8 {UNLOOP4 UNLOOP4}
for (y = 3; y+8 < bm; )
UNLOOP8;
for (; y < bm; )
UNLOOP;
*p = '0', p++, *p = '0', p++, *p = '0', p++;
*p = '\n', p++;
}
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
*p = '\0', p++;
printf("Case #%d:\n", ++cases);
puts(ret);
}
return 0;
}
Read More +

b446. 搜索美學 史蒂芙的煩惱

Problem

背景

動畫 遊戲人生《No Game No Life》中,史蒂芙 (Stephanie Dola) 常常被欺負,儘管她以學院第一畢業,對於遊戲一竅不通的她在這個世界常常被欺負。現在就交給你來幫幫她。

問題描述

兩個人輪流在一個大棋盤上下棋,每一步棋的得分根據這一步棋與最鄰近的敵方棋子的曼哈頓距離。

對於兩個點 $p, q$ 座標 $(p_x, p_y), (q_x, q_y)$,曼哈頓距離 (Manhattan distance) 為 $|p_x - q_x| + |p_y - q_y|$

Sample Input

1
2
3
4
5
6
7
3
1 1
5 5
4 4
3 2
2 4
2 3

Sample Output

1
2
3
4
5
8
2
3
3
1

Solution

把鄰近搜索問題做個總結,普遍處理的是靜態資料跟單一詢問,而在最近餐館那一題已經用 KD-tree 處理過 KNN 問題。這一題是採用動態插入和詢問以及數學性質較強的曼哈頓距離,離線處理也是個選擇。

此問題限制在 $n = 50000$ 的情況下,進行測試討論,除了分桶、方格法外,探討三種思路:

  • Dynamic KD-tree 利用替罪羊樹的概念完成,看著卦長的代碼以及卡車口述概要,終於敲敲打打拼湊起來,掛上啟發式的搭配具有不錯的成效。空間複雜度 $O(n)$,插入複雜度 $O(\log^2 n)$,查詢 $O(\log n)$ (據說是在曼哈頓距離下的緣故),速度是暴力法 $O(n^2)$ 二十倍左右。

  • Segment tree + 平衡樹,空間複雜度 $O(n \log n)$,時間複雜度 $O(\log^3 n)$,使用座標轉換將菱形轉換成正方形,套上二分邊長去查找區域內部是否有點。由於 $n$ 的緣故,速度比 Dynamic KD-tree 慢上許多,若用暴力法 $O(n^2)$ 只快兩倍之多。實作測試提供者 liouzhou_101。

  • 離線處理 CDQ 分治,空間複雜度 $O(n)$,總時間複雜度 $O(n \log^2 n)$,採用思路為曼哈頓距離切割成四個象限進行極值查找。比暴力法快十倍所右。

前兩個作法比較裸,在此特別補充 CDQ 分治,曼哈頓距離可以考慮成四個象限,詢問 $(x, y)$ 的最鄰近點,首先考慮左下角 $(x', y')$,亦即 $x' \le x, \; y' \le y$,則曼哈頓距離 $dist = (x - x') + (y - y') = (x + y) - (x' + y')$,明顯地求最近距離要讓 $x' + y'$ 最大化。同理其他象限。

為了解決這詢問,套用 CDQ 分治,按照 $x$ 座標排序,接著二分操作順序,切割操作 $[l, mid], [mid+1, r]$,在左右兩塊仍然按照 $x$ 排序。單獨看 $[mid+1, r]$ 的操作會受 $[l, mid]$ 和自己本身影響,對於前者而言,採用歸併排序那樣,按照 $x$ 座標慢慢合併 (概念上),合併過程套用 Binary indexed tree 進行極值查找。對於後者,就進行遞迴求解,明顯地 $[l, mid]$ 只會受 $[l, mid]$ 影響。

CDQ 分治的概要,按照其中一個關鍵排序,接著二分操作順序進行分置處理。國外是有論文在描述這個 Online to Offline 的算法,CDQ 命名就是人名,會給國外看笑話吧。

1
2
3
4
5
6
sort(key)
solve(l, r)
solve(l, mid)
process([l, mid], [mid+1, r])
solve(mid+1, r)

備註「欸欸,加上悔棋的話,是不是持久化 kd-tree」

實作探討

關於 kd-tree 實作細節探討,與通常會犯的錯誤,關係到速度有常數差異。

closest() 中,常犯的錯誤是 探索順序 ,盡可能先靠近,啟發式才能更加快速,別像我打出錯誤的搜索順序如下:

1
2
3
4
5
6
7
8
9
10
11
if (x.d[k] <= u->pid.d[k]) {
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = old;
} else {
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = old;
closest(u->rson, (k+1)%kD, x, h, mndist);
}

實作的順序應該如下,別總是先探訪左子樹、在去探訪右子樹,kd-tree 必須注意順序。

1
2
3
4
5
6
7
8
9
10
11
if (x.d[k] <= u->pid.d[k]) {
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = old;
} else {
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = old;
}

假設資料大小 sizeof(dim_element) 很大,通常會利用指針陣列來進行排序,這樣可以降低搬運大型資料複製時間,但奇怪的是由於指針陣列佔有一定空間,索引資料時又會佔據一段空間,估計是快取方面出了點問題 (或者是我寫不好),導致直接搬運資料是來得比較快速,這個修改在 b348. 最近餐館 也有進行測試,速度有提升。

接著可以藉由函數參數少量,來拉快程式在堆疊參數所需要的時間,在節點內部宣告採用維度 d

1
2
3
4
5
struct Node {
Node *lson, *rson;
Point pid;
int size, d;
};

這個修改造成詢問時,不僅僅在走訪傳遞參數少了一個,還少 k+1 的計算。測試結果中,光靠這一點速度沒有明顯提升。kd tree 還有一個靠臉吃飯的邊界分割,要是相同時分左分右,這一點是最痛苦的,在此就不去討論,當然可以利用隨機擾動來解決這問題。

至於要使用 sort() 進行 $O(n \log n)$、還是使用 nth_element()$O(n)$ 找到中位數,根據兩題的測試,由於 $n$ 都不大,照理來講 nth_element() 快於 sort(),但根據實際測試於 liouzhou_101 的代碼,sort() 的速度會比較快,其一是運氣、其二是未知情況。就從以下代碼中,差異並不明顯。

Dynamic Kd tree

沒有提供垃圾回收,靠內存持運作,要是 RE 就放大一點。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 131072;
const int MAXM = 50005;
const int MAXD = 2;
const int INF = INT_MAX;
const double ALPHA = 0.75;
const double LOG_ALPHA = log2(1.0 / ALPHA);
class KD_TREE {
public:
struct Point {
static int kD;
int d[MAXD], pid;
int dist(Point &x) {
int ret = 0;
for (int i = 0; i < kD; i++)
ret += abs(d[i] - x.d[i]);
return ret;
}
void read(int id = 0) {
for (int i = 0; i < kD; i++)
scanf("%d", &d[i]);
pid = id;
}
static int sortIdx;
bool operator<(const Point &x) const {
return d[sortIdx] < x.d[sortIdx];
}
};
struct Node {
Node *lson, *rson;
Point pid;
int size;
Node() {
lson = rson = NULL;
size = 1;
}
void update() {
size = 1;
if (lson) size += lson->size;
if (rson) size += rson->size;
}
} nodes[MAXN];
Node *root;
Point A[MAXM];
int bufsize, size, kD;
void init(int kd) {
size = bufsize = 0;
root = NULL;
Point::sortIdx = 0;
Point::kD = kD = kd;
}
void insert(Point x) {
insert(root, 0, x, log2int(size) / LOG_ALPHA);
}
int closest(Point x) {
int mndist = INF, h[MAXD] = {};
closest(root, 0, x, h, mndist);
return mndist;
}
private:
int log2int(int x){
return __builtin_clz((int)1)-__builtin_clz(x);
}
inline int isbad(Node *u) {
if (u->lson && u->lson->size > u->size * ALPHA)
return 1;
if (u->rson && u->rson->size > u->size * ALPHA)
return 1;
return 0;
}
Node* newNode() {
Node *ret = &nodes[bufsize++];
*ret = Node();
return ret;
}
Node* build(int k, int l, int r) {
if (l > r) return NULL;
if (k == kD) k = 0;
Node *ret = newNode();
int mid = (l + r)>>1;
Point::sortIdx = k;
sort(A+l, A+r+1);
ret->pid = A[mid];
ret->lson = build(k+1, l, mid-1);
ret->rson = build(k+1, mid+1, r);
ret->update();
return ret;
}
void flatten(Node *u, Point* &buf) {
if (u == NULL) return ;
flatten(u->lson, buf);
*buf = u->pid, buf++;
flatten(u->rson, buf);
}
bool insert(Node* &u, int k, Point &x, int dep) {
if (u == NULL) {
u = newNode(), u->pid = x;
return dep <= 0;
}
u->size++;
int t = 0;
if (x.d[k] <= u->pid.d[k])
t = insert(u->lson, (k+1)%kD, x, dep-1);
else
t = insert(u->rson, (k+1)%kD, x, dep-1);
if (t && !isbad(u))
return 1;
if (t) {
Point *ptr = &A[0];
flatten(u, ptr);
u = build(k, 0, u->size-1);
}
return 0;
}
int heuristic(int h[]) {
int ret = 0;
for (int i = 0; i < kD; i++)
ret += h[i];
return ret;
}
void closest(Node *u, int k, Point &x, int h[], int &mndist) {
if (u == NULL || heuristic(h) >= mndist)
return ;
int dist = u->pid.dist(x), old;
mndist = min(mndist, dist), old = h[k];
if (x.d[k] <= u->pid.d[k]) {
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = old;
} else {
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = old;
}
}
} A, B;
int KD_TREE::Point::sortIdx = 0, KD_TREE::Point::kD = 2;
int main() {
int N;
KD_TREE::Point pt;
while (scanf("%d", &N) == 1) {
A.init(2), B.init(2);
for (int i = 0; i < N; i++) {
pt.read(i);
if (i) printf("%d\n", B.closest(pt));
A.insert(pt);
pt.read(i);
printf("%d\n", A.closest(pt));
B.insert(pt);
}
}
return 0;
}

CDQ 分治

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include <bits/stdc++.h>
using namespace std;
const int MAXQ = 131072;
const int MAXN = 40000;
const int INF = INT_MAX;
class Offline {
public:
struct Point {
int x, y;
Point(int a = 0, int b = 0):
x(a), y(b) {}
bool operator<(const Point &a) const {
return x < a.x || (x == a.x && y < a.y);
}
void read() {
scanf("%d %d", &x, &y);
x++, y++;
}
};
struct Event {
Point p;
int qtype, qid;
Event(int a = 0, int b = 0, Point c = Point()):
qtype(a), qid(b), p(c) {}
bool operator<(const Event &e) const {
if (p.x != e.p.x) return p.x < e.p.x;
return qid < e.qid;
}
};
vector<Event> event;
int ret[MAXQ], N;
void init(int n) {
event.clear();
N = n;
}
void addEvent(int qtype, int qid, Point x) {
event.push_back(Event(qtype, qid, x));
}
void run() {
for (int i = 0; i < event.size(); i++)
ret[i] = 0x3f3f3f3f;
cases = 0;
for (int i = 0; i <= N; i++)
used[i] = 0;
sort(event.begin(), event.end());
CDQ(0, event.size()-1);
}
private:
Event ebuf[MAXQ];
int BIT[MAXN], used[MAXN];
int cases = 0;
void modify(int x, int val, int dir) {
for (; x && x <= N; x += (x&(-x)) * dir) {
if (used[x] != cases)
BIT[x] = -0x3f3f3f3f, used[x] = cases;
BIT[x] = max(BIT[x], val);
}
}
int query(int x, int dir) {
int ret = -0x3f3f3f3f;
for (; x && x <= N; x += (x&(-x)) * dir) {
if (used[x] == cases)
ret = max(ret, BIT[x]);
}
return ret;
}
void merge(int l, int mid, int r) {
cases++;
for (int i = mid+1, j = l; i <= r; i++) {
if (event[i].qtype == 0) {
for (; j <= mid && event[j].p.x <= event[i].p.x; j++) {
if (event[j].qtype == 1)
modify(event[j].p.y, event[j].p.x+event[j].p.y, 1);
}
ret[event[i].qid] = min(ret[event[i].qid], event[i].p.x+event[i].p.y-query(event[i].p.y, -1));
}
}
cases++;
for (int i = mid+1, j = l; i <= r; i++) {
if (event[i].qtype == 0) {
for (; j <= mid && event[j].p.x <= event[i].p.x; j++) {
if (event[j].qtype == 1)
modify(event[j].p.y, event[j].p.x-event[j].p.y, -1);
}
ret[event[i].qid] = min(ret[event[i].qid], event[i].p.x-event[i].p.y-query(event[i].p.y, 1));
}
}
cases++;
for (int i = r, j = mid; i > mid; i--) {
if (event[i].qtype == 0) {
for (; j >= l && event[j].p.x >= event[i].p.x; j--) {
if (event[j].qtype == 1)
modify(event[j].p.y, event[j].p.y-event[j].p.x, 1);
}
ret[event[i].qid] = min(ret[event[i].qid], event[i].p.y-event[i].p.x-query(event[i].p.y, -1));
}
}
cases++;
for (int i = r, j = mid; i > mid; i--) {
if (event[i].qtype == 0) {
for (; j >= l && event[j].p.x >= event[i].p.x; j--) {
if (event[j].qtype == 1)
modify(event[j].p.y, -event[j].p.x-event[j].p.y, -1);
}
ret[event[i].qid] = min(ret[event[i].qid], -event[i].p.x-event[i].p.y-query(event[i].p.y, 1));
}
}
}
void CDQ(int l, int r) {
if (l == r)
return ;
int mid = (l + r)/2, lidx, ridx;
lidx = l, ridx = mid+1;
for (int i = l; i <= r; i++) {
if (event[i].qid <= mid)
ebuf[lidx++] = event[i];
else
ebuf[ridx++] = event[i];
}
for (int i = l; i <= r; i++)
event[i] = ebuf[i];
CDQ(l, mid);
merge(l, mid, r);
CDQ(mid+1, r);
lidx = l, ridx = mid+1;
for (int i = l; i <= r; i++) {
if ((lidx <= mid && event[lidx] < event[ridx]) || ridx > r)
ebuf[i] = event[lidx++];
else
ebuf[i] = event[ridx++];
}
for (int i = l; i <= r; i++)
event[i] = ebuf[i];
}
} A, B;
int main() {
int N;
Offline::Point pt;
while (scanf("%d", &N) == 1) {
A.init(65536), B.init(65536);
int max_y = 0;
for (int i = 0; i < N; i++) {
pt.read(), max_y = max(max_y, pt.y);
B.addEvent(0, 2*i, pt);
A.addEvent(1, 2*i, pt);
pt.read(), max_y = max(max_y, pt.y);
A.addEvent(0, 2*i+1, pt);
B.addEvent(1, 2*i+1, pt);
}
A.N = B.N = max_y; // y in [1, max_y]
A.run(), B.run();
for (int i = 0; i < N; i++) {
if (i) printf("%d\n", B.ret[2*i]);
printf("%d\n", A.ret[2*i+1]);
}
}
return 0;
}
Read More +

b443. 我愛 Fibonacci

Problem

$$F_n=\begin{cases} n, & n=0,1 \\ F_{n-1}+F_{n-2}, & n \geq 2 \end{cases}$$

求出$F_{2^n} \mod m$ 的結果。

Sample Input

1
2
3
4
3
1 1000000007
2 1000000007
3 1000000007

Sample Output

1
2
3
1
3
21

Solution

一般的矩陣計算,利用 $M^n$ 求出$F_n$,其中

$$M = \begin{bmatrix} 1 & 1\\ 1 & 0 \end{bmatrix}$$

如果是求$F_n$ 時間複雜度 $O(\log n)$,而這一題求的是$F_{2^n}$,時間複雜度 $O(n)$

為了加速運算,目標是要找到 $\mod p$ 下的循環長度 $L$,最後求出$F_{2^n \mod L}$,根據待會的證明,保證 $L \le p$,那複雜度就可以回到 $O(\log L)$ 解決。但為了要找到 $L$ 又是一段很長的故事,總時間複雜度為 $O(\sqrt{p})$,不用保證 $p$ 是質數。

參考資料

故事

數列$F_0 = 1, F_1 = 1, F_2 = 2, \cdots$,循環是連續兩項出現重複,而費氏數列會完全循環,也就是出現連續兩項$F_i = 0, F_{i-1} = 1$。下方是一個 $\mod 4$ 的情況。

1
1 1 2 3 1 0 | 1 2 3 1 0 | ...

要找到恰好連續兩項$F_i = 0, F_{i-1} = 1$ 是困難的,考慮去找到$F_i = 0$ 即可,接著再去想辦法讓$F_{i-1} = 1$

假設最小的 $k$ 滿足$F_k = 0 \mod p$,而$F_{k-1} = a \mod p$,那麼之後的序列$F_{i} = a^j F_{i+j \times k} \mod p$。從矩陣乘法的概念中可以理解,是一個常數為 $a$ 的初始項,第二輪循環常數就會變成 $a^2$,類推。

接下來

  • 考慮一個嚴重的問題「 何種模 $p$ 情況一定循環,即從$F_0$ 再次循環。 」答案是 質數 $p$
    原因是$F_{i} = a^j F_{i+j \times k} \mod p$,由於 $a$$p$ 互質,$a^j \mod p \neq 0$ 恆成立,同時還是一個 $ord_{p}(a) = p-1$,這部分從歐拉定理中可以了解,那麼只有可能在$F_{i} = 0$ 的情況成立,就是 $k$ 的倍數之外,不發生$F_i = 0$ 的出現。
  • 接續上一個問題「模 $p$ 不是質數怎麼處理?」
    進行質因數分解,對於每一個質因子找到模循環長度,模 $p$ 循環長度就是所有質因子循環長度的最小公倍數 lcm。

現在問題落在 $k$ 怎麼找到,若能找到 $k$,其循環長度落在 $k$ 的倍數,或者有更好的獲取方式。

  • 若模質數 $p$ 且滿足 $p > 5$,5 是 $p$ 的二次剩餘 (quadratic residue),意即滿足 $\exists \; x^2 \equiv 5 \mod p$,循環長度為 $p-1$ 的因數。反之,循環長度是 $2(p+1)$ 的因數。

關於二次剩餘的判斷,在模質數 $p$ 下,對於 $gcd(x, p) = 1$,藉由歐拉定理得到 $x^{p-1} \equiv 1 \mod p$,以下不保證是正確的說法,提供理解的一個方案。

  • $d$$p$ 的二次剩餘,則滿足 $d^{(p-1)/2} \equiv 1 \mod p$,因為$x^{2 \times (p-1)/2} \equiv 1 \mod p \Rightarrow x^{p-1} \equiv 1 \mod p$
  • 若非二次剩餘,則滿足 $d^{(p-1)/2} \equiv -1 \mod p$,因為 $\left [ d^{(p-1)/2} \right ]^2 \equiv 1 \mod p \Rightarrow d^{p-1} \equiv 1 \mod p$
  • 數學上用 Legendre symbol 來表示這個判斷 wiki

回過頭來,看一下費氏數列的公式解

$F_n = \frac{1}{\sqrt{5}} \left [ \left ( \frac{1+\sqrt{5}}{2} \right )^n - \left (\frac{1-\sqrt{5}}{2} \right )^n \right ]$

藉由展開公式,儘管它是實數、根號,展開之後一定只會剩下整數冪次的總和。令 $a = \sqrt{5}$,觀察二次剩餘與否和滿足$F_n = 0$ 的關係。

在模質數 $p$ 下,滿足二次剩餘$F_n \equiv 0 \mod p$,當 $n = p-1$ 的時候成立,可以藉由噁心的展開式得到。同理在非二次剩餘情況,$n = 2(p+1)$,找到一個最大的倍數情況,答案一定落在其因數下。詳細推導請看參考資料,太噁心就不提。

參考資料中有特別提到,有一個地方還 沒有確認 ,對於模數 $p^k$ 的循環長度 $g(p) \times p^{k-1}$ 如何證明。但我想根據中國餘式定理,能了解循環長度倍數的模關係吧。接著由於大整數分解期望是 $O(\sqrt{n})$,中間也要找到所有因數來得到循環長度的驗證,還要搭配快速矩陣乘法,最後也是 $O(\sqrt{n})$

這一題是更進階的費氏數列計算,從建表、矩陣乘法,最後來到循環搜尋,這些證明不久之後就會忘記了,二次剩餘判斷的想法還會留著,其他就忘得一乾二淨吧。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
#include <bits/stdc++.h>
using namespace std;
#define MILLER_BABIN 4
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
struct Matrix {
UINT64 v[2][2];
int row, col; // row x col
Matrix(int n, int m, int a = 0) {
memset(v, 0, sizeof(v));
row = n, col = m;
for(int i = 0; i < row && i < col; i++)
v[i][i] = a;
}
Matrix multiply(const Matrix& x, const long long mod) const {
Matrix ret(row, x.col);
for(int i = 0; i < row; i++) {
for(int k = 0; k < col; k++) {
if (!v[i][k])
continue;
for(int j = 0; j < x.col; j++) {
ret.v[i][j] += mul(v[i][k], x.v[k][j], mod);
if (ret.v[i][j] >= mod)
ret.v[i][j] -= mod;
}
}
}
return ret;
}
Matrix pow(const long long& n, const long long mod) const {
Matrix ret(row, col, 1), x = *this;
long long y = n;
while(y) {
if(y&1) ret = ret.multiply(x, mod);
y = y>>1, x = x.multiply(x, mod);
}
return ret;
}
} FibA(2, 2, 0);
#define MAXL (50000>>5)+1
#define GET(x) (mark[x>>5]>>(x&31)&1)
#define SET(x) (mark[x>>5] |= 1<<(x&31))
int mark[MAXL], P[50000], Pt = 0;
void sieve() {
register int i, j, k;
SET(1);
int n = 46340;
for (i = 2; i <= n; i++) {
if (!GET(i)) {
for (k = n/i, j = i*k; k >= i; k--, j -= i)
SET(j);
P[Pt++] = i;
}
}
}
UINT64 mpow(UINT64 x, UINT64 y, UINT64 mod) { // mod < 2^32
UINT64 ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
UINT64 mpow2(UINT64 x, UINT64 y, UINT64 mod) {
UINT64 ret = 1;
while (y) {
if (y&1)
ret = mul(ret, x, mod);
y >>= 1, x = mul(x, x, mod);
}
return ret;
}
void exgcd(long long x, long long y, long long &g, long long &a, long long &b) {
if (y == 0)
g = x, a = 1, b = 0;
else
exgcd(y, x%y, g, b, a), b -= (x/y) * a;
}
long long llgcd(long long x, long long y) {
if (x < 0) x = -x;
if (y < 0) y = -y;
if (!x || !y) return x + y;
long long t;
while (x%y)
t = x, x = y, y = t%y;
return y;
}
long long inverse(long long x, long long p) {
long long g, b, r;
exgcd(x, p, g, r, b);
if (g < 0) r = -r;
return (r%p + p)%p;
}
int isPrime(long long p) { // implements by miller-babin
if (p < 2 || !(p&1)) return 0;
if (p == 2) return 1;
long long q = p-1, a, t;
int k = 0, b = 0;
while (!(q&1)) q >>= 1, k++;
for (int it = 0; it < MILLER_BABIN; it++) {
a = rand()%(p-4) + 2;
t = mpow2(a, q, p);
b = (t == 1) || (t == p-1);
for (int i = 1; i < k && !b; i++) {
t = mul(t, t, p);
if (t == p-1)
b = 1;
}
if (b == 0)
return 0;
}
return 1;
}
long long pollard_rho(long long n, long long c) {
long long x = 2, y = 2, i = 1, k = 2, d;
while (true) {
x = (mul(x, x, n) + c);
if (x >= n) x -= n;
d = llgcd(x - y, n);
if (d > 1) return d;
if (++i == k) y = x, k <<= 1;
}
return n;
}
void factorize(int n, vector<long long> &f) {
for (int i = 0; i < Pt && P[i]*P[i] <= n; i++) {
if (n%P[i] == 0) {
while (n%P[i] == 0)
f.push_back(P[i]), n /= P[i];
}
}
if (n != 1) f.push_back(n);
}
void llfactorize(long long n, vector<long long> &f) {
if (n == 1)
return ;
if (n < 1e+9) {
factorize(n, f);
return ;
}
if (isPrime(n)) {
f.push_back(n);
return ;
}
long long d = n;
for (int i = 2; d == n; i++)
d = pollard_rho(n, i);
llfactorize(d, f);
llfactorize(n/d, f);
}
// above largest factor
// ---------------------- //
int legendre_symbol(UINT64 d, UINT64 p) {
if (d%p == 0) return 0;
return mpow2(d, (p-1)>>1, p) == 1 ? 1 : -1;
}
void factor_gen(int idx, long long x, vector< pair<long long, int> > &f, vector<long long> &ret) {
if (idx == f.size()) {
ret.push_back(x);
return ;
}
for (long long i = 0, a = 1; i <= f[idx].second; i++, a *= f[idx].first)
factor_gen(idx+1, x*a, f, ret);
}
void factor_gen(long long n, vector<long long> &ret) {
vector<long long> f;
vector< pair<long long, int> > f2;
llfactorize(n, f);
sort(f.begin(), f.end());
int cnt = 1;
for (int i = 1; i <= f.size(); i++) {
if (i == f.size() || f[i] != f[i-1])
f2.push_back(make_pair(f[i-1], cnt)), cnt = 1;
else
cnt ++;
}
factor_gen(0, 1, f2, ret);
sort(ret.begin(), ret.end());
}
UINT64 cycleInFib(UINT64 p) {
if (p == 2) return 3;
if (p == 3) return 8;
if (p == 5) return 20;
vector<long long> f;
if (legendre_symbol(5, p) == 1)
factor_gen(p-1, f);
else
factor_gen(2*(p+1), f);
long long f1, f2;
for (int i = 0; i < f.size(); i++) {
Matrix t = FibA.pow(f[i]-1, p);
f1 = (t.v[0][0] + t.v[0][1])%p;
f2 = (t.v[1][0] + t.v[1][1])%p;
if (f1 == 1 && f2 == 0)
return f[i];
}
return 0;
}
UINT64 cycleInFib(UINT64 p, int k) {
UINT64 s = cycleInFib(p);
for (int i = 1; i < k; i++)
s = s * p;
return s;
}
int main() {
sieve();
FibA.v[0][0] = 1, FibA.v[0][1] = 1;
FibA.v[1][0] = 1, FibA.v[1][1] = 0;
int testcase;
scanf("%d", &testcase);
while (testcase--) {
long long n, m;
scanf("%lld %lld", &n, &m);
vector<long long> f;
map<long long, int> r;
llfactorize(m, f);
for (auto &x : f)
r[x]++;
UINT64 cycle = 1;
for (auto &x : r) {
UINT64 t = cycleInFib(x.first, x.second);
cycle = cycle / llgcd(t, cycle) * t;
}
n = mpow2(2, n, cycle);
Matrix t = FibA.pow(n, m);
long long fn = t.v[1][0];
printf("%lld\n", fn);
}
return 0;
}
Read More +

b444. 期望試驗 快速冪次

Problem

背景

曾經某 M 被期望值坑,就只是在計算 $x^y \mod z$ 時偷偷替換成 $x^{y-1} \times x \mod z$,結果得到 Time Limit Exceeded。

根據分析 $y = 16$ 時,用二進制表示為$(10000)_{2}$,若變成 $y = 15$,就會變成$(01111)_{2}$,通常快速求冪的乘法次數與二進制的 1 個數成正比,所以速度就慢非常多。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpow(UINT64 x, UINT64 y, UINT64 mod) {
UINT64 ret = 1;
while (y) {
if (y&1)
ret = mul(ret, x, mod);
y >>= 1, x = mul(x, x, mod);
}
return ret;
}

問題描述

讓我們來一場 $x^y \mod z$ 的期望值試驗吧,基礎目標是減少乘法次數。

其中 $1 \le x, z \le 10^{18}$, $0 \le y \le 2^{2^{20}}$,在這場試驗中展現你的優化吧。

Sample Input

1
2
3
4
5 01 7
5 10 7
5 0101 7
3 0110 7

Sample Output

1
2
3
4
5
4
3
1

Solution

詳細可以參考《資訊安全 - 近代加密 快速冪次計算》那一篇。

Algorithm Table Size #squaring Average #Multiplication
Right-To-Left 1: $x^{2^i}$ $n$ $n/2$
Left-To-Right 1: $x$ $n$ $n/2$
Left-To-Right(2-bits) 3: $x$, $x^2$, $x^3$ $n$ $3n/8$
Left-To-Right(sliding) 2: $x$, $x^3$ $n$ $n/3$

減少乘法次數,但以上期望乘法次數是跟 1 的個數有關,雖然最好是從 $n/2$ 降到 $n/3$,並不表示速度會真的快上 $1.5$ 倍左右,畢竟還有所謂的基礎乘法次數需求,根據實驗下來大約能快個 10% 到 20% 之間,加上 -Ofast 編譯此時的差異又會再少一點,看起來實作方法影響很嚴重。

例如在不加編譯優化參數下

1
2
3
4
if a[i] == 0 && a[i+1] == 0
else if a[i] == 0 && a[i+1] == 1
else if a[i] == 1 && a[i+1] == 0
else

上述做法會比下述來得快上許多

1
2
3
4
5
6
if a[i] == 0
if a[i+1] == 0
else
else
if a[i+1] == 0
else

最後,產出一個 cheat 版本,使用 L-to-R-2bits 的概念下去擴充,使用 loop unrolling 進行加速,由於會發生不被整除的問題,小測資就靠 L-to-R-sliding 的方案去解決。

在 zerojudge 主機上平台上,隨機測資下的運作情況如下:

Algorithm Time
Right-To-Left 5.4s
Left-To-Right(2-bits) 4.9s
Left-To-Right(sliding) 4.8s
Left-To-Right-sliding-cheat 4.2s

R-to-L

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowR2L(UINT64 x, char y[], UINT64 z) {
int n;
for (n = 1; y[n]; n <<= 1);
UINT64 ret = 1;
for (int i = n-1; i >= 0; i--) {
if (y[i] == '1')
ret = mul(ret, x, z);
x = mul(x, x, z);
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowR2L(x, y, z));
}
return 0;
}

L-to-R-2bits

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowL2R2(UINT64 x, char y[], UINT64 z) {
UINT64 x2 = mul(x, x, z);
UINT64 x3 = mul(x2, x, z);
UINT64 ret = 1;
for (int i = 0; y[i]; i += 2) {
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
if (y[i] == '1' && y[i+1] == '1') {
ret = mul(ret, x3, z);
} else if (y[i] == '1' && y[i+1] == '0') {
ret = mul(ret, x2, z);
} else if (y[i+1] == '1') {
ret = mul(ret, x, z);
}
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowL2R2(x, y, z));
}
return 0;
}

L-to-R-sliding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowL2RS(UINT64 x, char y[], UINT64 z) {
UINT64 x3 = mul(mul(x, x, z), x, z);
UINT64 ret = 1;
for (int i = 0; y[i]; ) {
ret = mul(ret, ret, z);
if (y[i] == '1' && y[i+1] == '1') {
ret = mul(ret, ret, z);
ret = mul(ret, x3, z);
i += 2;
} else if (y[i] == '1') {
ret = mul(ret, x, z);
i ++;
} else if (y[i+1] == '0') {
ret = mul(ret, ret, z);
i += 2;
} else {
i++;
}
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowL2RS(x, y, z));
}
return 0;
}

L-to-R-sliding-cheat

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowL2RS(UINT64 x, char y[], UINT64 z) {
UINT64 x3 = mul(mul(x, x, z), x, z);
UINT64 ret = 1;
for (int i = 0; y[i]; ) {
ret = mul(ret, ret, z);
if (y[i] == '1' && y[i+1] == '1') {
ret = mul(ret, ret, z);
ret = mul(ret, x3, z);
i += 2;
} else if (y[i] == '1') {
ret = mul(ret, x, z);
i ++;
} else if (y[i+1] == '0') {
ret = mul(ret, ret, z);
i += 2;
} else {
i++;
}
}
return ret;
}
#define PREPROC 8
UINT64 mpowCHEAT(UINT64 x, char y[], UINT64 z) {
int n;
for (n = 1; y[n]; n <<= 1);
if (n < 1<<PREPROC)
return mpowL2RS(x, y, z);
UINT64 X[1<<PREPROC] = {1};
for (int i = 1; i < (1<<PREPROC); i++)
X[i] = mul(X[i-1], x, z);
UINT64 ret = 1;
for (int i = 0, v; y[i]; i += PREPROC) {
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
v = (y[i]-'0')<<7|(y[i+1]-'0')<<6|(y[i+2]-'0')<<5|(y[i+3]-'0')<<4|(y[i+4]-'0')<<3|(y[i+5]-'0')<<2|(y[i+6]-'0')<<1|(y[i+7]-'0');
ret = mul(ret, X[v], z);
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowCHEAT(x, y, z));
}
return 0;
}
Read More +

b440. 互質對

Problem

給定 $n$$m$,請你統計有序對 $(a,b)$ 的個數,其中 $1 \le a \le n, 1 \le b \le m$$a$$b$ 互質。

$a$$b$ 互質的定義是:$a$$b$ 的最大公約數等於 $1$

> count coprime pair

## Sample Input ##
1
2
3
2
3 4
10000000 10000000

Sample Output

1
2
9
60792712854483

Solution

前言

柳柳州出數論基礎-「莫比烏斯反演 (Möbius inversion formula)」

之前放棄看完的內容又要撿回來看,百般痛苦,之前沒看懂啊。但能知道的是莫比烏斯反演類似排容原理,就像錯排依樣。先不管莫比烏斯,看到「線性時間求出所有乘法反元素」真的可以嗎?那也是件很有趣的事情。

參考資源

  • 莫比烏斯參考《線性篩法與積性函數》-賈志鵬 link
  • 分塊優化參考《POI XIV Stage.1 Queries Zap 解题报告》-Kwc Oliver link

基礎定義

了解莫比烏斯反演之前,要先介紹積性函數

積性函數

  • 何謂積性函數 (Multiplicative function)?
    對於定義域 $\mathbb{N}^+$ 的函數 $f$ 而言,任兩個互質的 $gcd(a, b) = 1$ 正整數 $a, b$ 滿足 $f(ab) = f(a)f(b)$
  • 何謂完全積性函數?
    $f$ 是一個積性函數,同時滿足 $f(p^n) = f(p)^n$
  • $f(n), g(n)$ 是積性函數,則 $h(n) = f(n) g(n)$ 也是積性函數。
  • $f(n)$ 為積性函數,則函數 $g(n) = \sum_{d|n} f(d)$ 也是積性函數。

歐拉函數

回顧歐拉函數 (Euler’s totient function) $\phi$

  • 定義 $\phi(n)$$1 \cdots n$ 中與 $n$ 互質的個數。
  • $\phi(n)$ 是一個積性函數,但不是完全積性函數。根據中國餘數定理 或者 基本算術可以證明之。

  • 歐拉定理 $a^{\phi(n)} \equiv 1 \mod n, \text{ when } gcd(a, n) = 1$

  • 特性 1:$\sum_{d|n} \phi(d) = n$。當作把互質個數 $\phi(n)$ 相加不互質個數 $s$,由於 $d$$n$ 的因數,則與 $d$ 互質 $x$ 個數 $\phi(d)$,把那些 $x' = x \times d/n$ 就會補上那些不互質的個數 $s = |set(x')|$

  • 特性 2:$1 \cdots n$$n$ 互質的數和為 $n\phi(n)/2$。原因很簡單,若 $gcd(x, n) = 1$,則會滿足 $gcd(n-x, n) = 1$,看起來就是一個對稱總和。

莫比烏斯

根據積性函數的性質,我們得到莫比烏斯反演的基礎:

莫比烏斯反演公式 $f(n)$

$f(n) = \sum_{d|n} \mu(d) g(\frac{n}{d})$

莫比烏斯函數 $\mu(n)$

$$\mu(n) = \left\{\begin{matrix} 1 && n = 1\\ (-1)^k && n = p_1 p_2 \cdots p_k \\ 0 && \text{otherwise} \end{matrix}\right.$$
  • 特性 1:$\sum_{d|n} \mu(d) = [n = 1]$

在此簡單說一次,莫比烏斯反演公式就像排容原理,而莫比烏斯函數 $\mu(n)$ 就像一加一減,之所以在 $n = p_1 p_2 \cdots p_k$$\mu(n) = (-1)^k$,也就是說當 $n$ 只全由質數相乘 (不允許冪次方大於 1,否則為 0),相當於一般認知,奇數次要扣,偶數次要補回來的排容口訣。而特定自然不必多說,其中數學表達 $[n = 1]$ 相當於程式中的 n == 1 ? 1 : 0

簡單展示歐拉函數 $\phi(n)$

  • $f(n) = \sum_{d|n} \phi(d) = n$,由歐拉函數特性 2 得知。
    *$\phi(n) = \sum_{d|n} \mu(d) f(\frac{n}{d}) = \sum_{d|n} \frac{\mu(d) n}{d}$,套用莫比烏斯公式後整理。

關於公式$\phi(n) = \sum_{d|n} \frac{\mu(d) n}{d}$ 可以這麼理解。

  • 求出 $gcd(n, k) = 1$ 的個數,其餘要捨棄掉。
  • 計算 $gcd(n, k) = d$ 的個數,利用排容原理即莫比烏斯函數,顯而易見排容方案只當 $d$$n$ 個質因數組合而成。

應用此題

利用莫比烏斯函數 特性 1:$\sum_{d|n} \mu(d) = [n = 1]$

$$\begin{align*} & \sum_{a = 1}^{N} \sum_{b = 1}^{M} [gcd(a, b) = 1] \\ & = \sum_{a = 1}^{N} \sum_{b = 1}^{M} \sum_{d|gcd(a, b)} \mu(d) \\ & = \sum \mu(d) \left ( \sum_{1 \le a \le N \text{ and } d | a} \left ( \sum_{1 \le b \le M \text{ and } d | b} 1 \right ) \right )\\ & = \sum \mu(d) \left \lfloor \frac{n}{d} \right \rfloor \left \lfloor \frac{m}{d} \right \rfloor \end{align*}$$

第二行就是抽換莫比烏斯,第三行則是交換順序,第四行則是可以快速找到 $d|a$ 的總數為 $\lfloor n/d \rfloor$ 所導致。

若直接窮舉 $d$ 時間複雜度為 $O(min(n, m))$,可以利用塊狀的概念優化到 $O(\sqrt{n} + \sqrt{m})$。因為 $\lfloor n/d \rfloor$ 的值只有 $2\lfloor \sqrt{n} \rfloor$ 種可能,同理 $\lfloor m/d \rfloor$ 也是,那麼對於同一個 $d$ 的計數會是一個連續區間不斷地移動。

以下是一個簡單的 $\lfloor n/d \rfloor$ 示意圖,用 $\lfloor n/d \rfloor$ 不同當作劃分。

1
2
3
4
5
6
7
8
9
10
11
X X X X X X X XX X X
X X X X X X X XX X X
+----------------------------------------------------+
| | | | | | | | | n
+----------------------------------------------------+
012345678901234567890123456789012345678901234567890123
+---------------------------------------------------------------------------+
| | | | | | | | | | | | m
+---------------------------------------------------------------------------+
X X X X X X X XX X X
X X X X X X X XX X X

程式只要處理打 X 的位置即可,時間複雜度 $O(\sqrt{n} + \sqrt{m})$,可以參照一般解。代碼短,但除法次數多。

加速一般解由 liouzhou_101 提供。開一次根號 sqrt(),省下 $O(\sqrt{n})$ 次的除法,賺了 400 ms,代碼長了 800 B。加速 25% 的效能,可謂除法的可怕,根據研究差異後才發現到這一點。

一般解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
using namespace std;
#define GET(x) (mark[(x)>>5]>>((x)&31)&1)
#define SET(x) (mark[(x)>>5] |= 1<<((x)&31))
const int MAXN = 10000005;
const int MAXL = (MAXN>>5)+1;
int mark[MAXL];
int P[700000], Pt = 0;
short mu[MAXN], sum[MAXN];
void sieve_mobius() {
register int i, j, k;
SET(1), mu[1] = 1;
int n = 10000000;
for (i = 2; i <= n; i++) {
if (!GET(i))
P[Pt++] = i, mu[i] = -1;
for (j = 0; j < Pt && (k = i*P[j]) <= n; j++) {
SET(k);
if (i%P[j] == 0) {
mu[k] = 0;
break;
}
mu[k] = -mu[i];
}
}
}
long long coprime_pair(int n, int m) {
long long ret = 0;
if (n > m) swap(n, m);
for (int d = 1, r; d <= n; d = r+1) {
r = min(n / (n/d), m / (m/d));
ret += (long long)(sum[r] - sum[d-1]) * (n/d) * (m/d);
}
return ret;
}
int main() {
sieve_mobius();
for (int i = 1; i < MAXN; i++)
sum[i] = sum[i-1] + mu[i];
int testcase, N, M;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &N, &M);
printf("%lld\n", coprime_pair(N, M));
}
return 0;
}

加速運算解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
using namespace std;
#define GET(x) (mark[(x)>>5]>>((x)&31)&1)
#define SET(x) (mark[(x)>>5] |= 1<<((x)&31))
const int MAXN = 10000005;
const int SQRN = 3200 * 2;
const int MAXL = (MAXN>>5)+1;
int mark[MAXL];
int P[700000], Pt = 0;
short mu[MAXN], sum[MAXN];
void sieve_mobius() {
register int i, j, k;
SET(1), mu[1] = 1;
int n = 10000000;
for (i = 2; i <= n; i++) {
if (!GET(i))
P[Pt++] = i, mu[i] = -1;
for (j = 0; j < Pt && (k = i*P[j]) <= n; j++) {
SET(k);
if (i%P[j] == 0) {
mu[k] = 0;
break;
}
mu[k] = -mu[i];
}
}
}
long long coprime_pair(int n, int m) {
static int An[SQRN][2], Bn[SQRN][2];
if (n > m) swap(n, m);
long long ret = 0;
int aidx = 0, bidx = 0, sq;
sq = sqrt(n);
for (int i = 1; i <= sq; i++, aidx++) An[aidx][1] = n/(An[aidx][0] = n / i);
if (sq * sq == n) aidx--;
for (int i = sq; i >= 1; i--, aidx++) An[aidx][1] = n/(An[aidx][0] = i);
sq = sqrt(m);
for (int i = 1; i <= sq; i++, bidx++) Bn[bidx][1] = m/(Bn[bidx][0] = m / i);
if (sq * sq == m) bidx--;
for (int i = sq; i >= 1; i--, bidx++) Bn[bidx][1] = m/(Bn[bidx][0] = i);
for (int l = 1, r, *a = &An[0][1], *b = &Bn[0][1], *A = &An[0][0], *B = &Bn[0][0]; l <= n; l = r+1) {
if (*a < *b)
r = *a, ret += (long long) (sum[r] - sum[l-1]) * *A * *B, A+=2, a+=2;
else if (*a > *b)
r = *b, ret += (long long) (sum[r] - sum[l-1]) * *A * *B, B+=2, b+=2;
else
r = *a, ret += (long long) (sum[r] - sum[l-1]) * *A * *B, A+=2, B+=2, a+=2, b+=2;
}
return ret;
}
int main() {
sieve_mobius();
for (int i = 1; i < MAXN; i++)
sum[i] = sum[i-1] + mu[i];
int testcase, N, M;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &N, &M);
printf("%lld\n", coprime_pair(N, M));
}
return 0;
}
Read More +

b441. 延展尺寸

Problem

將 N 張小正方形拼成長條圖,並且每一次挑選一個正方形區域黏貼在長條後,黏貼的條件是比較重疊一半的區域,利用能量最少路徑進行裁剪後併在一起。

如果最少能量大於等於某個值,則再隨機挑選一個正方形區域進行黏貼,捨棄掉這次黏貼操作,若嘗試 50 次沒有成功低於閥值,則取消全部的黏貼操作。

當初的問題在於 連續 50 次 看得似懂非懂,然後還以為是要碎形圖,只找一個正方形去重疊得到 N 個。

Sample Input

1
2
3
4
2 1
2 2
1 2 3 4 5 6
7 8 9 10 11 12

Sample Output

1
2
3
2 2
1 2 3 4 5 6
7 8 9 10 11 12

Solution

接續上一題 b438. 裁剪尺寸 的作法,現在多一個串接和重疊的操作。

感謝蔡星 asas 對此題描述的解說。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <bits/stdc++.h>
using namespace std;
class IMAGE {
public:
struct Pixel {
int r, g, b;
Pixel(int x = 0, int y = 0, int z = 0):
r(x), g(y), b(z) {}
void read() {
scanf("%d %d %d", &r, &g, &b);
}
Pixel operator-(const Pixel &x) const {
return Pixel(r-x.r, g-x.g, b-x.b);
}
Pixel operator+(const Pixel &x) const {
return Pixel(r+x.r, g+x.g, b+x.b);
}
Pixel operator*(const double x) const {
return Pixel(r*x, g*x, b*x);
}
Pixel operator/(const double x) const {
return Pixel(r/x, g/x, b/x);
}
void print() {
printf("%d %d %d", r, g, b);
}
int length() {
return abs(r) + abs(g) + abs(b);
}
double dist(Pixel x) {
return sqrt((r-x.r)*(r-x.r)+(g-x.g)*(g-x.g)+(b-x.b)*(b-x.b));
}
};
int W, H;
static const int MAXN = 256;
Pixel data[MAXN][MAXN*3], tmp[MAXN][MAXN*3];
int energy[MAXN][MAXN], dp[MAXN][MAXN];
long long seed;
int random() {
return seed = ( seed * 9301 + 49297 ) % 233280;
}
void getSquarePosition(int &x, int &y, int L) {
y = (W <= L) ? 0 : random() % (W - L);
x = (H <= L) ? 0 : random() % (H - L);
}
void read() {
scanf("%d %d", &W, &H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].read();
seed = 0;
}
int isValid(int x, int y) {
return x >= 0 && y >= 0 && x < H && y < W;
}
void pattern(int L, int N) {
int ERROR_TRY = 50;
int threshold = round(L*255/8.0);
int overlap = round(L/2.0);
int lx, ly, tW = 0, path[MAXN] = {};
for (int n = 0, it; n < N; n++) {
for (it = 0; it < ERROR_TRY; it++) {
getSquarePosition(lx, ly, L);
if (n == 0) {
for (int i = 0; i < L; i++)
for (int j = 0; j < L; j++)
tmp[i][j] = data[lx+i][ly+j];
tW = L;
break;
}
for (int i = 0; i < L; i++) {
for (int j = 0; j < overlap; j++) {
energy[i][j] = round(data[lx+i][ly+j].dist(tmp[i][tW-overlap+j]));
}
}
int cost = shrink(path, L, overlap);
if (cost < threshold) {
for (int i = 0; i < L; i++) {
for (int j = L-1; j >= path[i]; j--)
tmp[i][tW-overlap+j] = data[lx+i][ly+j];
}
tW += L - overlap;
break;
}
}
if (it == ERROR_TRY)
return ;
}
W = tW, H = L;
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j] = tmp[i][j];
}
int shrink(int path[], int H, int W) {
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
int &val = dp[i][j];
if (i == 0) val = energy[i][j];
else {
val = dp[i-1][j];
if (j-1 >= 0)
val = min(val, dp[i-1][j-1]);
if (j+1 < W)
val = min(val, dp[i-1][j+1]);
val += energy[i][j];
}
}
}
int st = 0, cost;
for (int i = 0; i < W; i++)
if (dp[H-1][i] < dp[H-1][st])
st = i;
cost = dp[H-1][st];
for (int i = H-1; i >= 0; i--) {
path[i] = st;
if (i == 0) continue;
int val = dp[i][st] - energy[i][st];
if (st-1 >= 0 && val == dp[i-1][st-1])
st = st-1;
else if (val == dp[i-1][st])
st = st;
else
st = st+1;
}
return cost;
}
void print() {
printf("%d %d\n", W, H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].print(), printf("%c", j == W-1 ? '\n' : ' ');
}
} test;
int main() {
int n, L, N;
scanf("%d %d", &L, &N);
test.read();
test.pattern(L, N);
test.print();
return 0;
}
Read More +

b438. 裁剪尺寸

Problem

要讓影像的寬度減少,其一的方案就是每一列刪除一個像素,為了讓刪除更加地完善、盡可能看不出來突兀的地方,採用一個一條由上而下路徑進行刪除。

刪除路徑採用最少費用,這個費用採用 sobel operator,也就是說費用越高表示可能是邊緣像素,減少突兀就是盡可能不要去刪除到邊緣像素,這是顯而易見的方案。接著就是採用貪心方案,依序刪除 n 次路徑。

Sample Input

1
2
3
0
1 1
255 255 255

Sample Output

1
2
1 1
255 255 255

Solution

題目說明最好的選擇方案是一次 n 條不相交的由上而下的路徑,但這個定義是總花費,還是最小化最大花費路徑這是有疑惑的,若單純總花費可以採用最小費用流去完成,但複雜度對這個影像處理會可怕,點數跟邊數都是破萬,複雜度 $O(VE)$ 的算法要跑非常久,因此貪心是個好選擇。

找尋花費最小的路徑是採用 dynamic programming 的方式得到,並且要回溯得到左側最小的一條路徑。每刪除一條路徑後,sobel operator 要重新計算每一個像素的異動,刷新這一塊可以指更動路徑的右側和鄰近路徑的像素即可,這可以大幅度地增加速度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <bits/stdc++.h>
using namespace std;
class IMAGE {
public:
struct Pixel {
int r, g, b;
Pixel(int x = 0, int y = 0, int z = 0):
r(x), g(y), b(z) {}
void read() {
scanf("%d %d %d", &r, &g, &b);
}
Pixel operator-(const Pixel &x) const {
return Pixel(r-x.r, g-x.g, b-x.b);
}
Pixel operator+(const Pixel &x) const {
return Pixel(r+x.r, g+x.g, b+x.b);
}
Pixel operator*(const int x) const {
return Pixel(r*x, g*x, b*x);
}
Pixel operator/(const int x) const {
return Pixel(r/x, g/x, b/x);
}
void print() {
printf("%d %d %d", r, g, b);
}
int length() {
return abs(r) + abs(g) + abs(b);
}
};
int W, H;
static const int MAXN = 256;
Pixel data[MAXN][MAXN];
int energy[MAXN][MAXN], dp[MAXN][MAXN];
void read() {
scanf("%d %d", &W, &H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].read();
}
Pixel pabs(Pixel x) {
return Pixel(fabs(x.r), fabs(x.g), fabs(x.b));
}
inline Pixel getPixel(int x, int y) {
if (x >= 0 && y >= 0 && x < H && y < W)
return data[x][y];
if (y < 0) return data[min(max(x, 0), H-1)][0];
if (y >= W) return data[min(max(x, 0), H-1)][W-1];
if (x < 0) return data[0][min(max(y, 0), W-1)];
if (x >= H) return data[H-1][min(max(y, 0), W-1)];
return Pixel(0, 0, 0);
}
int sobel(int i, int j) {
const static int dx[] = {-1, -1, -1, 0, 0, 0, 1, 1, 1};
const static int dy[] = {-1, 0, 1, -1, 0, 1, -1, 0, 1};
const static int xw[] = {-1, 0, 1, -2, 0, 2, -1, 0, 1};
const static int yw[] = {-1, -2, -1, 0, 0, 0, 1, 2, 1};
Pixel Dx(0, 0, 0), Dy(0, 0, 0);
for (int k = 0; k < 9; k++) {
if (xw[k])
Dx = Dx + getPixel(i+dx[k], j+dy[k]) * xw[k];
if (yw[k])
Dy = Dy + getPixel(i+dx[k], j+dy[k]) * yw[k];
}
return Dx.length() + Dy.length();
}
void shrink(int n) {
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
energy[i][j] = sobel(i, j);
int path[MAXN];
for (int it = 0; it < n; it++) {
shrink(path);
for (int i = 0; i < H; i++) {
int y = path[i];
for (int j = y - 3; j < W; j++) {
if (j >= 0)
energy[i][j] = sobel(i, j);
}
}
}
}
void shrink(int path[]) {
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
int &val = dp[i][j];
if (i == 0) val = energy[i][j];
else {
val = dp[i-1][j];
if (j-1 >= 0)
val = min(val, dp[i-1][j-1]);
if (j+1 < W)
val = min(val, dp[i-1][j+1]);
val += energy[i][j];
}
}
}
int st = 0;
for (int i = 0; i < W; i++)
if (dp[H-1][i] < dp[H-1][st])
st = i;
for (int i = H-1; i >= 0; i--) {
path[i] = st;
int v = dp[i][st] - energy[i][st];
for (Pixel *p = &data[i][st], *q = &data[i][st+1], *end = &data[i][W]; q != end; p++, q++)
*p = *q;
if (i == 0) continue;
if (st-1 >= 0 && v == dp[i-1][st-1])
st = st-1;
else if (v == dp[i-1][st])
st = st;
else
st = st+1;
}
W--;
}
void print() {
printf("%d %d\n", W, H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].print(), printf("%c", j == W-1 ? '\n' : ' ');
}
} test;
int main() {
int n;
scanf("%d", &n);
test.read();
test.shrink(n);
test.print();
return 0;
}
Read More +

b442. 快取實驗 矩陣乘法

Problem

背景

記得 b439: 快取置換機制 提到的快取置換機制嗎?現在來一場實驗吧!

題目描述

相信不少人都已經實作所謂的矩陣乘法,計算兩個方陣大小為 $N \times N$ 的矩陣 $A, B$。為了方便起見,提供一個偽隨機數的生成,減少在輸入處理浪費的時間,同時也減少上傳測資的辛苦。

根據種子 $c = S1$ 生成矩陣 $A$,種子 $c = S2$ 生成矩陣 $B$,計算矩陣相乘 $A \times B$,為了方便起見,使用 hash 函數進行簽章,最後輸出一個值。由於會牽涉到 overflow 問題,此題作為快取實驗就不考慮這個,overflow 問題大家都會相同。

請利用快取優勢修改代碼如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <bits/stdc++.h>
using namespace std;
class Matrix {
public:
vector< vector<int> > data;
int row, col;
Matrix(int n = 1, int m = 1) {
data = vector< vector<int> >(n, vector<int>(m, 0));
row = n, col = m;
}
void rand_gen(int c) {
int x = 2, n = row * col;
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
x = ((long long) x * x + c + i + j)%n;
data[i][j] = x;
}
}
}
Matrix multiply(Matrix x) {
Matrix ret(row, x.col);
for (int i = 0; i < row; i++) {
for (int j = 0; j < x.col; j++) {
int sum = 0; // overflow
for (int k = 0; k < col; k++)
sum += data[i][k] * x.data[k][j];
ret.data[i][j] = sum;
}
}
return ret;
}
void print() {
for (int i = 0; i < row; i++) {
printf("[");
for (int j = 0; j < col; j++) {
printf(" %d", data[i][j]);
}
printf(" ]\n");
}
}
unsigned long signature() {
unsigned long h = 0;
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
h = hash(h + data[i][j]);
}
}
return h;
}
private:
unsigned long hash(unsigned long x) {
return (x * 2654435761LU);
}
};
int main() {
int N, S1, S2;
while (scanf("%d %d %d", &N, &S1, &S2) == 3) {
Matrix A(N, N), B(N, N), C;
A.rand_gen(S1);
B.rand_gen(S2);
C = A.multiply(B);
// A.print();
// B.print();
// C.print();
printf("%lu\n", C.signature());
}
return 0;
}

Sample Input

1
2
2 2 5
2 2 5

Sample Output

1
2
573770929
573770929

Solution

在越大的矩陣中,快取記憶體置換是很慢的,即便用 L1, L2, main memory 分層快取,速度差異很明顯,盡可能在線性 $O(n)$ 運算上都不造成記憶體置換 (搬移),雖然複雜度都是 $O(n^3)$,相信沒人會去實作 $O(n^{2.807})$ 的 Strassen 算法,又或者是現今更快的 Coppersmith-Winograd $O(n^{2.3727})$

作業系統題目第二彈,考驗快取應用。在 ZJ 主機上速度可以快二十倍,在本機上只快了十倍。

  1. 蔡神 asas 修改輸入,而我選擇了轉置,速度落後。
  2. 柳州 liouzhou_101 同步簽章,而我選擇分開函數,速度落後。
  3. 廖氏如神 pcsh710742 下刀計組,等等,我看見了什麼。

現在已經快了四十倍,就只是因為編譯器的優化參數變成手動。詳細實作和效能分析,需要的知識不只是作業系統,同時涵蓋計算機組織的 CPU stall 問題,參考 Optimizing Large Matrix-Vector Multiplications

先來個一般解,在計算矩陣 $A \times B$ 前,先將 $B$ 轉置,接著修改計算方式

1
2
for (int k = 0; k < col; k++)
sum += data[i][k] * x.data[j][k];

在這一個迴圈中,陣列採用 row-major 的方式儲存,所以第二維度的區域基本上都在快取中,miss 的機會大幅度降低。

接下來要提案的做法是結合上述三位的做法,代碼快,但不能當 Matrix 模板使用,僅僅做為一個效能體現。共用記憶體,採用指針的方式減少複製,一樣進行轉置,但這會破壞原有的 $B$ 矩陣配置。同步進行簽章,捨棄掉 $C$ 矩陣的儲存,使用 LOOP UNROLLING,由於 Zerojudge 的優化沒有開到 -O2-O3-Ofast,關於這一部分的優化由使用者自己來。

LOOP UNROLLING 的優化在於 branch 指令,由於 pipeline 會讓效能提升,但遇到 branch 時必須捨棄掉偷偷載入的指令、算到一半的結果,效能會下降,因此使用 LOOP UNROLLING 的好處在於減少 branch 次數。

關於 data prefetch 可以用 __builtin_prefetch() 來完成,根據廖氏如神所言,這個概念可以預先載入防止 stall (pipeline hazard) 的拖延,速度並不會提升太多,有可能是硬體已經完成這一部分,甚至用別的架構去克服。

轉置解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <bits/stdc++.h>
using namespace std;
class Matrix {
public:
vector< vector<int> > data;
int row, col;
Matrix(int n = 1, int m = 1) {
data = vector< vector<int> >(n, vector<int>(m, 0));
row = n, col = m;
}
void rand_gen(int c) {
int x = 2, n = row * col;
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
x = ((long long) x * x + c + i + j)%n;
data[i][j] = x;
}
}
}
Matrix multiply(Matrix &x) {
Matrix ret(row, x.col);
x.transpose();
for (int i = 0; i < row; i++) {
for (int j = 0; j < x.col; j++) {
int sum = 0; // overflow
int *a = &data[i][0], *b = &x.data[j][0];
for (int k = 0; k < col; k++)
sum += *a * *b, a++, b++;
ret.data[i][j] = sum;
}
}
x.transpose();
return ret;
}
void transpose() {
for (int i = 0; i < row; i++)
for (int j = i+1; j < col; j++)
swap(data[i][j], data[j][i]);
}
void print() {
for (int i = 0; i < row; i++) {
printf("[");
for (int j = 0; j < col; j++)
printf(" %d", data[i][j]);
printf(" ]\n");
}
}
unsigned long signature() {
unsigned long h = 0;
for (int i = 0; i < row; i++)
for (int j = 0; j < col; j++)
h = hash(h + data[i][j]);
return h;
}
private:
inline unsigned long hash(unsigned long x) {
return (x * 2654435761LU);
}
};
int main() {
int N, S1, S2;
while (scanf("%d %d %d", &N, &S1, &S2) == 3) {
Matrix A(N, N), B(N, N), C;
A.rand_gen(S1);
B.rand_gen(S2);
C = A.multiply(B);
printf("%lu\n", C.signature());
}
return 0;
}

LOOP UNROLL 解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <bits/stdc++.h>
using namespace std;
#define LOOP_UNROLL 8
class Matrix {
public:
vector< vector<int> > data;
int row, col;
Matrix(int n = 1, int m = 1) {
row = n, col = m;
data = vector< vector<int> >(n, vector<int>(m, 0));
}
void rand_gen(int c) {
int x = 2, n = row * col;
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
x = ((long long) x * x + c + i + j)%n;
data[i][j] = x;
}
}
}
unsigned long multiply(Matrix &x) {
x.transpose();
unsigned long h = 0;
for (int i = 0; i < row; i++) {
for (int j = 0; j < x.col; j++) {
register int sum = 0;
int *a = &data[i][0], *b = &x.data[j][0], k;
for (k = 0; k+LOOP_UNROLL < col; k += LOOP_UNROLL) {
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
}
for (; k < col; k++)
sum += *a * *b, a++, b++;
h = hash(h + sum);
}
}
return h;
}
void transpose() {
for (int i = 0; i < row; i++)
for (int j = i+1; j < col; j++)
swap(data[i][j], data[j][i]);
}
void print() {
for (int i = 0; i < row; i++) {
printf("[");
for (int j = 0; j < col; j++)
printf(" %d", data[i][j]);
printf(" ]\n");
}
}
private:
unsigned long hash(unsigned long x) {
return (x * 2654435761LU);
}
} A(1000, 1000), B(1000, 1000);
int main() {
int N, S1, S2;
while (scanf("%d %d %d", &N, &S1, &S2) == 3) {
A.row = A.col = B.row = B.col = N;
A.rand_gen(S1);
B.rand_gen(S2);
printf("%lu\n", A.multiply(B));
}
return 0;
}
Read More +

b439. 快取置換機制

Problem

背景

編寫程式不僅僅是在數學分析上達到高效率,儘管數學分析的複雜度不是最好,理解電腦運作模式也能有效地讓程式變快。例如 資料結構 Unrolled linked list (常翻譯成 塊狀鏈表) 便是利用此一優勢,讓速度顯著地提升,之所以能追上不少常數大的平衡樹操作運用的技巧就是快取效能改善。

講一個更簡單的例子,宣告整數陣列的兩個方案:

方案一

1
const int LARGE_SIZE = 1<<30; int A[LARGE_SIZE], B[LARGE_SIZE];

方案二

1
const int LARGE_SIZE = 1<<30; struct DATA { int A, B; } E[LARGE_SIZE];

演算法的複雜度倘若一樣,若在 A, B 相當高頻率的替換,則快取操作必須不斷地將內容置換,若 A, B 在算法中是獨立運算,則方案一的寫法會來得更好,反之取用方案二會比較好。最常見的運作可以在軟體模擬矩陣乘法中見到,預先將矩陣轉置,利用快取優勢速度可以快上 8 倍以上。

問題描述

給予一個記憶體空間大小為 $M$,使用 Least Recently Used (LRU) 策略進行置換,LRU 的策略為將最久沒有被使用過的空間替換掉,也就是需要從硬碟讀取 $disk[i]$$memory$ 時,發現記憶體都已經用光,則把不常用的 $mem[j]$ 寫入 $disk[k]$,再將 $disk[i]$ 內容寫入 $mem[j]$

下圖是一個 $M = 4$ 簡單的範例:

1
2
3
4
5
6
7
8
9
10
+---+ +---+ +---+ +---+ +---+ +---+ +---+ +---+
mem[0] | 1 | | 1 | | 1 | | 1 | | 1 | | 1 | | 1 | | 1 |
+---+ +---+ +---+ +---+ +---+ +---+ +---+ +---+
mem[1] | | | 2 | | 2 | | 2 | | 2 | | 2 | | 2 | | 2 |
+---+ +-> +---+ +-> +---+ +-> +---+ +-> +---+ +-> +---+ +-> +---+ +-> +---+
mem[2] | | | | | 3 | | 3 | | 3 | | 3 | | 5 | | 5 |
+---+ +---+ +---+ +---+ +---+ +---+ +---+ +---+
mem[3] | | | | | | | 4 | | 4 | | 4 | | 4 | | 3 |
+---+ +---+ +---+ +---+ +---+ +---+ +---+ +---+
1 2 3 4 1 2 5 3

依序使用 1, 2, 3, 4, 1, 2, 5, 3 的配置情況,特別是在 5 使用的時候,會將記憶體中上一次最晚使用的 3 替換掉。

現在寫程式去模擬置換情況。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
4
0 1
1 2 514
1 3 101
1 4 50216
0 1
0 2
1 5 6
0 3
0 5
0 2

Sample Output

1
2
3
4
5
6
0 0
0 0
1 514
3 0
2 6
1 514

Solution

軟體仿作,使用一個 hash 和一個 list 進行維護,而不是使用 priority queue 來維護,用 list 取代之。當進行取址時,順道把對應 list 指針移動,所有步驟期望複雜度都是在 $O(1)$ 完成,若使用 priority queue 會掉到 $O(\log n)$ 去進行更新。

換出作業系統題目,來一個最常見到的快取 LRU 算法。這個問題在 Leetcode OJ 上面也有,公司面試有機會遇到。

但實作時被自己坑,比較柳柳州給予的測試,速度居然慢個二十多倍。原因是這樣子的

1
2
map[key] = value;
return map.find(key);

替換成以下寫法,速度就起飛了。

1
return map.insert({key, value}).first;

也就是說插入的時候順便把指針拿到,避免第二次搜索。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, list<int>::iterator> PIT_TYPE;
typedef unordered_map<int, pair<int, list<int>::iterator> >::iterator HIT_TYPE;
class LRUCache {
public:
int memIdx, size;
list<int> time;
vector<int> mem;
unordered_map<int, PIT_TYPE> addr;
LRUCache(int capacity) {
size = capacity;
mem.resize(capacity, 0);
addr.clear(), time.clear();
memIdx = 0;
}
pair<int, int> get(int key) {
it = addr.find(key);
if (it == addr.end())
it = replace(key), mem[it->second.first] = 0;
it->second.second = recent(it->second.second);
return make_pair(it->second.first, mem[it->second.first]);
}
HIT_TYPE set(int key, int value) {
it = addr.find(key);
if (it == addr.end())
it = replace(key);
it->second.second = recent(it->second.second);
mem[it->second.first] = value;
return it;
}
private:
HIT_TYPE it;
HIT_TYPE replace(int key) {
int mpos = -1, trash;
list<int>::iterator lpos;
if (addr.size() == size) {
trash = time.front(), time.pop_front();
it = addr.find(trash);
mpos = it->second.first;
addr.erase(it);
} else {
mpos = memIdx++;
}
lpos = time.insert(time.end(), key);
return addr.insert({key, make_pair(mpos, lpos)}).first;
}
list<int>::iterator recent(list<int>::iterator p) {
int key = *p;
time.erase(p);
return time.insert(time.end(), key);
}
};
int main() {
int M, cmd, x, y;
pair<int, int> t;
scanf("%d", &M);
LRUCache LL(M);
while (scanf("%d", &cmd) == 1) {
if (cmd == 0) {
scanf("%d", &x);
pair<int, int> t = LL.get(x);
printf("%d %d\n", t.first, t.second);
} else {
scanf("%d %d", &x, &y);
LL.set(x, y);
}
}
return 0;
}
Read More +

b435. 尋找原根

Problem

給定一個模數 $m$,找到其 primitive root $g$

primitive root $g$ 滿足

$$g^{\phi(m)} \equiv 1 \mod m \\ g^{\gamma} \not\equiv 1 \mod m \text{ , for } 1 \le \gamma < \phi(m)$$

意即在模 $m$ 下,循環節 $ord_m(g) = \phi(m)$,若 $m$ 是質數,則 $\phi(m) = m-1$,也就是說循環節長度是 $m-1$,更可怕的是組成一個 $[1, m-1]$ 的數字排列。

現在假定 $m$ 是質數,請求出一個最小的 primitive root $g$

Sample Input

1
2
3
4
2
3
5
7

Sample Output

1
2
3
4
1
2
2
3

Solution

primitive root 在密碼學中,用在 Diffie-Hellman 的選定中的穩定度,倘若基底是一個 primitive root 將會提升破解的難度,因為對應情況大幅增加 (值域比較廣,因為不循環)。

由於歐拉定理,可以得知 $a^{\phi(p)} \equiv 1 \mod p$,那麼可以堆導得到若 $a^x \equiv 1 \mod p$,則滿足 $x | \phi(p)$,否則與歐拉矛盾,藉此可以作為一個篩選的檢查手法,對於所有 $phi(p)$ 的因數都進行檢查,最後我們可以得到 一般檢測 的作法。

為了加速驗證,由於 $x | \phi(p)$,若最小的 $x$ 存在 $a^x \equiv 1 \mod p$,那麼 $kx$ 也會符合等式,因此只需要檢查 $x = (p-1)/\text{prime-factor}$ 的情況,如此一來速度就快上非常多。最後得到 加速版本 ,此做法由 liouzhou_101 提供想法。

一般檢測

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <bits/stdc++.h>
using namespace std;
#define MAXL (50000>>5)+1
#define GET(x) (mark[(x)>>5]>>((x)&31)&1)
#define SET(x) (mark[(x)>>5] |= 1<<((x)&31))
int mark[MAXL];
int P[5500], Pt = 0;
void sieve() {
register int i, j, k;
SET(1);
int n = 45825;
for(i = 2; i <= n; i++) {
if(!GET(i)) {
for(k = n/i, j = i*k; k >= i; k--, j -= i)
SET(j);
P[Pt++] = i;
}
}
}
vector< pair<int, int> > factor(int n) {
vector< pair<int, int> > R;
for(int i = 0, j; i < Pt && P[i] * P[i] <= n; i++) {
if(n%P[i] == 0) {
for(j = 0; n%P[i] == 0; n /= P[i], j++);
R.push_back(make_pair(P[i], j));
}
}
if(n != 1) R.push_back(make_pair(n, 1));
return R;
}
void gen_factor(int idx, int x, vector< pair<int, int> > &f, vector<int> &g) {
if (idx == f.size()) {
g.push_back(x);
return ;
}
for (int i = 0, a = 1; i <= f[idx].second; i++, a *= f[idx].first)
gen_factor(idx+1, x*a, f, g);
}
long long mpow(long long x, long long y, long long mod) {
long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
int primitive_root(int p) {
if (p == 2) return 1;
vector< pair<int, int> > f = factor(p-1);
vector<int> g;
gen_factor(0, 1, f, g);
g.erase(g.begin()), g.pop_back(); // remove 1, p-1
random_shuffle(g.begin(), g.end());
for (int i = 2; i <= p; i++) {
int ok = 1;
for (auto &x: g) {
if (mpow(i, x, p) == 1) {
ok = 0;
break;
}
}
if (ok) return i;
}
return -1;
}
int main() {
sieve();
int p;
while (scanf("%d", &p) == 1)
printf("%d\n", primitive_root(p));
return 0;
}

加速版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
using namespace std;
#define MAXL (50000>>5)+1
#define GET(x) (mark[(x)>>5]>>((x)&31)&1)
#define SET(x) (mark[(x)>>5] |= 1<<((x)&31))
int mark[MAXL];
int P[5500], Pt = 0;
void sieve() {
register int i, j, k;
SET(1);
int n = 45825;
for(i = 2; i <= n; i++) {
if(!GET(i)) {
for(k = n/i, j = i*k; k >= i; k--, j -= i)
SET(j);
P[Pt++] = i;
}
}
}
vector< pair<int, int> > factor(int n) {
vector< pair<int, int> > R;
for(int i = 0, j; i < Pt && P[i] * P[i] <= n; i++) {
if(n%P[i] == 0) {
for(j = 0; n%P[i] == 0; n /= P[i], j++);
R.push_back(make_pair(P[i], j));
}
}
if(n != 1) R.push_back(make_pair(n, 1));
return R;
}
long long mpow(long long x, long long y, long long mod) {
long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
int primitive_root(int p) {
if (p == 2) return 1;
vector< pair<int, int> > f = factor(p-1);
for (int i = 2; i <= p; i++) {
int ok = 1;
for (auto &x: f) {
if (mpow(i, (p-1)/x.first, p) == 1) {
ok = 0;
break;
}
}
if (ok) return i;
}
return -1;
}
int main() {
sieve();
int p;
while (scanf("%d", &p) == 1)
printf("%d\n", primitive_root(p));
return 0;
}
Read More +