b446. 搜索美學 史蒂芙的煩惱

contents

  1. 1. Problem
    1. 1.1. 背景
    2. 1.2. 問題描述
  2. 2. Sample Input
  3. 3. Sample Output
  4. 4. Solution
    1. 4.1. 實作探討
    2. 4.2. Dynamic Kd tree
    3. 4.3. CDQ 分治

Problem

背景

動畫 遊戲人生《No Game No Life》中,史蒂芙 (Stephanie Dola) 常常被欺負,儘管她以學院第一畢業,對於遊戲一竅不通的她在這個世界常常被欺負。現在就交給你來幫幫她。

問題描述

兩個人輪流在一個大棋盤上下棋,每一步棋的得分根據這一步棋與最鄰近的敵方棋子的曼哈頓距離。

對於兩個點 $p, q$ 座標 $(p_x, p_y), (q_x, q_y)$,曼哈頓距離 (Manhattan distance) 為 $|p_x - q_x| + |p_y - q_y|$

Sample Input

1
2
3
4
5
6
7
3
1 1
5 5
4 4
3 2
2 4
2 3

Sample Output

1
2
3
4
5
8
2
3
3
1

Solution

把鄰近搜索問題做個總結,普遍處理的是靜態資料跟單一詢問,而在最近餐館那一題已經用 KD-tree 處理過 KNN 問題。這一題是採用動態插入和詢問以及數學性質較強的曼哈頓距離,離線處理也是個選擇。

此問題限制在 $n = 50000$ 的情況下,進行測試討論,除了分桶、方格法外,探討三種思路:

  • Dynamic KD-tree 利用替罪羊樹的概念完成,看著卦長的代碼以及卡車口述概要,終於敲敲打打拼湊起來,掛上啟發式的搭配具有不錯的成效。空間複雜度 $O(n)$,插入複雜度 $O(\log^2 n)$,查詢 $O(\log n)$ (據說是在曼哈頓距離下的緣故),速度是暴力法 $O(n^2)$ 二十倍左右。

  • Segment tree + 平衡樹,空間複雜度 $O(n \log n)$,時間複雜度 $O(\log^3 n)$,使用座標轉換將菱形轉換成正方形,套上二分邊長去查找區域內部是否有點。由於 $n$ 的緣故,速度比 Dynamic KD-tree 慢上許多,若用暴力法 $O(n^2)$ 只快兩倍之多。實作測試提供者 liouzhou_101。

  • 離線處理 CDQ 分治,空間複雜度 $O(n)$,總時間複雜度 $O(n \log^2 n)$,採用思路為曼哈頓距離切割成四個象限進行極值查找。比暴力法快十倍所右。

前兩個作法比較裸,在此特別補充 CDQ 分治,曼哈頓距離可以考慮成四個象限,詢問 $(x, y)$ 的最鄰近點,首先考慮左下角 $(x', y')$,亦即 $x' \le x, \; y' \le y$,則曼哈頓距離 $dist = (x - x') + (y - y') = (x + y) - (x' + y')$,明顯地求最近距離要讓 $x' + y'$ 最大化。同理其他象限。

為了解決這詢問,套用 CDQ 分治,按照 $x$ 座標排序,接著二分操作順序,切割操作 $[l, mid], [mid+1, r]$,在左右兩塊仍然按照 $x$ 排序。單獨看 $[mid+1, r]$ 的操作會受 $[l, mid]$ 和自己本身影響,對於前者而言,採用歸併排序那樣,按照 $x$ 座標慢慢合併 (概念上),合併過程套用 Binary indexed tree 進行極值查找。對於後者,就進行遞迴求解,明顯地 $[l, mid]$ 只會受 $[l, mid]$ 影響。

CDQ 分治的概要,按照其中一個關鍵排序,接著二分操作順序進行分置處理。國外是有論文在描述這個 Online to Offline 的算法,CDQ 命名就是人名,會給國外看笑話吧。

