ADA 2020 Fall P3. ADA Party

Algorithm Design and Analysis (NTU CSIE, Fall 2020)

Problem

$N$ 個堆,每個堆有 $a_i$ 個糖果,現在邀請 $K$ 個人,現在問有多少種挑選區間的方法,滿足扣掉最大堆和最小堆後,區間內的糖果總數可以被 $K$ 整除。

Sample Input

1
2
10 2
6 9 3 4 5 6 1 7 8 3

Sample Output

1
25

Solution

由於沒辦法參與課程,就測測自己產的測試資料,正確性有待確認。

分治處理可行解的組合,每一次剖半計算,統計跨區間的答案個數。

討論項目分別為

  • 最大值、最小值嚴格都在左側
  • 最大值、最小值嚴格都在右側
  • 最大值在左側、最小值在右側
  • 最大值在右側、最小值在左側

最後兩項會有交集部分,則扣除 在左側的最大最小值接等於右側的最大最小值。對於每一項回答,搭配單調運行的滑動窗口解決。

時間複雜度 $\mathcal{O}(n \log n)$、空間複雜度 $\mathcal{O}(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <bits/stdc++.h>
using namespace std;
// Algorithm Design and Analysis (NTU CSIE, Fall 2020)
// Problem 3. ADA Party
const int MAXN = 500005;
const int32_t MIN = LONG_MIN;
const int32_t MAX = LONG_MAX;
int32_t a[MAXN];
int32_t lsum[MAXN], rsum[MAXN];
int32_t lmin[MAXN], lmax[MAXN];
int32_t rmin[MAXN], rmax[MAXN];
int cases = 0;
int mark[MAXN];
int counter[MAXN];
int n, k;
void inc(int32_t val) {
val = (val%k+k)%k;
if (mark[val] != cases)
mark[val] = cases, counter[val] = 0;
counter[val]++;
}
void dec(int32_t val) {
val = (val%k+k)%k;
if (mark[val] != cases)
mark[val] = cases, counter[val] = 0;
counter[val]--;
}
int get(int32_t val) {
val = (val%k+k)%k;
if (mark[val] != cases)
mark[val] = cases, counter[val] = 0;
return counter[val];
}
int64_t common(int l, int m, int r) {
int64_t ret = 0;
cases++; // max and min is same on both end
for (int i = m, j = m+1, jl = m+1; i >= l; i--) {
while (j <= r && (rmax[j] <= lmax[i] && rmin[j] >= lmin[i])) {
inc(rsum[j]-rmin[j]);
j++;
}
while (jl < j && (rmin[jl] > lmin[i] || rmax[jl] < lmax[i])) {
dec(rsum[jl]-rmin[jl]);
jl++;
}
if (j > m+1 && lmin[i] == rmin[j-1] && lmax[i] == rmax[j-1])
ret += get(k-(lsum[i]-lmax[i]));
}
return ret;
}
int64_t divide(int l, int r) {
if (l >= r)
return 0;
int m = (l+r)/2;
int32_t sum = 0;
int32_t mn = MAX, mx = MIN;
for (int i = m; i >= l; i--) {
sum += a[i], mn = min(mn, a[i]), mx = max(mx, a[i]);
if (sum >= k) sum %= k;
lsum[i] = sum, lmin[i] = mn, lmax[i] = mx;
}
sum = 0, mn = MAX, mx = MIN;
for (int i = m+1; i <= r; i++) {
sum += a[i], mn = min(mn, a[i]), mx = max(mx, a[i]);
if (sum >= k) sum %= k;
rsum[i] = sum, rmin[i] = mn, rmax[i] = mx;
}
int64_t c1 = 0, c2 = 0, c3 = 0, c4 = 0;
cases++; // min max on the left
for (int i = m, j = m+1; i >= l; i--) {
while (j <= r && lmin[i] < a[j] && a[j] < lmax[i]) {
inc(rsum[j]);
j++;
}
if (i < m)
c1 += get(k-(lsum[i]-lmin[i]-lmax[i]));
}
cases++; // min max on the right
for (int i = m+1, j = m; i <= r; i++) {
while (j >= l && rmin[i] < a[j] && a[j] < rmax[i]) {
inc(lsum[j]);
j--;
}
if (i > m+1)
c2 += get(k-(rsum[i]-rmin[i]-rmax[i]));
}
cases++; // min on the left, max on the right
for (int i = m, j = m+1, jl = m+1; i >= l; i--) {
while (j <= r && rmin[j] >= lmin[i]) {
inc(rsum[j]-rmax[j]);
j++;
}
while (jl < j && rmax[jl] < lmax[i]) {
dec(rsum[jl]-rmax[jl]);
jl++;
}
c3 += get(k-(lsum[i]-lmin[i]));
}
cases++; // min on the right, max on the left
for (int i = m+1, j = m, jl = m; i <= r; i++) {
while (j >= l && lmin[j] >= rmin[i]) {
inc(lsum[j]-lmax[j]);
j--;
}
while (jl > j && lmax[jl] < rmax[i]) {
dec(lsum[jl]-lmax[jl]);
jl--;
}
c4 += get(k-(rsum[i]-rmin[i]));
}
int64_t local = c1 + c2 + c3 + c4 - common(l, m, r);
return local + divide(l, m) + divide(m+1, r);
}
int main() {
while (scanf("%d %d", &n, &k) == 2) {
for (int i = 0; i < n; i++)
scanf("%d", &a[i]);
memset(counter, 0, sizeof(counter[0])*k);
int64_t ret = divide(0, n-1);
printf("%lld\n", ret);
}
return 0;
}
Read More +

ADA 2020 Fall P2. Bomb Game

Algorithm Design and Analysis (NTU CSIE, Fall 2020)

Problem

有數名玩家依序抵達遊戲,並且落在位置 $c_i$,並且具有防禦力 $d_i$,過程中會有炸彈發生於 $[l_i, r_i]$,對防禦力小於等於 $p_i$ 造成 $k_i$ 點傷害。

回報遊戲最後每一名玩家所受的傷害總額。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
10 10
P 3 5
A 2 8 15 5
P 7 10
A 4 10 5 3
A 1 9 10 7
P 6 20
P 5 1
A 4 9 17 2
A 1 2 20 4
P 9 5

Sample Output

1
2
3
4
5
12
9
0
2
0

Solution

由於沒辦法參與課程,就測測自己產的測試資料,正確性有待確認。

如果這一題目強制在線對答,則需要一個樹套樹在 $\mathcal{O}(\log^2 n)$ 內回答每一個結果,需要一個動態開區間的實作方法。

如果採用離線處理,則可以透過逆序處理來回答,可以透過二維空間的 BIT 結構來完成,這時候空間上會是另一個問題,即使使用懶惰標記,預期可能會達到 $\mathcal{O}(C \; D)$,通常是不允許的。

從分治算法切入,預想防禦能力高影響不受到攻擊力低的炸彈影響,無論時間與否都不受到影響。接下來,對防禦能力和攻擊力統稱力量。在分治的時候,對力量低的累計出答案,在合併階段受時間順序的影響才能回答。最後:

  1. 對力量從小到大排序
  2. 分治算法
    1. 對左區間和右區間按照時間由大到小排序
    2. 對於每一個左區間的詢問,插入所有滿足的右區間

時間複雜度 $\mathcal{O}(n \log^2 n)$、空間複雜度 $\mathcal{O}(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <bits/stdc++.h>
using namespace std;
// Algorithm Design and Analysis (NTU CSIE, Fall 2020)
// Problem 2. Bomb Game
const int MAXN = 200005;
struct BIT {
int64_t a[MAXN];
int l[MAXN];
int cases = 0;
void add(int x, int val, int n) {
while (x <= n) {
if (l[x] != cases)
l[x] = cases, a[x] = 0;
a[x] += val, x += x&(-x);
}
}
int64_t sum(int x) {
int64_t ret = 0;
while (x) {
if (l[x] != cases)
l[x] = cases, a[x] = 0;
ret += a[x], x -= x&(-x);
}
return ret;
}
void reset(int n) {
cases++;
}
void add(int l, int r, int k, int n) {
add(l, k, n);
add(r+1, -k, n);
}
} bit;
int n;
struct PEvent {
int c, d, i;
};
struct AEvent {
int l, r, p, k;
};
struct Event {
int type; // 'P' 0 or 'A' 1
int time; // input order
union {
PEvent p;
AEvent a;
} data;
int power() {
if (type == 0)
return data.p.d;
else
return data.a.p;
}
void println() {
if (type == 0)
printf("P %d %d\n", data.p.c, data.p.d);
else
printf("A %d %d %d %d\n", data.a.l, data.a.r, data.a.p, data.a.k);
}
} events[MAXN];
static bool cmp_p(Event &a, Event &b) {
int pa = a.power();
int pb = b.power();
if (pa != pb)
return pa < pb;
return a.time < b.time;
}
static bool cmp_t(Event &a, Event &b) {
return a.time > b.time;
}
int ret[MAXN];
void resolve(int l, int m, int r) {
sort(events+l, events+m+1, cmp_t);
sort(events+m+1, events+r+1, cmp_t);
// printf("resolve %d %d =========\n", l, r);
// for (int i = l; i <= m; i++)
// events[i].println();
// puts("---");
// for (int i = m+1; i <= r; i++)
// events[i].println();
bit.reset(n);
int j = m+1;
for (int i = l; i <= m; i++) {
if (events[i].type)
continue;
int qtime = events[i].time;
while (j <= r && events[j].time > qtime) {
if (events[j].type) {
bit.add(events[j].data.a.l,
events[j].data.a.r,
events[j].data.a.k,
n);
// printf("add %d %d %d %d\n", events[j].data.a.l,
// events[j].data.a.r,
// events[j].data.a.p,
// events[j].data.a.k);
}
j++;
}
// printf("%d --- %d\n", events[i].data.p.i, bit.sum(events[i].data.p.c));
ret[events[i].data.p.i] += bit.sum(events[i].data.p.c);
}
}
void divide(int l, int r) {
if (l >= r)
return;
int m = (l+r)/2;
divide(l, m);
divide(m+1, r);
resolve(l, m, r);
}
int main() {
int m;
char s[2];
scanf("%d %d", &n, &m);
int id = 0;
for (int i = 0; i < m; i++) {
scanf("%s", s);
events[i].time = i;
if (s[0] == 'P') {
events[i].type = 0;
events[i].data.p.i = id++;
scanf("%d %d",
&events[i].data.p.c,
&events[i].data.p.d);
} else {
events[i].type = 1;
scanf("%d %d %d %d",
&events[i].data.a.l,
&events[i].data.a.r,
&events[i].data.a.p,
&events[i].data.a.k);
}
}
sort(events, events+m, cmp_p);
divide(0, m-1);
for (int i = 0; i < id; i++)
printf("%d\n", ret[i]);
return 0;
}
Read More +

動態幾何 史蒂芙的泡泡 (解法 2)

題目描述

在處理完數以百計的政事後,受盡折磨的史蒂芙,打算回家好好地休息。 拖著疲倦的身軀,再也無法再容納任何一點複雜計算。從王宮走回寢居的路上, 發現身邊所見的事物都不再圓滑,看起來就像是粗糙的幾何多邊形構成的一切。

打算享受著泡泡浴的史蒂芙,看著眼前的多邊形泡泡,失去原本應有的色澤,那透涼的心境更蒙上了一層灰影

「為什麼是我呢?」感嘆道

伸出手戳著眼前的泡泡,卻飄了過去

「區區的泡泡也跟我作對,嗚嗚」

將一個泡泡視為一個簡單多邊形 $A$,方便起見用一個序列 $a_0, a_1, ..., a_{n-1}$ 表示多邊形 $A$ 的每一個頂點,則會有 $n$ 個線段 $\overline{a_0 a_1}, \overline{a_1 a_2}, \cdots, \overline{a_{n-1} a_0}$

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
5 14
0 0
10 0
10 10
0 10
5 5
1 0 0
1 1 0
1 2 1
1 3 2
3 5.5 5
3 10 -1
3 10 5
1 4 1
3 5.5 5
3 10 -1
3 10 5
2 3
3 5 7.5
3 5 2.5

Sample Output

1
2
3
4
5
6
7
8
1
0
0
0
0
0
0
1

Solution

參閱 動態幾何 史蒂芙的泡泡 (解法 1)

相較於先前的解法 $\mathcal{O}(\log n) - \mathcal{O}(\ast \sqrt{n})$,相當不穩定的嵌套 KD-BRH 的解法,實際上如果單純針對這一題,可以拋開 region search 的操作,只有完全的詢問點是否在多邊形內部,則可以做到 $\mathcal{O}(\log^2 n)$

如同一般的射線法,對詢問點拉出無限延長的線,找到與多邊形的邊相交個數。如果單純把每一條邊拿出來看,最暴力的複雜度為 $\mathcal{O}(n)$,現在要減少查閱的邊數,且透過 rank tree 在 $\mathcal{O}(\log n)$ 累計相交數量。

由於給定的座標不需要動態,則著手離散化 $X = x_0, x_1, \cdots, x_n$,線段樹中的每一個點 $u$ 維護一個 rank tree,把包含者個區段 $[x_l, x_r]$ 的線段通通放入,且保證放入的線段彼此之間不相交 (除了兩端點)。如此一來,當詢問一個點,需要探訪 $\mathcal{O}(\log n)$ 個節點,每個節點 $u$ 需要 $\mathcal{O}(\log n)$ 時間來計算相交數量,最後詢問的複雜度為 $\mathcal{O}(\log^2 n)$。同理,修改點的操作也會在 $\mathcal{O}(\log^2 n)$

實作時,特別注意到與射線方向相同的線段不予處理,按照下面的寫法則是不處理垂直的線段,一來是射線法也不會去計算,二來是線段樹劃分的時候會造成一些邊界問題,由於我們對於點離散,父節點 $[x_l, x_r]$,左右子樹控制的分別為 $[x_l, x_\text{mid}]$$[x_\text{mid}, x_r]$,劃分時會共用中間點。

即使有了上述的概念來解題,我們仍需要維護 rope data structrue 來維護節點的相鄰關係,可以自己實作任何的 binary tree 來達到 $\mathcal{O}(\log n)$,這裡採用 splay tree 示範。


接下來介紹內建黑魔法 PBDS (policy-based data structure) 和 rope。很多人都提及到這些非正式的 STL 函式庫,只有在 gcc/g++ 裡面才有附錄,如果是 clang 等其他編譯器可能是沒有辦法的。所以上傳相關代碼要看 OJ 是否採用純正的 gcc/g++。

參考資料:

PBDS 比較常用在 rank tree 和 heap,由於代碼量非常多,用內建防止 code length exceed limitation 的出現,也不妨是個好辦法。用 rank tree 的每一個詢問操作皆在 $\mathcal{O}(\log n)$,而 heap 選擇 thin heap 或 pairing heap,除了 pop 操作外,皆為 $\mathcal{O}(1)$,在對最短路徑問題上別有優勢。

而這一題不太適用 SGI rope,原因在於雖為 $\mathcal{O}(\log n)$ 操作,但它原本就為了仿作 string 而強迫變成可持久化的引數結構,導致每一個操作需要額外的開銷來減少內存使用。由於此題經常在單一元素上操作,SGI rope 對於單一元素效能不彰,以造成嚴重的逾時。

這裡仍提及和示範這些概念的資料結構,哪天正式編入標準函式庫,想必效能問題都已解決。


  • KD BRH AC(0.2s, 10 MB)
  • PBDS + SPLAY AC(0.3s, 38 MB)
  • PBDS + SGI rope AC(0.5s, 41 MB)
    本機 MINGW gcc 4.9 c++ 11 編譯完運行最大的測資需要 20 倍慢於其他寫法,但是上傳到 ZJ 的 gcc 7 又追回去了。於是下載 CYGWIN gcc 7 測試結果與 ZJ 運行結果相當,對於需要調適的人,建議在新版上測試。

用 splay tree 模擬 rope 的時候,會遇到循序非遞減的走訪,這時候若不斷地旋轉至根,會在之後的操作造成退化,常數稍微大一點,雖然在迭代過程中造就好的快取效能,很快地在 $\mathcal{O}(1)$ 訪問到相鄰的節點,卻在下一次的插入/刪除操作中變慢。一些變異的版本中,如只在插入/刪除操作中旋轉至根,在查找操作中不旋轉,或者限制旋轉的深度。雖然有一些特別的退化操作,splay 仍舊是某些 OS 場景中使用的資料結構,現實總是很特別的。

In 2000, Danny Sleator and Robert Tarjan won the ACM Kanellakis Theory and Practice Award for their papers on splay trees and amortized analysis. Splay trees are used in Windows NT (in the virtual memory, networking, and file system code), the gcc compiler and GNU C++ library, the sed string editor, Fore Systems network routers, the most popular implementation of Unix malloc, Linux loadable kernel modules, and in much other software.

PBDS + SPLAY

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
struct Mesh {
static const int MAXN = 1e5 + 5;
int pt[MAXN][2];
vector<int> X;
void read(int n) {
for (int i = 0; i < n; i++)
scanf("%d %d", &pt[i][0], &pt[i][1]);
X.clear(); X.reserve(n);
for (int i = 0; i < n; i++)
X.push_back(pt[i][0]);
sort(X.begin(), X.end());
X.erase(unique(X.begin(), X.end()), X.end());
}
} mesh;
class SplayTree {
public:
struct Node {
Node *ch[2], *fa;
int size; int data;
Node() {
ch[0] = ch[1] = fa = NULL;
size = 1;
}
bool is_root() {
return fa->ch[0] != this && fa->ch[1] != this;
}
};
Node *root, *EMPTY;
void pushdown(Node *u) {}
void pushup(Node *u) {
if (u->ch[0] != EMPTY) pushdown(u->ch[0]);
if (u->ch[1] != EMPTY) pushdown(u->ch[1]);
u->size = 1 + u->ch[0]->size + u->ch[1]->size;
}
void setch(Node *p, Node *u, int i) {
if (p != EMPTY) p->ch[i] = u;
if (u != EMPTY) u->fa = p;
}
SplayTree() {
EMPTY = new Node();
EMPTY->fa = EMPTY->ch[0] = EMPTY->ch[1] = EMPTY;
EMPTY->size = 0;
}
void init() {
root = EMPTY;
}
Node* newNode() {
Node *u = new Node();
u->fa = u->ch[0] = u->ch[1] = EMPTY;
return u;
}
void rotate(Node *x) {
Node *y;
int d;
y = x->fa, d = y->ch[1] == x ? 1 : 0;
x->ch[d^1]->fa = y, y->ch[d] = x->ch[d^1];
x->ch[d^1] = y;
if (!y->is_root())
y->fa->ch[y->fa->ch[1] == y] = x;
x->fa = y->fa, y->fa = x;
pushup(y);
}
void deal(Node *x) {
if (!x->is_root()) deal(x->fa);
pushdown(x);
}
void splay(Node *x, Node *below) {
if (x == EMPTY) return ;
Node *y, *z;
deal(x);
while (!x->is_root() && x->fa != below) {
y = x->fa, z = y->fa;
if (!y->is_root() && y->fa != below) {
if (y->ch[0] == x ^ z->ch[0] == y)
rotate(x);
else
rotate(y);
}
rotate(x);
}
pushup(x);
if (x->fa == EMPTY) root = x;
}
Node* prevNode(Node *u) {
splay(u, EMPTY);
return maxNode(u->ch[0]);
}
Node* nextNode(Node *u) {
splay(u, EMPTY);
return minNode(u->ch[1]);
}
Node* minNode(Node *u) {
Node *p = u->fa;
for (; pushdown(u), u->ch[0] != EMPTY; u = u->ch[0]);
splay(u, p);
return u;
}
Node* maxNode(Node *u) {
Node *p = u->fa;
for (; pushdown(u), u->ch[1] != EMPTY; u = u->ch[1]);
splay(u, p);
return u;
}
Node* findPos(int pos) { // [0...]
for (Node *u = root; u != EMPTY;) {
pushdown(u);
int t = u->ch[0]->size;
if (t == pos) {
splay(u, EMPTY);
return u;
}
if (t > pos)
u = u->ch[0];
else
u = u->ch[1], pos -= t + 1;
}
return EMPTY;
}
tuple<int, int, int> insert(int data, int pos) { // make [pos] = data
Node *p, *q = findPos(pos);
Node *x = newNode(); x->data = data;
if (q == EMPTY) {
p = maxNode(root), splay(p, EMPTY);
setch(x, p, 0);
splay(x, EMPTY);
} else {
splay(q, EMPTY), p = q->ch[0];
setch(x, p, 0), setch(x, q, 1);
setch(q, EMPTY, 0);
splay(q, EMPTY);
p = prevNode(x);
}
if (p == EMPTY) p = maxNode(root);
if (q == EMPTY) q = minNode(root);
return make_tuple(p->data, data, q->data);
}
tuple<int, int, int> remove(int pos) {
Node *x = findPos(pos), *p, *q;
p = prevNode(x), q = nextNode(x);
if (p != EMPTY && q != EMPTY) {
setch(p, q, 1);
p->fa = EMPTY, splay(q, EMPTY);
} else if (p != EMPTY) {
p->fa = EMPTY, root = p;
} else {
q->fa = EMPTY, root = q;
}
int del = x->data;
delete x;
if (p == EMPTY) p = maxNode(root);
if (q == EMPTY) q = minNode(root);
return make_tuple(p->data, del, q->data);
}
int size() {
return root == EMPTY ? 0 : root->size;
}
} mlist;
struct Pt {
double x, y;
Pt() {}
Pt(int xy[]):Pt(xy[0], xy[1]) {}
Pt(double x, double y):x(x), y(y) {}
bool operator<(const Pt &o) const {
if (x != o.x) return x < o.x;
return y < o.y;
}
};
struct PtP {
static double x;
Pt p, q;
PtP(Pt a, Pt b) {
p = a, q = b;
if (q < p)
swap(p, q);
}
double interpolate(const Pt& p1, const Pt& p2, double& x) const {
if (p1.x == p2.x) return min(p1.y, p2.y);
return p1.y + (p2.y - p1.y) / (p2.x - p1.x) * (x - p1.x);
}
bool operator<(const PtP &o) const {
return interpolate(p, q, x) < interpolate(o.p, o.q, x);
}
};
double PtP::x = 1;
struct SegSeg {
struct Node {
Node *lson, *rson;
tree<pair<PtP, int>, null_type, less<pair<PtP, int>>, rb_tree_tag, tree_order_statistics_node_update> segs;
Node() {
lson = rson = NULL;
}
};
Node *root;
int xn;
Node* newNode() {
return new Node();
}
void freeNode(Node *u) {
free(u);
}
void init() {
root = NULL;
xn = mesh.X.size();
}
void insert(tuple<int, int, int> e) {
int p, q, r; tie(p, q, r) = e;
remove(0, xn-1, {PtP(Pt(mesh.pt[p]), Pt(mesh.pt[r])), r});
insert(0, xn-1, {PtP(Pt(mesh.pt[p]), Pt(mesh.pt[q])), q});
insert(0, xn-1, {PtP(Pt(mesh.pt[q]), Pt(mesh.pt[r])), r});
}
void remove(tuple<int, int, int> e) {
int p, q, r; tie(p, q, r) = e;
remove(0, xn-1, {PtP(Pt(mesh.pt[p]), Pt(mesh.pt[q])), q});
remove(0, xn-1, {PtP(Pt(mesh.pt[q]), Pt(mesh.pt[r])), r});
insert(0, xn-1, {PtP(Pt(mesh.pt[p]), Pt(mesh.pt[r])), r});
}
int inside(double x, double y) {
PtP::x = x;
return count(root, 0, xn-1, x, y)&1;
}
int count(Node* u, int l, int r, double x, double y) {
if (u == NULL)
return 0;
int ret = 0;
if ((mesh.X[l] > x) != (mesh.X[r] > x))
ret += u->segs.order_of_key({PtP(Pt(x, y), Pt(x, y)), -1});
int m = (l+r)/2;
if (x <= mesh.X[m])
ret += count(u->lson, l, m, x, y);
if (x >= mesh.X[m])
ret += count(u->rson, m, r, x, y);
return ret;
}
void insert(int l, int r, pair<PtP, int> s) {
if (s.first.p.x != s.first.q.x)
insert(root, l, r, s);
}
void remove(int l, int r, pair<PtP, int> s) {
if (s.first.p.x != s.first.q.x)
remove(root, l, r, s);
}
void insert(Node* &u, int l, int r, pair<PtP, int> s) {
if (u == NULL)
u = newNode();
if (s.first.p.x <= mesh.X[l] && mesh.X[r] <= s.first.q.x) {
PtP::x = (mesh.X[l] + mesh.X[r])/2.0;
u->segs.insert(s);
return;
}
int m = (l+r)/2;
if (s.first.q.x <= mesh.X[m]) insert(u->lson, l, m, s);
else if (s.first.p.x >= mesh.X[m]) insert(u->rson, m, r, s);
else insert(u->lson, l, m, s), insert(u->rson, m, r, s);
}
void remove(Node* u, int l, int r, pair<PtP, int> s) {
if (u == NULL)
return;
if (s.first.p.x <= mesh.X[l] && mesh.X[r] <= s.first.q.x) {
PtP::x = (mesh.X[l] + mesh.X[r])/2.0;
u->segs.erase(s);
return;
}
int m = (l+r)/2;
if (s.first.q.x <= mesh.X[m]) remove(u->lson, l, m, s);
else if (s.first.p.x >= mesh.X[m]) remove(u->rson, m, r, s);
else remove(u->lson, l, m, s), remove(u->rson, m, r, s);
}
} mbrh;
int main() {
int n, m, cmd, x, pos;
double px, py;
scanf("%d %d", &n, &m);
mesh.read(n);
mlist.init(), mbrh.init();
for (int i = 0; i < m; i++) {
scanf("%d", &cmd);
if (cmd == 1) {
scanf("%d %d", &x, &pos);
mbrh.insert(mlist.insert(x, pos));
} else if (cmd == 2) {
scanf("%d", &x);
mbrh.remove(mlist.remove(x));
} else {
scanf("%lf %lf", &px, &py);
puts(mbrh.inside(px, py) ? "1" : "0");
}
}
return 0;
}

PBDS + ROPE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#include <bits/stdc++.h>
using namespace std;
#include <ext/rope>
using namespace __gnu_cxx;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
struct Mesh {
static const int MAXN = 1e5 + 5;
int pt[MAXN][2];
vector<int> X;
void read(int n) {
for (int i = 0; i < n; i++)
scanf("%d %d", &pt[i][0], &pt[i][1]);
X.clear(); X.reserve(n);
for (int i = 0; i < n; i++)
X.push_back(pt[i][0]);
sort(X.begin(), X.end());
X.erase(unique(X.begin(), X.end()), X.end());
}
} mesh;
class Rope {
public:
rope<int> r;
void init() {
r.clear();
}
int next(rope<int>::const_iterator it) {
it++;
if (it == r.end())
return *r.begin();
return *it;
}
int prev(rope<int>::const_iterator it) {
if (it == r.begin())
return *r.rbegin();
it--;
return *it;
}
tuple<int, int, int> insert(int data, int pos) {
r.insert(pos, data);
auto it = r.begin() + pos;
int p = prev(it);
int q = next(it);
return make_tuple(p, data, q);
}
tuple<int, int, int> remove(int pos) {
auto it = r.begin() + pos;
int del = *it;
int p = prev(it);
int q = next(it);
r.erase(pos, 1);
return make_tuple(p, del, q);
}
} mlist;
struct Pt {
double x, y;
Pt() {}
Pt(int xy[]):Pt(xy[0], xy[1]) {}
Pt(double x, double y):x(x), y(y) {}
bool operator<(const Pt &o) const {
if (x != o.x) return x < o.x;
return y < o.y;
}
};
struct PtP {
static double x;
Pt p, q;
PtP(Pt a, Pt b) {
p = a, q = b;
if (q < p)
swap(p, q);
}
double interpolate(const Pt& p1, const Pt& p2, double& x) const {
if (p1.x == p2.x) return min(p1.y, p2.y);
return p1.y + (p2.y - p1.y) / (p2.x - p1.x) * (x - p1.x);
}
bool operator<(const PtP &o) const {
return interpolate(p, q, x) < interpolate(o.p, o.q, x);
}
};
double PtP::x = 1;
struct SegSeg {
struct Node {
Node *lson, *rson;
tree<pair<PtP, int>, null_type, less<pair<PtP, int>>, rb_tree_tag, tree_order_statistics_node_update> segs;
Node() {
lson = rson = NULL;
}
};
Node *root;
int xn;
Node* newNode() {
return new Node();
}
void init() {
root = NULL;
xn = mesh.X.size();
}
void insert(tuple<int, int, int> e) {
int p, q, r; tie(p, q, r) = e;
remove(0, xn-1, {PtP(Pt(mesh.pt[p]), Pt(mesh.pt[r])), r});
insert(0, xn-1, {PtP(Pt(mesh.pt[p]), Pt(mesh.pt[q])), q});
insert(0, xn-1, {PtP(Pt(mesh.pt[q]), Pt(mesh.pt[r])), r});
}
void remove(tuple<int, int, int> e) {
int p, q, r; tie(p, q, r) = e;
remove(0, xn-1, {PtP(Pt(mesh.pt[p]), Pt(mesh.pt[q])), q});
remove(0, xn-1, {PtP(Pt(mesh.pt[q]), Pt(mesh.pt[r])), r});
insert(0, xn-1, {PtP(Pt(mesh.pt[p]), Pt(mesh.pt[r])), r});
}
int inside(double x, double y) {
PtP::x = x;
return count(root, 0, xn-1, x, y)&1;
}
int count(Node* u, int l, int r, double x, double y) {
if (u == NULL)
return 0;
int ret = 0;
if ((mesh.X[l] > x) != (mesh.X[r] > x))
ret += u->segs.order_of_key({PtP(Pt(x, y), Pt(x, y)), -1});
if (l == r)
return ret;
int m = (l+r)/2;
if (x <= mesh.X[m])
ret += count(u->lson, l, m, x, y);
if (x >= mesh.X[m])
ret += count(u->rson, m, r, x, y);
return ret;
}
void insert(int l, int r, pair<PtP, int> s) {
if (s.first.p.x != s.first.q.x)
insert(root, l, r, s);
}
void remove(int l, int r, pair<PtP, int> s) {
if (s.first.p.x != s.first.q.x)
remove(root, l, r, s);
}
void insert(Node* &u, int l, int r, pair<PtP, int> s) {
if (u == NULL)
u = newNode();
if (s.first.p.x <= mesh.X[l] && mesh.X[r] <= s.first.q.x) {
PtP::x = (mesh.X[l] + mesh.X[r])/2.0;
u->segs.insert(s);
return;
}
if (l == r)
return;
int m = (l+r)/2;
if (s.first.q.x <= mesh.X[m]) insert(u->lson, l, m, s);
else if (s.first.p.x >= mesh.X[m]) insert(u->rson, m, r, s);
else insert(u->lson, l, m, s), insert(u->rson, m, r, s);
}
void remove(Node* &u, int l, int r, pair<PtP, int> s) {
if (u == NULL)
return;
if (s.first.p.x <= mesh.X[l] && mesh.X[r] <= s.first.q.x) {
PtP::x = (mesh.X[l] + mesh.X[r])/2.0;
u->segs.erase(s);
return;
}
if (l == r)
return;
int m = (l+r)/2;
if (s.first.q.x <= mesh.X[m]) remove(u->lson, l, m, s);
else if (s.first.p.x >= mesh.X[m]) remove(u->rson, m, r, s);
else remove(u->lson, l, m, s), remove(u->rson, m, r, s);
}
} mbrh;
int main() {
int n, m, cmd, x, pos;
double px, py;
scanf("%d %d", &n, &m);
mesh.read(n);
mlist.init(), mbrh.init();
for (int i = 0; i < m; i++) {
scanf("%d", &cmd);
if (cmd == 1) {
scanf("%d %d", &x, &pos);
mbrh.insert(mlist.insert(x, pos));
} else if (cmd == 2) {
scanf("%d", &x);
mbrh.remove(mlist.remove(x));
} else {
scanf("%lf %lf", &px, &py);
puts(mbrh.inside(px, py) ? "1" : "0");
}
}
return 0;
}
Read More +

動態幾何 史蒂芙的泡泡 (解法 1)

題目描述請至 Zerojudge e021: 史蒂芙的泡泡 查閱詳細內容

題目描述

在處理完數以百計的政事後,受盡折磨的史蒂芙,打算回家好好地休息。 拖著疲倦的身軀,再也無法再容納任何一點複雜計算。從王宮走回寢居的路上, 發現身邊所見的事物都不再圓滑,看起來就像是粗糙的幾何多邊形構成的一切。

打算享受著泡泡浴的史蒂芙,看著眼前的多邊形泡泡,失去原本應有的色澤,那透涼的心境更蒙上了一層灰影

「為什麼是我呢?」感嘆道

伸出手戳著眼前的泡泡,卻飄了過去

「區區的泡泡也跟我作對,嗚嗚」

將一個泡泡視為一個簡單多邊形 $A$,方便起見用一個序列 $a_0, a_1, ..., a_{n-1}$ 表示多邊形 $A$ 的每一個頂點,則會有 $n$ 個線段 $\overline{a_0 a_1}, \overline{a_1 a_2}, \cdots, \overline{a_{n-1} a_0}$

解法

從能找到的論文「A Unified Approach to Dynamic Point Location, Ray Shooting, and Shortest Paths in Planar Maps」 得知操作更新 $\mathcal{O}(\log^3 n)$,詢問操作 $\mathcal{O}(\log n)$,這一個實作難度估計沒辦法在 10K 限制下完成 (代碼上傳長度上限)。

從動態梯形剖分開始,複雜度就相當高了,而論文後要有類似輕重鏈剖分的概念,將對偶圖產生的節點以輕重邊保存,接著再維護雙線輪廓,讓相連的連續重邊可以保存左右兩側的輪廓,… 資訊量龐大,想到一些可能未提及的邊界案例,實作相當困難。而從其他題目的經驗中,當設計到 $\mathcal{O}(\log^2 n)$ 的操作時,由於操作常數大,運行時間可能無法匹敵 $\mathcal{O}(\sqrt{n})$ 的算法。

那麼有沒有其他的替代方案,只需要跑得比暴力法還要快就行?特別小心,暴力法找一個點是否在多邊形內,需要 $\mathcal{O}(n)$ 的時間,且快取效果非常好,很容易做到 SIMD 的平行技術。

首先,解決維護多邊形點集序列,可以使用任何的平衡樹完成,若採用 splay tree 在均攤 $\text{Amortized} \; \mathcal{O}(\log n)$。使用 treap 也可以完成這類操作,還可以做到額外的持久化功能。接著,修改點的操作,可以轉換成替換一個點的下一個點,因此將點放在 KD tree 上,而維護下一個點相當於把每一個節點擴充成 BRH。

接著,當解決點是否在多邊形內時,可走訪 BRH 找射線穿過的交點個數,此時複雜度至少為 $\mathcal{O}(\sqrt{n} + k)$,其中 $k$ 為交點個數。我們可以額外維護多邊形的順逆時針來優化搜尋空間,最後一個接觸的線段方向,額外提供奇偶數來判斷該點是否在內側,然後把射線縮短到交點到詢問點之間。不斷地縮短搜尋空間下,大部分的情況為 $\mathcal{O}(\sqrt{n})$,拋開了原本存在的 $k$

由於是封閉的簡單多邊形,所以當邊的疏密程度不一時,KD tree 轉換 BRH 會造成額外的負擔,bounding box 無法有效提供分割空間的功能。那麼在實際搜尋情況,額外走訪的節點與平面分布有關,這部分就不好分析。如果有更好的想法,歡迎分享。

現在實作動態 KD tree 必須有替罪羊樹 Scapegoat tree 的概念,再掛上 BRH 的想法,提供了相對有效率的動態方法來查詢點是否在多邊形內部。若把維護點序列的 splay tree 換成 treap,便可以做到持久化的動態幾何操作,這便是一開始想要的目標。

如果梯形剖分也可以持久化?暫時還不想去思考,那思考的容器要多大呢?

參考解法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
#include <bits/stdc++.h>
using namespace std;
struct Mesh {
static const int MAXN = 1e5 + 5;
int pt[MAXN][2];
void read(int n) {
for (int i = 0; i < n; i++)
scanf("%d %d", &pt[i][0], &pt[i][1]);
}
} mesh;
class SplayTree {
public:
struct Node {
Node *ch[2], *fa;
int size; int data;
Node() {
ch[0] = ch[1] = fa = NULL;
size = 1;
}
bool is_root() {
return fa->ch[0] != this && fa->ch[1] != this;
}
};
Node *root, *EMPTY;
void pushdown(Node *u) {}
void pushup(Node *u) {
if (u->ch[0] != EMPTY) pushdown(u->ch[0]);
if (u->ch[1] != EMPTY) pushdown(u->ch[1]);
u->size = 1 + u->ch[0]->size + u->ch[1]->size;
}
void setch(Node *p, Node *u, int i) {
if (p != EMPTY) p->ch[i] = u;
if (u != EMPTY) u->fa = p;
}
SplayTree() {
EMPTY = new Node();
EMPTY->fa = EMPTY->ch[0] = EMPTY->ch[1] = EMPTY;
EMPTY->size = 0;
}
void init() {
root = EMPTY;
}
Node* newNode() {
Node *u = new Node();
u->fa = u->ch[0] = u->ch[1] = EMPTY;
return u;
}
void rotate(Node *x) {
Node *y;
int d;
y = x->fa, d = y->ch[1] == x ? 1 : 0;
x->ch[d^1]->fa = y, y->ch[d] = x->ch[d^1];
x->ch[d^1] = y;
if (!y->is_root())
y->fa->ch[y->fa->ch[1] == y] = x;
x->fa = y->fa, y->fa = x;
pushup(y);
}
void deal(Node *x) {
if (!x->is_root()) deal(x->fa);
pushdown(x);
}
void splay(Node *x, Node *below) {
if (x == EMPTY) return ;
Node *y, *z;
deal(x);
while (!x->is_root() && x->fa != below) {
y = x->fa, z = y->fa;
if (!y->is_root() && y->fa != below) {
if (y->ch[0] == x ^ z->ch[0] == y)
rotate(x);
else
rotate(y);
}
rotate(x);
}
pushup(x);
if (x->fa == EMPTY) root = x;
}
Node* prevNode(Node *u) {
splay(u, EMPTY);
return maxNode(u->ch[0]);
}
Node* nextNode(Node *u) {
splay(u, EMPTY);
return minNode(u->ch[1]);
}
Node* minNode(Node *u) {
Node *p = u->fa;
for (; pushdown(u), u->ch[0] != EMPTY; u = u->ch[0]);
splay(u, p);
return u;
}
Node* maxNode(Node *u) {
Node *p = u->fa;
for (; pushdown(u), u->ch[1] != EMPTY; u = u->ch[1]);
splay(u, p);
return u;
}
Node* findPos(int pos) { // [0...]
for (Node *u = root; u != EMPTY;) {
pushdown(u);
int t = u->ch[0]->size;
if (t == pos) {
splay(u, EMPTY);
return u;
}
if (t > pos)
u = u->ch[0];
else
u = u->ch[1], pos -= t + 1;
}
return EMPTY;
}
tuple<int, int, int> insert(int data, int pos) { // make [pos] = data
Node *p, *q = findPos(pos);
Node *x = newNode(); x->data = data;
if (q == EMPTY) {
p = maxNode(root), splay(p, EMPTY);
setch(x, p, 0);
splay(x, EMPTY);
} else {
splay(q, EMPTY), p = q->ch[0];
setch(x, p, 0), setch(x, q, 1);
setch(q, EMPTY, 0);
splay(q, EMPTY);
p = prevNode(x);
}
if (p == EMPTY) p = maxNode(root);
if (q == EMPTY) q = minNode(root);
return make_tuple(p->data, data, q->data);
}
tuple<int, int, int> remove(int pos) {
Node *x = findPos(pos), *p, *q;
p = prevNode(x), q = nextNode(x);
if (p != EMPTY && q != EMPTY) {
setch(p, q, 1);
p->fa = EMPTY, splay(q, EMPTY);
} else if (p != EMPTY) {
p->fa = EMPTY, root = p;
} else {
q->fa = EMPTY, root = q;
}
int del = x->data;
free(x);
if (p == EMPTY) p = maxNode(root);
if (q == EMPTY) q = minNode(root);
return make_tuple(p->data, del, q->data);
}
int size() {
return root == EMPTY ? 0 : root->size;
}
} mlist;
static inline int log2int(int x) {
return 31 - __builtin_clz(x);
}
static inline int64_t h(int p, int q) {
return (int64_t) mesh.pt[p][0]*mesh.pt[q][1] - (int64_t) mesh.pt[p][1]*mesh.pt[q][0];
}
struct KDBRH {
static constexpr double ALPHA = 0.75;
static constexpr double LOG_ALPHA = log2(1.0 / ALPHA);
struct Pt {
int d[2];
Pt() {}
Pt(int xy[]):Pt(xy[0], xy[1]) {}
Pt(int x, int y) {d[0] = x, d[1] = y;}
bool operator==(const Pt &x) const {
return d[0] == x.d[0] && d[1] == x.d[1];
}
static Pt NaN() {return Pt(INT_MIN, INT_MIN);}
int isNaN() {return d[0] == INT_MIN;}
};
struct PtP {
Pt p, q;
PtP(Pt p, Pt q): p(p), q(q) {}
};
struct cmpAxis {
int k;
cmpAxis(int k): k(k) {}
bool operator() (const PtP &x, const PtP &y) const {
return x.p.d[k] < y.p.d[k];
}
};
struct BBox {
#define KDMIN(a, b, c) {a[0] = min(b[0], c[0]), a[1] = min(b[1], c[1]);}
#define KDMAX(a, b, c) {a[0] = max(b[0], c[0]), a[1] = max(b[1], c[1]);}
int l[2], r[2];
BBox() {}
BBox(int a[], int b[]) {
KDMIN(l, a, b); KDMAX(r, a, b);
}
void expand(Pt p) {
KDMIN(l, l, p.d); KDMAX(r, r, p.d);
}
void expand(BBox b) {
KDMIN(l, l, b.l); KDMAX(r, r, b.r);
}
inline int raycast(double x, double fx, double y) {
return l[1] <= y && y <= r[1] && r[0] >= x && l[0] <= fx;
}
static BBox init() {
BBox b; b.l[0] = b.l[1] = INT_MAX, b.r[0] = b.r[1] = INT_MIN;
return b;
}
};
struct Node {
Node *lson, *rson;
Pt pt, qt;
BBox box;
int size; int8_t used;
Node() {}
void init() {
lson = rson = NULL;
size = 1, used = 1;
pt = qt = Pt::NaN();
}
bool hasBox() { return box.l[0] <= box.r[0]; }
void pushup() {
size = used;
if (lson) size += lson->size;
if (rson) size += rson->size;
pushupBox();
}
void pushupBox() {
BBox t = BBox::init();
if (!qt.isNaN())
t.expand(pt), t.expand(qt);
if (lson && lson->hasBox())
t.expand(lson->box);
if (rson && rson->hasBox())
t.expand(rson->box);
box = t;
}
double interpolate(double y) {
if (pt.d[1] == qt.d[1]) return pt.d[0];
return pt.d[0] + (qt.d[0] - pt.d[0]) * (y - pt.d[1]) / (qt.d[1] - pt.d[1]);
}
};
Node *root;
vector<PtP> A;
int64_t area;
Node _mem[262144];
int gc[262144], gci, memi;
Node* newNode() {
Node *u = gci >= 0 ? &_mem[gc[gci--]] : &_mem[memi++];
u->init();
return u;
}
void freeNode(Node *u) {
gc[++gci] = u-_mem;
}
void init() {
root = NULL, area = 0;
gci = -1, memi = 0;
}
void insert(tuple<int, int, int> e) {
int p, q, r; tie(p, q, r) = e;
insert(root, 0, Pt(mesh.pt[q]), Pt(mesh.pt[p]), log2int(size()) / LOG_ALPHA);
changeNode(root, 0, Pt(mesh.pt[r]), Pt(mesh.pt[q]));
area += h(p, q) + h(q, r) - h(p, r);
}
void remove(tuple<int, int, int> e) {
int p, q, r; tie(p, q, r) = e;
remove(root, 0, Pt(mesh.pt[q]), log2int(size()) / LOG_ALPHA);
changeNode(root, 0, Pt(mesh.pt[r]), Pt(mesh.pt[p]));
area -= h(p, q) + h(q, r) - h(p, r);
}
int RAY_T;
double RAY_X;
vector<double> X;
int inside(double x, double y) {
if (area == 0) return 0;
X.clear(), RAY_X = 1e+12, RAY_T = -1;
raycast(root, x, y);
if (RAY_T < 0)
return X.size()&1;
int pass = (area > 0) == RAY_T;
for (auto &x : X)
pass += x <= RAY_X;
return pass&1;
}
int size() { return root == NULL ? 0 : root->size; }
inline int isbad(Node *u, Node *v) {
if (u == root) return 1;
int l = v ? v->size : 0;
l = max(l, u->size-u->used-l);
return l > u->size * ALPHA;
}
Node* build(int k, int l, int r) {
if (l > r) return NULL;
int mid = (l + r)>>1;
Node *ret = newNode();
sort(A.begin()+l, A.begin()+r+1, cmpAxis(k));
while (mid > l && A[mid].p.d[k] == A[mid-1].p.d[k])
mid--;
tie(ret->pt, ret->qt) = tie(A[mid].p, A[mid].q);
ret->lson = build(!k, l, mid-1);
ret->rson = build(!k, mid+1, r);
ret->pushup();
return ret;
}
void flatten(Node *u) {
if (!u) return ;
flatten(u->lson);
flatten(u->rson);
if (u->used) A.emplace_back(u->pt, u->qt);
freeNode(u);
}
void changeNode(Node *u, int k, Pt x, Pt qt) {
if (!u) return;
if (x == u->pt) {
u->qt = qt, u->pushupBox();
return;
}
changeNode(x.d[k] < u->pt.d[k] ? u->lson : u->rson, !k, x, qt);
u->pushupBox();
}
void rebuild(Node* &u, int k) {
A.clear(), A.reserve(u->size);
flatten(u);
u = build(k, 0, A.size()-1);
}
bool insert(Node* &u, int k, Pt x, Pt y, int d) {
if (!u) {
u = newNode(), u->pt = x, u->qt = y, u->pushup();
return d <= 0;
}
if (x == u->pt) {
u->used = 1, u->qt = y, u->pushup();
return d <= 0;
}
auto &v = x.d[k] < u->pt.d[k] ? u->lson : u->rson;
int t = insert(v, !k, x, y, d-1);
u->pushup();
if (t && !isbad(u, v))
return 1;
if (t) rebuild(u, k);
return 0;
}
bool remove(Node* &u, int k, Pt x, int d) {
if (!u)
return d <= 0;
if (x == u->pt) {
if (u->lson || u->rson)
u->used = 0, u->qt = Pt::NaN(), u->pushup();
else
freeNode(u), u = NULL;
return d <= 0;
}
auto &v = x.d[k] < u->pt.d[k] ? u->lson : u->rson;
int t = remove(v, !k, x, d-1);
u->pushup();
if (t && !isbad(u, v))
return 1;
if (t) rebuild(u, k);
return 0;
}
inline int cast(Node *u, double x, double y) {
if (u->qt.isNaN() || (u->pt.d[1] > y) == (u->qt.d[1] > y))
return 0;
double tx = u->interpolate(y);
if (tx <= x || tx > RAY_X)
return 0;
RAY_X = tx, RAY_T = u->pt.d[1] < u->qt.d[1];
X.emplace_back(tx);
return 1;
}
Node* stk[128];
void raycast(Node *u, double x, double y) {
#define pushstk(u) {*p++ = u;}
Node **p = stk;
pushstk(u);
while (p > stk) {
u = *--p;
if (!u || !u->size || !u->box.raycast(x, RAY_X, y))
continue;
cast(u, x, y);
pushstk(u->rson);
pushstk(u->lson);
}
}
} mbrh;
int main() {
int n, m, cmd, x, pos;
double px, py;
scanf("%d %d", &n, &m);
mesh.read(n);
mlist.init(), mbrh.init();
for (int i = 0; i < m; i++) {
scanf("%d", &cmd);
if (cmd == 1) {
scanf("%d %d", &x, &pos);
mbrh.insert(mlist.insert(x, pos));
} else if (cmd == 2) {
scanf("%d", &x);
mbrh.remove(mlist.remove(x));
} else {
scanf("%lf %lf", &px, &py);
puts(mbrh.inside(px, py) ? "1" : "0");
}
}
return 0;
}
Read More +

動態樹 樹形避難所

題目描述請至 Zerojudge e003: 樹形避難所 I、e004: 樹形避難所 II 查閱詳細內容

樹形避難所 I

在一個樹形避難所中有 $N$ 個房間,待在充滿監視器房間的你,透過監視器的顯示發現存在一些未知的入侵者出現在某些房間。為了保護同伴,你可以選擇開啟或關閉房間之間的通道,而你也會收到來自於某個房間的同伴求救訊號,此時給予所有可能遇見的入侵者數量,以便同伴做好萬全的作戰準備。然而,操縱通道的控制器已不受限制,你只能眼睜睜地看著同伴與入侵者對抗,現在的你 … 做好準備了嗎?

  • 操作 1 $u$ $v$:將房間 $u$$v$ 的通道開啟
  • 操作 2 $u$ $v$:將房間 $u$$v$ 的通道關閉
  • 操作 3 $u$ $w$:更正房間 $u$$w$ 個入侵者
  • 操作 4 $u$:回答來自 $u$ 的求救信號,告知與其可能面臨到的入侵者個數

樹形避難所 II

由於上一個樹形避難所已經不再安全,全員轉移到下一個避難所,新的地方將不再是先前的平面構造,新的避難所建構在地下水層中,每一個房間可以在水中移動,並且打通到上一層的某一個房間。不幸地,新的入侵者更加地難纏,想保護大家的你,想藉由破壞某一個房間,將其相連的下層房間的入侵者一同殲滅,情局不斷地變化,哪一個才是最好的破壞手段呢 …

  • 操作 1 $u$ $v$:將房間 $u$ 與上層房間 $v$ 的通道開啟
  • 操作 2 $u$:關閉房間 $u$ 與上層的通道
  • 操作 3 $u$ $w$:更正房間 $u$$w$ 個入侵者
  • 操作 4 $u$:估算摧毀房間 $u$,可以殲滅的入侵者個數

分析

這一題對於樹的操作,牽涉到修改邊,修改點權,詢問整個連通的樹大小、以及子樹大小。可以考慮使用 Link/Cut Tree 完成。也有高手使用了離線算法,對操作分治後計算答案,將一個邊的存在時間記錄在某個時間戳記上,整體落在 $\mathcal{O}(M \log M)$,由上而下合併所有存在的邊來完成單一詢問,需要搭配啟發式合併,來指查找根造成的退化,這一點相當地聰明。若強制在線,仍需要使用動態樹完成之。

對於第一題,只有包含前三個操作,可以當作一個無根樹操作,因此在 LCT 中的 $\text{MakeRoot}$ 打上反轉標記,無視原本的父子關係,額外維護一個虛邊上的子樹大小,如此一來就可以計算整個子樹大小。其餘的點權修改就相較於容易許多。對於第二題,便要求有根樹,因此在轉到根節點時,不能打上反轉標記,此時的左子樹為父節點,扣除掉父節點的大小後,便可以得到子樹大小。

  • 動態樹解法:空間複雜度 $\mathcal{O}(N)$,操作複雜度 $\text{Amortized} \; \mathcal{O}(\log N)$
  • 離線解法:空間複雜度 $\mathcal{O}(M \log M)$,整體時間複雜度 $\mathcal{O}(M \log M)$

延伸問題

如果子樹存在等價關係,意即他們會被一個操作同時受到影響,那麼更新會退化嗎?

如果這個問題能被解決,那麼使用相同的概念,便能解決更多的高壓縮架構問題。工作就能輕鬆一些了。

參考解答

樹形避難所 I

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#include <bits/stdc++.h>
using namespace std;
class LCT { // Link-Cut Tree
public:
static const int MAXN = 30005;
struct Node;
static Node *EMPTY;
static Node _mem[MAXN];
static int bufIdx;
struct Node {
Node *ch[2], *fa;
int rev;
int vsize, size, val;
void init() {
ch[0] = ch[1] = fa = NULL;
size = 0, vsize = 0, rev = 0;
}
bool is_root() {
return fa->ch[0] != this && fa->ch[1] != this;
}
void pushdown() {
if (rev) {
ch[0]->rev ^= 1;
ch[1]->rev ^= 1;
swap(ch[0], ch[1]);
rev = 0;
}
}
void pushup() {
if (this == EMPTY)
return;
size = ch[0]->size + ch[1]->size + val + vsize;
}
};
LCT() {
EMPTY = &_mem[0];
EMPTY->fa = EMPTY->ch[0] = EMPTY->ch[1] = EMPTY;
EMPTY->size = 0;
bufIdx = 1;
}
void init() {
bufIdx = 1;
}
Node* newNode() {
Node *u = &_mem[bufIdx++];
u->init();
u->fa = u->ch[0] = u->ch[1] = EMPTY;
return u;
}
void rotate(Node *x) {
Node *y;
int d;
y = x->fa, d = y->ch[1] == x ? 1 : 0;
x->ch[d^1]->fa = y, y->ch[d] = x->ch[d^1];
x->ch[d^1] = y;
if (!y->is_root())
y->fa->ch[y->fa->ch[1] == y] = x;
x->fa = y->fa, y->fa = x;
y->pushup(), x->pushup();
}
void deal(Node *x) {
if (!x->is_root()) deal(x->fa);
x->pushdown();
}
void splay(Node *x) {
Node *y, *z;
deal(x);
while (!x->is_root()) {
y = x->fa, z = y->fa;
if (!y->is_root()) {
if (y->ch[0] == x ^ z->ch[0] == y)
rotate(x);
else
rotate(y);
}
rotate(x);
}
x->pushdown();
}
Node* access(Node *u) {
Node *v = EMPTY;
for (; u != EMPTY; u = u->fa) {
splay(u);
u->vsize += u->ch[1] != EMPTY ? u->ch[1]->size : 0;
u->vsize -= v != EMPTY ? v->size : 0;
u->ch[1] = v;
u->pushup();
v = u;
}
return v;
}
void mk_root(Node *u) {
access(u)->rev ^= 1, splay(u);
}
void cut(Node *x, Node *y) {
mk_root(x);
access(y), splay(y);
// debug(10);
assert(y->ch[0] == x);
y->ch[0] = x->fa = EMPTY;
y->pushup();
}
void link(Node *x, Node *y) {
mk_root(y);
access(x), splay(x);
y->fa = x;
x->vsize += y->size;
x->pushup();
}
Node* find(Node *x) {
access(x), splay(x);
for (; x->ch[0] != EMPTY; x = x->ch[0]);
return x;
}
void set(Node *x, int val) {
mk_root(x);
x->val = val;
x->pushup();
}
int get(Node *u) {
mk_root(u);
return u->size;
}
int same(Node *x, Node *y) {
return find(x) == find(y);
}
void debug(int n) {
return;
puts("==================");
for (int i = 1; i <= n; i++) {
Node *u = &_mem[i];
printf("[%d] %d, %d %d, %d %d %d\n", i, u->fa-_mem, u->ch[0]-_mem, u->ch[1]-_mem, u->size, u->vsize, u->val);
}
}
} lct;
LCT::Node *LCT::EMPTY, LCT::_mem[LCT::MAXN];
int LCT::bufIdx;
LCT::Node *node[LCT::MAXN];
int main() {
int n, m;
while (scanf("%d %d", &n, &m) == 2) {
lct.init();
int cmd, u, v, w;
for (int i = 1; i <= n; i++) {
scanf("%d", &w);
node[i] = lct.newNode();
lct.set(node[i], w);
}
for (int i = 0; i < m; i++) {
scanf("%d", &cmd);
if (cmd == 1) {
scanf("%d %d", &u, &v);
lct.link(node[u], node[v]);
} else if (cmd == 2) {
scanf("%d %d", &u, &v);
lct.cut(node[u], node[v]);
} else if (cmd == 3) {
scanf("%d %d", &u, &w);
lct.set(node[u], w);
} else if (cmd == 4) {
scanf("%d", &u);
int p = lct.get(node[u]);
printf("%d\n", p);
} else {
scanf("%d %d", &u, &v);
int f = lct.same(node[u], node[v]);
printf("%d\n", f);
}
lct.debug(10);
}
}
return 0;
}

樹形避難所 II

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#include <bits/stdc++.h>
using namespace std;
class LCT { // Link-Cut Tree
public:
static const int MAXN = 30005;
struct Node;
static Node *EMPTY;
static Node _mem[MAXN];
static int bufIdx;
struct Node {
Node *ch[2], *fa;
int vsize, size, val;
void init() {
ch[0] = ch[1] = fa = NULL;
size = 0, vsize = 0;
}
bool is_root() {
return fa->ch[0] != this && fa->ch[1] != this;
}
void pushdown() {
}
void pushup() {
if (this == EMPTY)
return;
size = ch[0]->size + ch[1]->size + val + vsize;
}
};
LCT() {
EMPTY = &_mem[0];
EMPTY->fa = EMPTY->ch[0] = EMPTY->ch[1] = EMPTY;
EMPTY->size = 0;
bufIdx = 1;
}
void init() {
bufIdx = 1;
}
Node* newNode() {
Node *u = &_mem[bufIdx++];
u->init();
u->fa = u->ch[0] = u->ch[1] = EMPTY;
return u;
}
void rotate(Node *x) {
Node *y;
int d;
y = x->fa, d = y->ch[1] == x ? 1 : 0;
x->ch[d^1]->fa = y, y->ch[d] = x->ch[d^1];
x->ch[d^1] = y;
if (!y->is_root())
y->fa->ch[y->fa->ch[1] == y] = x;
x->fa = y->fa, y->fa = x;
y->pushup(), x->pushup();
}
void deal(Node *x) {
if (!x->is_root()) deal(x->fa);
x->pushdown();
}
void splay(Node *x) {
Node *y, *z;
deal(x);
while (!x->is_root()) {
y = x->fa, z = y->fa;
if (!y->is_root()) {
if (y->ch[0] == x ^ z->ch[0] == y)
rotate(x);
else
rotate(y);
}
rotate(x);
}
x->pushdown();
}
Node* access(Node *u) {
Node *v = EMPTY;
for (; u != EMPTY; u = u->fa) {
splay(u);
u->vsize += u->ch[1] != EMPTY ? u->ch[1]->size : 0;
u->vsize -= v != EMPTY ? v->size : 0;
u->ch[1] = v;
u->pushup();
v = u;
}
return v;
}
void mk_root(Node *u) {
access(u), splay(u);
}
void cut(Node *x) {
access(x), splay(x);
x->ch[0]->fa = EMPTY;
x->ch[0] = EMPTY;
x->pushup();
}
void link(Node *x, Node *y) {
access(x), splay(x);
access(y), splay(y);
x->fa = y;
y->vsize += x->size;
y->pushup();
}
Node* find(Node *x) {
access(x), splay(x);
for (; x->ch[0] != EMPTY; x = x->ch[0]);
return x;
}
void set(Node *x, int val) {
mk_root(x);
x->val = val;
x->pushup();
}
int get(Node *u) {
mk_root(u);
int ret = u->size;
if (u->ch[0] != EMPTY)
ret -= u->ch[0]->size;
return ret;
}
int same(Node *x, Node *y) {
return find(x) == find(y);
}
void debug(int n) {
return;
puts("==================");
for (int i = 1; i <= n; i++) {
Node *u = &_mem[i];
printf("[%d] %d, %d %d, %d %d %d\n", i, u->fa-_mem, u->ch[0]-_mem, u->ch[1]-_mem, u->size, u->vsize, u->val);
}
}
} lct;
LCT::Node *LCT::EMPTY, LCT::_mem[LCT::MAXN];
int LCT::bufIdx;
LCT::Node *node[LCT::MAXN];
int main() {
int n, m;
while (scanf("%d %d", &n, &m) == 2) {
lct.init();
int cmd, u, v, w;
for (int i = 1; i <= n; i++) {
scanf("%d", &w);
node[i] = lct.newNode();
lct.set(node[i], w);
}
for (int i = 0; i < m; i++) {
scanf("%d", &cmd);
if (cmd == 1) {
scanf("%d %d", &u, &v);
lct.link(node[u], node[v]);
} else if (cmd == 2) {
scanf("%d", &u);
lct.cut(node[u]);
} else if (cmd == 3) {
scanf("%d %d", &u, &w);
lct.set(node[u], w);
} else if (cmd == 4) {
scanf("%d", &u);
int p = lct.get(node[u]);
printf("%d\n", p);
} else {
scanf("%d %d", &u, &v);
int f = lct.same(node[u], node[v]);
printf("%d\n", f);
}
lct.debug(4);
}
}
return 0;
}
Read More +

UVa 12254 - Electricity Connection

Problem

給一個 $8 \times 8$ 平面圖,上面至多有八個住家,我們目標要從發電廠出發,牽電到所有的住家,拉線跨過水路的花費 $pw$、一般陸路為 $pl$,求最少花費。

經典的斯坦納樹問題,但有別於一般的平面圖,使用歐基里德距離或者曼哈頓距離作為花費函數。接著讓我們細談如何進行常數優化。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2
0 10
H.W.WH..
..W.W...
..WGW...
........
........
........
........
........
0 0
H.W.WH..
..W.W...
..WGW...
........
........
........
........
........

Sample Output

1
2
Case 1: 12
Case 2: 7

Solution

斯坦納樹為 NP-hard 問題,因此沒有多項式解法,而要得到確切的最小化解,則必須透過動態規劃來完成。由於題目給定的範圍很小,首先將目標要連通完成的點集,壓縮成 $N$ 個位元,接著紀錄這個聯通元件的其中一個節點視為根。最後,得到狀態數為 $M \cdot 2^N$,可以參考 《「Steiner tree problem in graphs」斯坦納樹》 -日月卦長 的說明。

公式可以拆成兩種情況,第一種為從子集合併中著手,另一種為拓展連通元件 (替換根節點,但不改變目前已經連到的目標集合)。定義 dp[s][i] 為根 i,連通集合 s 的最小花費

  • $dp[S][i] = \min(dp[T][j]+dp[S−T][j]+\text{dist}(i, j):j \in V,T \subset S)$
  • $dp[S][i] = \min(dp[S][k] + \text{dist}(S, k))$

由上述的公式,我們便可知道複雜度為 $O(M \cdot 3^N)$

實作細節

  • 對於內存布局,有兩個選擇 dp[2^N][M] 或者是 dp[M][2^N],其中以 dp[2^N][M] 最為適合,在撰寫迴圈的時候,最內層的迴圈為替換根,這麼一來 cache miss 的機會就非常低。更容易透過 unroll loop 和向量化來運作。

    • 如果內存布局使用 dp[M][2^N],在撰寫向量化時,需要使用 gather 相關的指令,這部分只有在 AVX2 有,並不是每一個 online judge 都支援,而 latency 也算挺高的,等到哪天 CPU 架構換了,這解法可能才會快得起來。
  • 當我們窮舉子集合時,發現到公式有對稱性,便可只窮舉上半部。這樣可以加速 20% 的效能。

1
2
3
4
5
6
7
8
9
10
for (int i = 1, h; i < (1<<n); i++) {
if ((i&(-i)) == i)
h = i-1;
for (int k = (i-1)&i; k > h; k = (k-1)&i) {
for (int j = 0; j < 64; j++)
dp[i][j] = min(dp[i][j], dp[k][j]+dp[i^k][j]-w[j]);
}
relax(i);
}

基礎篇

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#pragma GCC target("avx")
#pragma GCC optimize ("O3")
#include <bits/stdc++.h>
using namespace std;
char g[8][16], w[64];
static const int dx[] = {0, 0, 1, -1};
static const int dy[] = {1, -1, 0, 0};
static int32_t dp[1<<8][64];
const int INF = 0x3f3f3f3f;
void relax(int s) {
static int8_t Q[1024];
uint64_t inq = 0;
int Qn = 0;
for (int i = 0; i < 64; i++) {
if (dp[s][i] != INF)
Q[Qn++] = i, inq |= 1ULL<<i;
}
for (int i = 0; i < Qn; i++) {
int u = Q[i];
int x = u>>3, y = u&7;
inq ^= 1ULL<<u;
for (int k = 0; k < 4; k++) {
int tx = x+dx[k], ty = y+dy[k];
if (tx < 0 || ty < 0 || tx >= 8 || ty >= 8)
continue;
int v = tx<<3|ty;
if (dp[s][v] > dp[s][u]+1+w[v]) {
dp[s][v] = dp[s][u]+1+w[v];
if (((inq>>v)&1) == 0) {
inq |= 1ULL<<v;
Q[Qn++] = v;
}
}
}
}
}
int main() {
int testcase, cases = 0;
int pl, pw;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &pl, &pw);
for (int i = 0; i < 8; i++)
scanf("%s", &g[i]);
int n = 0, root = 0;
int A[8];
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j++) {
int u = i<<3|j;
if (g[i][j] == 'H')
A[n++] = u, w[u] = 0;
else if (g[i][j] == 'G')
w[u] = 0, root = u;
else if (g[i][j] == 'W')
w[u] = pw;
else if (g[i][j] == '.')
w[u] = pl;
}
}
memset(dp, 0x3f, sizeof(dp));
for (int i = 0; i < n; i++)
dp[1<<i][A[i]] = 0;
for (int i = 1, h; i < (1<<n); i++) {
if ((i&(-i)) == i)
h = i-1;
for (int k = (i-1)&i; k > h; k = (k-1)&i) {
for (int j = 0; j < 64; j++)
dp[i][j] = min(dp[i][j], dp[k][j]+dp[i^k][j]-w[j]);
}
relax(i);
}
int ret = dp[(1<<n)-1][root];
printf("Case %d: %d\n", ++cases, ret);
}
return 0;
}

