b454. 困難版 輸出 RGB 數值

Problem

給一張 $45 \times 45$ 的經典 Lena 彩色影像,共有 $2025$ 個像素,程式碼長度上限 10K,避開打表的長度限制,輸出一模一樣的圖形。

Sample Input

1
NO INPUT

Sample Output

1
2
45 45
225 134 119 223 130 108 220 126 105 (後面省略)

Solution

Zerojudge b454. 請輸出這張圖片的 RGB 數值 改編,強迫使用 base64 的方案,使用一般的 16 進制輸出編碼會超過限制。16 進制下,共計需要 $6075$0 - 255 的整數,共計需要用 $12150$ 個可視字元。

根據前一題的實驗,雖然霍夫曼編碼會比較短,但是還要附加解壓縮的代碼一起上傳,除非寫短碼否則很容易虧本。而 lz77 是不錯的壓縮方案,但用在影像中很容易虧本,因為重複的片段並不高。因此最後選擇直接使用 base64,則可視字元數量可以降到 10K 以下,接下來就比誰的短碼能力好。

base64 只用到 0-9a-zA-Z+/= 這 64 個字元,為了在一般 C/C++ 的字串宣告語法,跳逸字元如 \\\t\n … 等必須用兩個字元表示,通常都是在 ASCII [0, 31] 為跳逸字元。蔡神 asas 直接用連續片段,因此會有一些跳逸字元,儘管跳逸字元占用兩個以上的字符,會多好幾個字元,但就不必編寫解碼程序,不用去寫繁複的映射,代碼居然比較短。

產生器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
#include <stdio.h>
#include <stdlib.h>
#include <algorithm>
#include <string.h>
#include <math.h>
#include <map>
#include <vector>
#include <set>
using namespace std;
void lz77_encode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
if (n <= 16) {
for (int i = 0; i < n; i++)
out[m++] = in[i];
return ;
}
memcpy(out, in, 16);
m += 16;
while (n > 16) {
int mx = 0, offset, i, j;
for (i = 0; i < 16; i++) {
for (j = 0; i+j < 16 && j+16 < n && in[i+j] == in[j+16]; j++);
if (j > mx)
mx = j, offset = i;
}
if (mx <= 1) {
out[m++] = 0;
out[m++] = in[16];
in++, n--;
} else {
out[m++] = (offset<<4)|(mx-1);
in += mx, n -= mx;
}
}
}
void lz77_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
if (n <= 16) {
for (int i = 0; i < n; i++)
out[m++] = in[i];
return ;
}
memcpy(out, in, 16);
in += 16, n -= 16, m += 16;
int offset, mx;
while (n > 0) {
if (*in) {
offset = (*in)>>4, mx = ((*in)&0xf)+1;
memcpy(out + m, out + m - 16 + offset, mx);
m += mx;
in++, n -= 1;
} else {
out[m++] = in[1];
in += 2, n -= 2;
}
}
}
void base64_encode(unsigned char *in, int n, unsigned char *out, int &m) {
static char encode_table[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3',
'4', '5', '6', '7', '8', '9', '+', '/'};
static int mod_table[] = {0, 2, 1};
m = 4 * ((n + 2) / 3);
for (int i = 0, j = 0; i < n; ) {
unsigned int a, b, c, d;
a = i < n ? in[i++] : 0;
b = i < n ? in[i++] : 0;
c = i < n ? in[i++] : 0;
d = (a<<16)|(b<<8)|c;
out[j++] = encode_table[(d>>18)&0x3f];
out[j++] = encode_table[(d>>12)&0x3f];
out[j++] = encode_table[(d>>6)&0x3f];
out[j++] = encode_table[d&0x3f];
}
for (int i = 0; i < mod_table[n%3]; i++)
out[m - 1 - i] = '=';
}
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int decode_table[256];
for (int i = 'A'; i <= 'Z'; i++) decode_table[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) decode_table[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) decode_table[i] = i - '0' + 52;
decode_table['+'] = 62, decode_table['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : decode_table[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
void huffman_encode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
struct Table {
int f, bit, bn;
} tables[512];
struct Node {
int f, tid;
Node *l, *r;
Node(int a = 0, int b = 0, Node *ls = NULL, Node *rs = NULL):
f(a), tid(b), l(ls), r(rs) {}
bool operator<(const Node &x) const {
return f < x.f;
}
} nodes[512];
struct Stack {
Node *node;
int bit, bn;
Stack(Node *a = NULL, int b = 0, int c = 0):
node(a), bit(b), bn(c) {}
} stacks[1024];
int size = 0, leaf;
multiset< pair<int, Node*> > S;
memset(tables, 0, sizeof(tables));
for (int i = 0; i < n; i++)
tables[in[i]].f++;
for (int i = 0; i < 256; i++) {
if (tables[i].f) {
nodes[size] = Node(tables[i].f, i);
S.insert({nodes[size].f, &nodes[size]});
size++;
}
}
leaf = size;
while (S.size() >= 2) {
pair<int, Node*> u, v;
u = *S.begin(), S.erase(S.begin());
v = *S.begin(), S.erase(S.begin());
int f = u.second->f + v.second->f;
nodes[size] = Node(f, -1, u.second, v.second);
S.insert({nodes[size].f, &nodes[size]});
size++;
}
int stkIdx = 0;
stacks[stkIdx++] = Stack(&nodes[size-1], 0, 0);
while (stkIdx) {
Stack u = stacks[--stkIdx], v;
if (u.bn >= 31) {
fprintf(stderr, "huffman error: bit length exceeded\n");
exit(0);
}
if (u.node->l == NULL) {
tables[u.node->tid].bit = u.bit;
tables[u.node->tid].bn = u.bn;
} else {
v = Stack(u.node->l, u.bit | (1<<u.bn), u.bn+1);
u.node = u.node->r, u.bn++;
stacks[stkIdx++] = u;
stacks[stkIdx++] = v;
}
}
int bits_cnt = 0;
for (int i = 0; i < 256; i++)
bits_cnt += tables[i].f * tables[i].bn;
fprintf(stderr, "huffman: %d bit length\n", bits_cnt);
memcpy(out+m, &bits_cnt, 4), m += 4;
memcpy(out+m, &leaf, 1), m += 1;
for (int i = 0; i < leaf; i++) {
Table t = tables[nodes[i].tid];
memcpy(out+m, &nodes[i].tid, 1), m += 1;
memcpy(out+m, &t.bn, 1), m += 1;
memcpy(out+m, &t.bit, (t.bn + 7)/8), m += (t.bn + 7)/8;
}
int cnt = 0, mask = 0;
for (int i = 0; i < n; i++) {
int bit = tables[in[i]].bit;
int bn = tables[in[i]].bn;
while (bn) { // LBS
int j = min(bn, 8 - cnt);
mask |= (bit&((1<<j)-1))<<cnt;
cnt += j, bn -= j, bit >>= j;
if (cnt == 8) {
memcpy(out+m, &mask, 1), m += 1;
cnt = 0, mask = 0;
}
}
}
if (cnt)
memcpy(out+m, &mask, 1), m += 1;
fprintf(stderr, "huffman encode length %d\n", m);
}
void huffman_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
map<int, int> R[64];
int bits_cnt = 0, leaf = 0;
memcpy(&bits_cnt, in, 4), in += 4, n -= 4;
memcpy(&leaf, in, 1), in += 1, n -= 1;
fprintf(stderr, "huffman: %d bit length\n", bits_cnt);
fprintf(stderr, "huffman: %d leaves\n", leaf);
for (int i = 0; i < leaf; i++) {
int tid = 0, bn = 0, bit = 0;
memcpy(&tid, in, 1), in += 1, n -= 1;
memcpy(&bn, in, 1), in += 1, n -= 1;
memcpy(&bit, in, (bn+7)/8), in += (bn+7)/8, n -= (bn+7)/8;
R[bn][bit] = tid;
}
int mask = 0, cnt = 0;
for (int i = 0; i < bits_cnt; i++) {
mask |= ((in[(i>>3)]>>(i&7))&1)<<cnt;
cnt++;
if (R[cnt].count(mask)) {
out[m++] = R[cnt][mask];
cnt = 0, mask = 0;
}
}
}
int main() {
int n, m;
while (scanf("%d %d", &n, &m) == 2) {
unsigned char in[32767], in2[32767], *pin;
pin = in;
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
int r, g, b;
scanf("%d %d %d", &r, &g, &b);
*pin = r, pin++;
*pin = g, pin++;
*pin = b, pin++;
}
}
int lz77n, b64n, elz77n, tn, hufn, ehufn;
unsigned char buf[4][32767] = {}, test[32767] = {};
// Case 1
// base64_encode(in, n*m*3, buf[1], b64n);
// Case 2
// huffman_encode(in, n*m, buf[0], hufn);
// base64_encode(buf[0], hufn, buf[1], b64n);
// base64_decode(buf[1], b64n, buf[2], ehufn);
// huffman_decode(buf[2], ehufn, test, tn);
// Case 3
// lz77_encode(in, n*m, buf[0], lz77n);
// base64_encode(buf[0], lz77n, buf[1], b64n);
// base64_decode(buf[1], b64n, buf[2], elz77n);
// lz77_decode(buf[2], elz77n, test, tn);
// Case 4
// huffman_encode(in, n*m, buf[0], hufn);
// lz77_encode(buf[0], hufn, buf[2], lz77n);
// base64_encode(buf[2], lz77n, buf[1], b64n);
// Case 5
// lz77_encode(in, n*m, buf[0], lz77n);
// huffman_encode(buf[0], hufn, buf[2], hufn);
// base64_encode(buf[2], lz77n, buf[1], b64n);
// Case 6
for (int i = 0; i < n*m*3; i += 3) {
in2[i] = in[i] - (i >= 3 ? in[i-3] : 0);
in2[i+1] = in[i+1] - (i >= 4 ? in[i-4] : 0);
in2[i+2] = in[i+2] - (i >= 5 ? in[i-5] : 0);
}
huffman_encode(in2, n*m*3, buf[0], hufn);
base64_encode(buf[0], hufn, buf[1], b64n);
for (int i = 0; i < b64n; i++)
printf("%c", buf[1][i]);
puts("");
}
return 0;
}

base64

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
using namespace std;
unsigned char data[] = "4YZ334Js3H5p5o1z8JZ0t01cpEBWsklUsE ... ignore";
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int R[256];
for (int i = 'A'; i <= 'Z'; i++) R[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) R[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) R[i] = i - '0' + 52;
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : R[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
int main() {
unsigned char d[32767] = {};
int n;
base64_decode(data, sizeof(data), d, n);
printf("%d %d\n", 45, 45);
for (int i = 0; i < 45; i++) {
for (int j = 0; j < 45; j++) {
int v = (i*45+j)*3;
printf("%d %d %d%c", d[v], d[v+1], d[v+2], " \n"[j == 44]);
}
}
return 0;
}

base64 短碼

利用巨集展開重複的指令

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include <bits/stdc++.h>
typedef unsigned char UINT8;
UINT8 data[] = "4YZ334Js3H5p5o1z8JZ0t01cpEBWsklUsE ... ignore";
UINT8 d[32767] = {};
void base64_decode(UINT8 *in, int n, UINT8 *out, int &m) {
#define MK(st, ed, b) for (int i = st; i <= ed; i++) R[i] = i-st+b;
#define MG(k) b |= (in[i] == '=' ? 0 : R[in[i]])<<(k*6), i++;
#define MP(k) if (j < m) out[j++] = (b>>(k*8))&0xff;
int R[256];
MK('A', 'Z', 0); MK('a', 'z', 26); MK('0', '9', 52);
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
int b = 0;
MG(3); MG(2); MG(1); MG(0);
MP(2); MP(1); MP(0);
}
}
int main() {
int n, v;
base64_decode(data, sizeof(data), d, n);
printf("%d %d\n", 45, 45);
for (int i = 0; i < 45; i++) {
for (int j = 0; j < 45; j++) {
v = (i*45+j)*3;
printf("%d %d %d%c", d[v], d[v+1], d[v+2], " \n"[j == 44]);
}
}
return 0;
}
Read More +

a458. Beats of the Angel 加強版

Problem

給一張圖,給定起點 S、終點 T,把任意邊的權重放大兩倍,請問最多能拉長最短道路長度為何。

意即修改一條邊找替代道路,使得替代道路越長越好。

Sample Input

1
2
3
4
5
6
7
8
9
5 7
2 1 5
1 3 1
3 2 8
3 5 7
3 4 3
2 4 7
4 5 2
1 5

Sample Output

1
2

Solution

終於向 liouzhou101 問到解法,類似 BZOJ 2725 Violet 6 故乡的梦,過了三年多,才把 Angel Beats 加強版寫出來,終於完成了,代碼是如此地迷人,一堆掃描線算法,找瓶頸也用掃描線代替感覺萌萌哒。

移除掉一條邊,找到 S-T 最短路徑的替代道路 $O(E \log E)$ 離線處理。

  1. 先在 DAG 上找瓶頸,只有瓶頸發生變化才受影響。
  2. 計算 DAG 上,經過瓶頸的最小次數。
  3. 接著,對瓶頸由近至遠進行掃描,維護一個最小堆 multiset,維護替代道路的可行性,保持掃描線左右兩方各自通過的瓶頸數接起來不會通過瓶頸。
1
2
3
4
5
B -- C H -- I
/ \ / \
S - A D -- G ------ J - T
\ /
E -- F

變數說明

  • Ds(u) 表示從起點 S 到 u 的最短路徑
  • De(u) 表示從 u 到終點 T 的最短路徑
  • Bs(u) 表示從起點 S 到 u 保持在最短路徑上,經過最少的瓶頸數
  • Be(u) 表示從 u 到終點 T 保持在最短路徑上,經過最少的瓶頸數

例如先找到 shortest-path DAG 圖如上,則 D - GJ - T 都是瓶頸。瓶頸有兩種找法,第一種是類似最大流的最小割,第二種是轉換成區間,例如 e(u, v) 則可以得到區間 [Ds(u), Ds(v)],然後做掃描線算法,得到哪個時候只有一個區間,則那個邊就是瓶頸。

得到所有瓶頸後,依序掃描離 S 近到遠的瓶頸,維護一個 mutliset<int>,當窮舉瓶頸 e(u, v),則替代道路為 min(Ds(a) + w(a, b) + De(b)),滿足 Bs(a) <= Bs(u)Be(b) <= Be(v)。條件 Bs(a) <= Bs(u)Be(b) <= Be(v) 是為了保證不經過這個瓶頸。

隨著掃描線移動,Be(b) 遞減,Bs(a) 遞增,因此先排序 Bs[]Be[],每當掃描到一個瓶頸,鬆弛新的路徑,同時移除掉最小堆中 Ds(a) + w(a, b) + De(b)Be(b) > Be(v) 的所有元素。