1
2
3
4
5
6
sort(key)
solve(l, r)
solve(l, mid)
process([l, mid], [mid+1, r])
solve(mid+1, r)

備註「欸欸,加上悔棋的話,是不是持久化 kd-tree」

實作探討

關於 kd-tree 實作細節探討,與通常會犯的錯誤,關係到速度有常數差異。

closest() 中,常犯的錯誤是 探索順序 ,盡可能先靠近,啟發式才能更加快速,別像我打出錯誤的搜索順序如下:

1
2
3
4
5
6
7
8
9
10
11
if (x.d[k] <= u->pid.d[k]) {
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = old;
} else {
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = old;
closest(u->rson, (k+1)%kD, x, h, mndist);
}

實作的順序應該如下,別總是先探訪左子樹、在去探訪右子樹,kd-tree 必須注意順序。

1
2
3
4
5
6
7
8
9
10
11
if (x.d[k] <= u->pid.d[k]) {
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = old;
} else {
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = old;
}

假設資料大小 sizeof(dim_element) 很大,通常會利用指針陣列來進行排序,這樣可以降低搬運大型資料複製時間,但奇怪的是由於指針陣列佔有一定空間,索引資料時又會佔據一段空間,估計是快取方面出了點問題 (或者是我寫不好),導致直接搬運資料是來得比較快速,這個修改在 b348. 最近餐館 也有進行測試,速度有提升。

接著可以藉由函數參數少量,來拉快程式在堆疊參數所需要的時間,在節點內部宣告採用維度 d

1
2
3
4
5
struct Node {
Node *lson, *rson;
Point pid;
int size, d;
};

這個修改造成詢問時,不僅僅在走訪傳遞參數少了一個,還少 k+1 的計算。測試結果中,光靠這一點速度沒有明顯提升。kd tree 還有一個靠臉吃飯的邊界分割,要是相同時分左分右,這一點是最痛苦的,在此就不去討論,當然可以利用隨機擾動來解決這問題。

至於要使用 sort() 進行 $O(n \log n)$、還是使用 nth_element()$O(n)$ 找到中位數,根據兩題的測試,由於 $n$ 都不大,照理來講 nth_element() 快於 sort(),但根據實際測試於 liouzhou_101 的代碼,sort() 的速度會比較快,其一是運氣、其二是未知情況。就從以下代碼中,差異並不明顯。

Dynamic Kd tree

