b440. 互質對

Problem

給定 $n$$m$,請你統計有序對 $(a,b)$ 的個數,其中 $1 \le a \le n, 1 \le b \le m$$a$$b$ 互質。

$a$$b$ 互質的定義是:$a$$b$ 的最大公約數等於 $1$

> count coprime pair

## Sample Input ##
1
2
3
2
3 4
10000000 10000000

Sample Output

1
2
9
60792712854483

Solution

前言

柳柳州出數論基礎-「莫比烏斯反演 (Möbius inversion formula)」

之前放棄看完的內容又要撿回來看,百般痛苦,之前沒看懂啊。但能知道的是莫比烏斯反演類似排容原理,就像錯排依樣。先不管莫比烏斯,看到「線性時間求出所有乘法反元素」真的可以嗎?那也是件很有趣的事情。

參考資源

  • 莫比烏斯參考《線性篩法與積性函數》-賈志鵬 link
  • 分塊優化參考《POI XIV Stage.1 Queries Zap 解题报告》-Kwc Oliver link

基礎定義

了解莫比烏斯反演之前,要先介紹積性函數

積性函數

  • 何謂積性函數 (Multiplicative function)?
    對於定義域 $\mathbb{N}^+$ 的函數 $f$ 而言,任兩個互質的 $gcd(a, b) = 1$ 正整數 $a, b$ 滿足 $f(ab) = f(a)f(b)$
  • 何謂完全積性函數?
    $f$ 是一個積性函數,同時滿足 $f(p^n) = f(p)^n$
  • $f(n), g(n)$ 是積性函數,則 $h(n) = f(n) g(n)$ 也是積性函數。
  • $f(n)$ 為積性函數,則函數 $g(n) = \sum_{d|n} f(d)$ 也是積性函數。

歐拉函數

回顧歐拉函數 (Euler’s totient function) $\phi$

  • 定義 $\phi(n)$$1 \cdots n$ 中與 $n$ 互質的個數。
  • $\phi(n)$ 是一個積性函數,但不是完全積性函數。根據中國餘數定理 或者 基本算術可以證明之。

  • 歐拉定理 $a^{\phi(n)} \equiv 1 \mod n, \text{ when } gcd(a, n) = 1$

  • 特性 1:$\sum_{d|n} \phi(d) = n$。當作把互質個數 $\phi(n)$ 相加不互質個數 $s$,由於 $d$$n$ 的因數,則與 $d$ 互質 $x$ 個數 $\phi(d)$,把那些 $x' = x \times d/n$ 就會補上那些不互質的個數 $s = |set(x')|$

  • 特性 2:$1 \cdots n$$n$ 互質的數和為 $n\phi(n)/2$。原因很簡單,若 $gcd(x, n) = 1$,則會滿足 $gcd(n-x, n) = 1$,看起來就是一個對稱總和。

莫比烏斯

根據積性函數的性質,我們得到莫比烏斯反演的基礎:

莫比烏斯反演公式 $f(n)$

$f(n) = \sum_{d|n} \mu(d) g(\frac{n}{d})$

莫比烏斯函數 $\mu(n)$

$$\mu(n) = \left\{\begin{matrix} 1 && n = 1\\ (-1)^k && n = p_1 p_2 \cdots p_k \\ 0 && \text{otherwise} \end{matrix}\right.$$
  • 特性 1:$\sum_{d|n} \mu(d) = [n = 1]$

在此簡單說一次,莫比烏斯反演公式就像排容原理,而莫比烏斯函數 $\mu(n)$ 就像一加一減,之所以在 $n = p_1 p_2 \cdots p_k$$\mu(n) = (-1)^k$,也就是說當 $n$ 只全由質數相乘 (不允許冪次方大於 1,否則為 0),相當於一般認知,奇數次要扣,偶數次要補回來的排容口訣。而特定自然不必多說,其中數學表達 $[n = 1]$ 相當於程式中的 n == 1 ? 1 : 0

簡單展示歐拉函數 $\phi(n)$

  • $f(n) = \sum_{d|n} \phi(d) = n$,由歐拉函數特性 2 得知。
    *$\phi(n) = \sum_{d|n} \mu(d) f(\frac{n}{d}) = \sum_{d|n} \frac{\mu(d) n}{d}$,套用莫比烏斯公式後整理。

關於公式$\phi(n) = \sum_{d|n} \frac{\mu(d) n}{d}$ 可以這麼理解。

  • 求出 $gcd(n, k) = 1$ 的個數,其餘要捨棄掉。
  • 計算 $gcd(n, k) = d$ 的個數,利用排容原理即莫比烏斯函數,顯而易見排容方案只當 $d$$n$ 個質因數組合而成。

應用此題

利用莫比烏斯函數 特性 1:$\sum_{d|n} \mu(d) = [n = 1]$

$$\begin{align*} & \sum_{a = 1}^{N} \sum_{b = 1}^{M} [gcd(a, b) = 1] \\ & = \sum_{a = 1}^{N} \sum_{b = 1}^{M} \sum_{d|gcd(a, b)} \mu(d) \\ & = \sum \mu(d) \left ( \sum_{1 \le a \le N \text{ and } d | a} \left ( \sum_{1 \le b \le M \text{ and } d | b} 1 \right ) \right )\\ & = \sum \mu(d) \left \lfloor \frac{n}{d} \right \rfloor \left \lfloor \frac{m}{d} \right \rfloor \end{align*}$$

第二行就是抽換莫比烏斯,第三行則是交換順序,第四行則是可以快速找到 $d|a$ 的總數為 $\lfloor n/d \rfloor$ 所導致。

若直接窮舉 $d$ 時間複雜度為 $O(min(n, m))$,可以利用塊狀的概念優化到 $O(\sqrt{n} + \sqrt{m})$。因為 $\lfloor n/d \rfloor$ 的值只有 $2\lfloor \sqrt{n} \rfloor$ 種可能,同理 $\lfloor m/d \rfloor$ 也是,那麼對於同一個 $d$ 的計數會是一個連續區間不斷地移動。

以下是一個簡單的 $\lfloor n/d \rfloor$ 示意圖,用 $\lfloor n/d \rfloor$ 不同當作劃分。

1
2
3
4
5
6
7
8
9
10
11
X X X X X X X XX X X
X X X X X X X XX X X
+----------------------------------------------------+
| | | | | | | | | n
+----------------------------------------------------+
012345678901234567890123456789012345678901234567890123
+---------------------------------------------------------------------------+
| | | | | | | | | | | | m
+---------------------------------------------------------------------------+
X X X X X X X XX X X
X X X X X X X XX X X

程式只要處理打 X 的位置即可,時間複雜度 $O(\sqrt{n} + \sqrt{m})$,可以參照一般解。代碼短,但除法次數多。

加速一般解由 liouzhou_101 提供。開一次根號 sqrt(),省下 $O(\sqrt{n})$ 次的除法,賺了 400 ms,代碼長了 800 B。加速 25% 的效能,可謂除法的可怕,根據研究差異後才發現到這一點。

一般解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
using namespace std;
#define GET(x) (mark[(x)>>5]>>((x)&31)&1)
#define SET(x) (mark[(x)>>5] |= 1<<((x)&31))
const int MAXN = 10000005;
const int MAXL = (MAXN>>5)+1;
int mark[MAXL];
int P[700000], Pt = 0;
short mu[MAXN], sum[MAXN];
void sieve_mobius() {
register int i, j, k;
SET(1), mu[1] = 1;
int n = 10000000;
for (i = 2; i <= n; i++) {
if (!GET(i))
P[Pt++] = i, mu[i] = -1;
for (j = 0; j < Pt && (k = i*P[j]) <= n; j++) {
SET(k);
if (i%P[j] == 0) {
mu[k] = 0;
break;
}
mu[k] = -mu[i];
}
}
}
long long coprime_pair(int n, int m) {
long long ret = 0;
if (n > m) swap(n, m);
for (int d = 1, r; d <= n; d = r+1) {
r = min(n / (n/d), m / (m/d));
ret += (long long)(sum[r] - sum[d-1]) * (n/d) * (m/d);
}
return ret;
}
int main() {
sieve_mobius();
for (int i = 1; i < MAXN; i++)
sum[i] = sum[i-1] + mu[i];
int testcase, N, M;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &N, &M);
printf("%lld\n", coprime_pair(N, M));
}
return 0;
}