英文名稱 The selfish-edges Shortest-Path problem

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <bits/stdc++.h>
using namespace std;
// http://zrt.github.io/2014/08/13/bzoj2725/
const int MAXV = 200005;
const int MAXE = 200005<<1;
const long long INF = 1LL<<60;
struct Edge {
int to, eid;
long long w;
Edge *next;
};
Edge edge[MAXE], *adj[MAXV];
int e = 0;
long long Ds[MAXV], De[MAXV];
int Bs[MAXV], Be[MAXV], keye[MAXE];
void addEdge(int x, int y, long long v) {
edge[e].to = y, edge[e].w = v, edge[e].eid = e;
edge[e].next = adj[x], adj[x] = &edge[e++];
edge[e].to = x, edge[e].w = v, edge[e].eid = e;
edge[e].next = adj[y], adj[y] = &edge[e++];
}
void dijkstra(int st, long long dist[], int n) {
typedef pair<long long, int> PLL;
for (int i = 0; i <= n; i++)
dist[i] = INF;
set<PLL> pQ;
PLL u;
pQ.insert(PLL(0, st)), dist[st] = 0;
while (!pQ.empty()) {
u = *pQ.begin(), pQ.erase(pQ.begin());
for (Edge *p = adj[u.second]; p; p = p->next) {
if (dist[p->to] > dist[u.second] + p->w) {
if (dist[p->to] != INF)
pQ.erase(pQ.find(PLL(dist[p->to], p->to)));
dist[p->to] = dist[u.second] + p->w;
pQ.insert(PLL(dist[p->to], p->to));
}
}
}
}
int bottleneck(int st, int ed, long long Ds[], long long De[], int n) {
typedef pair<long long, pair<long long, int> > PLL;
vector<PLL> A;
// short path DAG st-ed
long long d = Ds[ed];
for (int i = 1; i <= n; i++) {
for (Edge *p = adj[i]; p; p = p->next) {
if (Ds[p->to] == Ds[i] + p->w && Ds[p->to] + De[p->to] == d) {
A.push_back(PLL(Ds[i], {Ds[p->to], p->eid}));
}
}
}
// bottleneck edge st-ed
sort(A.begin(), A.end());
priority_queue<long long, vector<long long>, std::greater<long long>> pQ;
int cnt = 0;
for (int i = 0; i < e; i++)
keye[i] = 0;
for (int i = 0; i < A.size(); ) {
long long l = A[i].first;
while (!pQ.empty() && pQ.top() <= l)
pQ.pop();
while (i < A.size() && A[i].first <= l)
pQ.push(A[i].second.first), i++;
if (pQ.size() == 1) {
keye[A[i-1].second.second] = 1;
keye[A[i-1].second.second^1] = -1;
cnt++;
}
}
return cnt;
}
void bfs(int st, int dist[], long long Ds[], int n, int f) {
typedef pair<long long, int> PLL;
// work for short-path DAG, weight: key edge
vector<PLL> A;
for (int i = 1; i <= n; i++)
A.push_back(PLL(Ds[i], i)), dist[i] = 0x3f3f3f3f;
dist[st] = 0;
sort(A.begin(), A.end());
for (int i = 0; i < n; i++) {
int x = A[i].second;
for (Edge *p = adj[x]; p; p = p->next) {
if (Ds[p->to] == Ds[x] + p->w)
dist[p->to] = min(dist[p->to], dist[x] + (keye[p->eid] == f));
}
}
}
long long solve(int st, int ed, long long Ds[], long long De[], int n) {
typedef pair<long long, int> PLL;
typedef pair<int, int> PII;
typedef multiset<long long>::iterator MIT;
bfs(st, Bs, Ds, n, 1);
bfs(ed, Be, De, n, -1);
vector<PLL> A;
vector<PII> B, C;
for (int i = 1; i <= n; i++) {
A.push_back(PLL(Ds[i], i));
B.push_back(PII(Bs[i], i));
C.push_back(PII(Be[i], i));
}
sort(A.begin(), A.end());
sort(B.begin(), B.end());
sort(C.begin(), C.end());
long long d = Ds[ed];
int Bidx = 0, Cidx = n-1;
multiset<long long> S;
vector< vector<MIT> > RM(n+1, vector<MIT>());
long long ret = 0;
for (int i = 0; i < n; i++) {
int x = A[i].second;
for (Edge *p = adj[x]; p; p = p->next) {
if (Ds[p->to] == Ds[x] + p->w && Ds[p->to] + De[p->to] == d) {
if (keye[p->eid]) {
int bb = Bs[x], cut = Be[p->to];
// relax
for (; Bidx < B.size() && B[Bidx].first <= bb; Bidx++) {
int u = B[Bidx].second;
for (Edge *q = adj[u]; q; q = q->next) {
if (Be[q->to] <= cut && p != q) {
MIT it = S.insert(Ds[u] + q->w + De[q->to]);
RM[q->to].push_back(it);
}
}
}
// remove
for (; Cidx >= 0 && C[Cidx].first > cut; Cidx--) {
int u = C[Cidx].second;
for (auto e : RM[u])
S.erase(e);
}
long long replace_path = INF;
if (!S.empty())
replace_path = *S.begin();
replace_path = min(replace_path, d + p->w); // this edge double weight
ret = max(ret, replace_path - d);
}
}
}
}
return ret;
}
int main() {
int n, m, x, y;
int st, ed;
long long v;
while (scanf("%d %d", &n, &m) == 2) {
for (int i = 1; i <= n; i++)
adj[i] = NULL;
for (int i = 0; i < m; i++) {
scanf("%d %d %lld", &x, &y, &v);
addEdge(x, y, v);
}
scanf("%d %d", &st, &ed);
dijkstra(st, Ds, n);
dijkstra(ed, De, n);
int cnt = bottleneck(st, ed, Ds, De, n);
if (cnt == 0)
puts("0");
else
printf("%lld\n", solve(st, ed, Ds, De, n));
}
return 0;
}
Read More +

b454. 請輸出這張圖片的 RGB 數值

Problem

給一張 $64 \times 64$ 的經典 Lena 影像,共有 $4096$ 個像素,程式碼長度上限 10K,避開打表的長度限制,輸出一模一樣的圖形。

Sample Input

1
NO INPUT

Sample Output

1
2
64 64
153 153 153 151 151 151 148 148 148 (後面省略)

Solution

前處理

首先遇到的問題是將灰階圖片取出,變成一般常看到的數字格式,直接用 python 來完成,必要時需要安裝 install Python Imaging Library PIL,這是前處理,也可以用其他的 OpenCV 來完成。

1
$ python img2txt.py b454.png >in.txt
img2txt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import sys
import Image
import codecs
sys.stdout = codecs.getwriter('utf8')(sys.stdout)
for infile in sys.argv[1:]:
try:
im = Image.open(infile)
gray = im.convert('LA')
width = gray.size[0]
height = gray.size[1]
pix = gray.load()
print height, width
for i in range(height):
for j in range(width):
print pix[j, i][0],
print
except IOError:
pass

建表壓縮

由於 4096 個像素,若使用十六進制表示每一個 0 - 255 像素,只需要 2 個字元 00 - FF,程式碼長度大約會落在 8192 bytes,若使用十進制,需要使用 3 個字元表示,那麼大約是在 13KB,所以至少要用十六進制表示法。

還有更好的方式,如一般常在網頁上瀏覽的 base64,它充分利用 64 個可視字元進行編碼,相較於十六進制只使用 16 個可視字元來比較會更加地短。利用傳統的霍夫曼編碼 (huffman coding) 來壓縮,結果會短個 10% 到 20%,當每種像素次數分布相當懸殊時,效果就越好,但上傳時還要附加解壓縮的代碼,所以沒辦法短太多,還不如直接 base64。由於是圖片,漸層效果比較普遍,改用與前一個相鄰像素的差值來轉換,得到的分布會比較極端,帶入 huffman 的效果就會不錯。

最後,還有一種比較容易實作的 lz77 壓縮,比較類似最長平台,每一次的訊息為 (起始位置, 重疊長度, 補尾字元),設定一個 window 長度,然後 sliding 滑動,但是起始位置、長度需要選好,代碼中嘗試用 window size = 16,則起始位置和重疊長度可以用 8-bits 表示,越大不見得越好,太小也不是好事,但不管怎麼做,由於影像的重複 pattern 並不多,沒有像一般數學性質的數列來得強,壓縮實驗不管怎樣都大於 10KB。

關於實作細節,霍夫曼編碼儲存格式是 壓縮位元長度 bits length + 字典表 + 壓縮資料,對於字典表的儲存有很多種,由於是 complete binary tree (只會有兩個、零個子節點),可以用一個 0 / 1 前序走訪來完成,這會造成解壓縮代碼長度就會有點虧本,所以在代碼中直接使用一般的打表,所以要保證每一個壓縮完的最大 bit-length 小於 32 來方便型態 int32 操作。

方法 代碼長度 (bytes)
base64 6485
huffman+base64 7965
diff+huffman+base64 7717
lz77+base64 > 10K

產生器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
#include <bits/stdc++.h>
using namespace std;
void lz77_encode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
if (n <= 16) {
for (int i = 0; i < n; i++)
out[m++] = in[i];
return ;
}
memcpy(out, in, 16);
m += 16;
while (n > 16) {
int mx = 0, offset, i, j;
for (i = 0; i < 16; i++) {
for (j = 0; i+j < 16 && j+16 < n && in[i+j] == in[j+16]; j++);
if (j > mx)
mx = j, offset = i;
}
if (mx <= 1) {
out[m++] = 0;
out[m++] = in[16];
in++, n--;
} else {
out[m++] = (offset<<4)|(mx-1);
in += mx, n -= mx;
}
}
}
void lz77_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
if (n <= 16) {
for (int i = 0; i < n; i++)
out[m++] = in[i];
return ;
}
memcpy(out, in, 16);
in += 16, n -= 16, m += 16;
int offset, mx;
while (n > 0) {
if (*in) {
offset = (*in)>>4, mx = ((*in)&0xf)+1;
memcpy(out + m, out + m - 16 + offset, mx);
m += mx;
in++, n -= 1;
} else {
out[m++] = in[1];
in += 2, n -= 2;
}
}
}
void base64_encode(unsigned char *in, int n, unsigned char *out, int &m) {
static char encode_table[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3',
'4', '5', '6', '7', '8', '9', '+', '/'};
static int mod_table[] = {0, 2, 1};
m = 4 * ((n + 2) / 3);
for (int i = 0, j = 0; i < n; ) {
unsigned int a, b, c, d;
a = i < n ? in[i++] : 0;
b = i < n ? in[i++] : 0;
c = i < n ? in[i++] : 0;
d = (a<<16)|(b<<8)|c;
out[j++] = encode_table[(d>>18)&0x3f];
out[j++] = encode_table[(d>>12)&0x3f];
out[j++] = encode_table[(d>>6)&0x3f];
out[j++] = encode_table[d&0x3f];
}
for (int i = 0; i < mod_table[n%3]; i++)
out[m - 1 - i] = '=';
}
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int decode_table[256];
for (int i = 'A'; i <= 'Z'; i++) decode_table[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) decode_table[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) decode_table[i] = i - '0' + 52;
decode_table['+'] = 62, decode_table['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : decode_table[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
void huffman_encode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
struct Table {
int f, bit, bn;
} tables[512];
struct Node {
int f, tid;
Node *l, *r;
Node(int a = 0, int b = 0, Node *ls = NULL, Node *rs = NULL):
f(a), tid(b), l(ls), r(rs) {}
bool operator<(const Node &x) const {
return f < x.f;
}
} nodes[512];
struct Stack {
Node *node;
int bit, bn;
Stack(Node *a = NULL, int b = 0, int c = 0):
node(a), bit(b), bn(c) {}
} stacks[1024];
int size = 0, leaf;
multiset< pair<int, Node*> > S;
memset(tables, 0, sizeof(tables));
for (int i = 0; i < n; i++)
tables[in[i]].f++;
for (int i = 0; i < 256; i++) {
if (tables[i].f) {
nodes[size] = Node(tables[i].f, i);
S.insert({nodes[size].f, &nodes[size]});
size++;
}
}
leaf = size;
while (S.size() >= 2) {
pair<int, Node*> u, v;
u = *S.begin(), S.erase(S.begin());
v = *S.begin(), S.erase(S.begin());
int f = u.second->f + v.second->f;
nodes[size] = Node(f, -1, u.second, v.second);
S.insert({nodes[size].f, &nodes[size]});
size++;
}
int stkIdx = 0;
stacks[stkIdx++] = Stack(&nodes[size-1], 0, 0);
while (stkIdx) {
Stack u = stacks[--stkIdx], v;
if (u.bn >= 31) {
fprintf(stderr, "huffman error: bit length exceeded\n");
exit(0);
}
if (u.node->l == NULL) {
tables[u.node->tid].bit = u.bit;
tables[u.node->tid].bn = u.bn;
} else {
v = Stack(u.node->l, u.bit | (1<<u.bn), u.bn+1);
u.node = u.node->r, u.bn++;
stacks[stkIdx++] = u;
stacks[stkIdx++] = v;
}
}
int bits_cnt = 0;
for (int i = 0; i < 256; i++)
bits_cnt += tables[i].f * tables[i].bn;
fprintf(stderr, "huffman: %d bit length\n", bits_cnt);
memcpy(out+m, &bits_cnt, 4), m += 4;
memcpy(out+m, &leaf, 1), m += 1;
for (int i = 0; i < leaf; i++) {
Table t = tables[nodes[i].tid];
memcpy(out+m, &nodes[i].tid, 1), m += 1;
memcpy(out+m, &t.bn, 1), m += 1;
memcpy(out+m, &t.bit, (t.bn + 7)/8), m += (t.bn + 7)/8;
}
int cnt = 0, mask = 0;
for (int i = 0; i < n; i++) {
int bit = tables[in[i]].bit;
int bn = tables[in[i]].bn;
while (bn) { // LBS
int j = min(bn, 8 - cnt);
mask |= (bit&((1<<j)-1))<<cnt;
cnt += j, bn -= j, bit >>= j;
if (cnt == 8) {
memcpy(out+m, &mask, 1), m += 1;
cnt = 0, mask = 0;
}
}
}
if (cnt)
memcpy(out+m, &mask, 1), m += 1;
fprintf(stderr, "huffman encode length %d\n", m);
}
void huffman_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
map<int, int> R[64];
int bits_cnt = 0, leaf = 0;
memcpy(&bits_cnt, in, 4), in += 4, n -= 4;
memcpy(&leaf, in, 1), in += 1, n -= 1;
fprintf(stderr, "huffman: %d bit length\n", bits_cnt);
fprintf(stderr, "huffman: %d leaves\n", leaf);
for (int i = 0; i < leaf; i++) {
int tid = 0, bn = 0, bit = 0;
memcpy(&tid, in, 1), in += 1, n -= 1;
memcpy(&bn, in, 1), in += 1, n -= 1;
memcpy(&bit, in, (bn+7)/8), in += (bn+7)/8, n -= (bn+7)/8;
R[bn][bit] = tid;
}
int mask = 0, cnt = 0;
for (int i = 0; i < bits_cnt; i++) {
mask |= ((in[(i>>3)]>>(i&7))&1)<<cnt;
cnt++;
if (R[cnt].count(mask)) {
out[m++] = R[cnt][mask];
cnt = 0, mask = 0;
}
}
}
int main() {
int n, m;
int data[64][64];
while (scanf("%d %d", &n, &m) == 2) {
unsigned char in[32767], in2[32767];
int stat[256] = {};
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
scanf("%d", &data[i][j]);
in[i*m+j] = data[i][j];
stat[data[i][j]]++;
}
}
int lz77n, b64n, elz77n, tn, hufn, ehufn;
unsigned char buf[4][32767] = {}, test[32767] = {};
// Case 1
base64_encode(in, n*m, buf[1], b64n);
// Case 2
// huffman_encode(in, n*m, buf[0], hufn);
// base64_encode(buf[0], hufn, buf[1], b64n);
// base64_decode(buf[1], b64n, buf[2], ehufn);
// huffman_decode(buf[2], ehufn, test, tn);
// Case 3
// lz77_encode(in, n*m, buf[0], lz77n);
// base64_encode(buf[0], lz77n, buf[1], b64n);
// base64_decode(buf[1], b64n, buf[2], elz77n);
// lz77_decode(buf[2], elz77n, test, tn);
// Case 4
// huffman_encode(in, n*m, buf[0], hufn);
// lz77_encode(buf[0], hufn, buf[2], lz77n);
// base64_encode(buf[2], lz77n, buf[1], b64n);
// Case 5
// lz77_encode(in, n*m, buf[0], lz77n);
// huffman_encode(buf[0], hufn, buf[2], hufn);
// base64_encode(buf[2], lz77n, buf[1], b64n);
// Case 6
// for (int i = 0; i < n*m; i++)
// in2[i] = in[i] - (i ? in[i-1] : 0);
// huffman_encode(in2, n*m, buf[0], hufn);
// base64_encode(buf[0], hufn, buf[1], b64n);
for (int i = 0; i < b64n; i++)
printf("%c", buf[1][i]);
puts("");
}
return 0;
}