SSE 向量化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#pragma GCC target("avx")
#pragma GCC optimize ("O3")
#include <bits/stdc++.h>
#include <x86intrin.h>
using namespace std;
char g[8][16];
static const int dx[] = {0, 0, 1, -1};
static const int dy[] = {1, -1, 0, 0};
static int32_t w[64] __attribute__ ((aligned(16)));
static int32_t dp[1<<8][64] __attribute__ ((aligned(16)));
const int INF = 0x3f3f3f3f;
void relax(int s) {
static int8_t Q[1024];
uint64_t inq = 0;
int Qn = 0;
for (int i = 0; i < 64; i++) {
if (dp[s][i] != INF)
Q[Qn++] = i, inq |= 1ULL<<i;
}
for (int i = 0; i < Qn; i++) {
int u = Q[i];
int x = u>>3, y = u&7;
inq ^= 1ULL<<u;
for (int k = 0; k < 4; k++) {
int tx = x+dx[k], ty = y+dy[k];
if (tx < 0 || ty < 0 || tx >= 8 || ty >= 8)
continue;
int v = tx<<3|ty;
if (dp[s][v] > dp[s][u]+1+w[v]) {
dp[s][v] = dp[s][u]+1+w[v];
if (((inq>>v)&1) == 0) {
inq |= 1ULL<<v;
Q[Qn++] = v;
}
}
}
}
}
int main() {
int testcase, cases = 0;
int pl, pw;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &pl, &pw);
for (int i = 0; i < 8; i++)
scanf("%s", &g[i]);
int n = 0, root = 0;
int A[8];
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j++) {
int u = i<<3|j;
if (g[i][j] == 'H')
A[n++] = u, w[u] = 0;
else if (g[i][j] == 'G')
w[u] = 0, root = u;
else if (g[i][j] == 'W')
w[u] = pw;
else if (g[i][j] == '.')
w[u] = pl;
}
}
memset(dp, 0x3f, sizeof(dp));
for (int i = 0; i < n; i++)
dp[1<<i][A[i]] = 0;
for (int i = 1, h; i < (1<<n); i++) {
if ((i&(-i)) == i)
h = i-1;
for (int k = (i-1)&i; k > h; k = (k-1)&i) {
for (int j = 0; j+4 <= 64; j += 4) {
__m128i mv = _mm_load_si128((__m128i*) (dp[i]+j));
__m128i a = _mm_load_si128((__m128i*) (dp[k]+j));
__m128i b = _mm_load_si128((__m128i*) (dp[i^k]+j));
__m128i tm = _mm_add_epi32(a, b);
__m128i c = _mm_load_si128((__m128i*) (w+j));
__m128i tn = _mm_sub_epi32(tm, c);
__m128i mn = _mm_min_epi32(mv, tn);
_mm_store_si128((__m128i*) (dp[i]+j), mn);
}
// dp[i][j] = min(dp[i][j], dp[k][j]+dp[i^k][j]-w[j]);
}
relax(i);
}
int ret = dp[(1<<n)-1][root];
printf("Case %d: %d\n", ++cases, ret);
}
return 0;
}