加速運算解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
using namespace std;
#define GET(x) (mark[(x)>>5]>>((x)&31)&1)
#define SET(x) (mark[(x)>>5] |= 1<<((x)&31))
const int MAXN = 10000005;
const int SQRN = 3200 * 2;
const int MAXL = (MAXN>>5)+1;
int mark[MAXL];
int P[700000], Pt = 0;
short mu[MAXN], sum[MAXN];
void sieve_mobius() {
register int i, j, k;
SET(1), mu[1] = 1;
int n = 10000000;
for (i = 2; i <= n; i++) {
if (!GET(i))
P[Pt++] = i, mu[i] = -1;
for (j = 0; j < Pt && (k = i*P[j]) <= n; j++) {
SET(k);
if (i%P[j] == 0) {
mu[k] = 0;
break;
}
mu[k] = -mu[i];
}
}
}
long long coprime_pair(int n, int m) {
static int An[SQRN][2], Bn[SQRN][2];
if (n > m) swap(n, m);
long long ret = 0;
int aidx = 0, bidx = 0, sq;
sq = sqrt(n);
for (int i = 1; i <= sq; i++, aidx++) An[aidx][1] = n/(An[aidx][0] = n / i);
if (sq * sq == n) aidx--;
for (int i = sq; i >= 1; i--, aidx++) An[aidx][1] = n/(An[aidx][0] = i);
sq = sqrt(m);
for (int i = 1; i <= sq; i++, bidx++) Bn[bidx][1] = m/(Bn[bidx][0] = m / i);
if (sq * sq == m) bidx--;
for (int i = sq; i >= 1; i--, bidx++) Bn[bidx][1] = m/(Bn[bidx][0] = i);
for (int l = 1, r, *a = &An[0][1], *b = &Bn[0][1], *A = &An[0][0], *B = &Bn[0][0]; l <= n; l = r+1) {
if (*a < *b)
r = *a, ret += (long long) (sum[r] - sum[l-1]) * *A * *B, A+=2, a+=2;
else if (*a > *b)
r = *b, ret += (long long) (sum[r] - sum[l-1]) * *A * *B, B+=2, b+=2;
else
r = *a, ret += (long long) (sum[r] - sum[l-1]) * *A * *B, A+=2, B+=2, a+=2, b+=2;
}
return ret;
}
int main() {
sieve_mobius();
for (int i = 1; i < MAXN; i++)
sum[i] = sum[i-1] + mu[i];
int testcase, N, M;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &N, &M);
printf("%lld\n", coprime_pair(N, M));
}
return 0;
}
Read More +

b441. 延展尺寸

Problem

將 N 張小正方形拼成長條圖,並且每一次挑選一個正方形區域黏貼在長條後,黏貼的條件是比較重疊一半的區域,利用能量最少路徑進行裁剪後併在一起。

如果最少能量大於等於某個值,則再隨機挑選一個正方形區域進行黏貼,捨棄掉這次黏貼操作,若嘗試 50 次沒有成功低於閥值,則取消全部的黏貼操作。

當初的問題在於 連續 50 次 看得似懂非懂,然後還以為是要碎形圖,只找一個正方形去重疊得到 N 個。

Sample Input

1
2
3
4
2 1
2 2
1 2 3 4 5 6
7 8 9 10 11 12

Sample Output

1
2
3
2 2
1 2 3 4 5 6
7 8 9 10 11 12

Solution

接續上一題 b438. 裁剪尺寸 的作法,現在多一個串接和重疊的操作。

感謝蔡星 asas 對此題描述的解說。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <bits/stdc++.h>
using namespace std;
class IMAGE {
public:
struct Pixel {
int r, g, b;
Pixel(int x = 0, int y = 0, int z = 0):
r(x), g(y), b(z) {}
void read() {
scanf("%d %d %d", &r, &g, &b);
}
Pixel operator-(const Pixel &x) const {
return Pixel(r-x.r, g-x.g, b-x.b);
}
Pixel operator+(const Pixel &x) const {
return Pixel(r+x.r, g+x.g, b+x.b);
}
Pixel operator*(const double x) const {
return Pixel(r*x, g*x, b*x);
}
Pixel operator/(const double x) const {
return Pixel(r/x, g/x, b/x);
}
void print() {
printf("%d %d %d", r, g, b);
}
int length() {
return abs(r) + abs(g) + abs(b);
}
double dist(Pixel x) {
return sqrt((r-x.r)*(r-x.r)+(g-x.g)*(g-x.g)+(b-x.b)*(b-x.b));
}
};
int W, H;
static const int MAXN = 256;
Pixel data[MAXN][MAXN*3], tmp[MAXN][MAXN*3];
int energy[MAXN][MAXN], dp[MAXN][MAXN];
long long seed;
int random() {
return seed = ( seed * 9301 + 49297 ) % 233280;
}
void getSquarePosition(int &x, int &y, int L) {
y = (W <= L) ? 0 : random() % (W - L);
x = (H <= L) ? 0 : random() % (H - L);
}
void read() {
scanf("%d %d", &W, &H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].read();
seed = 0;
}
int isValid(int x, int y) {
return x >= 0 && y >= 0 && x < H && y < W;
}
void pattern(int L, int N) {
int ERROR_TRY = 50;
int threshold = round(L*255/8.0);
int overlap = round(L/2.0);
int lx, ly, tW = 0, path[MAXN] = {};
for (int n = 0, it; n < N; n++) {
for (it = 0; it < ERROR_TRY; it++) {
getSquarePosition(lx, ly, L);
if (n == 0) {
for (int i = 0; i < L; i++)
for (int j = 0; j < L; j++)
tmp[i][j] = data[lx+i][ly+j];
tW = L;
break;
}
for (int i = 0; i < L; i++) {
for (int j = 0; j < overlap; j++) {
energy[i][j] = round(data[lx+i][ly+j].dist(tmp[i][tW-overlap+j]));
}
}
int cost = shrink(path, L, overlap);
if (cost < threshold) {
for (int i = 0; i < L; i++) {
for (int j = L-1; j >= path[i]; j--)
tmp[i][tW-overlap+j] = data[lx+i][ly+j];
}
tW += L - overlap;
break;
}
}
if (it == ERROR_TRY)
return ;
}
W = tW, H = L;
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j] = tmp[i][j];
}
int shrink(int path[], int H, int W) {
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
int &val = dp[i][j];
if (i == 0) val = energy[i][j];
else {
val = dp[i-1][j];
if (j-1 >= 0)
val = min(val, dp[i-1][j-1]);
if (j+1 < W)
val = min(val, dp[i-1][j+1]);
val += energy[i][j];
}
}
}
int st = 0, cost;
for (int i = 0; i < W; i++)
if (dp[H-1][i] < dp[H-1][st])
st = i;
cost = dp[H-1][st];
for (int i = H-1; i >= 0; i--) {
path[i] = st;
if (i == 0) continue;
int val = dp[i][st] - energy[i][st];
if (st-1 >= 0 && val == dp[i-1][st-1])
st = st-1;
else if (val == dp[i-1][st])
st = st;
else
st = st+1;
}
return cost;
}
void print() {
printf("%d %d\n", W, H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].print(), printf("%c", j == W-1 ? '\n' : ' ');
}
} test;
int main() {
int n, L, N;
scanf("%d %d", &L, &N);
test.read();
test.pattern(L, N);
test.print();
return 0;
}
Read More +