base64

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <bits/stdc++.h>
using namespace std;
unsigned char data[] = "mZeUkpampXpOX2FhYGlxd ... ignore";
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int R[256];
for (int i = 'A'; i <= 'Z'; i++) R[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) R[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) R[i] = i - '0' + 52;
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : R[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
int main() {
unsigned char d[32767] = {};
int n;
base64_decode(data, sizeof(data), d, n);
printf("%d %d\n", 64, 64);
for (int i = 0; i < 64; i++) {
for (int j = 0; j < 64; j++) {
printf("%d %d %d%c", d[i*64+j], d[i*64+j], d[i*64+j], j == 63 ? '\n' : ' ');
}
}
return 0;
}

huffman+base64

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <bits/stdc++.h>
using namespace std;
unsigned char data[] = "RHkAAOoAC4kCAguWBwYMFggIDBYA ... ignore";
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int R[256];
for (int i = 'A'; i <= 'Z'; i++) R[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) R[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) R[i] = i - '0' + 52;
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : R[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
void huffman_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
map<int, int> R[64];
int bits_cnt = 0, leaf = 0;
memcpy(&bits_cnt, in, 4), in += 4, n -= 4;
memcpy(&leaf, in, 1), in += 1, n -= 1;
for (int i = 0; i < leaf; i++) {
int tid = 0, bn = 0, bit = 0;
memcpy(&tid, in, 1), in += 1, n -= 1;
memcpy(&bn, in, 1), in += 1, n -= 1;
memcpy(&bit, in, (bn+7)/8), in += (bn+7)/8, n -= (bn+7)/8;
R[bn][bit] = tid;
}
int mask = 0, cnt = 0;
for (int i = 0; i < bits_cnt; i++) {
mask |= ((in[(i>>3)]>>(i&7))&1)<<cnt;
cnt++;
if (R[cnt].count(mask)) {
out[m++] = R[cnt][mask];
cnt = 0, mask = 0;
}
}
}
int main() {
unsigned char d[2][32767] = {};
int n, m;
base64_decode(data, sizeof(data), d[0], n);
huffman_decode(d[0], n, d[1], m);
printf("%d %d\n", 64, 64);
for (int i = 0; i < 64; i++) {
for (int j = 0; j < 64; j++) {
printf("%d %d %d%c", d[1][i*64+j], d[1][i*64+j], d[1][i*64+j], j == 63 ? '\n' : ' ');
}
}
return 0;
}

diff+huffman+base64

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <bits/stdc++.h>
using namespace std;
unsigned char data[] = "nG0AAPkABQIBBRICBRoDBRMEBQ0FB ... ignore";
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int R[256];
for (int i = 'A'; i <= 'Z'; i++) R[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) R[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) R[i] = i - '0' + 52;
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : R[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
void huffman_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
map<int, int> R[64];
int bits_cnt = 0, leaf = 0;
memcpy(&bits_cnt, in, 4), in += 4, n -= 4;
memcpy(&leaf, in, 1), in += 1, n -= 1;
for (int i = 0; i < leaf; i++) {
int tid = 0, bn = 0, bit = 0;
memcpy(&tid, in, 1), in += 1, n -= 1;
memcpy(&bn, in, 1), in += 1, n -= 1;
memcpy(&bit, in, (bn+7)/8), in += (bn+7)/8, n -= (bn+7)/8;
R[bn][bit] = tid;
}
int mask = 0, cnt = 0;
for (int i = 0; i < bits_cnt; i++) {
mask |= ((in[(i>>3)]>>(i&7))&1)<<cnt;
cnt++;
if (R[cnt].count(mask)) {
out[m++] = R[cnt][mask];
cnt = 0, mask = 0;
}
}
}
int main() {
unsigned char d[2][32767] = {};
int n, m;
base64_decode(data, sizeof(data), d[0], n);
huffman_decode(d[0], n, d[1], m);
for (int i = 1; i < m; i++)
d[1][i] = d[1][i] + d[1][i-1];
printf("%d %d\n", 64, 64);
for (int i = 0; i < 64; i++) {
for (int j = 0; j < 64; j++) {
printf("%d %d %d%c", d[1][i*64+j], d[1][i*64+j], d[1][i*64+j], j == 63 ? '\n' : ' ');
}
}
return 0;
}
Read More +

b451. 圖片匹配

Problem

背景

圖片匹配和字串匹配有一點不同,字串匹配通常要求其子字串與搜尋字串完全相符,而圖片匹配則用相似度為依據,當圖片大、複雜且具有干擾,或者需要匹配數量非常多,更先進的領域會利用特徵擷取,用機率統計的方式來篩選可能的匹配數量,篩選過後才進行圖片的細節匹配。

題目描述

給予兩個圖片 $A, B$,圖片格式為灰階影像,每個像素 $\mathit{pixel}(x, y)$ 採用 8-bits 表示,範圍為 $\mathit{pixel}(x, y) \in [0, 255]$

舉一個例子,有一個 $3 \times 3$ 的圖片 $A$ 和一個 $2 \times 2$ 的圖片 $B$,用矩陣表示如下:

$$A := \begin{bmatrix} a1 & a2 & a3 \\ a4 & a5 & a6 \\ a7 & a8 & a9 \end{bmatrix} ,\; B := \begin{bmatrix} b1 & b2\\ b3 & b4\\ \end{bmatrix}$$

假設左上角座標 $(1, 1)$$a1$ 的位置、$(1, 2)$$a2$ 的位置。

  • 把影像 $B$ 左上角對齊 $A$$(1, 1)$ 位置,其差異程度 $\mathit{diff}(A, B) = (a1 - b1)^2 + (a2 - b2)^2 + (a4 - b3)^2 + (a5 - b4)^2$
  • 相同地,對齊 $(2, 1)$ 位置,其差異程度 $\mathit{diff}(A, B) = (a4 - b1)^2 + (a5 - b2)^2 + (a7 - b3)^2 + (a8 - b4)^2$

比較時,整張 $B$ 都要落在 $A$ 中。現在要找到一個對齊位置 $(x, y)$,使得 $\mathit{diff}(A, B)$ 最小。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
4 4 2 2
0 3 3 0
0 3 3 0
0 0 5 5
0 0 5 5
5 5
5 5
3 5 1 2
2 3 0 4 5
0 0 7 0 0
0 0 7 0 0
3 4

Sample Output

1
2
3 3
1 1

Solution

參考資料

相同樸素 FFT 題目 ZOJ 1637 - Fast Image Match

小記

雖然不是很明白 FFT,參考上面的資料,了解 FFT 的旋積 (convolution) 可以在 $O(n \log n)$ 完成就好,接著套模板。

從差異公式中得到$\mathit{diff}(A, B) = \sum (a_i - b_i)^2 = \sum a_{i}^{2} - \sum 2 a_i b_i + \sum b_{i}^{2}$$\sum a_{i}^{2}$$\sum b_{i}^{2}$ 都是獨立的,麻煩的是在於$\sum 2 a_i b_i$ 向量內積,若樸素去計算會在 $O(H^4)$ 完成,套用 FFT 旋積計算,得到 $O(H^2 \log H)$。FFT 有一個缺點,在浮點數的複數域下運行,計算時會失去精準度,要四捨五入到整數。儘管如此,由於不用像 NTT 那樣有很多模運算,速度是最快的。

數論變換 (NTT) / 快速數論變換 (FNT),採用費馬數數論變換,取代複數根的疊加,利用原根的性質來完成。NTT/FNT 處理整數域內積時不存在誤差,所有計算皆在整數,但是要取模變得非常慢。

為了加速計算,丟 CRT 下去降一半的 bits,乘法速度只能提升兩倍,但要做兩次計算,速度稍微快一點點而已。在實作時,特別要注意到,CRT 運作時,挑選兩個質數 $P1, \; P2$ 分別計算 FNT,最後 CRT 逆推回去。

其他實作細節

  1. Fast Reverse Bit 使用位元運算取代建表、迴圈方案。除非多測資。
  2. std::complex<double>struct complex 取代,加速 method interface 拷貝。
  3. sin() cos() 不考慮建表,採用乘法和加法疊加,這會損失一點點精度。若固定測資考慮內存池去搞。

在 O1 下編輯跟 O3 一樣快,由於第二點的貢獻,速度直接從快兩倍。0.3s -> 90ms。

FFT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <bits/stdc++.h>
using namespace std;
template<typename T> class TOOL_FFT {
public:
typedef unsigned int UINT32;
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void FFT(bool InverseTransform, vector<complex<T> >& In, vector<complex<T> >& Out) {
// simultaneous data copy and bit-reversal ordering into outputs
int NumSamples = In.size();
int NumBits = NumberOfBitsNeeded(NumSamples);
for (int i = 0; i < NumSamples; ++i) {
Out[FastReverseBits(i, NumBits)] = In[i];
}
// the FFT process
T angle_numerator = acos(-1.) * (InverseTransform ? -2 : 2);
for (int BlockEnd = 1, BlockSize = 2; BlockSize <= NumSamples; BlockSize <<= 1) {
T delta_angle = angle_numerator / BlockSize;
T sin1 = sin(-delta_angle);
T cos1 = cos(-delta_angle);
T sin2 = sin(-delta_angle * 2);
T cos2 = cos(-delta_angle * 2);
for (int i = 0; i < NumSamples; i += BlockSize) {
complex<T> a1(cos1, sin1), a2(cos2, sin2), a0;
for (int j = i, n = 0; n < BlockEnd; ++j, ++n) {
a0 = complex<T>(2 * cos1 * a1.real() - a2.real(), 2 * cos1 * a1.imag() - a2.imag());
a2 = a1;
a1 = a0;
complex<T> a = a0 * Out[j + BlockEnd];
Out[j + BlockEnd] = Out[j] - a;
Out[j] += a;
}
}
BlockEnd = BlockSize;
}
// normalize if inverse transform
if (InverseTransform) {
for (int i = 0; i < NumSamples; ++i) {
Out[i] /= NumSamples;
}
}
}
vector<T> convolution(T *a, T *b, int n) {
vector<std::complex<T>> s(n), d1(n), d2(n), y(n);
vector<T> ret(n);
for (int i = 0; i < n; ++i) {
s[i] = complex<T>(a[i], 0);
}
FFT(false, s, d1);
s[0] = complex<T>(b[0], 0);
for (int i = 1; i < n; ++i) {
s[i] = complex<T>(b[n - i], 0);
}
FFT(false, s, d2);
for (int i = 0; i < n; ++i) {
y[i] = d1[i] * d2[i];
}
FFT(true, y, s);
for (int i = 0; i < n; ++i) {
ret[i] = s[i].real();
}
return ret;
}
};
TOOL_FFT<double> tool;
double a[262144], b[262144];
long long sum[512][512];
long long getArea(int lx, int ly, int rx, int ry) {
long long ret = sum[rx][ry];
if(lx-1 >= 0)
ret -= sum[lx-1][ry];
if(ly-1 >= 0)
ret -= sum[rx][ly-1];
if(lx-1 >= 0 && ly-1 >= 0)
ret += sum[lx-1][ly-1];
return ret;
}
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
int main() {
int m, n, p, q, x, N, M, S;
while (ReadInt(&m)) {
ReadInt(&n), ReadInt(&p), ReadInt(&q);
N = max(m, p), M = max(n, q);
S = 1;
for (; S < N*M; S <<= 1);
memset(a, 0, sizeof(a[0]) * S);
memset(b, 0, sizeof(b[0]) * S);
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
ReadInt(&x);
a[i*M+j] = x;
}
}
for (int i = 0; i < p; i++) {
for (int j = 0; j < q; j++) {
ReadInt(&x);
b[i*M+j] = x;
}
}
for (int i = 0; i < m; i++) {
long long s = 0;
for (int j = 0; j < n; j++) {
x = a[i*M+j];
s += x*x;
sum[i][j] = (i > 0 ? sum[i-1][j] : 0) + s;
}
}
vector<double> r = tool.convolution(a, b, S);
int qx = m - p, qy = n - q, bX = 0, bY = 0;
long long diff = LONG_MAX;
for (int i = 0; i <= qx; i++) {
for (int j = 0; j <= qy; j++) {
long long v = round(getArea(i, j, i+p-1, j+q-1) - 2*r[i*M + j]);
if (v < diff) {
diff = v, bX = i, bY = j;
}
}
}
printf("%d %d\n", bX+1, bY+1);
}
return 0;
}

FFT 編譯優化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <bits/stdc++.h>
using namespace std;
template<typename T> class TOOL_FFT {
public:
struct complex {
T x, y;
complex(T x = 0, T y = 0):
x(x), y(y) {}
complex operator+(const complex &A) {
return complex(x+A.x,y+A.y);
}
complex operator-(const complex &A) {
return complex(x-A.x,y-A.y);
}
complex operator*(const complex &A) {
return complex(x*A.x-y*A.y,x*A.y+y*A.x);
}
};
typedef unsigned int UINT32;
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void FFT(bool InverseTransform, vector<complex>& In, vector<complex>& Out) {
// simultaneous data copy and bit-reversal ordering into outputs
int NumSamples = In.size();
int NumBits = NumberOfBitsNeeded(NumSamples);
for (int i = 0; i < NumSamples; ++i) {
Out[FastReverseBits(i, NumBits)] = In[i];
}
// the FFT process
T angle_numerator = acos(-1.) * (InverseTransform ? -2 : 2);
for (int BlockEnd = 1, BlockSize = 2; BlockSize <= NumSamples; BlockSize <<= 1) {
T delta_angle = angle_numerator / BlockSize;
T sin1 = sin(-delta_angle);
T cos1 = cos(-delta_angle);
T sin2 = sin(-delta_angle * 2);
T cos2 = cos(-delta_angle * 2);
for (int i = 0; i < NumSamples; i += BlockSize) {
complex a1(cos1, sin1), a2(cos2, sin2), a0, a;
int j, n;
for (j = i, n = 0; n+8 < BlockEnd; ) {
#define UNLOOP {\
a0 = complex(2 * cos1 * a1.x - a2.x, 2 * cos1 * a1.y - a2.y); \
a2 = a1, a1 = a0; \
a = a0 * Out[j + BlockEnd]; \
Out[j + BlockEnd] = Out[j] - a; \
Out[j] = Out[j] + a; \
++j, ++n; }
#define UNLOOP8 {UNLOOP UNLOOP UNLOOP UNLOOP UNLOOP UNLOOP UNLOOP UNLOOP}
UNLOOP8;
}
for (; n < BlockEnd; )
UNLOOP;
}
BlockEnd = BlockSize;
}
// normalize if inverse transform
if (InverseTransform) {
for (int i = 0; i < NumSamples; ++i) {
Out[i] = Out[i].x / NumSamples;
}
}
}
void convolution(T *a, T *b, int n, T *c) {
vector<complex> s(n), d1(n), d2(n), y(n);
for (int i = 0; i < n; ++i)
s[i] = complex(a[i], 0);
FFT(false, s, d1);
s[0] = complex(b[0], 0);
for (int i = 1; i < n; ++i)
s[i] = complex(b[n - i], 0);
FFT(false, s, d2);
for (int i = 0; i < n; ++i)
y[i] = d1[i] * d2[i];
FFT(true, y, s);
for (int i = 0; i < n; ++i)
c[i] = s[i].x;
}
};
TOOL_FFT<double> tool;
double a[262144], b[262144], c[262144];
long long sum[512][512];
inline long long getArea(int lx, int ly, int rx, int ry) {
long long ret = sum[rx][ry];
if(lx) ret -= sum[lx-1][ry];
if(ly) ret -= sum[rx][ly-1];
if(lx && ly) ret += sum[lx-1][ly-1];
return ret;
}
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
int main() {
int m, n, p, q, x, N, M, S;
while (ReadInt(&m)) {
ReadInt(&n), ReadInt(&p), ReadInt(&q);
N = max(m, p), M = max(n, q);
S = 1;
for (; S < N*M; S <<= 1);
memset(a, 0, sizeof(a[0]) * S);
memset(b, 0, sizeof(b[0]) * S);
for (int i = 0; i < m; i++) {
long long s = 0;
for (int j = 0; j < n; j++) {
ReadInt(&x);
a[i*M+j] = x;
s += x*x;
sum[i][j] = (i > 0 ? sum[i-1][j] : 0) + s;
}
}
for (int i = 0; i < p; i++) {
for (int j = 0; j < q; j++) {
ReadInt(&x);
b[i*M+j] = x;
}
}
tool.convolution(a, b, S, c);
int qx = m - p, qy = n - q, bX = 0, bY = 0;
long long diff = LONG_MAX;
for (int i = 0; i <= qx; i++) {
for (int j = 0; j <= qy; j++) {
long long v = getArea(i, j, i+p-1, j+q-1) - 2*c[i*M + j] + 0.5;
if (v < diff)
diff = v, bX = i, bY = j;
}
}
fprintf(stderr, "%lld\n", diff);
printf("%d %d\n", bX+1, bY+1);
}
return 0;
}