AVX2 Gather

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/*
It doesn't work at UVA 2018/08/13 because server CPU does not support
AVX2 instruction set. Although we could pass the compiler, you still
get runtime error during executing an illegal instruction.
*/
#pragma GCC target("avx")
#pragma GCC target("avx2")
#pragma GCC optimize ("O3")
#include <bits/stdc++.h>
#include <x86intrin.h>
#include <avx2intrin.h>
using namespace std;
char g[8][16], w[64];
static const int dx[] = {0, 0, 1, -1};
static const int dy[] = {1, -1, 0, 0};
static int32_t dp[64][1<<8] __attribute__ ((aligned(16)));
const int INF = 0x3f3f3f3f;
void relax(int s) {
static int8_t Q[1024];
uint64_t inq = 0;
int Qn = 0;
for (int i = 0; i < 64; i++) {
if (dp[i][s] != INF)
Q[Qn++] = i, inq |= 1ULL<<i;
}
for (int i = 0; i < Qn; i++) {
int u = Q[i];
int x = u>>3, y = u&7;
inq ^= 1ULL<<u;
for (int k = 0; k < 4; k++) {
int tx = x+dx[k], ty = y+dy[k];
if (tx < 0 || ty < 0 || tx >= 8 || ty >= 8)
continue;
int v = tx<<3|ty;
if (dp[v][s] > dp[u][s]+1+w[v]) {
dp[v][s] = dp[u][s]+1+w[v];
if (((inq>>v)&1) == 0) {
inq |= 1ULL<<v;
Q[Qn++] = v;
}
}
}
}
}
int main() {
int testcase, cases = 0;
int pl, pw;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &pl, &pw);
for (int i = 0; i < 8; i++)
scanf("%s", &g[i]);
int n = 0, root = 0;
int A[8];
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j++) {
int u = i<<3|j;
if (g[i][j] == 'H')
A[n++] = u, w[u] = 0;
else if (g[i][j] == 'G')
w[u] = 0, root = u;
else if (g[i][j] == 'W')
w[u] = pw;
else if (g[i][j] == '.')
w[u] = pl;
}
}
memset(dp, 0x3f, sizeof(dp));
for (int i = 0; i < n; i++)
dp[A[i]][1<<i] = 0;
for (int i = 1, h; i < (1<<n); i++) {
if ((i&(-i)) == i)
h = i-1;
__attribute__ ((aligned(16))) static int subset1[1<<8] = {};
__attribute__ ((aligned(16))) static int subset2[1<<8] = {};
int sn = 0;
for (int k = (i-1)&i; k > h; k = (k-1)&i)
subset1[sn] = k, subset2[sn] = i^k, sn++;
while (sn&4)
subset1[sn] = 0, subset2[sn] = i, sn++;
for (int j = 0; j < 64; j++) {
int32_t mn = dp[j][i]+w[j];
__m128i mv = _mm_setr_epi32(mn, mn, mn, mn);
for (int t = 0; t <= sn; t += 4) {
int k;
__m128i a = _mm_load_si128((__m128i*) subset1+t);
__m128i b = _mm_load_si128((__m128i*) subset2+t);
__m128i t1 = _mm_i32gather_epi32(dp[j], a, 4);
__m128i t2 = _mm_i32gather_epi32(dp[j], b, 4);
__m128i tm = _mm_add_epi32(t1, t2);
mv = _mm_min_epi32(mv, tm);
}
__attribute__ ((aligned(16))) int32_t mr[4];
_mm_store_si128((__m128i*) mr, mv);
mn = min(min(mr[0], mr[1]), min(mr[2], mr[3]));
dp[j][i] = mn-w[j];
// mn = min(mn, dp[j][k]+dp[j][i^k]-ww);
}
relax(i);
}
int ret = dp[root][(1<<n)-1];
printf("Case %d: %d\n", ++cases, ret);
}
return 0;
}
Read More +

