UVa 12254 - Electricity Connection

contents

  1. 1. Problem
  2. 2. Sample Input
  3. 3. Sample Output
  4. 4. Solution
    1. 4.1. 實作細節
    2. 4.2. 基礎篇
    3. 4.3. SSE 向量化
    4. 4.4. AVX2 Gather

Problem

給一個 $8 \times 8$ 平面圖,上面至多有八個住家,我們目標要從發電廠出發,牽電到所有的住家,拉線跨過水路的花費 $pw$、一般陸路為 $pl$,求最少花費。

經典的斯坦納樹問題,但有別於一般的平面圖,使用歐基里德距離或者曼哈頓距離作為花費函數。接著讓我們細談如何進行常數優化。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2
0 10
H.W.WH..
..W.W...
..WGW...
........
........
........
........
........
0 0
H.W.WH..
..W.W...
..WGW...
........
........
........
........
........

Sample Output

1
2
Case 1: 12
Case 2: 7

Solution

斯坦納樹為 NP-hard 問題,因此沒有多項式解法,而要得到確切的最小化解,則必須透過動態規劃來完成。由於題目給定的範圍很小,首先將目標要連通完成的點集,壓縮成 $N$ 個位元,接著紀錄這個聯通元件的其中一個節點視為根。最後,得到狀態數為 $M \cdot 2^N$,可以參考 《「Steiner tree problem in graphs」斯坦納樹》 -日月卦長 的說明。

公式可以拆成兩種情況,第一種為從子集合併中著手,另一種為拓展連通元件 (替換根節點,但不改變目前已經連到的目標集合)。定義 dp[s][i] 為根 i,連通集合 s 的最小花費

  • $dp[S][i] = \min(dp[T][j]+dp[S−T][j]+\text{dist}(i, j):j \in V,T \subset S)$
  • $dp[S][i] = \min(dp[S][k] + \text{dist}(S, k))$

由上述的公式,我們便可知道複雜度為 $O(M \cdot 3^N)$

實作細節

  • 對於內存布局,有兩個選擇 dp[2^N][M] 或者是 dp[M][2^N],其中以 dp[2^N][M] 最為適合,在撰寫迴圈的時候,最內層的迴圈為替換根,這麼一來 cache miss 的機會就非常低。更容易透過 unroll loop 和向量化來運作。

    • 如果內存布局使用 dp[M][2^N],在撰寫向量化時,需要使用 gather 相關的指令,這部分只有在 AVX2 有,並不是每一個 online judge 都支援,而 latency 也算挺高的,等到哪天 CPU 架構換了,這解法可能才會快得起來。
  • 當我們窮舉子集合時,發現到公式有對稱性,便可只窮舉上半部。這樣可以加速 20% 的效能。

1
2
3
4
5
6
7
8
9
10
for (int i = 1, h; i < (1<<n); i++) {
if ((i&(-i)) == i)
h = i-1;
for (int k = (i-1)&i; k > h; k = (k-1)&i) {
for (int j = 0; j < 64; j++)
dp[i][j] = min(dp[i][j], dp[k][j]+dp[i^k][j]-w[j]);
}
relax(i);
}