NTT/FNT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <bits/stdc++.h>
using namespace std;
typedef unsigned int UINT32;
typedef long long INT64;
class TOOL_NTT {
public:
#define MAXN 262144
const INT64 P = 50000000001507329LL; // prime m = kn+1
const INT64 G = 3;
INT64 wn[20];
INT64 s[MAXN], d1[MAXN], d2[MAXN], y[MAXN];
TOOL_NTT() {
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
INT64 mod_mul(INT64 a, INT64 b, INT64 mod) {
long long y = (long long)((double)a*b/mod+0.5); // fast for P < 2^58
long long r = (a*b-y*mod)%mod;
return r < 0 ? r + mod : r;
// INT64 ret = 0;
// for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
// if (b&1) {
// ret += a;
// if (ret >= mod)
// ret -= mod;
// }
// }
// return ret;
}
INT64 mod_pow(INT64 n, INT64 e, INT64 m) {
INT64 x = 1;
for (n = n >= m ? n%m : n; e; e >>= 1) {
if (e&1)
x = mod_mul(x, n, m);
n = mod_mul(n, n, m);
}
return x;
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void NTT(int on, INT64 *In, INT64 *Out, int n) {
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; ++i)
Out[FastReverseBits(i, NumBits)] = In[i];
for(int h = 2, id = 1; h <= n; h <<= 1, id++) {
for(int j = 0; j < n; j += h) {
INT64 w = 1, u, t;
int block = h/2, blockEnd = j + h/2;
for(int k = j; k < blockEnd; k++) {
u = Out[k], t = mod_mul(w, Out[k+block], P);
Out[k] = u + t;
Out[k+block] = u - t + P;
if (Out[k] >= P) Out[k] -= P;
if (Out[k+block] >= P) Out[k+block] -= P;
w = mod_mul(w, wn[id], P);
}
}
}
if (on == 1) {
for (int i = 1; i < n/2; i++)
swap(Out[i], Out[n-i]);
INT64 invn = mod_pow(n, P-2, P);
for (int i = 0; i < n; i++)
Out[i] = mod_mul(Out[i], invn, P);
}
}
void convolution(INT64 *a, INT64 *b, int n, INT64 *c) {
NTT(0, a, d1, n);
s[0] = b[0];
for (int i = 1; i < n; ++i)
s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++)
s[i] = mod_mul(d1[i], d2[i], P);
NTT(1, s, c, n);
}
} tool;
INT64 a[262144], b[262144], c[262144];
long long sum[512][512];
inline long long getArea(int lx, int ly, int rx, int ry) {
long long ret = sum[rx][ry];
if(lx) ret -= sum[lx-1][ry];
if(ly) ret -= sum[rx][ly-1];
if(lx && ly) ret += sum[lx-1][ly-1];
return ret;
}
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
int main() {
int m, n, p, q, x, N, M, S;
while (ReadInt(&m)) {
ReadInt(&n), ReadInt(&p), ReadInt(&q);
N = max(m, p), M = max(n, q);
S = 1;
for (; S < N*M; S <<= 1);
memset(a, 0, sizeof(a[0]) * S);
memset(b, 0, sizeof(b[0]) * S);
for (int i = 0; i < m; i++) {
long long s = 0;
for (int j = 0; j < n; j++) {
ReadInt(&x);
a[i*M+j] = x;
s += x*x;
sum[i][j] = (i > 0 ? sum[i-1][j] : 0) + s;
}
}
for (int i = 0; i < p; i++) {
for (int j = 0; j < q; j++) {
ReadInt(&x);
b[i*M+j] = x;
}
}
tool.convolution(a, b, S, c);
int qx = m - p, qy = n - q, bX = 0, bY = 0;
long long diff = LONG_MAX;
for (int i = 0; i <= qx; i++) {
for (int j = 0; j <= qy; j++) {
long long v = getArea(i, j, i+p-1, j+q-1) - 2*c[i*M + j];
if (v < diff)
diff = v, bX = i, bY = j;
}
}
fprintf(stderr, "diff = %lld\n", diff);
printf("%d %d\n", bX+1, bY+1);
}
return 0;
}

NTT/FNT CRT 加速

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#include <bits/stdc++.h>
using namespace std;
typedef uint_fast32_t UINT32;
typedef long long INT64;
typedef uint_fast32_t INT32;
class TOOL_NTT {
public:
#define MAXN 262144
// INT64 P = 50000000001507329LL; // prime m = kn+1
// INT64 G = 3;
INT32 P = 3, G = 2;
INT32 wn[20];
INT32 s[MAXN], d1[MAXN], d2[MAXN], c1[MAXN], c2[MAXN];
const INT32 P1 = 998244353; // P1 = 2^23 * 7 * 17 + 1
const INT32 G1 = 3;
const INT32 P2 = 995622913; // P2 = 2^19 *3*3*211 + 1
const INT32 G2 = 5;
const INT64 M1 = 397550359381069386LL;
const INT64 M2 = 596324591238590904LL;
const INT64 MM = 993874950619660289LL; // MM = P1*P2
TOOL_NTT() {
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
void reset(INT32 p, INT32 g) {
P = p, G = g;
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
INT64 mod_mul(INT64 a, INT64 b, INT64 mod) {
long long y = (long long)((double)a*b/mod+0.5); // fast for P < 2^58
long long r = (a*b-y*mod)%mod;
return r < 0 ? r + mod : r;
// INT64 ret = 0;
// for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
// if (b&1) {
// ret += a;
// if (ret >= mod)
// ret -= mod;
// }
// }
// return ret;
}
INT64 mod_pow(INT64 n, INT64 e, INT64 m) {
INT64 x = 1;
for (n = n >= m ? n%m : n; e; e >>= 1) {
if (e&1)
x = mod_mul(x, n, m);
n = mod_mul(n, n, m);
}
return x;
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void NTT(int on, INT32 *In, INT32 *Out, int n) {
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; ++i)
Out[FastReverseBits(i, NumBits)] = In[i];
for (int h = 2, id = 1; h <= n; h <<= 1, id++) {
for (int j = 0; j < n; j += h) {
INT32 w = 1, u, t;
int block = h/2, blockEnd = j + h/2;
for (int k = j; k < blockEnd; k++) {
u = Out[k], t = (INT64)w*Out[k+block]%P;
Out[k] = (u + t)%P;
Out[k+block] = (u - t + P)%P;
w = (INT64)w * wn[id]%P;
}
}
}
if (on == 1) {
for (int i = 1; i < n/2; i++)
swap(Out[i], Out[n-i]);
INT32 invn = mod_pow(n, P-2, P);
for (int i = 0; i < n; i++)
Out[i] = (INT64)Out[i]*invn%P;
}
}
INT64 crt(INT32 a, INT32 b) {
return (mod_mul(a, M1, MM) + mod_mul(b, M2, MM))%MM;
}
void convolution(INT32 *a, INT32 *b, int n, INT64 *c) {
reset(P1, G1);
NTT(0, a, d1, n);
s[0] = b[0]; for (int i = 1; i < n; ++i) s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++) s[i] = (INT64)d1[i] * d2[i]%P;
NTT(1, s, c1, n);
reset(P2, G2);
NTT(0, a, d1, n);
s[0] = b[0]; for (int i = 1; i < n; ++i) s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++) s[i] = (INT64)d1[i] * d2[i]%P;
NTT(1, s, c2, n);
for (int i = 0; i < n; i++)
c[i] = crt(c1[i], c2[i]);
}
} tool;
INT32 a[262144], b[262144];
INT64 c[262144];
long long sum[512][512];
inline long long getArea(int lx, int ly, int rx, int ry) {
long long ret = sum[rx][ry];
if(lx) ret -= sum[lx-1][ry];
if(ly) ret -= sum[rx][ly-1];
if(lx && ly) ret += sum[lx-1][ly-1];
return ret;
}
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
int main() {
int m, n, p, q, x, N, M, S;
while (ReadInt(&m)) {
ReadInt(&n), ReadInt(&p), ReadInt(&q);
N = max(m, p), M = max(n, q);
S = 1;
for (; S < N*M; S <<= 1);
memset(a, 0, sizeof(a[0]) * S);
memset(b, 0, sizeof(b[0]) * S);
for (int i = 0; i < m; i++) {
long long s = 0;
for (int j = 0; j < n; j++) {
ReadInt(&x);
a[i*M+j] = x;
s += x*x;
sum[i][j] = (i > 0 ? sum[i-1][j] : 0) + s;
}
}
for (int i = 0; i < p; i++) {
for (int j = 0; j < q; j++) {
ReadInt(&x);
b[i*M+j] = x;
}
}
tool.convolution(a, b, S, c);
int qx = m - p, qy = n - q, bX = 0, bY = 0;
long long diff = LONG_MAX;
for (int i = 0; i <= qx; i++) {
for (int j = 0; j <= qy; j++) {
long long v = getArea(i, j, i+p-1, j+q-1) - 2*c[i*M + j];
if (v < diff)
diff = v, bX = i, bY = j;
}
}
fprintf(stderr, "diff = %lld\n", diff);
printf("%d %d\n", bX+1, bY+1);
}
return 0;
}
Read More +

b325. 人格分裂

Problem

某 M 現在正在平面座標上的原點 $(0, 0)$,現在四周被擺放了很多很多鏡子,某 M 可以藉由鏡子與他的人格小夥伴對話,請問那些鏡子可以見到小夥伴。

鏡子可以當作一個線段,線段之間不會交任何一點,只要能見到該鏡子中一小段區域就算可見到。

備註:不考慮反射看到,保證鏡子不會通過原點。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
1
5 -5 5 5
4
4 2 5 -2
2 4 6 1
5 5 8 1
3 -4 7 -1
7
-1 2 3 1
2 4 5 -1
-3 -1 1 -2
-1 -4 3 -2
-2 -4 1 -5
-4 1 -1 4
-3 4 -4 3
1
1 1 2 2
2
-1 3 -2 -2
-2 -1 -3 -1

Sample Output

1
2
3
4
5
1
1 1 0 1
1 1 1 1 0 1 0
0
1 0

Solution

對於每一個線段,可以化作極角座標上的一個角度區間$[\theta_{start}, \theta_{end}]$,做一次極角排序,維護從原點射出的射線,找到該射線交到的所有角度區間,意即維護射線和線段交的最近距離,用一個平衡樹 set<Seg> 維護。由於線段之間不會相交,平衡樹靠遠近當作權重比較,遠近關係是單調的,故不影響插入和刪除。若發生遠近問題可採用 multiset<Seg> 來進行。

由於每一個線段可能拆分好幾個可視線段,而大多數的線段全是不可視的,可以考慮分治去處理,但我的實作效果並不好,其原因在於並沒有維護極角的 skyline,而是保留整個線段。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <bits/stdc++.h>
using namespace std;
#define eps 1e-12
struct Pt {
double x, y;
Pt(double a = 0, double b = 0):
x(a), y(b) {}
Pt operator-(const Pt &a) const {
return Pt(x - a.x, y - a.y);
}
Pt operator+(const Pt &a) const {
return Pt(x + a.x, y + a.y);
}
Pt operator*(const double a) const {
return Pt(x * a, y * a);
}
bool operator<(const Pt &a) const {
if (fabs(x - a.x) > eps)
return x < a.x;
if (fabs(y - a.y) > eps)
return y < a.y;
return false;
}
double dist2(Pt a) {
return (x - a.x)*(x - a.x)+(y - a.y)*(y - a.y);
}
};
double dot(Pt a, Pt b) {
return a.x * b.x + a.y * b.y;
}
double cross(Pt o, Pt a, Pt b) {
return (a.x-o.x)*(b.y-o.y)-(a.y-o.y)*(b.x-o.x);
}
double cross2(Pt a, Pt b) {
return a.x * b.y - a.y * b.x;
}
int between(Pt a, Pt b, Pt c) {
return dot(c - a, b - a) >= -eps && dot(c - b, a - b) >= -eps;
}
int onSeg(Pt a, Pt b, Pt c) {
return between(a, b, c) && fabs(cross(a, b, c)) < eps;
}
Pt getIntersect(Pt as, Pt ae, Pt bs, Pt be) {
Pt u = as - bs;
double t = cross2(be - bs, u)/cross2(ae - as, be - bs);
return as + (ae - as) * t;
}
struct Seg {
Pt s, e;
int id;
Seg(Pt a = Pt(), Pt b = Pt(), int c = 0):
s(a), e(b), id(c) {}
};
bool polar_cmp(const Pt& p1, const Pt& p2) {
if (p1.y == 0 && p2.y == 0 && p1.x * p2.x <= 0) return p1.x > p2.x;
if (p1.y == 0 && p1.x >= 0 && p2.y != 0) return true;
if (p2.y == 0 && p2.x >= 0 && p1.y != 0) return false;
if (p1.y * p2.y < 0) return p1.y > p2.y;
double c = cross2(p1, p2);
return c > 0 || (c == 0 && fabs(p1.x) < fabs(p2.x));
}
bool polar_cmp2(pair<Pt, int> x, pair<Pt, int> y) {
return polar_cmp(x.first, y.first);
}
int cmpZero(double x) {
if (fabs(x) < eps) return 0;
return x < 0 ? -1 : 1;
}
struct CMP {
static Pt ray_s, ray_e;
bool operator()(const Seg &x, const Seg &y) {
Pt v1 = getIntersect(ray_s, ray_e, x.s, x.e);
Pt v2 = getIntersect(ray_s, ray_e, y.s, y.e);
return cmpZero(ray_s.dist2(v1) - ray_s.dist2(v2)) < 0;
}
static bool ray2seg(Seg x) {
if (cmpZero(cross(ray_s, ray_e, x.s))*cmpZero(cross(ray_s, ray_e, x.e)) < 0) {
return cmpZero(cross(x.s, ray_s, ray_s+ray_e))*cmpZero(cross(x.s, ray_s, x.e)) >= 0 &&
cmpZero(cross(x.e, ray_s, ray_s+ray_e))*cmpZero(cross(x.e, ray_s, x.s)) >= 0;
}
return false;
}
};
Pt CMP::ray_s, CMP::ray_e;
const int MAXN = 32768;
int visual[MAXN];
Seg segs[MAXN];
int main() {
int N, sx, sy, ex, ey;
while (scanf("%d", &N) == 1) {
vector< pair<Pt, int> > A;
set<Seg, CMP> S;
for (int i = 1; i <= N; i++) {
scanf("%d %d %d %d", &sx, &sy, &ex, &ey);
A.push_back(make_pair(Pt(sx, sy), i));
A.push_back(make_pair(Pt(ex, ey), -i));
segs[i] = Seg(Pt(sx, sy), Pt(ex, ey), i);
visual[i] = 0;
}
sort(A.begin(), A.end(), polar_cmp2);
CMP::ray_s = Pt(0, 0), CMP::ray_e = A[0].first;
for (int i = 1; i <= N; i++) {
if (CMP::ray2seg(segs[i]))
S.insert(segs[i]);
}
for (int i = 0; i < A.size(); ) {
CMP::ray_e = A[i].first;
while (i < A.size() && cmpZero(cross(CMP::ray_s, CMP::ray_e, A[i].first)) == 0) {
int clockwise, id = abs(A[i].second);
if (A[i].second > 0)
clockwise = cmpZero(cross(CMP::ray_s, segs[id].s, segs[id].e));
else
clockwise = cmpZero(cross(CMP::ray_s, segs[id].e, segs[id].s));
if (clockwise) {
if (clockwise > 0)
S.insert(segs[id]);
else
S.erase(segs[id]);
}
i++;
}
if (S.size() > 0)
visual[S.begin()->id] = 1;
}
for (int i = 1; i <= N; i++)
printf("%d%c", visual[i], i == N ? '\n' : ' ');
}
return 0;
}