優化技巧 - 排序/優先隊列

主要問題

在算法問題中,時常結合到排序,而排序的常數也大大影響到我們整體的效能。

基數排序降低常數

若是一般原始型別的排序,可以透過基數排序 (radix sort)。從排序的範圍來決定是否要劃分 8-bit 一組一組。若範圍介於可容忍的 $[0, v]$,當 $v < n$ 時,直接開 $n$ 個 bucket 的方式更好。因為大多數的複雜度都大於線性 $O(n)$

非負整數基數排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void _radix_sort(Pt *A, int n) {
static Pt _tmp[MAXN];
const int CHUNK = 256;
static int C[CHUNK];
Pt *B = _tmp, *T;
for (int x = 0; x < 8; x++) {
const int d = x*8;
memset(C, 0, sizeof(C));
for (int i = 0; i < n; i++)
C[(A[i].x>>d)&(CHUNK-1)]++;
for (int i = 1; i < CHUNK; i++)
C[i] += C[i-1];
for (int i = n-1; i >= 0; i--)
B[--C[(A[i].x>>d)&(CHUNK-1)]] = A[i];
T = A, A = B, B = T;
}
}

浮點數基數排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
void radix_sort(Pt *A, int n) {
for (int i = 0; i < n; i++) {
int32_t &v = *((int32_t *) &(A[i].s));
if ((v>>31)&1)
v = ~v;
else
v = v | 0x80000000;
}
static Pt _tmp[MAXN*2];
static const int CHUNK = 256;
static int C[1<<8];
Pt *B = _tmp, *T;
for (int x = 0; x < 4; x++) {
const int d = x*8;
memset(C, 0, sizeof(C));
for (int i = 0; i < n; i++)
C[((*((int32_t *) &(A[i].s)))>>d)&(CHUNK-1)]++;
for (int i = 1; i < CHUNK; i++)
C[i] += C[i-1];
for (int i = n-1; i >= 0; i--)
B[--C[((*((int32_t *) &(A[i].s)))>>d)&(CHUNK-1)]] = A[i];
T = A, A = B, B = T;
}
for (int i = 0; i < n; i++) {
int32_t &v = *((int32_t *) &(A[i].s));
if ((v>>31)&1)
v = v & 0x7fffffff;
else
v = ~v;
}
}