b438. 裁剪尺寸

Problem

要讓影像的寬度減少,其一的方案就是每一列刪除一個像素,為了讓刪除更加地完善、盡可能看不出來突兀的地方,採用一個一條由上而下路徑進行刪除。

刪除路徑採用最少費用,這個費用採用 sobel operator,也就是說費用越高表示可能是邊緣像素,減少突兀就是盡可能不要去刪除到邊緣像素,這是顯而易見的方案。接著就是採用貪心方案,依序刪除 n 次路徑。

Sample Input

1
2
3
0
1 1
255 255 255

Sample Output

1
2
1 1
255 255 255

Solution

題目說明最好的選擇方案是一次 n 條不相交的由上而下的路徑,但這個定義是總花費,還是最小化最大花費路徑這是有疑惑的,若單純總花費可以採用最小費用流去完成,但複雜度對這個影像處理會可怕,點數跟邊數都是破萬,複雜度 $O(VE)$ 的算法要跑非常久,因此貪心是個好選擇。

找尋花費最小的路徑是採用 dynamic programming 的方式得到,並且要回溯得到左側最小的一條路徑。每刪除一條路徑後,sobel operator 要重新計算每一個像素的異動,刷新這一塊可以指更動路徑的右側和鄰近路徑的像素即可,這可以大幅度地增加速度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <bits/stdc++.h>
using namespace std;
class IMAGE {
public:
struct Pixel {
int r, g, b;
Pixel(int x = 0, int y = 0, int z = 0):
r(x), g(y), b(z) {}
void read() {
scanf("%d %d %d", &r, &g, &b);
}
Pixel operator-(const Pixel &x) const {
return Pixel(r-x.r, g-x.g, b-x.b);
}
Pixel operator+(const Pixel &x) const {
return Pixel(r+x.r, g+x.g, b+x.b);
}
Pixel operator*(const int x) const {
return Pixel(r*x, g*x, b*x);
}
Pixel operator/(const int x) const {
return Pixel(r/x, g/x, b/x);
}
void print() {
printf("%d %d %d", r, g, b);
}
int length() {
return abs(r) + abs(g) + abs(b);
}
};
int W, H;
static const int MAXN = 256;
Pixel data[MAXN][MAXN];
int energy[MAXN][MAXN], dp[MAXN][MAXN];
void read() {
scanf("%d %d", &W, &H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].read();
}
Pixel pabs(Pixel x) {
return Pixel(fabs(x.r), fabs(x.g), fabs(x.b));
}
inline Pixel getPixel(int x, int y) {
if (x >= 0 && y >= 0 && x < H && y < W)
return data[x][y];
if (y < 0) return data[min(max(x, 0), H-1)][0];
if (y >= W) return data[min(max(x, 0), H-1)][W-1];
if (x < 0) return data[0][min(max(y, 0), W-1)];
if (x >= H) return data[H-1][min(max(y, 0), W-1)];
return Pixel(0, 0, 0);
}
int sobel(int i, int j) {
const static int dx[] = {-1, -1, -1, 0, 0, 0, 1, 1, 1};
const static int dy[] = {-1, 0, 1, -1, 0, 1, -1, 0, 1};
const static int xw[] = {-1, 0, 1, -2, 0, 2, -1, 0, 1};
const static int yw[] = {-1, -2, -1, 0, 0, 0, 1, 2, 1};
Pixel Dx(0, 0, 0), Dy(0, 0, 0);
for (int k = 0; k < 9; k++) {
if (xw[k])
Dx = Dx + getPixel(i+dx[k], j+dy[k]) * xw[k];
if (yw[k])
Dy = Dy + getPixel(i+dx[k], j+dy[k]) * yw[k];
}
return Dx.length() + Dy.length();
}
void shrink(int n) {
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
energy[i][j] = sobel(i, j);
int path[MAXN];
for (int it = 0; it < n; it++) {
shrink(path);
for (int i = 0; i < H; i++) {
int y = path[i];
for (int j = y - 3; j < W; j++) {
if (j >= 0)
energy[i][j] = sobel(i, j);
}
}
}
}
void shrink(int path[]) {
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
int &val = dp[i][j];
if (i == 0) val = energy[i][j];
else {
val = dp[i-1][j];
if (j-1 >= 0)
val = min(val, dp[i-1][j-1]);
if (j+1 < W)
val = min(val, dp[i-1][j+1]);
val += energy[i][j];
}
}
}
int st = 0;
for (int i = 0; i < W; i++)
if (dp[H-1][i] < dp[H-1][st])
st = i;
for (int i = H-1; i >= 0; i--) {
path[i] = st;
int v = dp[i][st] - energy[i][st];
for (Pixel *p = &data[i][st], *q = &data[i][st+1], *end = &data[i][W]; q != end; p++, q++)
*p = *q;
if (i == 0) continue;
if (st-1 >= 0 && v == dp[i-1][st-1])
st = st-1;
else if (v == dp[i-1][st])
st = st;
else
st = st+1;
}
W--;
}
void print() {
printf("%d %d\n", W, H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].print(), printf("%c", j == W-1 ? '\n' : ' ');
}
} test;
int main() {
int n;
scanf("%d", &n);
test.read();
test.shrink(n);
test.print();
return 0;
}
Read More +