基礎篇

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#pragma GCC target("avx")
#pragma GCC optimize ("O3")
#include <bits/stdc++.h>
using namespace std;
char g[8][16], w[64];
static const int dx[] = {0, 0, 1, -1};
static const int dy[] = {1, -1, 0, 0};
static int32_t dp[1<<8][64];
const int INF = 0x3f3f3f3f;
void relax(int s) {
static int8_t Q[1024];
uint64_t inq = 0;
int Qn = 0;
for (int i = 0; i < 64; i++) {
if (dp[s][i] != INF)
Q[Qn++] = i, inq |= 1ULL<<i;
}
for (int i = 0; i < Qn; i++) {
int u = Q[i];
int x = u>>3, y = u&7;
inq ^= 1ULL<<u;
for (int k = 0; k < 4; k++) {
int tx = x+dx[k], ty = y+dy[k];
if (tx < 0 || ty < 0 || tx >= 8 || ty >= 8)
continue;
int v = tx<<3|ty;
if (dp[s][v] > dp[s][u]+1+w[v]) {
dp[s][v] = dp[s][u]+1+w[v];
if (((inq>>v)&1) == 0) {
inq |= 1ULL<<v;
Q[Qn++] = v;
}
}
}
}
}
int main() {
int testcase, cases = 0;
int pl, pw;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &pl, &pw);
for (int i = 0; i < 8; i++)
scanf("%s", &g[i]);
int n = 0, root = 0;
int A[8];
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j++) {
int u = i<<3|j;
if (g[i][j] == 'H')
A[n++] = u, w[u] = 0;
else if (g[i][j] == 'G')
w[u] = 0, root = u;
else if (g[i][j] == 'W')
w[u] = pw;
else if (g[i][j] == '.')
w[u] = pl;
}
}
memset(dp, 0x3f, sizeof(dp));
for (int i = 0; i < n; i++)
dp[1<<i][A[i]] = 0;
for (int i = 1, h; i < (1<<n); i++) {
if ((i&(-i)) == i)
h = i-1;
for (int k = (i-1)&i; k > h; k = (k-1)&i) {
for (int j = 0; j < 64; j++)
dp[i][j] = min(dp[i][j], dp[k][j]+dp[i^k][j]-w[j]);
}
relax(i);
}
int ret = dp[(1<<n)-1][root];
printf("Case %d: %d\n", ++cases, ret);
}
return 0;
}

SSE 向量化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#pragma GCC target("avx")
#pragma GCC optimize ("O3")
#include <bits/stdc++.h>
#include <x86intrin.h>
using namespace std;
char g[8][16];
static const int dx[] = {0, 0, 1, -1};
static const int dy[] = {1, -1, 0, 0};
static int32_t w[64] __attribute__ ((aligned(16)));
static int32_t dp[1<<8][64] __attribute__ ((aligned(16)));
const int INF = 0x3f3f3f3f;
void relax(int s) {
static int8_t Q[1024];
uint64_t inq = 0;
int Qn = 0;
for (int i = 0; i < 64; i++) {
if (dp[s][i] != INF)
Q[Qn++] = i, inq |= 1ULL<<i;
}
for (int i = 0; i < Qn; i++) {
int u = Q[i];
int x = u>>3, y = u&7;
inq ^= 1ULL<<u;
for (int k = 0; k < 4; k++) {
int tx = x+dx[k], ty = y+dy[k];
if (tx < 0 || ty < 0 || tx >= 8 || ty >= 8)
continue;
int v = tx<<3|ty;
if (dp[s][v] > dp[s][u]+1+w[v]) {
dp[s][v] = dp[s][u]+1+w[v];
if (((inq>>v)&1) == 0) {
inq |= 1ULL<<v;
Q[Qn++] = v;
}
}
}
}
}
int main() {
int testcase, cases = 0;
int pl, pw;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &pl, &pw);
for (int i = 0; i < 8; i++)
scanf("%s", &g[i]);
int n = 0, root = 0;
int A[8];
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j++) {
int u = i<<3|j;
if (g[i][j] == 'H')
A[n++] = u, w[u] = 0;
else if (g[i][j] == 'G')
w[u] = 0, root = u;
else if (g[i][j] == 'W')
w[u] = pw;
else if (g[i][j] == '.')
w[u] = pl;
}
}
memset(dp, 0x3f, sizeof(dp));
for (int i = 0; i < n; i++)
dp[1<<i][A[i]] = 0;
for (int i = 1, h; i < (1<<n); i++) {
if ((i&(-i)) == i)
h = i-1;
for (int k = (i-1)&i; k > h; k = (k-1)&i) {
for (int j = 0; j+4 <= 64; j += 4) {
__m128i mv = _mm_load_si128((__m128i*) (dp[i]+j));
__m128i a = _mm_load_si128((__m128i*) (dp[k]+j));
__m128i b = _mm_load_si128((__m128i*) (dp[i^k]+j));
__m128i tm = _mm_add_epi32(a, b);
__m128i c = _mm_load_si128((__m128i*) (w+j));
__m128i tn = _mm_sub_epi32(tm, c);
__m128i mn = _mm_min_epi32(mv, tn);
_mm_store_si128((__m128i*) (dp[i]+j), mn);
}
// dp[i][j] = min(dp[i][j], dp[k][j]+dp[i^k][j]-w[j]);
}
relax(i);
}
int ret = dp[(1<<n)-1][root];
printf("Case %d: %d\n", ++cases, ret);
}
return 0;
}