慎選內建排序

內建排序常見的有 qsort, sort, stable_sort,我不推薦使用 qsort,因為它在很久以前找到了一個退化情況 A Killer Adversary for Quicksort - 1995,有些題目的測資會出這種特別案例,導致當年的萌新內心受創。除非只有 C 的情況,否則請避開 qsort。如果算法屬於調整某些元素後,再對整體進行排序,這時候 sortstable_sort 慢很多。當情況非常接近已經排序的時候,就使用 stable_sort

優先隊列

使用 priority queue 的方法通常有三種 set, priority_queue, heap

  • 功能最多的 set、次少的 priority_queue,最少的 heap
  • 效能方面則是 heap 最快、次著 priority_queue、最後 set
  • 代碼維護最容易的 set、最差的 heap

之前很排斥使用 priority_queue,原因在於撰寫 operator< 與我的邏輯相反,因此都偏愛使用 set,但是效能被拉出來的時候,又必須退回類似 C 的 make_heappop_heap … 的接口。

Read More +

UVa 12310 - Point Location

Problem

給定一平面上 $n$ 個點,接著拉 $m$ 個線段拼湊成不相交的區域,並將標記數個點位置所在的區域。接著有數個詢問「點在哪一個先前標記的同一區域內」。

這一問題常在幾何處理中出現,詳細可查閱 Wikipedia - Point location,在此不額外說明。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
14 16 5 5
0 0
30 0
40 0
60 0
60 50
0 50
20 40
10 10
30 10
30 20
40 20
50 30
50 40
30 40
1 2
2 9
9 8
8 7
7 9
9 10
10 11
11 3
3 4
4 5
5 6
6 1
12 13
13 14
14 12
2 3
20 20
10 20
35 10
45 39
1 60
28 11
29 14
34 7
40 38
70 1
0 0 0 0

Sample Output

1
2
3
4
5
1
2
3
4
5

Solution

Previously

在大學修計算幾何的時候,知道這一類的題目要透過梯形剖分的方式,搭配的資料結構類似四分樹,分隔的基準是一個線段,將線段兩個端點分別拉出無限延長的垂直線,但這方法對於線段本身是垂直時,處理情況比較複雜,那時構思了好幾天沒個底,只好放棄。現在,在 Cadence 公司上班快滿一年,解決幾何操作不在少數,成長的我應該能解決這一題了吧。

Detail

題目已經保證給予的線段除了端點外,彼此之間不會相交任何一點。便可以對所有線段建立空間分割相關的資料結構,好讓我們對於解決,從詢問點出發往 $y$ 無窮遠的射線,最後會打到哪一個線段。

處理的流程如下:

Step 1. 輸入線段

Step 2. 套上邊界,需要能包含所有詢問點

Step 3. 對每一個相連的線段集合,找出 $y$ 軸最高的點,放出輔助射線

Step 4. 對每一個頂點進行極角排序,進行繞行分組

首先,對於幾何剖分的問題,為處理方便,都會套上一個無限大的外框,把所有頂點包在裡面。你可以透過 GeoGebra 這套數學軟體,將處理過程中的資料丟進去,方便你除錯。如果要批次輸入一大筆,請透過上方的建立按鈕 (Button),將 command 的語法以行為單位放入,語法類似 javascript。

接著,我們需要對每一個線段集合的最高點放出輔助射線 (這部分已經縮減了不少,原則上對每一個頂點放出射線也可以),這些輔助射線用來解決內部的孤立形,要把整個圖串成一個連通圖。否則,當詢問屬於內部的孤立形狀的外部時,我們會缺少足夠的資訊連到包含這個孤立形的外框。

最後,對每一個頂點相連的邊進行極角排序,接著才能決定相鄰的邊,而任兩個相鄰的邊所夾的區域屬於同一個集合,因此我們需要對每一個有方向的邊進行編號,將邊與邊合併到同一個集合。對於每一個詢問,只需要對詢問點放出射線找到接觸的線段,便可以知道所在的集合。

  • 若使用 KD tree,即使交替選擇 $x$-$y$ 軸分割,也沒辦法保證樹高。因為處理的是線段,而不是點,每一次挑選的分隔軸,將會產生三個子樹,中間的子樹為與分隔軸相交的線段。在大部分的情況下,整體時間複雜度落在 $O(n \log n)$,空間複雜度 $O(n)$
  • 若使用線段樹,將在 $x$ 軸上劃分,每一個線段至多被拆成 $O(\log n)$ 個,詢問射線第一個碰觸的線段時,單一詢問的時間複雜度 $O(\log n)$,常數相較於 KD tree 多。整體時間複雜度落在 $O(n \log n)$,空間複雜度 $O(n \log n)$

還有許多 spatial data structure 可以考慮,這裡只挑選兩個出來實驗,上述皆為在線算法。也可以將所有操作離線,這時候將套用掃描線算法,但對於垂直線段如何在算法中維護二元樹,目前遇到一些實作上的問題,這將在未來才能給各位參考。