附錄 DC

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include <bits/stdc++.h>
using namespace std;
#define eps 1e-9
struct Pt {
double x, y;
Pt(double a = 0, double b = 0):
x(a), y(b) {}
Pt operator-(const Pt &a) const {
return Pt(x - a.x, y - a.y);
}
Pt operator+(const Pt &a) const {
return Pt(x + a.x, y + a.y);
}
Pt operator*(const double a) const {
return Pt(x * a, y * a);
}
bool operator<(const Pt &a) const {
if (fabs(x - a.x) > eps)
return x < a.x;
if (fabs(y - a.y) > eps)
return y < a.y;
return false;
}
double dist2(Pt a) {
return (x - a.x)*(x - a.x)+(y - a.y)*(y - a.y);
}
};
double dot(Pt a, Pt b) {
return a.x * b.x + a.y * b.y;
}
double cross(Pt o, Pt a, Pt b) {
return (a.x-o.x)*(b.y-o.y)-(a.y-o.y)*(b.x-o.x);
}
double cross2(Pt a, Pt b) {
return a.x * b.y - a.y * b.x;
}
int between(Pt a, Pt b, Pt c) {
return dot(c - a, b - a) >= -eps && dot(c - b, a - b) >= -eps;
}
int onSeg(Pt a, Pt b, Pt c) {
return between(a, b, c) && fabs(cross(a, b, c)) < eps;
}
Pt getIntersect(Pt as, Pt ae, Pt bs, Pt be) {
Pt u = as - bs;
double t = cross2(be - bs, u)/cross2(ae - as, be - bs);
return as + (ae - as) * t;
}
struct Seg {
Pt s, e;
int id;
Seg(Pt a = Pt(), Pt b = Pt(), int c = 0):
s(a), e(b), id(c) {}
};
bool polar_cmp(const Pt& p1, const Pt& p2) {
if (p1.y == 0 && p2.y == 0 && p1.x * p2.x <= 0) return p1.x > p2.x;
if (p1.y == 0 && p1.x >= 0 && p2.y != 0) return true;
if (p2.y == 0 && p2.x >= 0 && p1.y != 0) return false;
if (p1.y * p2.y < 0) return p1.y > p2.y;
double c = cross2(p1, p2);
return c > 0 || (c == 0 && fabs(p1.x) < fabs(p2.x));
}
bool polar_cmp2(pair<Pt, int> x, pair<Pt, int> y) {
return polar_cmp(x.first, y.first);
}
int cmpZero(double x) {
if (fabs(x) < eps) return 0;
return x < 0 ? -1 : 1;
}
struct CMP {
static Pt ray_s, ray_e;
bool operator()(const Seg &x, const Seg &y) {
Pt v1 = getIntersect(ray_s, ray_e, x.s, x.e);
Pt v2 = getIntersect(ray_s, ray_e, y.s, y.e);
return cmpZero(ray_s.dist2(v1) - ray_s.dist2(v2)) < 0;
}
static bool ray2seg(Seg x) {
if (cmpZero(cross(ray_s, ray_e, x.s))*cmpZero(cross(ray_s, ray_e, x.e)) < 0) {
return cmpZero(cross(x.s, ray_s, ray_s+ray_e))*cmpZero(cross(x.s, ray_s, x.e)) >= 0 &&
cmpZero(cross(x.e, ray_s, ray_s+ray_e))*cmpZero(cross(x.e, ray_s, x.s)) >= 0;
}
return false;
}
};
Pt CMP::ray_s, CMP::ray_e;
const int MAXN = 32768;
int visual[MAXN];
Seg mirror[MAXN], sm[MAXN];
vector<Seg> computePolar(vector<Seg> segs) {
if (segs.size() == 0)
return vector<Seg>();
vector< pair<Pt, int> > A;
set<Seg, CMP> S;
for (int i = 0; i < segs.size(); i++) {
A.push_back(make_pair(segs[i].s, segs[i].id));
A.push_back(make_pair(segs[i].e, -segs[i].id));
visual[segs[i].id] = 0;
}
sort(A.begin(), A.end(), polar_cmp2);
CMP::ray_s = Pt(0, 0), CMP::ray_e = A[0].first;
for (int i = 0; i < segs.size(); i++) {
if (CMP::ray2seg(segs[i]))
S.insert(segs[i]);
}
for (int i = 0; i < A.size(); ) {
CMP::ray_e = A[i].first;
while (i < A.size() && cmpZero(cross(CMP::ray_s, CMP::ray_e, A[i].first)) == 0) {
int clockwise, id = abs(A[i].second);
if (A[i].second > 0)
clockwise = cmpZero(cross(CMP::ray_s, sm[id].s, sm[id].e));
else
clockwise = cmpZero(cross(CMP::ray_s, sm[id].e, sm[id].s));
if (clockwise) {
if (clockwise > 0)
S.insert(sm[id]);
else
S.erase(sm[id]);
}
i++;
}
if (S.size() > 0)
visual[S.begin()->id] = 1;
}
vector<Seg> ret;
for (int i = 0; i < segs.size(); i++) {
if (visual[segs[i].id])
ret.push_back(segs[i]);
}
return ret;
}
vector<Seg> dfs(int l, int r) {
vector<Seg> L, R;
if (l > r) return L;
if (l == r)
return computePolar(vector<Seg>(mirror+l, mirror+l+1));
int mid = (l+r)/2;
L = dfs(l, mid);
R = dfs(mid+1, r);
L.insert(L.end(), R.begin(), R.end());
return computePolar(L);
}
bool cmp(Seg a, Seg b) {
return polar_cmp(a.s, b.s);
}
int main() {
int N, sx, sy, ex, ey;
while (scanf("%d", &N) == 1) {
for (int i = 1; i <= N; i++) {
scanf("%d %d %d %d", &sx, &sy, &ex, &ey);
Pt s(sx, sy), e(ex, ey);
if (polar_cmp(s, e))
swap(s, e);
sm[i] = mirror[i] = Seg(s, e, i);
}
sort(mirror+1, mirror+N, cmp);
dfs(1, N);
for (int i = 1; i <= N; i++)
printf("%d%c", visual[i], i == N ? '\n' : ' ');
}
return 0;
}
Read More +

a739. 道路架設

Problem

給定一個有根樹,詢問兩個節點 $(u, v)$ 的距離,點數最多 $V = 2000000$

Sample Input

1
2
3
4
5
6
7
8
4
1 3
1 7
2 5
3
1 4
2 3
1 4

Sample Output

1
2
8
10

Solution

明顯地不可能使用 LCA-RMQ$O(Q \log V)$ 在線詢問,要套用 tarjan 算法中的 Offline-LCA,但 tarjan 算法靠的是遞迴遍歷,在 V = 20000000 很容易發生 stackoverflow,為了解決這一點採用非遞迴的方式實作。

這一題還有特別卡的地方,有向邊的儲存,之前 tarjan 都用在無向樹上,所以普遍都會儲存 $2 E$ 條邊,這一題必須只儲存 $E$ 條邊,常數卡得緊,不掛上優化輸入時間限制是在 TLE 邊緣。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#include <bits/stdc++.h>
using namespace std;
const int MAXV = 2000005;
const int MAXQ = 100005;
const int MAXE = 2000005;
class LCA {
public:
struct Edge {
int v;
Edge *next;
};
struct QEdge {
int qid, u;
QEdge *next;
};
Edge edge[MAXE], *adj[MAXV], *arc[MAXV];
QEdge qedge[MAXQ<<1], *qadj[MAXV], *qarc[MAXV];
int e, eq, n;
int parent[MAXV], weight[MAXV], visited[MAXV], LCA[MAXQ];
void init(int n) {
e = eq = 0, this->n = n;
for (int i = 0; i < n; i++)
adj[i] = NULL, qadj[i] = NULL;
}
void addDedge(int x, int y) {
edge[e].v = y, edge[e].next = adj[x], adj[x] = &edge[e++];
}
void addQuery(int x, int y, int qid) {
qedge[eq].qid = qid, qedge[eq].u = y, qedge[eq].next = qadj[x], qadj[x] = &qedge[eq++];
qedge[eq].qid = qid, qedge[eq].u = x, qedge[eq].next = qadj[y], qadj[y] = &qedge[eq++];
}
void offline(int root) {
tarjan(root);
}
private:
int findp(int x) {
return parent[x] == x ? x : (parent[x] = findp(parent[x]));
}
int joint(int x, int y) {
x = findp(x), y = findp(y);
if(x == y) return 0;
if(weight[x] > weight[y])
weight[x] += weight[y], parent[y] = x;
else
weight[y] += weight[x], parent[x] = y;
return 1;
}
struct Node {
int u, p, line;
Node(int a = 0, int b = 0, int c = 0):
u(a), p(b), line(c) {}
};
void tarjan(int root) {
for (int i = 0; i < n; i++)
arc[i] = adj[i], qarc[i] = qadj[i], visited[i] = 0;
stack<Node> stk;
Node u;
int x, y;
stk.push(Node(root, -1, 0));
parent[root] = root;
while (!stk.empty()) {
u = stk.top(), stk.pop();
if (u.line == 0) {
if (arc[u.u]) {
y = arc[u.u]->v, arc[u.u] = arc[u.u]->next;
stk.push(u);
parent[y] = y;
stk.push(Node(y, u.u, 0));
} else {
visited[u.u] = 1;
u.line++;
stk.push(u);
}
} else {
if (qarc[u.u]) {
x = qarc[u.u]->qid, y = qarc[u.u]->u, qarc[u.u] = qarc[u.u]->next;
stk.push(u);
if (visited[y])
LCA[x] = findp(y);
} else {
if (u.p != -1)
parent[findp(u.u)] = u.p;
}
}
}
}
} lca;
int A[MAXQ], B[MAXQ], dist[MAXV];
long long C[MAXV];
struct Node {
int u, line;
Node(int a = 0, int b = 0):
u(a), line(b) {}
};
void dfs(int root) {
for (int i = 0; i < lca.n; i++)
lca.arc[i] = lca.adj[i];
stack<Node> stk;
Node u;
int x, y;
dist[0] = 0, C[0] = 0;
stk.push(Node(root, 0));
while (!stk.empty()) {
u = stk.top(), stk.pop();
if (lca.arc[u.u] != NULL) {
y = lca.arc[u.u]->v;
lca.arc[u.u] = lca.arc[u.u]->next;
stk.push(u);
dist[y] = dist[u.u]+1;
C[y] += C[u.u];
stk.push(Node(y, 0));
}
}
}
namespace mLocalStream {
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
class Print {
public:
static const int N = 1048576;
char buf[N], *p, *end;
Print() {
p = buf, end = buf + N - 1;
}
void printInt(int x, char padding) {
static char stk[16];
int idx = 0;
stk[idx++] = padding;
if (!x)
stk[idx++] = '0';
while (x)
stk[idx++] = x%10 + '0', x /= 10;
while (idx) {
if (p == end) {
*p = '\0';
printf("%s", buf), p = buf;
}
*p = stk[--idx], p++;
}
}
static inline void online_printInt(int x) {
static char ch[16];
static int idx;
idx = 0;
if (x == 0) ch[++idx] = 0;
while (x > 0) ch[++idx] = x % 10, x /= 10;
while (idx)
putchar(ch[idx--]+48);
}
~Print() {
*p = '\0';
printf("%s", buf);
}
} bprint;
}
int main() {
int N, Q, P, c;
mLocalStream::ReadInt(&N);
lca.init(N);
for (int i = 1; i < N; i++) {
mLocalStream::ReadInt(&P);
mLocalStream::ReadInt(&c);
C[i] = c, P--;
lca.addDedge(P, i);
}
mLocalStream::ReadInt(&Q);
for (int i = 0; i < Q; i++) {
mLocalStream::ReadInt(A+i);
mLocalStream::ReadInt(B+i);
A[i]--, B[i]--;
lca.addQuery(A[i], B[i], i);
}
dfs(0);
lca.offline(0);
int lazy = 0;
long long d1, d2;
for (int i = 0; i < Q; i++) {
if (A[i] == B[i] || lca.LCA[i] != A[i])
lazy++;
else {
d1 = dist[A[i]] + dist[B[i]] - 2*dist[lca.LCA[i]];
d2 = C[A[i]] + C[B[i]] - 2*C[lca.LCA[i]];
printf("%lld\n", d1*lazy + d2);
}
}
return 0;
}
Read More +

b449. 加速策略 圈出角點

Problem

背景

影像處理中,給定一張圖,準確地找到點、線、邊、角都是相當困難的,由於圖片會受到干擾、顏色屬性的差異,使得擷取特徵相當困難。

問題描述

對於 $N \times M$ 的像素圖片,方便起見只由黑白影像構成,0 表示暗、1 表示亮,對於每一個像素位置判斷是否可能是角點。

在角點偵測的算法中,有一個由 Rosten and Drummond 提出的 FAST (Features from Accelerated Segment Test) 方法。概念由一個 $7 \times 7$ 的遮罩,待測點 $p$ 位於遮罩中心,由遮罩內圈上的 16 個像素的灰階判斷 $p$ 是否為角點。遮罩樣子如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
+--------------------+
| | |16| 1| 2| | |
+--------------------+
| |15| | | | 3| |
+--------------------+
|14| | | | | | 4|
+--------------------+
|13| | | p| | | 5|
+--------------------+
|12| | | | | | 6|
+--------------------+
| |11| | | | 7| |
+--------------------+
| | |10| 9| 8| | |
+---------------------