AVX2 Gather

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/*
It doesn't work at UVA 2018/08/13 because server CPU does not support
AVX2 instruction set. Although we could pass the compiler, you still
get runtime error during executing an illegal instruction.
*/
#pragma GCC target("avx")
#pragma GCC target("avx2")
#pragma GCC optimize ("O3")
#include <bits/stdc++.h>
#include <x86intrin.h>
#include <avx2intrin.h>
using namespace std;
char g[8][16], w[64];
static const int dx[] = {0, 0, 1, -1};
static const int dy[] = {1, -1, 0, 0};
static int32_t dp[64][1<<8] __attribute__ ((aligned(16)));
const int INF = 0x3f3f3f3f;
void relax(int s) {
static int8_t Q[1024];
uint64_t inq = 0;
int Qn = 0;
for (int i = 0; i < 64; i++) {
if (dp[i][s] != INF)
Q[Qn++] = i, inq |= 1ULL<<i;
}
for (int i = 0; i < Qn; i++) {
int u = Q[i];
int x = u>>3, y = u&7;
inq ^= 1ULL<<u;
for (int k = 0; k < 4; k++) {
int tx = x+dx[k], ty = y+dy[k];
if (tx < 0 || ty < 0 || tx >= 8 || ty >= 8)
continue;
int v = tx<<3|ty;
if (dp[v][s] > dp[u][s]+1+w[v]) {
dp[v][s] = dp[u][s]+1+w[v];
if (((inq>>v)&1) == 0) {
inq |= 1ULL<<v;
Q[Qn++] = v;
}
}
}
}
}
int main() {
int testcase, cases = 0;
int pl, pw;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d %d", &pl, &pw);
for (int i = 0; i < 8; i++)
scanf("%s", &g[i]);
int n = 0, root = 0;
int A[8];
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j++) {
int u = i<<3|j;
if (g[i][j] == 'H')
A[n++] = u, w[u] = 0;
else if (g[i][j] == 'G')
w[u] = 0, root = u;
else if (g[i][j] == 'W')
w[u] = pw;
else if (g[i][j] == '.')
w[u] = pl;
}
}
memset(dp, 0x3f, sizeof(dp));
for (int i = 0; i < n; i++)
dp[A[i]][1<<i] = 0;
for (int i = 1, h; i < (1<<n); i++) {
if ((i&(-i)) == i)
h = i-1;
__attribute__ ((aligned(16))) static int subset1[1<<8] = {};
__attribute__ ((aligned(16))) static int subset2[1<<8] = {};
int sn = 0;
for (int k = (i-1)&i; k > h; k = (k-1)&i)
subset1[sn] = k, subset2[sn] = i^k, sn++;
while (sn&4)
subset1[sn] = 0, subset2[sn] = i, sn++;
for (int j = 0; j < 64; j++) {
int32_t mn = dp[j][i]+w[j];
__m128i mv = _mm_setr_epi32(mn, mn, mn, mn);
for (int t = 0; t <= sn; t += 4) {
int k;
__m128i a = _mm_load_si128((__m128i*) subset1+t);
__m128i b = _mm_load_si128((__m128i*) subset2+t);
__m128i t1 = _mm_i32gather_epi32(dp[j], a, 4);
__m128i t2 = _mm_i32gather_epi32(dp[j], b, 4);
__m128i tm = _mm_add_epi32(t1, t2);
mv = _mm_min_epi32(mv, tm);
}
__attribute__ ((aligned(16))) int32_t mr[4];
_mm_store_si128((__m128i*) mr, mv);
mn = min(min(mr[0], mr[1]), min(mr[2], mr[3]));
dp[j][i] = mn-w[j];
// mn = min(mn, dp[j][k]+dp[j][i^k]-ww);
}
relax(i);
}
int ret = dp[root][(1<<n)-1];
printf("Case %d: %d\n", ++cases, ret);
}
return 0;
}