KD tree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
#include <bits/stdc++.h>
using namespace std;
const float eps = 1e-6;
const float BOX_MX = 2000000;
struct Pt {
float x, y;
Pt() {}
Pt(float a, float b): x(a), y(b) {}
Pt operator-(const Pt &a) const { return Pt(x - a.x, y - a.y); }
Pt operator+(const Pt &a) const { return Pt(x + a.x, y + a.y); }
Pt operator*(const double a) const { return Pt(x * a, y * a); }
bool operator<(const Pt &a) const {
if (fabs(x - a.x) > eps) return x < a.x;
if (fabs(y - a.y) > eps) return y < a.y;
return false;
}
bool operator==(const Pt &a) const {
return fabs(x - a.x) < eps && fabs(y - a.y) < eps;
}
void println() const {
printf("(%lf, %lf)\n", x, y);
}
};
struct Seg {
Pt s, e;
int i;
Seg() {}
Seg(Pt a, Pt b, int i):s(a), e(b), i(i) {}
void println() {
printf("Segment((%lf, %lf), (%lf, %lf))\n", s.x, s.y, e.x, e.y);
}
};
double dot(Pt a, Pt b) {
return a.x * b.x + a.y * b.y;
}
double cross(Pt o, Pt a, Pt b) {
return (a.x-o.x)*(b.y-o.y)-(a.y-o.y)*(b.x-o.x);
}
double cross2(Pt a, Pt b) {
return a.x * b.y - a.y * b.x;
}
int between(Pt a, Pt b, Pt c) {
return dot(c - a, b - a) >= -eps && dot(c - b, a - b) >= -eps;
}
int onSeg(Pt a, Pt b, Pt c) {
return between(a, b, c) && fabs(cross(a, b, c)) < eps;
}
int cmpZero(float v) {
if (fabs(v) > eps) return v > 0 ? 1 : -1;
return 0;
}
Pt getIntersect(Seg a, Seg b) {
Pt u = a.s - b.s;
double t = cross2(b.e - b.s, u)/cross2(a.e - a.s, b.e - b.s);
return a.s + (a.e - a.s) * t;
}
struct AngleCmp {
Pt o;
AngleCmp(Pt o = Pt()):o(o) {}
bool operator() (const pair<Pt, int>& ppa, const pair<Pt, int>& ppb) {
Pt pa = ppa.first, pb = ppb.first;
Pt p1 = pa - o, p2 = pb - o;
if (p1.y == 0 && p2.y == 0 && p1.x * p2.x <= 0) return p1.x > p2.x;
if (p1.y == 0 && p1.x >= 0 && p2.y != 0) return true;
if (p2.y == 0 && p2.x >= 0 && p1.y != 0) return false;
if (p1.y * p2.y < 0) return p1.y > p2.y;
double c = cross2(p1, p2);
return c > 0 || (c == 0 && fabs(p1.x) < fabs(p2.x));
}
};
static Pt pts[10005];
static Seg segs[30005];
struct BSP {
static const int MAXN = 60005;
static const int MAXNODE = 60005;
struct Node {
float lx, ly, rx, ry;
Node *ls, *ms, *rs;
void extend(Node *u) {
if (u == NULL)
return ;
lx = min(lx, u->lx), ly = min(ly, u->ly);
rx = max(rx, u->rx), ry = max(ry, u->ry);
}
} nodes[MAXNODE];
int sn[MAXNODE];
Seg *seg[MAXNODE];
float axis[MAXN];
Node *root;
int size;
Node* newNode() {
Node *p = &nodes[size++];
assert(size < MAXNODE);
p->ls = p->ms = p->rs = NULL;
sn[p-nodes] = 0;
return p;
}
Node* _build(int k, Seg segs[], int n) {
if (n == 0)
return NULL;
if (k == 2)
k = 0;
int Lsize = 0, Msize = 0, Rsize = 0;
Seg *L = NULL, *M = NULL, *R = NULL;
if (k == 0) {
for (int i = 0; i < n; i++)
axis[i<<1] = segs[i].s.x, axis[i<<1|1] = segs[i].e.x;
nth_element(axis, axis+n, axis+2*n);
const float mval = axis[n];
L = segs;
R = std::partition(segs, segs+n, [mval](const Seg &s) {
return max(s.s.x, s.e.x) <= mval;
});
M = std::partition(R, segs+n, [mval](const Seg &s) {
return min(s.s.x, s.e.x) <= mval;
});
Msize = segs+n - M;
Rsize = M - R;
Lsize = R - segs;
} else {
for (int i = 0; i < n; i++)
axis[i<<1] = segs[i].s.y, axis[i<<1|1] = segs[i].e.y;
nth_element(axis, axis+n, axis+2*n);
const float mval = axis[n];
L = segs;
R = std::partition(segs, segs+n, [mval](const Seg &s) {
return max(s.s.y, s.e.y) <= mval;
});
M = std::partition(R, segs+n, [mval](const Seg &s) {
return min(s.s.y, s.e.y) <= mval;
});
Msize = segs+n - M;
Rsize = M - R;
Lsize = R - segs;
}
Node *u = newNode();
u->lx = BOX_MX, u->ly = BOX_MX;
u->rx = -BOX_MX, u->ry = -BOX_MX;
if (Lsize == n || Rsize == n || Msize == n) {
sn[u - nodes] = n, seg[u - nodes] = segs;
for (int i = 0; i < n; i++) {
u->lx = min(u->lx, min(segs[i].s.x, segs[i].e.x));
u->rx = max(u->rx, max(segs[i].s.x, segs[i].e.x));
u->ly = min(u->ly, min(segs[i].s.y, segs[i].e.y));
u->ry = max(u->ry, max(segs[i].s.y, segs[i].e.y));
}
} else {
u->ls = _build(k+1, L, Lsize), u->extend(u->ls);
u->ms = _build(k+1, M, Msize), u->extend(u->ms);
u->rs = _build(k+1, R, Rsize), u->extend(u->rs);
}
return u;
}
void build_tree(Seg s[], int m) {
size = 0;
root = _build(0, s, m);
}
Pt q_st, q_ed;
int q_si;
void rayhit(Seg &seg) {
if (seg.s.x == seg.e.x) {
if (cmpZero(seg.s.x - q_st.x) == 0) {
double low = min(seg.s.y, seg.e.y);
if (low > q_st.y && low < q_ed.y) {
q_ed.y = low;
q_si = seg.i;
}
}
return ;
}
if (max(seg.s.x, seg.e.x) < q_st.x || min(seg.s.x, seg.e.x) > q_st.x)
return ;
float y = seg.s.y + (float) (seg.e.y - seg.s.y) * (q_st.x - seg.s.x) / (seg.e.x - seg.s.x);
if (y > q_st.y && y < q_ed.y) {
q_ed.y = y;
q_si = seg.i;
}
}
void search(Node *u) {
if (u == NULL)
return ;
if (u->lx > q_st.x || u->rx < q_st.x || u->ry <= q_st.y || u->ly >= q_ed.y)
return ;
for (int i = 0; i < sn[u - nodes]; i++)
rayhit(seg[u - nodes][i]);
search(u->ls);
search(u->ms);
search(u->rs);
}
pair<int, Pt> raycast(Pt st) {
q_st = st;
q_ed = Pt(st.x, BOX_MX+1);
q_si = -1;
search(root);
return {q_si, q_ed};
}
} tree;
struct Disjoint {
static const int MAXN = 65536;
uint16_t parent[MAXN], weight[MAXN];
void init(int n) {
if (n >= MAXN)
exit(0);
for (int i = 0; i <= n; i++)
parent[i] = i, weight[i] = 1;
}
int findp(int x) {
return parent[x] == x ? x : parent[x] = findp(parent[x]);
}
int joint(int x, int y) {
x = findp(x), y = findp(y);
if (weight[x] >= weight[y])
parent[y] = x, weight[x] += weight[y];
else
parent[x] = y, weight[y] += weight[x];
}
} egroup, sgroup;
int main() {
int n, m, p, q;
while (scanf("%d %d %d %d", &n, &m, &p, &q) == 4 && n) {
for (int i = 0; i < n; i++) {
int x, y;
scanf("%d %d", &x, &y);
pts[i] = Pt(x, y);
}
sgroup.init(n);
for (int i = 0; i < m; i++) {
int st_i, ed_i;
scanf("%d %d", &st_i, &ed_i);
st_i--, ed_i--;
segs[i] = Seg(pts[st_i], pts[ed_i], i);
sgroup.joint(st_i, ed_i);
}
segs[m] = Seg(Pt(BOX_MX, BOX_MX), Pt(BOX_MX, -BOX_MX), m), m++;
segs[m] = Seg(Pt(BOX_MX, -BOX_MX), Pt(-BOX_MX, -BOX_MX), m), m++;
segs[m] = Seg(Pt(-BOX_MX, -BOX_MX), Pt(-BOX_MX, BOX_MX), m), m++;
segs[m] = Seg(Pt(-BOX_MX, BOX_MX), Pt(BOX_MX, BOX_MX), m), m++;
static map<Pt, vector<pair<Pt, uint8_t>>> g; g.clear();
static vector<vector<Pt>> on_seg; on_seg.clear(); on_seg.resize(m);
for (int i = 0; i < m; i++) {
on_seg[segs[i].i].reserve(4);
on_seg[segs[i].i].push_back(segs[i].s);
on_seg[segs[i].i].push_back(segs[i].e);
}
// for (int i = 0; i < n; i++)
// pts[i].println();
// for (int i = 0; i < m; i++)
// segs[i].println();
tree.build_tree(segs, m);
{
Pt top[10005];
for (int i = 0; i < n; i++)
top[i] = pts[i];
for (int i = 0; i < n; i++) {
int gid = sgroup.findp(i);
if (top[gid].y < pts[i].y)
top[gid] = pts[i];
}
for (int i = 0; i < n; i++) {
if (sgroup.findp(i) != i)
continue;
auto p = top[i];
auto hit = tree.raycast(p);
if (hit.first >= 0) {
on_seg[hit.first].emplace_back(hit.second);
g[p].emplace_back(hit.second, 1);
g[hit.second].emplace_back(p, 1);
}
}
}
for (int i = 0; i < m; i++) {
vector<Pt> &a = on_seg[i];
sort(a.begin(), a.end());
a.resize(unique(a.begin(), a.end()) - a.begin());
auto *prev = &g[a[0]];
for (int j = 1; j < a.size(); j++) {
prev->emplace_back(a[j], 0);
prev = &g[a[j]];
prev->emplace_back(a[j-1], 0);
}
}
for (auto &e : g)
sort(e.second.begin(), e.second.end(), AngleCmp(e.first));
static map<Pt, map<Pt, int>> R; R.clear();
int Rsize = 0;
for (auto &e : g) {
int sz = e.second.size();
map<Pt, int> &Rg = R[e.first];
for (auto &f : e.second) {
int &eid = Rg[f.first];
if (eid == 0)
eid = ++Rsize;
}
}
egroup.init(Rsize);
for (auto &e : g) {
int sz = e.second.size();
map<Pt, int> &Rg = R[e.first];
for (int i = sz-1, j = 0; j < sz; i = j++) {
int l = R[e.second[i].first][e.first];
int r = Rg[e.second[j].first];
egroup.joint(l, r);
if (e.second[i].second != 0) {
r = Rg[e.second[i].first];
assert(l > 0 && r > 0);
egroup.joint(l, r);
}
}
}
for (auto &e : g) {
int sz = e.second.size();
int n = 0;
for (int i = 0; i < sz; i++) {
if (e.second[i].second == 0)
e.second[n++] = e.second[i];
}
e.second.resize(n);
}
static int region[65536]; memset(region, 0, sizeof(0)*Rsize);
for (int i = 0; i < p; i++) {
float x, y;
scanf("%f %f", &x, &y);
pair<int, Pt> hit = tree.raycast(Pt(x, y));
if (hit.first < 0)
continue;
if (g.find(hit.second) != g.end()) {
auto &adj = g[hit.second];
AngleCmp cmp(hit.second);
int pos = 0;
pair<Pt, int> q(Pt(x, y), i+1);
pos = lower_bound(adj.begin(), adj.end(), q, cmp) - adj.begin() - 1;
assert(pos >= 0);
int l = R[adj[pos].first][hit.second];
assert(l > 0);
l = egroup.findp(l);
region[l] = i+1;
} else {
auto &adj = on_seg[hit.first];
Pt lpt = adj[0], rpt = adj[1];
int l;
if (cmpZero(cross(lpt, rpt, Pt(x, y))) <= 0)
l = R[lpt][rpt];
else
l = R[rpt][lpt];
assert(l > 0);
l = egroup.findp(l);
region[l] = i+1;
}
}
for (int i = 0; i < q; i++) {
float x, y;
scanf("%f %f", &x, &y);
pair<int, Pt> hit = tree.raycast(Pt(x, y));
if (hit.first < 0) {
printf("0\n");
continue;
}
if (g.find(hit.second) != g.end()) {
auto &adj = g[hit.second];
AngleCmp cmp(hit.second);
int pos = 0;
pair<Pt, int> q(Pt(x, y), i+1);
pos = lower_bound(adj.begin(), adj.end(), q, cmp) - adj.begin() - 1;
assert(pos >= 0);
int l = R[adj[pos].first][hit.second];
assert(l > 0);
l = egroup.findp(l);
printf("%d\n", region[l]);
} else {
auto &adj = on_seg[hit.first];
Pt lpt = adj[0], rpt = adj[1];
int l;
if (cmpZero(cross(lpt, rpt, Pt(x, y))) <= 0)
l = R[lpt][rpt];
else
l = R[rpt][lpt];
assert(l > 0);
l = egroup.findp(l);
printf("%d\n", region[l]);
}
}
}
return 0;
}