只要這個圈上出現連續大於等於 12 個相同的暗像素或者是亮像素,則 $p$ 就被視為一個角點。

不幸地,這會造成在一個角上出現很多角點,通常會根據掃描的順序找到角點,當找到一個角點後,會抑制鄰近區域不可以是角點。此題不考慮抑制情況,對於每一個角點必須在 16 個像素在圖片上才進行判斷,圖片邊界不進行偵測。

輸出一個 $N \times M$ 的矩陣,按照原圖片位置,若該點是角點則為 1,反之為 0。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
7 7
0011100
0100010
1000001
1000001
1000001
0100010
0011100
7 7
0011100
0100010
1000001
0000001
1000001
0100010
0010100

Sample Output

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Case #1:
0000000
0000000
0000000
0001000
0000000
0000000
0000000
Case #2:
0000000
0000000
0000000
0000000
0000000
0000000
0000000

Solution

這一題的技巧不是演算法,而是實作的加速細節。

同時,來實驗老師上課說的優化策略的用處,由於要連續 12 個,根據鴿籠原理的方式,挑選位置 1, 5, 9, 13 出來,若連續三個狀態相同再進行 $O(16)$ 的判斷。然而挑出這四個位置,可以加速 50%,相較於只有 $O(16)$ 的判斷效能,可以參考下方的樸素解。

樸素寫法並不是最快的,因為 branch 太多,導致速度至少為 800ms,去處理一張 $1920 \times 1080$ 的影像,更好的方案是使用 bitmask,預處理在 16bits 下,連續 9 個相同狀態的位元情況,搭配 loop unrolling 的方式去撰寫,直接 $O(16)$ 判斷,為了減少代碼量,採用巨集的前處理展開。請參考 bitmask 版本。速度來到 140ms,加速幾乎 8 倍。

單純的 bitmask 還不是最快,直接建表 $O(2^{16} \times 16)$ 得到 16bits 是否是角點,建表消耗時間,但單一判斷變成 $O(1)$。請參考 bitmask2 版本。速度來到 76ms,直接翻了快兩倍。

最終 bitmask3 版本 56ms,採用以下的方案:

  1. 減少型別轉換 movz 的出現,用補數來抽換判斷。
    意即 (a&mask) == 0 || (a&mask) == mask) 將只會有 (a&mask) == mask,需要 (a&mask) == 0 的判斷,則先 a = ~a 再進行 (a&mask) == mask
  2. 利用編譯的常數展開,減少二維陣列取址時的一次乘法。
    意即 g[x][y] 取址使用時,會動用到一次乘法和一次加法,對於每一個角點偵測,動用到 16 次的乘法運算。若矩陣大小事先已知,那麼對於某一行的角點,g[x] 可以用一次乘法計算,接著該行所有角點偵測,只會剩下 16 次的加法。
  3. 建表太慢,用 __builtin_popcount() 提供剪枝。

當然,以上變態至極的作法,倒不如樸素解直接開 g++ -O3 或者是 g++ -Ofast 來的省事,速度慢一點也是沒問題的對吧。

1
#pragma GCC optimize ("O3")

樸素解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <bits/stdc++.h>
using namespace std;
const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
char g[2048][2048], ret[2048][2048];
int FAST(int x, int y) {
char c[4] = {g[x+dx[0]][y+dy[0]], g[x+dx[4]][y+dy[4]],
g[x+dx[8]][y+dy[8]], g[x+dx[12]][y+dy[12]]};
if (c[0] == c[1] && c[1] == c[2] ||
c[1] == c[2] && c[2] == c[3] ||
c[2] == c[3] && c[3] == c[0] ||
c[3] == c[0] && c[0] == c[1]) {
int cc = -1, p = 1;
for (int it = 1, i = 1, j = 0; it < 16; it++, i++, i = i >= 16 ? 0 : i) {
if (g[x+dx[i]][y+dy[i]] == g[x+dx[j]][y+dy[j]])
j ++, j = j >= 16 ? 0 : j, p++;
else {
if (cc == -1)
cc = p;
j = i, p = 1;
}
if (p >= 12)
return 1;
}
if (g[x+dx[0]][y+dy[0]] == g[x+dx[15]][y+dy[15]] && p+cc >= 12)
return 1;
}
return 0;
}
int main() {
int N, M, cases = 0;
while (scanf("%d %d", &N, &M) == 2) {
while (getchar() != '\n');
for (int i = 0; i < N; i++)
fgets(g[i], 2000, stdin);
for (int i = 0; i < N; i++) {
for (int j = 0; j < M; j++) {
ret[i][j] = '0';
if (i-3 >= 0 && j-3 >= 0 && i+3 < N && j+3 < M)
ret[i][j] = FAST(i, j) + '0';
}
}
printf("Case #%d:\n", ++cases);
for (int i = 0; i < N; i++) {
ret[i][M] = '\0';
puts(ret[i]);
}
}
return 0;
}

bitmask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>
using namespace std;
const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
const int MAXH = 2048, MAXW = 2048;
const int MAXN = MAXH * MAXW;
char g[MAXH][MAXW], ret[MAXH*MAXW];
int n, m;
#define T(x, y, z) ((g[x+dx[z]][y+dy[z]])<<z)
#define UNLOOPX(i) (val&corner[i]) == 0 || (val&corner[i]) == corner[i] || \
(val&corner[i+1]) == 0 || (val&corner[i+1]) == corner[i+1] || \
(val&corner[i+2]) == 0 || (val&corner[i+2]) == corner[i+2] || \
(val&corner[i+3]) == 0 || (val&corner[i+3]) == corner[i+3]
#define UNLOOPXALL UNLOOPX(0) || UNLOOPX(4) || UNLOOPX(8) || UNLOOPX(12)
#define UNLOOPY(i) T(x, y, i) | T(x, y, i+1) | T(x, y, i+2) | T(x, y, i+3)
#define UNLOOPYALL UNLOOPY(0) | UNLOOPY(4) | UNLOOPY(8) | UNLOOPY(12)
typedef unsigned short int UINT16;
UINT16 corner[16] = {};
UINT16 rotate_left(UINT16 x, UINT16 n) {
return (x << n) | (x >> (16-n));
}
inline int FAST(int x, int y) {
UINT16 val = UNLOOPYALL;
return UNLOOPXALL;
}
int main() {
for (int i = 0, j = (1<<12)-1; i < 16; i++, j = rotate_left(j, 1))
corner[i] = j;
int cases = 0;
while (scanf("%d %d", &n, &m) == 2) {
while (getchar() != '\n');
for (int i = 0; i < n; i++) {
fgets(g[i], MAXW, stdin);
for (int j = 0; j < m; j++)
g[i][j] -= '0';
}
int bn = n-3, bm = m-3;
char *p = ret;
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
for (int i = 3; i < bn; i++) {
*p = '0', p++;
*p = '0', p++;
*p = '0', p++;
for (int j = 3; j < bm; j++)
*p = FAST(i, j) + '0', p++;
*p = '0', p++;
*p = '0', p++;
*p = '0', p++;
*p = '\n', p++;
}
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
*p = '\0', p++;
printf("Case #%d:\n", ++cases);
puts(ret);
}
return 0;
}

bitmask2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
using namespace std;
//const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
//const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
const int MAXH = 2048, MAXW = 2048;
const int MAXN = MAXH * MAXW;
char g[MAXH][MAXW], ret[MAXH*MAXW];
int n, m;
#define dx_00 -3
#define dx_01 -3
#define dx_02 -2
#define dx_03 -1
#define dx_40 0
#define dx_41 1
#define dx_42 2
#define dx_43 3
#define dx_80 3
#define dx_81 3
#define dx_82 2
#define dx_83 1
#define dx_120 0
#define dx_121 -1
#define dx_122 -2
#define dx_123 -3
#define dy_00 0
#define dy_01 1
#define dy_02 2
#define dy_03 3
#define dy_40 3
#define dy_41 3
#define dy_42 2
#define dy_43 1
#define dy_80 0
#define dy_81 -1
#define dy_82 -2
#define dy_83 -3
#define dy_120 -3
#define dy_121 -3
#define dy_122 -2
#define dy_123 -1
#define T(x, y, z, w) ((g[x + dx_##z##w][y + dy_##z##w])<<(z+w))
#define UNLOOPX(i) ((val&cor[i]) == cor[i] || (val&cor[i+1]) == cor[i+1] || \
(val&cor[i+2]) == cor[i+2] || (val&cor[i+3]) == cor[i+3])
#define UNLOOPXALL UNLOOPX(0) || UNLOOPX(4) || UNLOOPX(8) || UNLOOPX(12)
#define UNLOOPY(i) T(x, y, i, 0) | T(x, y, i, 1) | T(x, y, i, 2) | T(x, y, i, 3)
#define UNLOOPYALL UNLOOPY(0) | UNLOOPY(4) | UNLOOPY(8) | UNLOOPY(12)
typedef unsigned short int UINT16;
UINT16 cor[16] = {};
UINT16 rotate_left(UINT16 x, UINT16 n) {
return (x << n) | (x >> (16-n));
}
int f[1<<16];
int main() {
UINT16 val;
for (int i = 0, j = (1<<12)-1; i < 16; i++, j = rotate_left(j, 1))
cor[i] = j;
for (int i = 0; i < 1<<16; i++) {
val = i;
f[i] = UNLOOPXALL ? 1 : (val = ~val, UNLOOPXALL);
}
int cases = 0;
char c;
while (scanf("%d %d", &n, &m) == 2) {
while (getchar() != '\n');
for (int i = 0; i < n; i++) {
fgets(g[i], 2000, stdin);
for (int j = 0; j < m; j++)
g[i][j] -= '0';
}
int bn = n-3, bm = m-3;
char *p = ret;
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
for (int x = 3; x < bn; x++) {
*p = '0', p++, *p = '0', p++, *p = '0', p++;
#define UNLOOP { \
val = UNLOOPYALL; \
*p = f[val] | '0'; \
p++, y++; \
}
#define UNLOOP4 UNLOOP UNLOOP UNLOOP UNLOOP
#define UNLOOP8 {UNLOOP4 UNLOOP4}
int y = 3;
for (; y+8 < bm; )
UNLOOP8;
for (; y < bm; )
UNLOOP;
*p = '0', p++, *p = '0', p++, *p = '0', p++;
*p = '\n', p++;
}
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
*p = '\0', p++;
printf("Case #%d:\n", ++cases);
puts(ret);
}
return 0;
}

bitmask3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
using namespace std;
//const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
//const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
#define MAXH 2048
#define MAXW 2048
char g[MAXH][MAXW], ret[MAXH*MAXW];
char *ptr_g;
int n, m;
#define AB00 -3*MAXW
#define AB01 -3*MAXW+1
#define AB02 -2*MAXW+2
#define AB03 -MAXW+3
#define AB40 3
#define AB41 MAXW+3
#define AB42 2*MAXW+2
#define AB43 3*MAXW+1
#define AB80 3*MAXW
#define AB81 3*MAXW-1
#define AB82 2*MAXW-2
#define AB83 MAXW-3
#define AB120 -3
#define AB121 -MAXW-3
#define AB122 -2*MAXW-2
#define AB123 -3*MAXW-1
#define T(x, y, z, w) (*(ptr_g + AB##z##w)<<(z+w))
#define UNLOOPX(i) ((val&cor[i]) == cor[i] || (val&cor[i+1]) == cor[i+1] || \
(val&cor[i+2]) == cor[i+2] || (val&cor[i+3]) == cor[i+3])
#define UNLOOPXALL UNLOOPX(0) || UNLOOPX(4) || UNLOOPX(8) || UNLOOPX(12)
#define UNLOOPY(i) T(x, y, i, 0) | T(x, y, i, 1) | T(x, y, i, 2) | T(x, y, i, 3)
#define UNLOOPYALL UNLOOPY(0) | UNLOOPY(4) | UNLOOPY(8) | UNLOOPY(12)
typedef unsigned short int UINT16;
UINT16 cor[16] = {};
UINT16 rotate_left(UINT16 x, UINT16 n) {
return (x << n) | (x >> (16-n));
}
char f[1<<16];
int main() {
UINT16 val;
for (int i = 0, j = (1<<12)-1; i < 16; i++, j = rotate_left(j, 1))
cor[i] = j;
for (int i = 0, one; i < 1<<16; i++) {
val = i, one = __builtin_popcount(val);
if (one < 12 && one > 4)
f[i] = 0;
else
f[i] = UNLOOPXALL ? 1 : (val = ~val, UNLOOPXALL);
f[i] |= '0';
}
int cases = 0;
char c;
while (scanf("%d %d", &n, &m) == 2) {
while (getchar() != '\n');
for (int i = 0; i < n; i++) {
fgets(g[i], 2000, stdin);
for (int j = 0; j < m; j++)
g[i][j] -= '0';
}
int bn = n-3, bm = m-3;
char *p = ret;
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
for (int x = 3, y; x < bn; x++) {
*p = '0', p++, *p = '0', p++, *p = '0', p++;
ptr_g = g[x]+3;
#define UNLOOP { \
val = UNLOOPYALL; \
*p = f[val]; \
p++, y++, ptr_g++; \
}
#define UNLOOP4 UNLOOP UNLOOP UNLOOP UNLOOP
#define UNLOOP8 {UNLOOP4 UNLOOP4}
for (y = 3; y+8 < bm; )
UNLOOP8;
for (; y < bm; )
UNLOOP;
*p = '0', p++, *p = '0', p++, *p = '0', p++;
*p = '\n', p++;
}
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
*p = '\0', p++;
printf("Case #%d:\n", ++cases);
puts(ret);
}
return 0;
}
Read More +

b446. 搜索美學 史蒂芙的煩惱

Problem

背景

動畫 遊戲人生《No Game No Life》中,史蒂芙 (Stephanie Dola) 常常被欺負,儘管她以學院第一畢業,對於遊戲一竅不通的她在這個世界常常被欺負。現在就交給你來幫幫她。

問題描述

兩個人輪流在一個大棋盤上下棋,每一步棋的得分根據這一步棋與最鄰近的敵方棋子的曼哈頓距離。

對於兩個點 $p, q$ 座標 $(p_x, p_y), (q_x, q_y)$,曼哈頓距離 (Manhattan distance) 為 $|p_x - q_x| + |p_y - q_y|$

Sample Input

1
2
3
4
5
6
7
3
1 1
5 5
4 4
3 2
2 4
2 3

Sample Output

1
2
3
4
5
8
2
3
3
1

Solution

把鄰近搜索問題做個總結,普遍處理的是靜態資料跟單一詢問,而在最近餐館那一題已經用 KD-tree 處理過 KNN 問題。這一題是採用動態插入和詢問以及數學性質較強的曼哈頓距離,離線處理也是個選擇。

此問題限制在 $n = 50000$ 的情況下,進行測試討論,除了分桶、方格法外,探討三種思路:

  • Dynamic KD-tree 利用替罪羊樹的概念完成,看著卦長的代碼以及卡車口述概要,終於敲敲打打拼湊起來,掛上啟發式的搭配具有不錯的成效。空間複雜度 $O(n)$,插入複雜度 $O(\log^2 n)$,查詢 $O(\log n)$ (據說是在曼哈頓距離下的緣故),速度是暴力法 $O(n^2)$ 二十倍左右。

  • Segment tree + 平衡樹,空間複雜度 $O(n \log n)$,時間複雜度 $O(\log^3 n)$,使用座標轉換將菱形轉換成正方形,套上二分邊長去查找區域內部是否有點。由於 $n$ 的緣故,速度比 Dynamic KD-tree 慢上許多,若用暴力法 $O(n^2)$ 只快兩倍之多。實作測試提供者 liouzhou_101。

  • 離線處理 CDQ 分治,空間複雜度 $O(n)$,總時間複雜度 $O(n \log^2 n)$,採用思路為曼哈頓距離切割成四個象限進行極值查找。比暴力法快十倍所右。