b442. 快取實驗 矩陣乘法

Problem

背景

記得 b439: 快取置換機制 提到的快取置換機制嗎?現在來一場實驗吧!

題目描述

相信不少人都已經實作所謂的矩陣乘法,計算兩個方陣大小為 $N \times N$ 的矩陣 $A, B$。為了方便起見,提供一個偽隨機數的生成,減少在輸入處理浪費的時間,同時也減少上傳測資的辛苦。

根據種子 $c = S1$ 生成矩陣 $A$,種子 $c = S2$ 生成矩陣 $B$,計算矩陣相乘 $A \times B$,為了方便起見,使用 hash 函數進行簽章,最後輸出一個值。由於會牽涉到 overflow 問題,此題作為快取實驗就不考慮這個,overflow 問題大家都會相同。

請利用快取優勢修改代碼如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <bits/stdc++.h>
using namespace std;
class Matrix {
public:
vector< vector<int> > data;
int row, col;
Matrix(int n = 1, int m = 1) {
data = vector< vector<int> >(n, vector<int>(m, 0));
row = n, col = m;
}
void rand_gen(int c) {
int x = 2, n = row * col;
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
x = ((long long) x * x + c + i + j)%n;
data[i][j] = x;
}
}
}
Matrix multiply(Matrix x) {
Matrix ret(row, x.col);
for (int i = 0; i < row; i++) {
for (int j = 0; j < x.col; j++) {
int sum = 0; // overflow
for (int k = 0; k < col; k++)
sum += data[i][k] * x.data[k][j];
ret.data[i][j] = sum;
}
}
return ret;
}
void print() {
for (int i = 0; i < row; i++) {
printf("[");
for (int j = 0; j < col; j++) {
printf(" %d", data[i][j]);
}
printf(" ]\n");
}
}
unsigned long signature() {
unsigned long h = 0;
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
h = hash(h + data[i][j]);
}
}
return h;
}
private:
unsigned long hash(unsigned long x) {
return (x * 2654435761LU);
}
};
int main() {
int N, S1, S2;
while (scanf("%d %d %d", &N, &S1, &S2) == 3) {
Matrix A(N, N), B(N, N), C;
A.rand_gen(S1);
B.rand_gen(S2);
C = A.multiply(B);
// A.print();
// B.print();
// C.print();
printf("%lu\n", C.signature());
}
return 0;
}

Sample Input

1
2
2 2 5
2 2 5

Sample Output

1
2
573770929
573770929

Solution

在越大的矩陣中,快取記憶體置換是很慢的,即便用 L1, L2, main memory 分層快取,速度差異很明顯,盡可能在線性 $O(n)$ 運算上都不造成記憶體置換 (搬移),雖然複雜度都是 $O(n^3)$,相信沒人會去實作 $O(n^{2.807})$ 的 Strassen 算法,又或者是現今更快的 Coppersmith-Winograd $O(n^{2.3727})$

作業系統題目第二彈,考驗快取應用。在 ZJ 主機上速度可以快二十倍,在本機上只快了十倍。

  1. 蔡神 asas 修改輸入,而我選擇了轉置,速度落後。
  2. 柳州 liouzhou_101 同步簽章,而我選擇分開函數,速度落後。
  3. 廖氏如神 pcsh710742 下刀計組,等等,我看見了什麼。

現在已經快了四十倍,就只是因為編譯器的優化參數變成手動。詳細實作和效能分析,需要的知識不只是作業系統,同時涵蓋計算機組織的 CPU stall 問題,參考 Optimizing Large Matrix-Vector Multiplications

先來個一般解,在計算矩陣 $A \times B$ 前,先將 $B$ 轉置,接著修改計算方式

1
2
for (int k = 0; k < col; k++)
sum += data[i][k] * x.data[j][k];

在這一個迴圈中,陣列採用 row-major 的方式儲存,所以第二維度的區域基本上都在快取中,miss 的機會大幅度降低。

接下來要提案的做法是結合上述三位的做法,代碼快,但不能當 Matrix 模板使用,僅僅做為一個效能體現。共用記憶體,採用指針的方式減少複製,一樣進行轉置,但這會破壞原有的 $B$ 矩陣配置。同步進行簽章,捨棄掉 $C$ 矩陣的儲存,使用 LOOP UNROLLING,由於 Zerojudge 的優化沒有開到 -O2-O3-Ofast,關於這一部分的優化由使用者自己來。

LOOP UNROLLING 的優化在於 branch 指令,由於 pipeline 會讓效能提升,但遇到 branch 時必須捨棄掉偷偷載入的指令、算到一半的結果,效能會下降,因此使用 LOOP UNROLLING 的好處在於減少 branch 次數。

關於 data prefetch 可以用 __builtin_prefetch() 來完成,根據廖氏如神所言,這個概念可以預先載入防止 stall (pipeline hazard) 的拖延,速度並不會提升太多,有可能是硬體已經完成這一部分,甚至用別的架構去克服。

轉置解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <bits/stdc++.h>
using namespace std;
class Matrix {
public:
vector< vector<int> > data;
int row, col;
Matrix(int n = 1, int m = 1) {
data = vector< vector<int> >(n, vector<int>(m, 0));
row = n, col = m;
}
void rand_gen(int c) {
int x = 2, n = row * col;
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
x = ((long long) x * x + c + i + j)%n;
data[i][j] = x;
}
}
}
Matrix multiply(Matrix &x) {
Matrix ret(row, x.col);
x.transpose();
for (int i = 0; i < row; i++) {
for (int j = 0; j < x.col; j++) {
int sum = 0; // overflow
int *a = &data[i][0], *b = &x.data[j][0];
for (int k = 0; k < col; k++)
sum += *a * *b, a++, b++;
ret.data[i][j] = sum;
}
}
x.transpose();
return ret;
}
void transpose() {
for (int i = 0; i < row; i++)
for (int j = i+1; j < col; j++)
swap(data[i][j], data[j][i]);
}
void print() {
for (int i = 0; i < row; i++) {
printf("[");
for (int j = 0; j < col; j++)
printf(" %d", data[i][j]);
printf(" ]\n");
}
}
unsigned long signature() {
unsigned long h = 0;
for (int i = 0; i < row; i++)
for (int j = 0; j < col; j++)
h = hash(h + data[i][j]);
return h;
}
private:
inline unsigned long hash(unsigned long x) {
return (x * 2654435761LU);
}
};
int main() {
int N, S1, S2;
while (scanf("%d %d %d", &N, &S1, &S2) == 3) {
Matrix A(N, N), B(N, N), C;
A.rand_gen(S1);
B.rand_gen(S2);
C = A.multiply(B);
printf("%lu\n", C.signature());
}
return 0;
}