沒有提供垃圾回收,靠內存持運作,要是 RE 就放大一點。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 131072;
const int MAXM = 50005;
const int MAXD = 2;
const int INF = INT_MAX;
const double ALPHA = 0.75;
const double LOG_ALPHA = log2(1.0 / ALPHA);
class KD_TREE {
public:
struct Point {
static int kD;
int d[MAXD], pid;
int dist(Point &x) {
int ret = 0;
for (int i = 0; i < kD; i++)
ret += abs(d[i] - x.d[i]);
return ret;
}
void read(int id = 0) {
for (int i = 0; i < kD; i++)
scanf("%d", &d[i]);
pid = id;
}
static int sortIdx;
bool operator<(const Point &x) const {
return d[sortIdx] < x.d[sortIdx];
}
};
struct Node {
Node *lson, *rson;
Point pid;
int size;
Node() {
lson = rson = NULL;
size = 1;
}
void update() {
size = 1;
if (lson) size += lson->size;
if (rson) size += rson->size;
}
} nodes[MAXN];
Node *root;
Point A[MAXM];
int bufsize, size, kD;
void init(int kd) {
size = bufsize = 0;
root = NULL;
Point::sortIdx = 0;
Point::kD = kD = kd;
}
void insert(Point x) {
insert(root, 0, x, log2int(size) / LOG_ALPHA);
}
int closest(Point x) {
int mndist = INF, h[MAXD] = {};
closest(root, 0, x, h, mndist);
return mndist;
}
private:
int log2int(int x){
return __builtin_clz((int)1)-__builtin_clz(x);
}
inline int isbad(Node *u) {
if (u->lson && u->lson->size > u->size * ALPHA)
return 1;
if (u->rson && u->rson->size > u->size * ALPHA)
return 1;
return 0;
}
Node* newNode() {
Node *ret = &nodes[bufsize++];
*ret = Node();
return ret;
}
Node* build(int k, int l, int r) {
if (l > r) return NULL;
if (k == kD) k = 0;
Node *ret = newNode();
int mid = (l + r)>>1;
Point::sortIdx = k;
sort(A+l, A+r+1);
ret->pid = A[mid];
ret->lson = build(k+1, l, mid-1);
ret->rson = build(k+1, mid+1, r);
ret->update();
return ret;
}
void flatten(Node *u, Point* &buf) {
if (u == NULL) return ;
flatten(u->lson, buf);
*buf = u->pid, buf++;
flatten(u->rson, buf);
}
bool insert(Node* &u, int k, Point &x, int dep) {
if (u == NULL) {
u = newNode(), u->pid = x;
return dep <= 0;
}
u->size++;
int t = 0;
if (x.d[k] <= u->pid.d[k])
t = insert(u->lson, (k+1)%kD, x, dep-1);
else
t = insert(u->rson, (k+1)%kD, x, dep-1);
if (t && !isbad(u))
return 1;
if (t) {
Point *ptr = &A[0];
flatten(u, ptr);
u = build(k, 0, u->size-1);
}
return 0;
}
int heuristic(int h[]) {
int ret = 0;
for (int i = 0; i < kD; i++)
ret += h[i];
return ret;
}
void closest(Node *u, int k, Point &x, int h[], int &mndist) {
if (u == NULL || heuristic(h) >= mndist)
return ;
int dist = u->pid.dist(x), old;
mndist = min(mndist, dist), old = h[k];
if (x.d[k] <= u->pid.d[k]) {
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = old;
} else {
closest(u->rson, (k+1)%kD, x, h, mndist);
h[k] = abs(x.d[k] - u->pid.d[k]);
closest(u->lson, (k+1)%kD, x, h, mndist);
h[k] = old;
}
}
} A, B;
int KD_TREE::Point::sortIdx = 0, KD_TREE::Point::kD = 2;
int main() {
int N;
KD_TREE::Point pt;
while (scanf("%d", &N) == 1) {
A.init(2), B.init(2);
for (int i = 0; i < N; i++) {
pt.read(i);
if (i) printf("%d\n", B.closest(pt));
A.insert(pt);
pt.read(i);
printf("%d\n", A.closest(pt));
B.insert(pt);
}
}
return 0;
}