前兩個作法比較裸,在此特別補充 CDQ 分治,曼哈頓距離可以考慮成四個象限,詢問 $(x, y)$ 的最鄰近點,首先考慮左下角 $(x', y')$,亦即 $x' \le x, \; y' \le y$,則曼哈頓距離 $dist = (x - x') + (y - y') = (x + y) - (x' + y')$,明顯地求最近距離要讓 $x' + y'$ 最大化。同理其他象限。

為了解決這詢問,套用 CDQ 分治,按照 $x$ 座標排序,接著二分操作順序,切割操作 $[l, mid], [mid+1, r]$,在左右兩塊仍然按照 $x$ 排序。單獨看 $[mid+1, r]$ 的操作會受 $[l, mid]$ 和自己本身影響,對於前者而言,採用歸併排序那樣,按照 $x$ 座標慢慢合併 (概念上),合併過程套用 Binary indexed tree 進行極值查找。對於後者,就進行遞迴求解,明顯地 $[l, mid]$ 只會受 $[l, mid]$ 影響。

CDQ 分治的概要,按照其中一個關鍵排序,接著二分操作順序進行分置處理。國外是有論文在描述這個 Online to Offline 的算法,CDQ 命名就是人名,會給國外看笑話吧。

1
2
3
4
5
6
sort(key)
solve(l, r)
solve(l, mid)
process([l, mid], [mid+1, r])
solve(mid+1, r)

備註「欸欸,加上悔棋的話,是不是持久化 kd-tree」

實作探討

關於 kd-tree 實作細節探討,與通常會犯的錯誤,關係到速度有常數差異。

closest() 中,常犯的錯誤是 探索順序 ,盡可能先靠近,啟發式才能更加快速,別像我打出錯誤的搜索順序如下:

1
2
3
4
5
6
7
8
9
10
11
if (x.d[k] <= u->pid.d[k]) {
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = old;
} else {
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = old;
closest(u->rson, (k+1)%kD, x, h, mndist);
}

實作的順序應該如下,別總是先探訪左子樹、在去探訪右子樹,kd-tree 必須注意順序。

1
2
3
4
5
6
7
8
9
10
11
if (x.d[k] <= u->pid.d[k]) {
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = old;
} else {
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = old;
}

假設資料大小 sizeof(dim_element) 很大,通常會利用指針陣列來進行排序,這樣可以降低搬運大型資料複製時間,但奇怪的是由於指針陣列佔有一定空間,索引資料時又會佔據一段空間,估計是快取方面出了點問題 (或者是我寫不好),導致直接搬運資料是來得比較快速,這個修改在 b348. 最近餐館 也有進行測試,速度有提升。

接著可以藉由函數參數少量,來拉快程式在堆疊參數所需要的時間,在節點內部宣告採用維度 d

1
2
3
4
5
struct Node {
Node *lson, *rson;
Point pid;
int size, d;
};

這個修改造成詢問時,不僅僅在走訪傳遞參數少了一個,還少 k+1 的計算。測試結果中,光靠這一點速度沒有明顯提升。kd tree 還有一個靠臉吃飯的邊界分割,要是相同時分左分右,這一點是最痛苦的,在此就不去討論,當然可以利用隨機擾動來解決這問題。

至於要使用 sort() 進行 $O(n \log n)$、還是使用 nth_element()$O(n)$ 找到中位數,根據兩題的測試,由於 $n$ 都不大,照理來講 nth_element() 快於 sort(),但根據實際測試於 liouzhou_101 的代碼,sort() 的速度會比較快,其一是運氣、其二是未知情況。就從以下代碼中,差異並不明顯。

Dynamic Kd tree

沒有提供垃圾回收,靠內存持運作,要是 RE 就放大一點。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 131072;
const int MAXM = 50005;
const int MAXD = 2;
const int INF = INT_MAX;
const double ALPHA = 0.75;
const double LOG_ALPHA = log2(1.0 / ALPHA);
class KD_TREE {
public:
struct Point {
static int kD;
int d[MAXD], pid;
int dist(Point &x) {
int ret = 0;
for (int i = 0; i < kD; i++)
ret += abs(d[i] - x.d[i]);
return ret;
}
void read(int id = 0) {
for (int i = 0; i < kD; i++)
scanf("%d", &d[i]);
pid = id;
}
static int sortIdx;
bool operator<(const Point &x) const {
return d[sortIdx] < x.d[sortIdx];
}
};
struct Node {
Node *lson, *rson;
Point pid;
int size;
Node() {
lson = rson = NULL;
size = 1;
}
void update() {
size = 1;
if (lson) size += lson->size;
if (rson) size += rson->size;
}
} nodes[MAXN];
Node *root;
Point A[MAXM];
int bufsize, size, kD;
void init(int kd) {
size = bufsize = 0;
root = NULL;
Point::sortIdx = 0;
Point::kD = kD = kd;
}
void insert(Point x) {
insert(root, 0, x, log2int(size) / LOG_ALPHA);
}
int closest(Point x) {
int mndist = INF, h[MAXD] = {};
closest(root, 0, x, h, mndist);
return mndist;
}
private:
int log2int(int x){
return __builtin_clz((int)1)-__builtin_clz(x);
}
inline int isbad(Node *u) {
if (u->lson && u->lson->size > u->size * ALPHA)
return 1;
if (u->rson && u->rson->size > u->size * ALPHA)
return 1;
return 0;
}
Node* newNode() {
Node *ret = &nodes[bufsize++];
*ret = Node();
return ret;
}
Node* build(int k, int l, int r) {
if (l > r) return NULL;
if (k == kD) k = 0;
Node *ret = newNode();
int mid = (l + r)>>1;
Point::sortIdx = k;
sort(A+l, A+r+1);
ret->pid = A[mid];
ret->lson = build(k+1, l, mid-1);
ret->rson = build(k+1, mid+1, r);
ret->update();
return ret;
}
void flatten(Node *u, Point* &buf) {
if (u == NULL) return ;
flatten(u->lson, buf);
*buf = u->pid, buf++;
flatten(u->rson, buf);
}
bool insert(Node* &u, int k, Point &x, int dep) {
if (u == NULL) {
u = newNode(), u->pid = x;
return dep <= 0;
}
u->size++;
int t = 0;
if (x.d[k] <= u->pid.d[k])
t = insert(u->lson, (k+1)%kD, x, dep-1);
else
t = insert(u->rson, (k+1)%kD, x, dep-1);
if (t && !isbad(u))
return 1;
if (t) {
Point *ptr = &A[0];
flatten(u, ptr);
u = build(k, 0, u->size-1);
}
return 0;
}
int heuristic(int h[]) {
int ret = 0;
for (int i = 0; i < kD; i++)
ret += h[i];
return ret;
}
void closest(Node *u, int k, Point &x, int h[], int &mndist) {
if (u == NULL || heuristic(h) >= mndist)
return ;
int dist = u->pid.dist(x), old;
mndist = min(mndist, dist), old = h[k];
if (x.d[k] <= u->pid.d[k]) {
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = old;
} else {
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = old;
}
}
} A, B;
int KD_TREE::Point::sortIdx = 0, KD_TREE::Point::kD = 2;
int main() {
int N;
KD_TREE::Point pt;
while (scanf("%d", &N) == 1) {
A.init(2), B.init(2);
for (int i = 0; i < N; i++) {
pt.read(i);
if (i) printf("%d\n", B.closest(pt));
A.insert(pt);
pt.read(i);
printf("%d\n", A.closest(pt));
B.insert(pt);
}
}
return 0;
}

CDQ 分治

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include <bits/stdc++.h>
using namespace std;
const int MAXQ = 131072;
const int MAXN = 40000;
const int INF = INT_MAX;
class Offline {
public:
struct Point {
int x, y;
Point(int a = 0, int b = 0):
x(a), y(b) {}
bool operator<(const Point &a) const {
return x < a.x || (x == a.x && y < a.y);
}
void read() {
scanf("%d %d", &x, &y);
x++, y++;
}
};
struct Event {
Point p;
int qtype, qid;
Event(int a = 0, int b = 0, Point c = Point()):
qtype(a), qid(b), p(c) {}
bool operator<(const Event &e) const {
if (p.x != e.p.x) return p.x < e.p.x;
return qid < e.qid;
}
};
vector<Event> event;
int ret[MAXQ], N;
void init(int n) {
event.clear();
N = n;
}
void addEvent(int qtype, int qid, Point x) {
event.push_back(Event(qtype, qid, x));
}
void run() {
for (int i = 0; i < event.size(); i++)
ret[i] = 0x3f3f3f3f;
cases = 0;
for (int i = 0; i <= N; i++)
used[i] = 0;
sort(event.begin(), event.end());
CDQ(0, event.size()-1);
}
private:
Event ebuf[MAXQ];
int BIT[MAXN], used[MAXN];
int cases = 0;
void modify(int x, int val, int dir) {
for (; x && x <= N; x += (x&(-x)) * dir) {
if (used[x] != cases)
BIT[x] = -0x3f3f3f3f, used[x] = cases;
BIT[x] = max(BIT[x], val);
}
}
int query(int x, int dir) {
int ret = -0x3f3f3f3f;
for (; x && x <= N; x += (x&(-x)) * dir) {
if (used[x] == cases)
ret = max(ret, BIT[x]);
}
return ret;
}
void merge(int l, int mid, int r) {
cases++;
for (int i = mid+1, j = l; i <= r; i++) {
if (event[i].qtype == 0) {
for (; j <= mid && event[j].p.x <= event[i].p.x; j++) {
if (event[j].qtype == 1)
modify(event[j].p.y, event[j].p.x+event[j].p.y, 1);
}
ret[event[i].qid] = min(ret[event[i].qid], event[i].p.x+event[i].p.y-query(event[i].p.y, -1));
}
}
cases++;
for (int i = mid+1, j = l; i <= r; i++) {
if (event[i].qtype == 0) {
for (; j <= mid && event[j].p.x <= event[i].p.x; j++) {
if (event[j].qtype == 1)
modify(event[j].p.y, event[j].p.x-event[j].p.y, -1);
}
ret[event[i].qid] = min(ret[event[i].qid], event[i].p.x-event[i].p.y-query(event[i].p.y, 1));
}
}
cases++;
for (int i = r, j = mid; i > mid; i--) {
if (event[i].qtype == 0) {
for (; j >= l && event[j].p.x >= event[i].p.x; j--) {
if (event[j].qtype == 1)
modify(event[j].p.y, event[j].p.y-event[j].p.x, 1);
}
ret[event[i].qid] = min(ret[event[i].qid], event[i].p.y-event[i].p.x-query(event[i].p.y, -1));
}
}
cases++;
for (int i = r, j = mid; i > mid; i--) {
if (event[i].qtype == 0) {
for (; j >= l && event[j].p.x >= event[i].p.x; j--) {
if (event[j].qtype == 1)
modify(event[j].p.y, -event[j].p.x-event[j].p.y, -1);
}
ret[event[i].qid] = min(ret[event[i].qid], -event[i].p.x-event[i].p.y-query(event[i].p.y, 1));
}
}
}
void CDQ(int l, int r) {
if (l == r)
return ;
int mid = (l + r)/2, lidx, ridx;
lidx = l, ridx = mid+1;
for (int i = l; i <= r; i++) {
if (event[i].qid <= mid)
ebuf[lidx++] = event[i];
else
ebuf[ridx++] = event[i];
}
for (int i = l; i <= r; i++)
event[i] = ebuf[i];
CDQ(l, mid);
merge(l, mid, r);
CDQ(mid+1, r);
lidx = l, ridx = mid+1;
for (int i = l; i <= r; i++) {
if ((lidx <= mid && event[lidx] < event[ridx]) || ridx > r)
ebuf[i] = event[lidx++];
else
ebuf[i] = event[ridx++];
}
for (int i = l; i <= r; i++)
event[i] = ebuf[i];
}
} A, B;
int main() {
int N;
Offline::Point pt;
while (scanf("%d", &N) == 1) {
A.init(65536), B.init(65536);
int max_y = 0;
for (int i = 0; i < N; i++) {
pt.read(), max_y = max(max_y, pt.y);
B.addEvent(0, 2*i, pt);
A.addEvent(1, 2*i, pt);
pt.read(), max_y = max(max_y, pt.y);
A.addEvent(0, 2*i+1, pt);
B.addEvent(1, 2*i+1, pt);
}
A.N = B.N = max_y; // y in [1, max_y]
A.run(), B.run();
for (int i = 0; i < N; i++) {
if (i) printf("%d\n", B.ret[2*i]);
printf("%d\n", A.ret[2*i+1]);
}
}
return 0;
}
Read More +

b443. 我愛 Fibonacci

Problem

$$F_n=\begin{cases} n, & n=0,1 \\ F_{n-1}+F_{n-2}, & n \geq 2 \end{cases}$$

求出$F_{2^n} \mod m$ 的結果。

Sample Input

1
2
3
4
3
1 1000000007
2 1000000007
3 1000000007

Sample Output

1
2
3
1
3
21

Solution

一般的矩陣計算,利用 $M^n$ 求出$F_n$,其中

$$M = \begin{bmatrix} 1 & 1\\ 1 & 0 \end{bmatrix}$$

如果是求$F_n$ 時間複雜度 $O(\log n)$,而這一題求的是$F_{2^n}$,時間複雜度 $O(n)$

為了加速運算,目標是要找到 $\mod p$ 下的循環長度 $L$,最後求出$F_{2^n \mod L}$,根據待會的證明,保證 $L \le p$,那複雜度就可以回到 $O(\log L)$ 解決。但為了要找到 $L$ 又是一段很長的故事,總時間複雜度為 $O(\sqrt{p})$,不用保證 $p$ 是質數。

參考資料

故事

數列$F_0 = 1, F_1 = 1, F_2 = 2, \cdots$,循環是連續兩項出現重複,而費氏數列會完全循環,也就是出現連續兩項$F_i = 0, F_{i-1} = 1$。下方是一個 $\mod 4$ 的情況。

1
1 1 2 3 1 0 | 1 2 3 1 0 | ...

要找到恰好連續兩項$F_i = 0, F_{i-1} = 1$ 是困難的,考慮去找到$F_i = 0$ 即可,接著再去想辦法讓$F_{i-1} = 1$

假設最小的 $k$ 滿足$F_k = 0 \mod p$,而$F_{k-1} = a \mod p$,那麼之後的序列$F_{i} = a^j F_{i+j \times k} \mod p$。從矩陣乘法的概念中可以理解,是一個常數為 $a$ 的初始項,第二輪循環常數就會變成 $a^2$,類推。

接下來

  • 考慮一個嚴重的問題「 何種模 $p$ 情況一定循環,即從$F_0$ 再次循環。 」答案是 質數 $p$
    原因是$F_{i} = a^j F_{i+j \times k} \mod p$,由於 $a$$p$ 互質,$a^j \mod p \neq 0$ 恆成立,同時還是一個 $ord_{p}(a) = p-1$,這部分從歐拉定理中可以了解,那麼只有可能在$F_{i} = 0$ 的情況成立,就是 $k$ 的倍數之外,不發生$F_i = 0$ 的出現。
  • 接續上一個問題「模 $p$ 不是質數怎麼處理?」
    進行質因數分解,對於每一個質因子找到模循環長度,模 $p$ 循環長度就是所有質因子循環長度的最小公倍數 lcm。