LOOP UNROLL 解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <bits/stdc++.h>
using namespace std;
#define LOOP_UNROLL 8
class Matrix {
public:
vector< vector<int> > data;
int row, col;
Matrix(int n = 1, int m = 1) {
row = n, col = m;
data = vector< vector<int> >(n, vector<int>(m, 0));
}
void rand_gen(int c) {
int x = 2, n = row * col;
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
x = ((long long) x * x + c + i + j)%n;
data[i][j] = x;
}
}
}
unsigned long multiply(Matrix &x) {
x.transpose();
unsigned long h = 0;
for (int i = 0; i < row; i++) {
for (int j = 0; j < x.col; j++) {
register int sum = 0;
int *a = &data[i][0], *b = &x.data[j][0], k;
for (k = 0; k+LOOP_UNROLL < col; k += LOOP_UNROLL) {
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
sum += *a * *b, a++, b++;
}
for (; k < col; k++)
sum += *a * *b, a++, b++;
h = hash(h + sum);
}
}
return h;
}
void transpose() {
for (int i = 0; i < row; i++)
for (int j = i+1; j < col; j++)
swap(data[i][j], data[j][i]);
}
void print() {
for (int i = 0; i < row; i++) {
printf("[");
for (int j = 0; j < col; j++)
printf(" %d", data[i][j]);
printf(" ]\n");
}
}
private:
unsigned long hash(unsigned long x) {
return (x * 2654435761LU);
}
} A(1000, 1000), B(1000, 1000);
int main() {
int N, S1, S2;
while (scanf("%d %d %d", &N, &S1, &S2) == 3) {
A.row = A.col = B.row = B.col = N;
A.rand_gen(S1);
B.rand_gen(S2);
printf("%lu\n", A.multiply(B));
}
return 0;
}
Read More +

b439. 快取置換機制

Problem

背景

編寫程式不僅僅是在數學分析上達到高效率,儘管數學分析的複雜度不是最好,理解電腦運作模式也能有效地讓程式變快。例如 資料結構 Unrolled linked list (常翻譯成 塊狀鏈表) 便是利用此一優勢,讓速度顯著地提升,之所以能追上不少常數大的平衡樹操作運用的技巧就是快取效能改善。

講一個更簡單的例子,宣告整數陣列的兩個方案:

方案一

1
const int LARGE_SIZE = 1<<30; int A[LARGE_SIZE], B[LARGE_SIZE];

方案二

1
const int LARGE_SIZE = 1<<30; struct DATA { int A, B; } E[LARGE_SIZE];

演算法的複雜度倘若一樣,若在 A, B 相當高頻率的替換,則快取操作必須不斷地將內容置換,若 A, B 在算法中是獨立運算,則方案一的寫法會來得更好,反之取用方案二會比較好。最常見的運作可以在軟體模擬矩陣乘法中見到,預先將矩陣轉置,利用快取優勢速度可以快上 8 倍以上。

問題描述

給予一個記憶體空間大小為 $M$,使用 Least Recently Used (LRU) 策略進行置換,LRU 的策略為將最久沒有被使用過的空間替換掉,也就是需要從硬碟讀取 $disk[i]$$memory$ 時,發現記憶體都已經用光,則把不常用的 $mem[j]$ 寫入 $disk[k]$,再將 $disk[i]$ 內容寫入 $mem[j]$

下圖是一個 $M = 4$ 簡單的範例:

1
2
3
4
5
6
7
8
9
10
+---+ +---+ +---+ +---+ +---+ +---+ +---+ +---+
mem[0] | 1 | | 1 | | 1 | | 1 | | 1 | | 1 | | 1 | | 1 |
+---+ +---+ +---+ +---+ +---+ +---+ +---+ +---+
mem[1] | | | 2 | | 2 | | 2 | | 2 | | 2 | | 2 | | 2 |
+---+ +-> +---+ +-> +---+ +-> +---+ +-> +---+ +-> +---+ +-> +---+ +-> +---+
mem[2] | | | | | 3 | | 3 | | 3 | | 3 | | 5 | | 5 |
+---+ +---+ +---+ +---+ +---+ +---+ +---+ +---+
mem[3] | | | | | | | 4 | | 4 | | 4 | | 4 | | 3 |
+---+ +---+ +---+ +---+ +---+ +---+ +---+ +---+
1 2 3 4 1 2 5 3

依序使用 1, 2, 3, 4, 1, 2, 5, 3 的配置情況,特別是在 5 使用的時候,會將記憶體中上一次最晚使用的 3 替換掉。

現在寫程式去模擬置換情況。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
4
0 1
1 2 514
1 3 101
1 4 50216
0 1
0 2
1 5 6
0 3
0 5
0 2

Sample Output

1
2
3
4
5
6
0 0
0 0
1 514
3 0
2 6
1 514

Solution

軟體仿作,使用一個 hash 和一個 list 進行維護,而不是使用 priority queue 來維護,用 list 取代之。當進行取址時,順道把對應 list 指針移動,所有步驟期望複雜度都是在 $O(1)$ 完成,若使用 priority queue 會掉到 $O(\log n)$ 去進行更新。

換出作業系統題目,來一個最常見到的快取 LRU 算法。這個問題在 Leetcode OJ 上面也有,公司面試有機會遇到。

但實作時被自己坑,比較柳柳州給予的測試,速度居然慢個二十多倍。原因是這樣子的

1
2
map[key] = value;
return map.find(key);

替換成以下寫法,速度就起飛了。

1
return map.insert({key, value}).first;

也就是說插入的時候順便把指針拿到,避免第二次搜索。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, list<int>::iterator> PIT_TYPE;
typedef unordered_map<int, pair<int, list<int>::iterator> >::iterator HIT_TYPE;
class LRUCache {
public:
int memIdx, size;
list<int> time;
vector<int> mem;
unordered_map<int, PIT_TYPE> addr;
LRUCache(int capacity) {
size = capacity;
mem.resize(capacity, 0);
addr.clear(), time.clear();
memIdx = 0;
}
pair<int, int> get(int key) {
it = addr.find(key);
if (it == addr.end())
it = replace(key), mem[it->second.first] = 0;
it->second.second = recent(it->second.second);
return make_pair(it->second.first, mem[it->second.first]);
}
HIT_TYPE set(int key, int value) {
it = addr.find(key);
if (it == addr.end())
it = replace(key);
it->second.second = recent(it->second.second);
mem[it->second.first] = value;
return it;
}
private:
HIT_TYPE it;
HIT_TYPE replace(int key) {
int mpos = -1, trash;
list<int>::iterator lpos;
if (addr.size() == size) {
trash = time.front(), time.pop_front();
it = addr.find(trash);
mpos = it->second.first;
addr.erase(it);
} else {
mpos = memIdx++;
}
lpos = time.insert(time.end(), key);
return addr.insert({key, make_pair(mpos, lpos)}).first;
}
list<int>::iterator recent(list<int>::iterator p) {
int key = *p;
time.erase(p);
return time.insert(time.end(), key);
}
};
int main() {
int M, cmd, x, y;
pair<int, int> t;
scanf("%d", &M);
LRUCache LL(M);
while (scanf("%d", &cmd) == 1) {
if (cmd == 0) {
scanf("%d", &x);
pair<int, int> t = LL.get(x);
printf("%d %d\n", t.first, t.second);
} else {
scanf("%d %d", &x, &y);
LL.set(x, y);
}
}
return 0;
}
Read More +