線段樹

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
#include <bits/stdc++.h>
using namespace std;
const double eps = 1e-8;
const double BOX_MX = (3000000);
const int AUXID = 200000;
struct Pt {
double x, y;
Pt(double a = 0, double b = 0): x(a), y(b) {}
Pt operator-(const Pt &a) const { return Pt(x - a.x, y - a.y); }
Pt operator+(const Pt &a) const { return Pt(x + a.x, y + a.y); }
Pt operator*(const double a) const { return Pt(x * a, y * a); }
bool operator<(const Pt &a) const {
if (fabs(x - a.x) > eps) return x < a.x;
if (fabs(y - a.y) > eps) return y < a.y;
return false;
}
bool operator==(const Pt &a) const {
return fabs(x - a.x) < eps && fabs(y - a.y) < eps;
}
void println() const {
printf("(%lf, %lf)\n", x, y);
}
};
struct Seg {
Pt s, e;
int i;
Seg() {}
Seg(Pt a, Pt b, int i):s(a), e(b), i(i) {}
void println() const {
printf("Segment((%lf, %lf), (%lf, %lf))\n", s.x, s.y, e.x, e.y);
}
};
double dot(Pt a, Pt b) {
return a.x * b.x + a.y * b.y;
}
double cross(Pt o, Pt a, Pt b) {
return (a.x-o.x)*(b.y-o.y)-(a.y-o.y)*(b.x-o.x);
}
double cross2(Pt a, Pt b) {
return a.x * b.y - a.y * b.x;
}
int between(Pt a, Pt b, Pt c) {
return dot(c - a, b - a) >= -eps && dot(c - b, a - b) >= -eps;
}
int onSeg(Pt a, Pt b, Pt c) {
return between(a, b, c) && fabs(cross(a, b, c)) < eps;
}
int cmpZero(double v) {
if (fabs(v) > eps) return v > 0 ? 1 : -1;
return 0;
}
Pt getIntersect(Seg a, Seg b) {
Pt u = a.s - b.s;
double t = cross2(b.e - b.s, u)/cross2(a.e - a.s, b.e - b.s);
return a.s + (a.e - a.s) * t;
}
struct AngleCmp {
Pt o;
AngleCmp(Pt o = Pt()):o(o) {}
bool operator() (const pair<Pt, int>& ppa, const pair<Pt, int>& ppb) {
Pt pa = ppa.first, pb = ppb.first;
Pt p1 = pa - o, p2 = pb - o;
if (p1.y == 0 && p2.y == 0 && p1.x * p2.x <= 0) return p1.x > p2.x;
if (p1.y == 0 && p1.x >= 0 && p2.y != 0) return true;
if (p2.y == 0 && p2.x >= 0 && p1.y != 0) return false;
if (p1.y * p2.y < 0) return p1.y > p2.y;
double c = cross2(p1, p2);
return c > 0 || (c == 0 && fabs(p1.x) < fabs(p2.x));
}
};
static Pt pts[10005];
static Seg segs[30005];
static double interpolate(const Pt& p1, const Pt& p2, double& x) {
if (cmpZero(p1.x - p2.x) == 0) return min(p1.y, p2.y);
return p1.y + (double)(p2.y - p1.y) / (p2.x - p1.x) * (x - p1.x);
}
struct CMP {
static double x;
bool operator() (const Seg &i, const Seg &j) {
double l = interpolate(i.s, i.e, x);
double r = interpolate(j.s, j.e, x);
return l < r;
}
};
double CMP::x;
struct BSP {
struct Node {
double ly, ry;
void init() {
ly = BOX_MX, ry = -BOX_MX;
}
void extend(Node *u) {
if (u == NULL)
return ;
ly = min(ly, u->ly);
ry = max(ry, u->ry);
}
} bbox[262144];
vector<double> x;
set<Seg, CMP> tree[524288];
Seg segs[262144];
void clear(int k, int l, int r) {
tree[k].clear();
bbox[k].init();
if (l+1 >= r)
return ;
int m = (l+r)/2;
clear(k<<1, l, m);
clear(k<<1|1, m, r);
}
void insert(Seg &s, int k, int l, int r) {
if (s.s.x <= x[l] && x[r] <= s.e.x) {
CMP::x = (x[l] + x[r])/2;
tree[k].insert(s);
double ly = interpolate(s.s, s.e, x[l]);
double ry = interpolate(s.s, s.e, x[r]);
bbox[k].ly = min(bbox[k].ly, min(ly, ry));
bbox[k].ry = max(bbox[k].ry, max(ly, ry));
return ;
}
if (l+1 >= r)
return ;
int m = (l+r)/2;
if (s.s.x <= x[m]) {
insert(s, k<<1, l, m);
bbox[k].extend(&bbox[k<<1]);
}
if (s.e.x > x[m]) {
insert(s, k<<1|1, m, r);
bbox[k].extend(&bbox[k<<1|1]);
}
}
void build_tree(Seg s[], int m) {
memcpy(segs, s, sizeof(segs[0])*m);
x.clear();
for (int i = 0; i < m; i++)
x.push_back(segs[i].s.x), x.push_back(segs[i].e.x);
sort(x.begin(), x.end());
x.resize(unique(x.begin(), x.end()) - x.begin());
clear(1, 0, x.size()-1);
for (int i = 0; i < m; i++) {
if (segs[i].s.x > segs[i].e.x)
swap(segs[i].s, segs[i].e);
if (segs[i].s.x < segs[i].e.x)
insert(segs[i], 1, 0, x.size()-1);
}
}
Pt q_st, q_ed;
int q_si;
void search(int k, int l, int r) {
if (bbox[k].ly >= q_ed.y || bbox[k].ry <= q_st.y)
return ;
if (x[l] <= q_st.x && q_st.x <= x[r]) {
CMP::x = q_st.x;
auto it = tree[k].upper_bound(Seg(q_st, q_st, 0));
while (it != tree[k].end()) {
double y = interpolate(it->s, it->e, CMP::x);
if (y > q_st.y) {
if (y < q_ed.y) {
q_ed.y = y;
q_si = it->i;
}
break;
}
it++;
}
}
if (l+1 >= r)
return ;
int m = (l+r)/2;
if (q_st.x <= x[m])
search(k<<1, l, m);
if (q_st.x >= x[m])
search(k<<1|1, m, r);
}
pair<int, Pt> raycast(Pt st) {
q_st = st;
q_ed = Pt(st.x, BOX_MX+1);
q_si = -1;
search(1, 0, x.size()-1);
return {q_si, q_ed};
}
} tree;
struct Disjoint {
static const int MAXN = 65536;
uint16_t parent[MAXN], weight[MAXN];
void init(int n) {
if (n >= MAXN)
exit(0);
for (int i = 0; i <= n; i++)
parent[i] = i, weight[i] = 1;
}
int findp(int x) {
return parent[x] == x ? x : parent[x] = findp(parent[x]);
}
int joint(int x, int y) {
x = findp(x), y = findp(y);
if (weight[x] >= weight[y])
parent[y] = x, weight[x] += weight[y];
else
parent[x] = y, weight[y] += weight[x];
}
} egroup, sgroup;
int main() {
int n, m, p, q;
while (scanf("%d %d %d %d", &n, &m, &p, &q) == 4 && n) {
for (int i = 0; i < n; i++) {
int x, y;
scanf("%d %d", &x, &y);
pts[i] = Pt(x, y);
}
sgroup.init(n);
for (int i = 0; i < m; i++) {
int st_i, ed_i;
scanf("%d %d", &st_i, &ed_i);
st_i--, ed_i--;
segs[i] = Seg(pts[st_i], pts[ed_i], i);
sgroup.joint(st_i, ed_i);
}
segs[m] = Seg(Pt(BOX_MX, BOX_MX), Pt(BOX_MX, -BOX_MX), m), m++;
segs[m] = Seg(Pt(BOX_MX, -BOX_MX), Pt(-BOX_MX, -BOX_MX), m), m++;
segs[m] = Seg(Pt(-BOX_MX, -BOX_MX), Pt(-BOX_MX, BOX_MX), m), m++;
segs[m] = Seg(Pt(-BOX_MX, BOX_MX), Pt(BOX_MX, BOX_MX), m), m++;
static map<Pt, vector<pair<Pt, uint8_t>>> g; g.clear();
static vector<vector<Pt>> on_seg; on_seg.clear(); on_seg.resize(m);
for (int i = 0; i < m; i++) {
on_seg[segs[i].i].reserve(4);
on_seg[segs[i].i].push_back(segs[i].s);
on_seg[segs[i].i].push_back(segs[i].e);
}
// for (int i = 0; i < n; i++)
// pts[i].println();
// for (int i = 0; i < m; i++)
// segs[i].println();
tree.build_tree(segs, m);
{
Pt top[10005];
for (int i = 0; i < n; i++)
top[i] = pts[i];
for (int i = 0; i < n; i++) {
int gid = sgroup.findp(i);
if (top[gid].y < pts[i].y)
top[gid] = pts[i];
}
for (int i = 0; i < n; i++) {
if (sgroup.findp(i) != i)
continue;
auto p = top[i];
auto hit = tree.raycast(p);
if (hit.first >= 0) {
on_seg[hit.first].emplace_back(hit.second);
g[p].emplace_back(hit.second, 1);
g[hit.second].emplace_back(p, 1);
}
}
}
for (int i = 0; i < m; i++) {
vector<Pt> &a = on_seg[i];
sort(a.begin(), a.end());
a.resize(unique(a.begin(), a.end()) - a.begin());
auto *prev = &g[a[0]];
for (int j = 1; j < a.size(); j++) {
prev->emplace_back(a[j], 0);
prev = &g[a[j]];
prev->emplace_back(a[j-1], 0);
}
}
for (auto &e : g)
sort(e.second.begin(), e.second.end(), AngleCmp(e.first));
static map<Pt, map<Pt, int>> R; R.clear();
int Rsize = 0;
for (auto &e : g) {
int sz = e.second.size();
map<Pt, int> &Rg = R[e.first];
for (auto &f : e.second) {
int &eid = Rg[f.first];
if (eid == 0)
eid = ++Rsize;
}
}
egroup.init(Rsize);
for (auto &e : g) {
int sz = e.second.size();
map<Pt, int> &Rg = R[e.first];
for (int i = sz-1, j = 0; j < sz; i = j++) {
int l = R[e.second[i].first][e.first];
int r = Rg[e.second[j].first];
egroup.joint(l, r);
if (e.second[i].second != 0) {
r = Rg[e.second[i].first];
assert(l > 0 && r > 0);
egroup.joint(l, r);
}
}
}
for (auto &e : g) {
int sz = e.second.size();
int n = 0;
for (int i = 0; i < sz; i++) {
if (e.second[i].second == 0)
e.second[n++] = e.second[i];
}
e.second.resize(n);
}
static int region[65536]; memset(region, 0, sizeof(0)*Rsize);
for (int i = 0; i < p; i++) {
float x, y;
scanf("%f %f", &x, &y);
pair<int, Pt> hit = tree.raycast(Pt(x, y));
if (hit.first < 0)
continue;
if (g.find(hit.second) != g.end()) {
auto &adj = g[hit.second];
AngleCmp cmp(hit.second);
int pos = 0;
pair<Pt, int> q(Pt(x, y), i+1);
pos = lower_bound(adj.begin(), adj.end(), q, cmp) - adj.begin() - 1;
assert(pos >= 0);
int l = R[adj[pos].first][hit.second];
assert(l > 0);
l = egroup.findp(l);
region[l] = i+1;
} else {
auto &adj = on_seg[hit.first];
Pt lpt = adj[0], rpt = adj[1];
int l;
if (cmpZero(cross(lpt, rpt, Pt(x, y))) <= 0)
l = R[lpt][rpt];
else
l = R[rpt][lpt];
assert(l > 0);
l = egroup.findp(l);
region[l] = i+1;
}
}
for (int i = 0; i < q; i++) {
float x, y;
scanf("%f %f", &x, &y);
pair<int, Pt> hit = tree.raycast(Pt(x, y));
if (hit.first < 0) {
printf("0\n");
continue;
}
if (g.find(hit.second) != g.end()) {
auto &adj = g[hit.second];
AngleCmp cmp(hit.second);
int pos = 0;
pair<Pt, int> q(Pt(x, y), i+1);
pos = lower_bound(adj.begin(), adj.end(), q, cmp) - adj.begin() - 1;
assert(pos >= 0);
int l = R[adj[pos].first][hit.second];
assert(l > 0);
l = egroup.findp(l);
printf("%d\n", region[l]);
} else {
auto &adj = on_seg[hit.first];
Pt lpt = adj[0], rpt = adj[1];
int l;
if (cmpZero(cross(lpt, rpt, Pt(x, y))) <= 0)
l = R[lpt][rpt];
else
l = R[rpt][lpt];
assert(l > 0);
l = egroup.findp(l);
printf("%d\n", region[l]);
}
}
}
return 0;
}
Read More +

UVa 13154 - Extreme XOR Sum

Problem

給一長度為 $N$ 的整數序列,詢問任意區間的極端異或和 Extreme XOR Sum。

定義 Extreme XOR Sum 為一系列操作的最後一個值

  • 當長度 $n>1$ 時,將陣列縮小為 $n-1$
  • $[a_0, a_1, a_2, \cdots, a_{n-1}]$,每一個元素與後一個元素運行互斥或,將會被轉換成 $[a_0 \oplus a_1, a_1\oplus a_2, a_2 \oplus a_3, \cdots, a_{n-2}\oplus a_{n-1}]$
  • 直到只剩下一個元素,即為 Extreme XOR Sum

Sample Input

1
2
3
4
5
6
7
1
5
1 4 6 7 8
3
0 0
0 1
2 4

Sample Output

1
2
3
4
Case 1:
1
5
14

Solution

這題詢問次數非常多,一般運行將對每一個詢問達到 $O(N)$ 的複雜度,這很容易得到 TLE。從大多的數據結構,如線段樹、塊狀表 … 等,他們提供高效率的查找效能,但也必須符合某些條件才能使用。因此,在此題若要符合結合律將變得相當困難。

觀察

假設要計算陣列 $[1, 4, 6, 7, 8]$ 的值時

  • 第一步,$[1 \oplus 4, 4 \oplus 6, 6 \oplus 7, 7 \oplus 8]$
  • 第二步,$[1 \oplus 4 \oplus 4 \oplus 6, \cdots]$
  • 如此類推下去,XOR 有結合律,我們發現到各別使用了 1 次 $a_0$、4 次 $a_1$、6 次 $a_2$、4 次 $a_3$ 和 1 次 $a_4$

對於不同的長度,我們發現到是二項係數的配對情況。由於偶數次的 XOR 會互消,只需要計算出現奇數次的即可,因此我們列出二項次數模二的情況,進而得到 Sierpinski triangle/Sierpinski sieve。即使知道 Sierpinski sieve 是二項係數模二的結果,我們仍不知道要怎麼套用結合律達到剖分加速的條件。

二項係數的公式如下

$$\begin{align*} \binom{n}{m} &= \binom{n-1}{m-1} + \binom{n-1}{m} \\ &= \frac{n!}{m!(n-m)!} \end{align*}$$

階層運算在數學運算上的性質並不多,逼得我們只好往碎形上觀察,以下列出前幾項的結果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
1
1 1
1 0 1
1 1 1 1
1 0 0 0 1
1 1 0 0 1 1
1 0 1 0 1 0 1
1 1 1 1 1 1 1 1
1 0 0 0 0 0 0 0 1
1 1 0 0 0 0 0 0 1 1
1 0 1 0 0 0 0 0 1 0 1
1 1 1 1 0 0 0 0 1 1 1 1
1 0 0 0 1 0 0 0 1 0 0 0 1
1 1 0 0 1 1 0 0 1 1 0 0 1 1
1 0 1 0 1 0 1 0 1 0 1 0 1 0 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

發現它是一個很有趣的碎形,每個三角形大小都是以二的冪次的。我們按照 $2^3 = 8$ 切割一下上圖,並且把右邊斜的補上 0 得到下圖。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
1
1 1
1 0 1
1 1 1 1
1 0 0 0 1
1 1 0 0 1 1
1 0 1 0 1 0 1
1 1 1 1 1 1 1 1
^---------------
1 0 0 0 0 0 0 0 |1 0 0 0 0 0 0 0
1 1 0 0 0 0 0 0 |1 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 |1 0 1 0 0 0 0 0
1 1 1 1 0 0 0 0 |1 1 1 1 0 0 0 0
1 0 0 0 1 0 0 0 |1 0 0 0 1 0 0 0
1 1 0 0 1 1 0 0 |1 1 0 0 1 1 0 0
1 0 1 0 1 0 1 0 |1 0 1 0 1 0 1 0
1 1 1 1 1 1 1 1 |1 1 1 1 1 1 1 1
^----------------^--------------
1 0 0 0 0 0 0 0 |0 0 0 0 0 0 0 0 |1 0 0 0 0 0 0 0
1 1 0 0 0 0 0 0 |0 0 0 0 0 0 0 0 |1 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 |0 0 0 0 0 0 0 0 |1 0 1 0 0 0 0 0
1 1 1 1 0 0 0 0 |0 0 0 0 0 0 0 0 |1 1 1 1 0 0 0 0
1 0 0 0 1 0 0 0 |0 0 0 0 0 0 0 0 |1 0 0 0 1 0 0 0
1 1 0 0 1 1 0 0 |0 0 0 0 0 0 0 0 |1 1 0 0 1 1 0 0
1 0 1 0 1 0 1 0 |0 0 0 0 0 0 0 0 |1 0 1 0 1 0 1 0
1 1 1 1 1 1 1 1 |0 0 0 0 0 0 0 0 |1 1 1 1 1 1 1 1
^ ^ ^
箭頭表示本身也是 Sierpinski sieve
區塊縮影得到 miniature pattern 也是 Sierpinski sieve
1
1 1
1 0 1

得到數個一模一樣的子圖,上述全零和非零的區塊,又恰好構成 Sierpinski sieve。這告訴我們任何操作全都要以二的冪次為基準,且合併區段時須以二項係數為係數。設定 pattern 大小為 $M=2^{\lceil \log_2 N\rceil}$,最後得到 miniature pattern。在同一層中,非零構成的條紋都是相同的模式,例如上述得圖中,最後一層的箭號組合必然是 101 或者是 000,最後得到下列公式計算條紋。

$A_{i, j} = A_{i-1}{j} \oplus A_{i-1,j+M}$

接下來,我們將需要確定每一個條紋 (stripe) 是否使用全零或者非零,只需要查找 miniature pattern 相應的係數即可。

如何在 Sierpinski sieve 找到非零係數的位置

$\binom{n}{i} \mod 2 = 1$,必滿足 $n\&i = i$。其證明從數學歸納法來,由二冪次的長度碎形著手,移除最高位的 1 得到 $i'$,從 $i'$ 舊有位置集合,保留此集合,並對每一個元素增加二的冪次得到碎形的另一邊。

故可利用下述算法,準確地找到每一個非零的係數位置

1
2
for (int pos = n; pos; pos = (pos-1)&n)
C[n][pos] mod 2 = 1