現在問題落在 $k$ 怎麼找到,若能找到 $k$,其循環長度落在 $k$ 的倍數,或者有更好的獲取方式。

  • 若模質數 $p$ 且滿足 $p > 5$,5 是 $p$ 的二次剩餘 (quadratic residue),意即滿足 $\exists \; x^2 \equiv 5 \mod p$,循環長度為 $p-1$ 的因數。反之,循環長度是 $2(p+1)$ 的因數。

關於二次剩餘的判斷,在模質數 $p$ 下,對於 $gcd(x, p) = 1$,藉由歐拉定理得到 $x^{p-1} \equiv 1 \mod p$,以下不保證是正確的說法,提供理解的一個方案。

  • $d$$p$ 的二次剩餘,則滿足 $d^{(p-1)/2} \equiv 1 \mod p$,因為$x^{2 \times (p-1)/2} \equiv 1 \mod p \Rightarrow x^{p-1} \equiv 1 \mod p$
  • 若非二次剩餘,則滿足 $d^{(p-1)/2} \equiv -1 \mod p$,因為 $\left [ d^{(p-1)/2} \right ]^2 \equiv 1 \mod p \Rightarrow d^{p-1} \equiv 1 \mod p$
  • 數學上用 Legendre symbol 來表示這個判斷 wiki

回過頭來,看一下費氏數列的公式解

$F_n = \frac{1}{\sqrt{5}} \left [ \left ( \frac{1+\sqrt{5}}{2} \right )^n - \left (\frac{1-\sqrt{5}}{2} \right )^n \right ]$

藉由展開公式,儘管它是實數、根號,展開之後一定只會剩下整數冪次的總和。令 $a = \sqrt{5}$,觀察二次剩餘與否和滿足$F_n = 0$ 的關係。

在模質數 $p$ 下,滿足二次剩餘$F_n \equiv 0 \mod p$,當 $n = p-1$ 的時候成立,可以藉由噁心的展開式得到。同理在非二次剩餘情況,$n = 2(p+1)$,找到一個最大的倍數情況,答案一定落在其因數下。詳細推導請看參考資料,太噁心就不提。

參考資料中有特別提到,有一個地方還 沒有確認 ,對於模數 $p^k$ 的循環長度 $g(p) \times p^{k-1}$ 如何證明。但我想根據中國餘式定理,能了解循環長度倍數的模關係吧。接著由於大整數分解期望是 $O(\sqrt{n})$,中間也要找到所有因數來得到循環長度的驗證,還要搭配快速矩陣乘法,最後也是 $O(\sqrt{n})$

這一題是更進階的費氏數列計算,從建表、矩陣乘法,最後來到循環搜尋,這些證明不久之後就會忘記了,二次剩餘判斷的想法還會留著,其他就忘得一乾二淨吧。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
#include <bits/stdc++.h>
using namespace std;
#define MILLER_BABIN 4
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
struct Matrix {
UINT64 v[2][2];
int row, col; // row x col
Matrix(int n, int m, int a = 0) {
memset(v, 0, sizeof(v));
row = n, col = m;
for(int i = 0; i < row && i < col; i++)
v[i][i] = a;
}
Matrix multiply(const Matrix& x, const long long mod) const {
Matrix ret(row, x.col);
for(int i = 0; i < row; i++) {
for(int k = 0; k < col; k++) {
if (!v[i][k])
continue;
for(int j = 0; j < x.col; j++) {
ret.v[i][j] += mul(v[i][k], x.v[k][j], mod);
if (ret.v[i][j] >= mod)
ret.v[i][j] -= mod;
}
}
}
return ret;
}
Matrix pow(const long long& n, const long long mod) const {
Matrix ret(row, col, 1), x = *this;
long long y = n;
while(y) {
if(y&1) ret = ret.multiply(x, mod);
y = y>>1, x = x.multiply(x, mod);
}
return ret;
}
} FibA(2, 2, 0);
#define MAXL (50000>>5)+1
#define GET(x) (mark[x>>5]>>(x&31)&1)
#define SET(x) (mark[x>>5] |= 1<<(x&31))
int mark[MAXL], P[50000], Pt = 0;
void sieve() {
register int i, j, k;
SET(1);
int n = 46340;
for (i = 2; i <= n; i++) {
if (!GET(i)) {
for (k = n/i, j = i*k; k >= i; k--, j -= i)
SET(j);
P[Pt++] = i;
}
}
}
UINT64 mpow(UINT64 x, UINT64 y, UINT64 mod) { // mod < 2^32
UINT64 ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
UINT64 mpow2(UINT64 x, UINT64 y, UINT64 mod) {
UINT64 ret = 1;
while (y) {
if (y&1)
ret = mul(ret, x, mod);
y >>= 1, x = mul(x, x, mod);
}
return ret;
}
void exgcd(long long x, long long y, long long &g, long long &a, long long &b) {
if (y == 0)
g = x, a = 1, b = 0;
else
exgcd(y, x%y, g, b, a), b -= (x/y) * a;
}
long long llgcd(long long x, long long y) {
if (x < 0) x = -x;
if (y < 0) y = -y;
if (!x || !y) return x + y;
long long t;
while (x%y)
t = x, x = y, y = t%y;
return y;
}
long long inverse(long long x, long long p) {
long long g, b, r;
exgcd(x, p, g, r, b);
if (g < 0) r = -r;
return (r%p + p)%p;
}
int isPrime(long long p) { // implements by miller-babin
if (p < 2 || !(p&1)) return 0;
if (p == 2) return 1;
long long q = p-1, a, t;
int k = 0, b = 0;
while (!(q&1)) q >>= 1, k++;
for (int it = 0; it < MILLER_BABIN; it++) {
a = rand()%(p-4) + 2;
t = mpow2(a, q, p);
b = (t == 1) || (t == p-1);
for (int i = 1; i < k && !b; i++) {
t = mul(t, t, p);
if (t == p-1)
b = 1;
}
if (b == 0)
return 0;
}
return 1;
}
long long pollard_rho(long long n, long long c) {
long long x = 2, y = 2, i = 1, k = 2, d;
while (true) {
x = (mul(x, x, n) + c);
if (x >= n) x -= n;
d = llgcd(x - y, n);
if (d > 1) return d;
if (++i == k) y = x, k <<= 1;
}
return n;
}
void factorize(int n, vector<long long> &f) {
for (int i = 0; i < Pt && P[i]*P[i] <= n; i++) {
if (n%P[i] == 0) {
while (n%P[i] == 0)
f.push_back(P[i]), n /= P[i];
}
}
if (n != 1) f.push_back(n);
}
void llfactorize(long long n, vector<long long> &f) {
if (n == 1)
return ;
if (n < 1e+9) {
factorize(n, f);
return ;
}
if (isPrime(n)) {
f.push_back(n);
return ;
}
long long d = n;
for (int i = 2; d == n; i++)
d = pollard_rho(n, i);
llfactorize(d, f);
llfactorize(n/d, f);
}
// above largest factor
// ---------------------- //
int legendre_symbol(UINT64 d, UINT64 p) {
if (d%p == 0) return 0;
return mpow2(d, (p-1)>>1, p) == 1 ? 1 : -1;
}
void factor_gen(int idx, long long x, vector< pair<long long, int> > &f, vector<long long> &ret) {
if (idx == f.size()) {
ret.push_back(x);
return ;
}
for (long long i = 0, a = 1; i <= f[idx].second; i++, a *= f[idx].first)
factor_gen(idx+1, x*a, f, ret);
}
void factor_gen(long long n, vector<long long> &ret) {
vector<long long> f;
vector< pair<long long, int> > f2;
llfactorize(n, f);
sort(f.begin(), f.end());
int cnt = 1;
for (int i = 1; i <= f.size(); i++) {
if (i == f.size() || f[i] != f[i-1])
f2.push_back(make_pair(f[i-1], cnt)), cnt = 1;
else
cnt ++;
}
factor_gen(0, 1, f2, ret);
sort(ret.begin(), ret.end());
}
UINT64 cycleInFib(UINT64 p) {
if (p == 2) return 3;
if (p == 3) return 8;
if (p == 5) return 20;
vector<long long> f;
if (legendre_symbol(5, p) == 1)
factor_gen(p-1, f);
else
factor_gen(2*(p+1), f);
long long f1, f2;
for (int i = 0; i < f.size(); i++) {
Matrix t = FibA.pow(f[i]-1, p);
f1 = (t.v[0][0] + t.v[0][1])%p;
f2 = (t.v[1][0] + t.v[1][1])%p;
if (f1 == 1 && f2 == 0)
return f[i];
}
return 0;
}
UINT64 cycleInFib(UINT64 p, int k) {
UINT64 s = cycleInFib(p);
for (int i = 1; i < k; i++)
s = s * p;
return s;
}
int main() {
sieve();
FibA.v[0][0] = 1, FibA.v[0][1] = 1;
FibA.v[1][0] = 1, FibA.v[1][1] = 0;
int testcase;
scanf("%d", &testcase);
while (testcase--) {
long long n, m;
scanf("%lld %lld", &n, &m);
vector<long long> f;
map<long long, int> r;
llfactorize(m, f);
for (auto &x : f)
r[x]++;
UINT64 cycle = 1;
for (auto &x : r) {
UINT64 t = cycleInFib(x.first, x.second);
cycle = cycle / llgcd(t, cycle) * t;
}
n = mpow2(2, n, cycle);
Matrix t = FibA.pow(n, m);
long long fn = t.v[1][0];
printf("%lld\n", fn);
}
return 0;
}
Read More +

b444. 期望試驗 快速冪次

Problem

背景

曾經某 M 被期望值坑,就只是在計算 $x^y \mod z$ 時偷偷替換成 $x^{y-1} \times x \mod z$,結果得到 Time Limit Exceeded。

根據分析 $y = 16$ 時,用二進制表示為$(10000)_{2}$,若變成 $y = 15$,就會變成$(01111)_{2}$,通常快速求冪的乘法次數與二進制的 1 個數成正比,所以速度就慢非常多。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpow(UINT64 x, UINT64 y, UINT64 mod) {
UINT64 ret = 1;
while (y) {
if (y&1)
ret = mul(ret, x, mod);
y >>= 1, x = mul(x, x, mod);
}
return ret;
}

問題描述

讓我們來一場 $x^y \mod z$ 的期望值試驗吧,基礎目標是減少乘法次數。

其中 $1 \le x, z \le 10^{18}$, $0 \le y \le 2^{2^{20}}$,在這場試驗中展現你的優化吧。

Sample Input

1
2
3
4
5 01 7
5 10 7
5 0101 7
3 0110 7

Sample Output

1
2
3
4
5
4
3
1

Solution

詳細可以參考《資訊安全 - 近代加密 快速冪次計算》那一篇。

Algorithm Table Size #squaring Average #Multiplication
Right-To-Left 1: $x^{2^i}$ $n$ $n/2$
Left-To-Right 1: $x$ $n$ $n/2$
Left-To-Right(2-bits) 3: $x$, $x^2$, $x^3$ $n$ $3n/8$
Left-To-Right(sliding) 2: $x$, $x^3$ $n$ $n/3$

減少乘法次數,但以上期望乘法次數是跟 1 的個數有關,雖然最好是從 $n/2$ 降到 $n/3$,並不表示速度會真的快上 $1.5$ 倍左右,畢竟還有所謂的基礎乘法次數需求,根據實驗下來大約能快個 10% 到 20% 之間,加上 -Ofast 編譯此時的差異又會再少一點,看起來實作方法影響很嚴重。

例如在不加編譯優化參數下

1
2
3
4
if a[i] == 0 && a[i+1] == 0
else if a[i] == 0 && a[i+1] == 1
else if a[i] == 1 && a[i+1] == 0
else

上述做法會比下述來得快上許多

1
2
3
4
5
6
if a[i] == 0
if a[i+1] == 0
else
else
if a[i+1] == 0
else

最後,產出一個 cheat 版本,使用 L-to-R-2bits 的概念下去擴充,使用 loop unrolling 進行加速,由於會發生不被整除的問題,小測資就靠 L-to-R-sliding 的方案去解決。

在 zerojudge 主機上平台上,隨機測資下的運作情況如下:

Algorithm Time
Right-To-Left 5.4s
Left-To-Right(2-bits) 4.9s
Left-To-Right(sliding) 4.8s
Left-To-Right-sliding-cheat 4.2s

R-to-L

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowR2L(UINT64 x, char y[], UINT64 z) {
int n;
for (n = 1; y[n]; n <<= 1);
UINT64 ret = 1;
for (int i = n-1; i >= 0; i--) {
if (y[i] == '1')
ret = mul(ret, x, z);
x = mul(x, x, z);
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowR2L(x, y, z));
}
return 0;
}

L-to-R-2bits

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowL2R2(UINT64 x, char y[], UINT64 z) {
UINT64 x2 = mul(x, x, z);
UINT64 x3 = mul(x2, x, z);
UINT64 ret = 1;
for (int i = 0; y[i]; i += 2) {
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
if (y[i] == '1' && y[i+1] == '1') {
ret = mul(ret, x3, z);
} else if (y[i] == '1' && y[i+1] == '0') {
ret = mul(ret, x2, z);
} else if (y[i+1] == '1') {
ret = mul(ret, x, z);
}
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowL2R2(x, y, z));
}
return 0;
}

L-to-R-sliding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowL2RS(UINT64 x, char y[], UINT64 z) {
UINT64 x3 = mul(mul(x, x, z), x, z);
UINT64 ret = 1;
for (int i = 0; y[i]; ) {
ret = mul(ret, ret, z);
if (y[i] == '1' && y[i+1] == '1') {
ret = mul(ret, ret, z);
ret = mul(ret, x3, z);
i += 2;
} else if (y[i] == '1') {
ret = mul(ret, x, z);
i ++;
} else if (y[i+1] == '0') {
ret = mul(ret, ret, z);
i += 2;
} else {
i++;
}
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowL2RS(x, y, z));
}
return 0;
}

L-to-R-sliding-cheat

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
UINT64 mpowL2RS(UINT64 x, char y[], UINT64 z) {
UINT64 x3 = mul(mul(x, x, z), x, z);
UINT64 ret = 1;
for (int i = 0; y[i]; ) {
ret = mul(ret, ret, z);
if (y[i] == '1' && y[i+1] == '1') {
ret = mul(ret, ret, z);
ret = mul(ret, x3, z);
i += 2;
} else if (y[i] == '1') {
ret = mul(ret, x, z);
i ++;
} else if (y[i+1] == '0') {
ret = mul(ret, ret, z);
i += 2;
} else {
i++;
}
}
return ret;
}
#define PREPROC 8
UINT64 mpowCHEAT(UINT64 x, char y[], UINT64 z) {
int n;
for (n = 1; y[n]; n <<= 1);
if (n < 1<<PREPROC)
return mpowL2RS(x, y, z);
UINT64 X[1<<PREPROC] = {1};
for (int i = 1; i < (1<<PREPROC); i++)
X[i] = mul(X[i-1], x, z);
UINT64 ret = 1;
for (int i = 0, v; y[i]; i += PREPROC) {
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
ret = mul(ret, ret, z);
v = (y[i]-'0')<<7|(y[i+1]-'0')<<6|(y[i+2]-'0')<<5|(y[i+3]-'0')<<4|(y[i+4]-'0')<<3|(y[i+5]-'0')<<2|(y[i+6]-'0')<<1|(y[i+7]-'0');
ret = mul(ret, X[v], z);
}
return ret;
}
char y[(1<<20) + 5];
int main() {
long long x, z;
while (scanf("%lld %s %lld", &x, y, &z) == 3) {
printf("%llu\n", mpowCHEAT(x, y, z));
}
return 0;
}
Read More +