b435. 尋找原根

Problem

給定一個模數 $m$,找到其 primitive root $g$

primitive root $g$ 滿足

$$g^{\phi(m)} \equiv 1 \mod m \\ g^{\gamma} \not\equiv 1 \mod m \text{ , for } 1 \le \gamma < \phi(m)$$

意即在模 $m$ 下,循環節 $ord_m(g) = \phi(m)$,若 $m$ 是質數,則 $\phi(m) = m-1$,也就是說循環節長度是 $m-1$,更可怕的是組成一個 $[1, m-1]$ 的數字排列。

現在假定 $m$ 是質數,請求出一個最小的 primitive root $g$

Sample Input

1
2
3
4
2
3
5
7

Sample Output

1
2
3
4
1
2
2
3

Solution

primitive root 在密碼學中,用在 Diffie-Hellman 的選定中的穩定度,倘若基底是一個 primitive root 將會提升破解的難度,因為對應情況大幅增加 (值域比較廣,因為不循環)。

由於歐拉定理,可以得知 $a^{\phi(p)} \equiv 1 \mod p$,那麼可以堆導得到若 $a^x \equiv 1 \mod p$,則滿足 $x | \phi(p)$,否則與歐拉矛盾,藉此可以作為一個篩選的檢查手法,對於所有 $phi(p)$ 的因數都進行檢查,最後我們可以得到 一般檢測 的作法。

為了加速驗證,由於 $x | \phi(p)$,若最小的 $x$ 存在 $a^x \equiv 1 \mod p$,那麼 $kx$ 也會符合等式,因此只需要檢查 $x = (p-1)/\text{prime-factor}$ 的情況,如此一來速度就快上非常多。最後得到 加速版本 ,此做法由 liouzhou_101 提供想法。