CDQ 分治

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include <bits/stdc++.h>
using namespace std;
const int MAXQ = 131072;
const int MAXN = 40000;
const int INF = INT_MAX;
class Offline {
public:
struct Point {
int x, y;
Point(int a = 0, int b = 0):
x(a), y(b) {}
bool operator<(const Point &a) const {
return x < a.x || (x == a.x && y < a.y);
}
void read() {
scanf("%d %d", &x, &y);
x++, y++;
}
};
struct Event {
Point p;
int qtype, qid;
Event(int a = 0, int b = 0, Point c = Point()):
qtype(a), qid(b), p(c) {}
bool operator<(const Event &e) const {
if (p.x != e.p.x) return p.x < e.p.x;
return qid < e.qid;
}
};
vector<Event> event;
int ret[MAXQ], N;
void init(int n) {
event.clear();
N = n;
}
void addEvent(int qtype, int qid, Point x) {
event.push_back(Event(qtype, qid, x));
}
void run() {
for (int i = 0; i < event.size(); i++)
ret[i] = 0x3f3f3f3f;
cases = 0;
for (int i = 0; i <= N; i++)
used[i] = 0;
sort(event.begin(), event.end());
CDQ(0, event.size()-1);
}
private:
Event ebuf[MAXQ];
int BIT[MAXN], used[MAXN];
int cases = 0;
void modify(int x, int val, int dir) {
for (; x && x <= N; x += (x&(-x)) * dir) {
if (used[x] != cases)
BIT[x] = -0x3f3f3f3f, used[x] = cases;
BIT[x] = max(BIT[x], val);
}
}
int query(int x, int dir) {
int ret = -0x3f3f3f3f;
for (; x && x <= N; x += (x&(-x)) * dir) {
if (used[x] == cases)
ret = max(ret, BIT[x]);
}
return ret;
}
void merge(int l, int mid, int r) {
cases++;
for (int i = mid+1, j = l; i <= r; i++) {
if (event[i].qtype == 0) {
for (; j <= mid && event[j].p.x <= event[i].p.x; j++) {
if (event[j].qtype == 1)
modify(event[j].p.y, event[j].p.x+event[j].p.y, 1);
}
ret[event[i].qid] = min(ret[event[i].qid], event[i].p.x+event[i].p.y-query(event[i].p.y, -1));
}
}
cases++;
for (int i = mid+1, j = l; i <= r; i++) {
if (event[i].qtype == 0) {
for (; j <= mid && event[j].p.x <= event[i].p.x; j++) {
if (event[j].qtype == 1)
modify(event[j].p.y, event[j].p.x-event[j].p.y, -1);
}
ret[event[i].qid] = min(ret[event[i].qid], event[i].p.x-event[i].p.y-query(event[i].p.y, 1));
}
}
cases++;
for (int i = r, j = mid; i > mid; i--) {
if (event[i].qtype == 0) {
for (; j >= l && event[j].p.x >= event[i].p.x; j--) {
if (event[j].qtype == 1)
modify(event[j].p.y, event[j].p.y-event[j].p.x, 1);
}
ret[event[i].qid] = min(ret[event[i].qid], event[i].p.y-event[i].p.x-query(event[i].p.y, -1));
}
}
cases++;
for (int i = r, j = mid; i > mid; i--) {
if (event[i].qtype == 0) {
for (; j >= l && event[j].p.x >= event[i].p.x; j--) {
if (event[j].qtype == 1)
modify(event[j].p.y, -event[j].p.x-event[j].p.y, -1);
}
ret[event[i].qid] = min(ret[event[i].qid], -event[i].p.x-event[i].p.y-query(event[i].p.y, 1));
}
}
}
void CDQ(int l, int r) {
if (l == r)
return ;
int mid = (l + r)/2, lidx, ridx;
lidx = l, ridx = mid+1;
for (int i = l; i <= r; i++) {
if (event[i].qid <= mid)
ebuf[lidx++] = event[i];
else
ebuf[ridx++] = event[i];
}
for (int i = l; i <= r; i++)
event[i] = ebuf[i];
CDQ(l, mid);
merge(l, mid, r);
CDQ(mid+1, r);
lidx = l, ridx = mid+1;
for (int i = l; i <= r; i++) {
if ((lidx <= mid && event[lidx] < event[ridx]) || ridx > r)
ebuf[i] = event[lidx++];
else
ebuf[i] = event[ridx++];
}
for (int i = l; i <= r; i++)
event[i] = ebuf[i];
}
} A, B;
int main() {
int N;
Offline::Point pt;
while (scanf("%d", &N) == 1) {
A.init(65536), B.init(65536);
int max_y = 0;
for (int i = 0; i < N; i++) {
pt.read(), max_y = max(max_y, pt.y);
B.addEvent(0, 2*i, pt);
A.addEvent(1, 2*i, pt);
pt.read(), max_y = max(max_y, pt.y);
A.addEvent(0, 2*i+1, pt);
B.addEvent(1, 2*i+1, pt);
}
A.N = B.N = max_y; // y in [1, max_y]
A.run(), B.run();
for (int i = 0; i < N; i++) {
if (i) printf("%d\n", B.ret[2*i]);
printf("%d\n", A.ret[2*i+1]);
}
}
return 0;
}