最後附上優化後得到 Rank 1 的程序 0.040 s

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <bits/stdc++.h>
using namespace std;
static const int M = (1<<7);
static const int MAXN = 10005;
static int A[M+1][MAXN];
void miniature(int n) {
for (int i = 1; i*M < n; i++) {
for (int j = 0; j+i*M < n; j++)
A[i][j] = A[i-1][j] ^ A[i-1][j+M];
}
}
int extract(int l, int r) {
const int n = r-l;
const int m = n/M;
const int o = n%M;
int ret = A[m][l];
for (int i = o; i; i = (i-1)&o)
ret ^= A[m][l+i];
return ret;
}
namespace MM {
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
class Print {
public:
static const int N = 1048576;
char buf[N], *p, *end;
Print() {
p = buf, end = buf + N - 1;
}
void printInt(int x, char padding) {
static char stk[16];
int idx = 0;
stk[idx++] = padding;
if (!x)
stk[idx++] = '0';
while (x)
stk[idx++] = x%10 + '0', x /= 10;
while (idx) {
if (p == end) {
*p = '\0';
printf("%s", buf), p = buf;
}
*p = stk[--idx], p++;
}
}
void flush() {
*p = '\0', p = buf;
printf("%s", buf);
}
static inline void online_printInt(int x) {
static char ch[16];
static int idx;
idx = 0;
if (x == 0) ch[++idx] = 0;
while (x > 0) ch[++idx] = x % 10, x /= 10;
while (idx)
putchar(ch[idx--]+48);
}
~Print() {
*p = '\0';
printf("%s", buf);
}
} bprint;
}
int main() {
int testcase, cases = 0;
int n, m;
// scanf("%d", &testcase);
MM::ReadInt(&testcase);
while (testcase--) {
// scanf("%d", &n);
MM::ReadInt(&n);
for (int i = 0; i < n; i++)
// scanf("%d", &A[0][i]);
MM::ReadInt(&A[0][i]);
miniature(n);
// scanf("%d", &m);
MM::ReadInt(&m);
printf("Case %d:\n", ++cases);
for (int i = 0; i < m; i++) {
int l, r;
// scanf("%d %d", &l, &r);
MM::ReadInt(&l), MM::ReadInt(&r);
// printf("%d\n", extract(l, r));
MM::bprint.printInt(extract(l, r), '\n');
}
MM::bprint.flush();
}
return 0;
}
/*
1
5
1 4 6 7 8
3
0 0
0 1
2 4
*/
Read More +

關於高效大數除法的那些事

前情提要

從國小開始學起加減乘除,腦海裡計算加法比減法簡單,減法可以換成加法,乘法則可以換成數個加法。即使到了中國最強的利器——珠心算,我們學習這古老的快速運算時,也都採取這類方法,而除法最為複雜,需要去估算商,嘗試去扣完,再做細微調整。然而,當整數位數為 $N$ 時,加減法效能為 $N$ 基礎運算,而乘除法為 $N^2$ 次基礎運算,一次基礎運算通常定義在查表,例如背誦的九九乘法表,使得我們在借位和乘積運算達到非常快速。

這些方法在計算機發展後,硬體實作大致使用這些算法,直到了快速傅立葉 (FFT) 算法能在 $O(N \log N)$ 時間內完成旋積 (卷積) 計算,順利地解決了多項式乘法的問題,使得大數乘法能在 $O(N \log N)$ 時間內完成。

一開始我們知道乘法和除法都可以在 $O(N^2)$ 時間內完成,有了 FFT 之後,除法是不是也能跟乘法一樣在 $O(N \log N)$ 內完成?

大數除法

就目前研究來看,除法可以轉換成乘法運算,在數論的整數域中,若存在乘法反元素,除一個數相當於成一個數,而我們發展的計算機,有效運算為 32/64-bit,其超過的部分視為溢位 (overflow),溢位則可視為模數。因此,早期 CPU 設計中,除法需要的運行次數比乘法多一些,編譯器會將除法轉換乘法運算 (企圖找到反元素),來加速運算。現在,由於 intel 的黑魔法,導致幾乎一模一樣快而沒人注意到這些瑣事。

回過頭來,我們介紹一個藉由 FFT 達到的除法運算 $O(N \log N \log N)$ 的算法。而那多的 $O(\log N)$ 來自於牛頓迭代法。目標算出 $C = \lfloor A/B \rfloor$,轉換成 $C = \lfloor A \ast (1/B) \rfloor$,快速算出 $1/B$ 的值,採用牛頓迭代法。

額外要求整數長度 $n = \text{length}(A)$$m = \text{length}(B)$,當計算 $1/B$ 時,要求精準度到小數點後 $n$ 位才行。

牛頓迭代法

目標求 $f(x) = 0$ 的一組解。

$$\begin{align} x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)} \end{align}$$

不斷地迭代逼近找解。若有多組解,則依賴初始值 $x_0$ 決定優先找到哪一組解。

這部分也應用在快速開根號——反平方根快速演算法 Fast inverse square root

倒數計算

藉由下述定義,找到 $x = 1/B$

$$\begin{align} f(x) = \frac{1}{x} - B, \; f'(x) = -\frac{1}{x^2} \end{align}$$

接著推導牛頓迭代法的公式

$$\begin{align} x_{n+1} &= x_n - \frac{f(x_n)}{f'(x_n)} \\ &= x_n - \frac{\frac{1}{x} - B}{-\frac{1}{x^2}} = (2 - x_n B) x_n \end{align}$$

運行範例

$A = 7123456, \; B = 123$,計算倒數 $1/B = 1/123$,由於 $A$$n=7$ 位數,小數點後精準 7 位,意即迭代比較小數點後 7 位不再變動時停止。

以下描述皆以 十進制 (在程式運行時我們可能會偏向萬進制,甚至億進制來充分利用 32-bit 暫存器)

  • 決定 $x_0$ 是很重要的一步,這嚴重影響著迭代次數。
  • $x_0$$B$ 的最高位數 1 進行大數除小數,得到 $x_0 = 0.01$
  • $x_1 = (2 - x_0 \ast B) \ast x_0 = 0.0077$
  • $x_2 = (2 - x_1 \ast B) \ast x_1 = 0.0081073$
  • $x_3 = (2 - x_2 \ast B) \ast x_2 = 0.00813001$
  • $x_4 = (2 - x_3 \ast B) \ast x_3 = 0.00813008$

接下來進行移位到整數部分,進行乘法後再移回去即可。

實作細節

精準度

使用浮點數實作的 FFT,需要小心精準度,網路上有些代碼利用合角公式取代建表。這類型的優化,在精準度不需要很高的時候提供較好的效能,卻無法提供精準的值,誤差修正成了一道難題。

牛頓迭代

由於有 FFT 造成的誤差,檢查收斂必須更加地保守,我們可能需要保留小數點下 $n+10$ 位,在收斂時檢查 $n+5$ 位確保收斂。通常收斂可以在 20 次左右完成,若發現迭代次數過多,有可能是第一個精準度造成的影響,盡可能使用內建函數得到的三角函數值。

加速優化

一般使用 IEEE-754 的浮點數格式,根據 FFT 的長度,要避開超過的震級,因此,在 Zerojudge b960 中,最多使用十萬進制進行加速。

誤差修正

儘管算出的商 $C=A \ast (1/B)$,得到的 $C$ 可能會差個 1,需要在後續乘法中檢驗。

參考題目/資料

b960 的亂碼

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
#include <bits/stdc++.h>
using namespace std;
template<typename T> class TOOL_FFT {
public:
struct complex {
T x, y;
complex(T x = 0, T y = 0):
x(x), y(y) {}
complex operator+(const complex &A) {
return complex(x+A.x,y+A.y);
}
complex operator-(const complex &A) {
return complex(x-A.x,y-A.y);
}
complex operator*(const complex &A) {
return complex(x*A.x-y*A.y,x*A.y+y*A.x);
}
};
T PI;
static const int MAXN = 1<<17;
complex p[2][MAXN];
int reversePos[MAXN];
TOOL_FFT() {
PI = acos(-1);
preprocessing();
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline uint32_t FastReverseBits(uint32_t a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void FFT(bool InverseTransform, complex In[], complex Out[], int n) {
// simultaneous data copy and bit-reversal ordering into outputs
int NumSamples = n;
int NumBits = NumberOfBitsNeeded(NumSamples);
for (int i = 0; i < NumSamples; ++i)
Out[reversePos[i]] = In[i];
// the FFT process
for (int i = 1; i <= NumBits; i++) {
int BlockSize = 1<<i, BlockEnd = BlockSize>>1, BlockCnt = NumSamples/BlockSize;
for (int j = 0; j < NumSamples; j += BlockSize) {
complex *t = p[InverseTransform];
int k = 0;
#define UNLOOP_SIZE 8
for (; k+UNLOOP_SIZE < BlockEnd; ) {
#define UNLOOP { \
complex a = (*t) * Out[k+j+BlockEnd]; \
Out[k+j+BlockEnd] = Out[k+j] - a; \
Out[k+j] = Out[k+j] + a;\
k++, t += BlockCnt;\
}
#define UNLOOP4 {UNLOOP UNLOOP UNLOOP UNLOOP;}
#define UNLOOP8 {UNLOOP4 UNLOOP4;}
UNLOOP8;
}
for (; k < BlockEnd;)
UNLOOP;
}
}
// normalize if inverse transform
if (InverseTransform) {
for (int i = 0; i < NumSamples; ++i) {
Out[i] = Out[i].x / NumSamples;
}
}
}
void convolution(T *a, T *b, int n, T *c) {
static complex s[MAXN], d1[MAXN], d2[MAXN], y[MAXN];
n = MAXN;
for (int i = 0; i < n; ++i)
s[i] = complex(a[i], 0);
FFT(false, s, d1, n);
s[0] = complex(b[0], 0);
for (int i = 1; i < n; ++i)
s[i] = complex(b[n - i], 0);
FFT(false, s, d2, n);
for (int i = 0; i < n; ++i)
y[i] = d1[i] * d2[i];
FFT(true, y, s, n);
for (int i = 0; i < n; ++i)
c[i] = s[i].x;
}
void preprocessing() {
int n = MAXN;
for (int i = 0; i < n; i ++) {
p[0][i] = complex(cos(2*i*PI / n), sin(2*i*PI / n));
p[1][i] = complex(p[0][i].x, -p[0][i].y);
}
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; i++)
reversePos[i] = FastReverseBits(i, NumBits);
}
};
TOOL_FFT<double> tool;
struct BigInt {
long long *v;
int size;
static const int DIGITS = 5;
static const int MAXN = 1<<17;
static int compare(const BigInt a, const BigInt b) {
for (int i = MAXN-1; i >= 0; i--) {
if (a.v[i] < b.v[i])
return -1;
if (a.v[i] > b.v[i])
return 1;
}
return 0;
}
void str2int(char *s, long long buf[]) {
int n = strlen(s);
size = (n+DIGITS-1) / DIGITS;
int cnt = n%DIGITS == 0 ? DIGITS : n%DIGITS;
int x = 0, pos = size-1;
v = buf;
for (int i = 0; i < n; i++) {
x = x*10 + s[i] - '0';
if (--cnt == 0) {
cnt = DIGITS;
v[pos--] = x, x = 0;
}
}
}
void println() {
printf("%lld", v[size-1]);
for (int pos = size-2; pos >= 0; pos--)
printf("%05lld", v[pos]);
puts("");
}
BigInt multiply(const BigInt &other, long long buf[]) const {
int m = MAXN;
static double na[MAXN], nb[MAXN];
static double tmp[MAXN];
memset(na+size, 0, sizeof(v[0])*(m-size));
for (int i = 0; i < size; i++)
na[i] = v[i];
memset(nb+1, 0, sizeof(v[0])*(m-other.size));
for (int i = 1, j = m-1; i < other.size; i++, j--)
nb[j] = other.v[i];
nb[0] = other.v[0];
tool.convolution(na, nb, m, tmp);
BigInt ret;
ret.size = m;
ret.v = buf;
for (int i = 0; i < m; i++)
buf[i] = (long long) (tmp[i] + 1.5e-1);
for (int i = 0; i < m; i++) {
if (buf[i] >= 100000)
buf[i+1] += buf[i]/100000, buf[i] %= 100000;
}
ret.reduce();
return ret;
}
void reduce() {
while (size > 1 && fabs(v[size-1]) < 5e-1)
size--;
}
BigInt divide(const BigInt &other, long long buf[]) const {
{
int cmp = compare(*this, other);
BigInt ret;
ret.size = MAXN-1, ret.v = buf;
if (cmp == 0) {
memset(buf, 0, sizeof(buf[0])*MAXN);
buf[0] = 1;
ret.reduce();
return ret;
} else if (cmp < 0) {
memset(buf, 0, sizeof(buf[0])*MAXN);
buf[0] = 0;
ret.reduce();
return ret;
}
}
// A / B = A * (1/B)
// x' = (2 - x * B) * x
int m = MAXN;
static double na[MAXN], nb[MAXN];
static double invB[MAXN], netB[MAXN], tmpB[MAXN];
static long long bufB[MAXN];
int PRECISION = size+10;
memset(nb+1, 0, sizeof(v[0])*(m-other.size));
for (int i = 1, j = m-1; i < other.size; i++, j--)
nb[j] = other.v[i];
nb[0] = other.v[0];
memset(invB, 0, sizeof(invB[0])*m);
{
long long t = 100000, a = other.v[other.size-1];
if (other.size-2 >= 0)
t = t*100000, a = a*100000+other.v[other.size-2];
for (int i = PRECISION-other.size; i >= 0; i--) {
invB[i] = t/a;
t = (t%a)*100000;
}
}
for (int it = 0; it < 100; it++) {
// netB = xi * B
tool.convolution(invB, nb, m, netB);
long long carry = 0;
for (int i = 0; i <= PRECISION*2; i++) {
bufB[i] = carry + (long long) (netB[i] + 1.5e-1);
if (bufB[i] >= 100000)
carry = bufB[i]/100000, bufB[i] %= 100000;
else
carry = 0;
bufB[i] = -bufB[i];
}
// tmpB = 2 - xi * B
bufB[PRECISION] += 2;
memset(tmpB, 0, sizeof(tmpB[0])*m);
for (int i = 0; i <= PRECISION*2; i++) {
if (bufB[i] < 0)
bufB[i] += 100000, bufB[i+1]--;
if (i != 0)
tmpB[m-i] = bufB[i];
else
tmpB[i] = bufB[i];
}
// netB = tmpB * invB
tool.convolution(invB, tmpB, m, netB);
{
long long carry = 0;
memset(bufB, 0, sizeof(bufB[0])*m);
for (int i = 0; i <= PRECISION*2; i++) {
bufB[i] = carry + (long long) (netB[i] + 1.5e-1);
if (bufB[i] >= 100000)
carry = bufB[i]/100000, bufB[i] %= 100000;
else
carry = 0;
}
}
{
int same = 1;
for (int i = PRECISION; same && i >= 5; i--)
same &= ((long long) (invB[i]) == bufB[i+PRECISION]);
if (same)
break;
}
memset(invB, 0, sizeof(invB[0])*m);
for (int i = 0; i+PRECISION <= PRECISION*2; i++)
invB[i] = bufB[i+PRECISION];
}
memset(na+1, 0, sizeof(v[0])*(m-size));
for (int i = 1, j = m-1; i < size; i++, j--)
na[j] = v[i];
na[0] = v[0];
tool.convolution(invB, na, m, netB);
BigInt ret;
ret.size = m-1;
ret.v = buf;
long long carry = 0;
for (int i = 0; i < m; i++) {
buf[i] = carry + (long long) (netB[i] + 1.5e-1);
if (buf[i] >= 100000)
carry = buf[i]/100000, buf[i] %= 100000;
else
carry = 0;
}
for (int i = 0; i+PRECISION < m; i++)
buf[i] = buf[i+PRECISION];
memset(buf+PRECISION, 0, sizeof(buf[0])*(m-PRECISION));
{
memset(na, 0, sizeof(na[0])*m);
for (int i = 1, j = m-1; i < m-1; i++, j--)
na[j] = buf[i];
na[0] = buf[0];
for (int i = 0; i < m; i++)
nb[i] = other.v[i];
tool.convolution(nb, na, m, netB);
long long carry = 0;
for (int i = 0; i < m; i++) {
bufB[i] = (carry + (long long) (netB[i] + 1e-1));
if (bufB[i] >= 100000)
carry = bufB[i]/100000, bufB[i] %= 100000;
else
carry = 0;
}
carry = 0;
for (int i = 0; i < m; i++) {
bufB[i] = v[i] - bufB[i] + carry;
if (bufB[i] < 0)
bufB[i] += 100000, carry = -1;
else
carry = 0;
}
BigInt R;
R.size = m-1, R.v = bufB;
if (compare(other, R) <= 0) {
buf[0]++;
for (int i = 0; buf[i] >= 100000; i++)
buf[i+1]++, buf[i] -= 100000;
}
}
ret.reduce();
return ret;
}
};
int main() {
static char sa[1<<20], sb[1<<20];
while (scanf("%s %s", sa, sb) == 2) {
static long long da[1<<19], db[1<<19], dc[1<<19];
BigInt a, b, c;
a.str2int(sa, da);
b.str2int(sb, db);
c = a.divide(b, dc);
c.println();
}
return 0;
}
Read More +