一般檢測

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <bits/stdc++.h>
using namespace std;
#define MAXL (50000>>5)+1
#define GET(x) (mark[(x)>>5]>>((x)&31)&1)
#define SET(x) (mark[(x)>>5] |= 1<<((x)&31))
int mark[MAXL];
int P[5500], Pt = 0;
void sieve() {
register int i, j, k;
SET(1);
int n = 45825;
for(i = 2; i <= n; i++) {
if(!GET(i)) {
for(k = n/i, j = i*k; k >= i; k--, j -= i)
SET(j);
P[Pt++] = i;
}
}
}
vector< pair<int, int> > factor(int n) {
vector< pair<int, int> > R;
for(int i = 0, j; i < Pt && P[i] * P[i] <= n; i++) {
if(n%P[i] == 0) {
for(j = 0; n%P[i] == 0; n /= P[i], j++);
R.push_back(make_pair(P[i], j));
}
}
if(n != 1) R.push_back(make_pair(n, 1));
return R;
}
void gen_factor(int idx, int x, vector< pair<int, int> > &f, vector<int> &g) {
if (idx == f.size()) {
g.push_back(x);
return ;
}
for (int i = 0, a = 1; i <= f[idx].second; i++, a *= f[idx].first)
gen_factor(idx+1, x*a, f, g);
}
long long mpow(long long x, long long y, long long mod) {
long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
int primitive_root(int p) {
if (p == 2) return 1;
vector< pair<int, int> > f = factor(p-1);
vector<int> g;
gen_factor(0, 1, f, g);
g.erase(g.begin()), g.pop_back(); // remove 1, p-1
random_shuffle(g.begin(), g.end());
for (int i = 2; i <= p; i++) {
int ok = 1;
for (auto &x: g) {
if (mpow(i, x, p) == 1) {
ok = 0;
break;
}
}
if (ok) return i;
}
return -1;
}
int main() {
sieve();
int p;
while (scanf("%d", &p) == 1)
printf("%d\n", primitive_root(p));
return 0;
}

加速版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
using namespace std;
#define MAXL (50000>>5)+1
#define GET(x) (mark[(x)>>5]>>((x)&31)&1)
#define SET(x) (mark[(x)>>5] |= 1<<((x)&31))
int mark[MAXL];
int P[5500], Pt = 0;
void sieve() {
register int i, j, k;
SET(1);
int n = 45825;
for(i = 2; i <= n; i++) {
if(!GET(i)) {
for(k = n/i, j = i*k; k >= i; k--, j -= i)
SET(j);
P[Pt++] = i;
}
}
}
vector< pair<int, int> > factor(int n) {
vector< pair<int, int> > R;
for(int i = 0, j; i < Pt && P[i] * P[i] <= n; i++) {
if(n%P[i] == 0) {
for(j = 0; n%P[i] == 0; n /= P[i], j++);
R.push_back(make_pair(P[i], j));
}
}
if(n != 1) R.push_back(make_pair(n, 1));
return R;
}
long long mpow(long long x, long long y, long long mod) {
long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
int primitive_root(int p) {
if (p == 2) return 1;
vector< pair<int, int> > f = factor(p-1);
for (int i = 2; i <= p; i++) {
int ok = 1;
for (auto &x: f) {
if (mpow(i, (p-1)/x.first, p) == 1) {
ok = 0;
break;
}
}
if (ok) return i;
}
return -1;
}
int main() {
sieve();
int p;
while (scanf("%d", &p) == 1)
printf("%d\n", primitive_root(p));
return 0;
}
Read More +

b432. 趣味加分

Problem

背景

廖氏如神 (pcsh710742) 在作業中遇到一題,用手算會算到天昏地暗,而問題是這樣子的:

$a^{13337} \equiv n \mod 2^{64}$

其中 $n= 2015 + 2 \times (\text{The last 2 digit of your student ID number})$,請找到 $a < 2^{64}$ 的其中一組解。

問題描述

模數 $2^{64}$ 看起來不大,但對於 Ghz 為 CPU 運算速度單位的電腦而言還是要跑非常久的。因此將問題簡化:

$a^{23333} \equiv n \mod 2^{20}$

現在給予一個 $n$,請求出一組 $a$,測資中保證答案唯一。

Sample Input

1
2
3
4
5
268275
888817
89215
63495
976477

Sample Output

1
2
3
4
5
387
817
639
487
909

Solution

除了偶數無解外,奇數都至少有一個解。而這一題的題目數據恰好奇數都只有一解,那麼就不必處理多組解或者是字典順序最小的,只要專心找到符合的解即可。

  • 暴力建表 $O(2^{20})$ 建完,之後直接查找。
  • 篩選正解,依序窮舉從最低位到最高位 (二進制下) 為 0 還是 1,由於次方會不斷地推移,高位結果不影響低位的對應。窮舉時保留低位符合的解,並且不斷地篩選掉不可能的解方案,複雜度 $O(20 k)$$k$ 是難以估計的數字。
  • 快速假解,類似篩選正解的做法,但只保留其中一組解進行,複雜度 $O(20)$。這個解法是有點毛病的,但目前找不到反例。

快速假解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
using namespace std;
long long mul(long long a, long long b, long long mod) {
long long ret = 0;
for (a %= mod, b %= mod; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod) ret -= mod;
}
}
return ret;
}
unsigned long long mpow(unsigned long long x, unsigned long long y, unsigned long long mod) {
unsigned long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
// find a^23333 = n \mod 2^32
int main() {
const long long M = 1LL<<20;
const long long E = 23333;
long long n;
while (scanf("%lld", &n) == 1) {
long long a = 0;
for (int i = 0; i < 32; i++) {
long long t = mpow(a, E, M), mask = (1LL<<(i+1))-1;
if ((t&mask) == (n&mask)) {
} else {
a |= 1LL<<i;
}
}
printf("%lld\n", a);
}
return 0;
}

篩選正解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <bits/stdc++.h>
using namespace std;
unsigned long long mpow(unsigned long long x, unsigned long long y, unsigned long long mod) {
unsigned long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
// find a^23333 = n \mod 2^32
int main() {
const long long M = 1LL<<20;
const long long E = 23333;
long long n;
while (scanf("%lld", &n) == 1) {
vector<long long> A(1, 0);
for (int i = 0; i < 20; i++) {
vector<long long> nA;
for (auto &a : A) {
long long t, mask;
t = mpow(a, E, M), mask = (1LL<<(i+1))-1;
if ((t&mask) == (n&mask)) {
nA.push_back(a);
}
t = mpow(a|(1LL<<i), E, M), mask = (1LL<<(i+1))-1;
if ((t&mask) == (n&mask)) {
nA.push_back(a|(1LL<<i));
}
}
A = nA;
}
assert(A.size() > 0 && mpow(A[0], E, M) == n);
printf("%lld\n", A[0]);
}
return 0;
}
Read More +

b431. 中國餘數

Problem

背景

中國餘數定理 (Chinese Remainder Theorem,簡稱 CRT) 經常是工程學裡面常用的一種轉換域,很多人不知道當初在大學離散數學中學這個做什麼,但是在不少計算的設計都會運用到 CRT。由於電腦 CPU 架構中的運算單位是 32-bits 或者 64-bits (也許在未來會更長),但值域高達 128-bits 或者 512-bits 以上模擬運算成了麻煩之處。

回顧中國餘數定理 CRT

$$(S): \left\{\begin{matrix} x \equiv a_1 \mod m_1 \\ x \equiv a_2 \mod m_2 \\ \vdots \\ x \equiv a_n \mod m_n \end{matrix}\right.$$
  • $m_1, m_2, \cdots , m_n$ 任兩數互質,意即 $\forall i \neq j, gcd(m_i, m_j) = 1$
  • 對於任意整數 $a_1, a_2, \cdots , a_n$ 方程組 $(S)$ 均有解,意即找得到一個 $x$ 的參數解。

構造法解 CRT

  1. $M = m_1 \times m_2 \times \cdots \times m_n = \prod_{i=1}^{n} m_i$
  2. $M_i = M / m_i$
  3. $t_i = M_i^{-1} \mod m_i<span>$,意即$</span><!-- Has MathJax -->t_i M_i \equiv 1 \mod m_i$
  4. 方程組 $(S)$ 的通解形式為: $x = a_1 t_1 M_1 + a_2 t_2 M_2 + \cdots + a_n t_n M_n + kM = kM + \sum_{i = 1}^{n} a_i t_i M_i, k \in \mathbb{Z}$
  5. 若限定 $0 \le x < M$,則 $x$ 只有一解。

很多人會納悶通解為什麼長那樣,原因很簡單,要滿足方程組每一條式子,勢必對於$a_i t_i M_i$ 要滿足$x \equiv a_i \mod m_i$ ,因此 $a_i t_i M_i \equiv a_i (t_i M_i) \mod m_i \equiv a_i \mod m_i$ 成立,但是$\forall i \neq j$,滿足$a_i t_i M_i \equiv a_i (t_i M_i) \mod m_j \equiv 0 \mod m_j$

問題描述

來個簡單運用,來計算簡單的 RSA 加解密,特化其中的數學運算。

$M \equiv C^d \mod n$ $n = p \times q$,其中 $p, q$ 是兩個不同的質數,已知 $C, d, p, q$,請求出 $M$

## Sample Input ##
1
2
88 7 17 11
11 23 17 11

Sample Output

1
2
11
88

Solution

RSA 可以預先處理

  • $c_p = q \times (q^{-1} \mod p)$
  • $c_q = p \times (p^{-1} \mod q)$

還原的算法則是 $M = Mp \times c_p + Mq \times c_q \mod N$

由於拆分後的 bit length 少一半,乘法速度快 4 倍,快速冪次方快 2 倍 (次方的 bit length 少一半),但是要算 2 次,最後共計快 4 倍。CPU 的乘法想必不會用快速傅立葉 FFT 來達到乘法速度為 $O(n \log n)$

特別小心「 bit length 少一半 」必須在 $gcd(C, p) = 1$ 時才成立,互質機率機率非常高,仍然有不成立的時候,這情況下速度不是加快 400%。請參照一般解。

例如 C = 27522, d = 17132, p = 2, q = 17293,若使用歐拉定理計算 $Mp = C^{d \mod (p-1)} \mod p = 1$ 事實上 $Mp = C^d \mod p = 0$。有一個特性尚未被利用,對於模質數 $p$ 而言,所有數 $x$ 在模 $p$ 下的循環長度 $L | (p-1)$,最後可以套用 $Mp = C^{d \mod (p-1) + (p-1)} \mod p$ 來完成。請參照循環解,如此一來就不必先判斷 $gcd(C, p) = 1$

番外

但是利用 CRT 計算容易受到硬體上攻擊,因為會造成 $p, q$ 在分解過程中獨立出現,當初利用 $N$ 很難被分解的特性來達到資訊安全,但是卻因為加速把 $p, q$ 存放道不同時刻的暫存器中。

其中一種攻擊,計算得到 $M = Mp \times q \times (q^{-1} \mod p) + Mq \times p \times (p^{-1} \mod q) \mod N$ 當擾亂後面的式子 (提供不穩定的計算結果)。得到 $M' = Mp \times q \times (q^{-1} \mod p) + (Mq + \Delta) \times p \times (p^{-1} \mod q) \mod N$

接著 $(M' - M) = (\Delta' \times p) \mod N$,若要求 $p$ 的方法為 $gcd(M' - M, N) = gcd(\Delta' \times p, N) = p$,輾轉相除的概念呼之欲出,原來 $p$ 會被這麼夾出,當得到兩個 $p, q$,RSA 算法就會被破解。

一般解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <bits/stdc++.h>
using namespace std;
long long mul(long long a, long long b, long long mod) {
long long ret = 0;
for (a %= mod, b %= mod; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod) ret -= mod;
}
}
return ret;
}
void exgcd(long long x, long long y, long long &g, long long &a, long long &b) {
if (y == 0)
g = x, a = 1, b = 0;
else
exgcd(y, x%y, g, b, a), b -= (x/y) * a;
}
long long llgcd(long long x, long long y) {
long long t;
while (x%y)
t = x, x = y, y = t%y;
return y;
}
long long inverse(long long x, long long p) {
long long g, b, r;
exgcd(x, p, g, r, b);
if (g < 0) r = -r;
return (r%p + p)%p;
}
long long mpow(long long x, long long y, long long mod) {
long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
int main() {
long long C, d, p, q;
long long N, M, Cp, Cq, Mp, Mq;
while (scanf("%lld %lld %lld %lld", &C, &d, &p, &q) == 4) {
N = p * q;
Mp = mpow(C%p, llgcd(C, p) == 1 ? d%(p-1) : d, p);
Mq = mpow(C%q, llgcd(C, q) == 1 ? d%(q-1) : d, q);
Cp = q*inverse(q, p)%N;
Cq = p*inverse(p, q)%N;
M = (mul(Mp, Cp, N) + mul(Mq, Cq, N))%N;
printf("%lld\n", M);
}
return 0;
}

循環解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
void exgcd(long long x, long long y, long long &g, long long &a, long long &b) {
if (y == 0)
g = x, a = 1, b = 0;
else
exgcd(y, x%y, g, b, a), b -= (x/y) * a;
}
long long inverse(long long x, long long p) {
long long g, b, r;
exgcd(x, p, g, r, b);
if (g < 0) r = -r;
return (r%p + p)%p;
}
long long mpow(long long x, long long y, long long mod) {
long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
int main() {
long long C, d, p, q;
long long N, M, Cp, Cq, Mp, Mq;
while (scanf("%lld %lld %lld %lld", &C, &d, &p, &q) == 4) {
N = p * q;
Mp = mpow(C%p, d%(p-1) + (p-1), p);
Mq = mpow(C%q, d%(q-1) + (q-1), q);
Cp = mul(q, inverse(q, p), N);
Cq = mul(p, inverse(p, q), N);
M = (mul(Mp, Cp, N) + mul(Mq, Cq, N))%N;
printf("%lld\n", M);
}
return 0;
}
Read More +

b430. 簡單乘法

Problem

背景

在早期密碼世界中, 各種運算都先講求速度,不管是在硬體、軟體利用各種數學定義來設計加密算法就為了加快幾倍的速度,但在近代加密中,加速方法會造成硬體實作攻擊,速度和安全,你選擇哪一個呢。

題目描述

$$a b \equiv x \mod n$$

已知 $a, b, n$,求出 $x$

Sample Input

1
2
3
4
3 5 7
2 4 3
2 0 2
5 1 4

Sample Output

1
2
3
4
1
2
0
1

Solution

利用加法代替乘法避免溢位,加法取模換成減法加快速度。

由於這一題範圍在在 $10^{18}$,用 long long 型態沒有問題,若在 $2^{63}$ 請替換成 unsigned long long

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include <bits/stdc++.h>
using namespace std;
long long mul(long long a, long long b, long long mod) {
long long ret = 0;
for (a %= mod, b %= mod; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
int main() {
long long a, b, n;
while (scanf("%lld %lld %lld", &a, &b, &n) == 3)
printf("%lld\n", mul(a, b, n));
return 0;
}
Read More +

b429. 離散對數

Problem

背景

高中時上過對數,了解 $a^x = b$,則 $x = \log_{a}{b}$。這個問題很簡單,但是 log() 又是怎麼運作,當時是用查表法,不久在大學就會學到泰勒級數,藉由電腦運算,計算越多次就能更逼近真正的結果。

離散對數的形式如下:

$$a^x \equiv b \mod n$$

已知 $a, b, n$,通常會設定 $0 \le a, b < n$。這問題的難處在於要怎麼解出 $x$,沒有 log() 可以這麼迅速得知。

為什麼需要離散對數?不少的近代加密算法的安全強度由這個問題的難度決定,例如 RSA 加密、Diffie-Hellman 金鑰交換 … 等,實際運用需要套用許多數論原理。然而,加密機制要保證解得回來,通常會保證 $gcd(a, n) = 1$,讓乘法反元素 (逆元) 存在。

問題描述

$$a^x \equiv b \mod n$$

已知 $a, b, n$,解出最小的 $x$,若不存在解則輸出 NOT FOUND

Sample Input

1
2
3
4
2 1 5
2 2 5
3 5 17
4 2 17

Sample Output

1
2
3
4
0
1
5
NOT FOUND

Solution

解決問題 $y = g^x \mod p$,當已知 $y, g, p$,要解出 $x$ 的難度大為提升,不像國高中學的指數計算,可以藉由 log() 運算來完成,離散對數可以的複雜度相當高,當 $p$ 是一個相當大的整數時,通常會取用到 256 bits 以上,複雜度則會在 $O(2^{100})$ 以上。

實際上有一個有趣的算法 body-step, giant-step algorithm,中文翻譯為 小步大步算法 ,在 ACM-ICPC 競賽中也可以找到一些題目來玩玩,算法的時間複雜度為 $O(\sqrt{p})$,空間複雜度也是 $O(\sqrt{p})$。相信除了這個外,還有更好的算法可以完成。

小步大步算法其實很類似塊狀表的算法,分塊處理,每一塊的大小為 $\sqrt{p}$,為了找尋答案計算其中一塊的所有資訊,每一塊就是一小步,接著就是利用數學運算,拉動數線,把這一塊往前推動 (或者反過來把目標搜尋結果相對往塊的地方推動)。因此需要走 $\sqrt{p}$ 大步完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <bits/stdc++.h>
using namespace std;
// Baby-step Giant-step Algorithm
// a x + by = g
void exgcd(long long x, long long y, long long &g, long long &a, long long &b) {
if (y == 0)
g = x, a = 1, b = 0;
else
exgcd(y, x%y, g, b, a), b -= (x/y) * a;
}
long long inverse(long long x, long long p) {
long long g, b, r;
exgcd(x, p, g, r, b);
if (g < 0) r = -r;
return (r%p + p)%p;
}
long long BSGS(long long P, long long B, long long N) {
unordered_map<long long, int> R;
long long sq = (long long) sqrt(P);
long long t = 1, f;
for (int i = 0; i < sq; i++) {
if (t == N)
return i;
if (!R.count(t))
R[t] = i;
t = (t * B) % P;
}
f = inverse(t, P);
for (int i = 0; i <= sq+1; i++) {
if (R.count(N))
return i * sq + R[N];
N = (N * f) % P;
}
return -1;
}
int main() {
long long P, B, N; // find B^L = N mod P
while (scanf("%lld %lld %lld", &B, &N, &P) == 3) {
long long L = BSGS(P, B, N);
if (L == -1)
puts("NOT FOUND");
else
printf("%lld\n", L);
}
return 0;
}
Read More +