b449. 加速策略 圈出角點

contents

  1. 1. Problem
    1. 1.1. 背景
    2. 1.2. 問題描述
  2. 2. Sample Input
  3. 3. Sample Output
  4. 4. Solution
    1. 4.1. 樸素解
    2. 4.2. bitmask
    3. 4.3. bitmask2
    4. 4.4. bitmask3

Problem

背景

影像處理中,給定一張圖,準確地找到點、線、邊、角都是相當困難的,由於圖片會受到干擾、顏色屬性的差異,使得擷取特徵相當困難。

問題描述

對於 $N \times M$ 的像素圖片,方便起見只由黑白影像構成,0 表示暗、1 表示亮,對於每一個像素位置判斷是否可能是角點。

在角點偵測的算法中,有一個由 Rosten and Drummond 提出的 FAST (Features from Accelerated Segment Test) 方法。概念由一個 $7 \times 7$ 的遮罩,待測點 $p$ 位於遮罩中心,由遮罩內圈上的 16 個像素的灰階判斷 $p$ 是否為角點。遮罩樣子如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
+--------------------+
| | |16| 1| 2| | |
+--------------------+
| |15| | | | 3| |
+--------------------+
|14| | | | | | 4|
+--------------------+
|13| | | p| | | 5|
+--------------------+
|12| | | | | | 6|
+--------------------+
| |11| | | | 7| |
+--------------------+
| | |10| 9| 8| | |
+---------------------

只要這個圈上出現連續大於等於 12 個相同的暗像素或者是亮像素,則 $p$ 就被視為一個角點。

不幸地,這會造成在一個角上出現很多角點,通常會根據掃描的順序找到角點,當找到一個角點後,會抑制鄰近區域不可以是角點。此題不考慮抑制情況,對於每一個角點必須在 16 個像素在圖片上才進行判斷,圖片邊界不進行偵測。

輸出一個 $N \times M$ 的矩陣,按照原圖片位置,若該點是角點則為 1,反之為 0。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
7 7
0011100
0100010
1000001
1000001
1000001
0100010
0011100
7 7
0011100
0100010
1000001
0000001
1000001
0100010
0010100

Sample Output

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Case #1:
0000000
0000000
0000000
0001000
0000000
0000000
0000000
Case #2:
0000000
0000000
0000000
0000000
0000000
0000000
0000000

Solution

這一題的技巧不是演算法,而是實作的加速細節。

同時,來實驗老師上課說的優化策略的用處,由於要連續 12 個,根據鴿籠原理的方式,挑選位置 1, 5, 9, 13 出來,若連續三個狀態相同再進行 $O(16)$ 的判斷。然而挑出這四個位置,可以加速 50%,相較於只有 $O(16)$ 的判斷效能,可以參考下方的樸素解。

樸素寫法並不是最快的,因為 branch 太多,導致速度至少為 800ms,去處理一張 $1920 \times 1080$ 的影像,更好的方案是使用 bitmask,預處理在 16bits 下,連續 9 個相同狀態的位元情況,搭配 loop unrolling 的方式去撰寫,直接 $O(16)$ 判斷,為了減少代碼量,採用巨集的前處理展開。請參考 bitmask 版本。速度來到 140ms,加速幾乎 8 倍。

單純的 bitmask 還不是最快,直接建表 $O(2^{16} \times 16)$ 得到 16bits 是否是角點,建表消耗時間,但單一判斷變成 $O(1)$。請參考 bitmask2 版本。速度來到 76ms,直接翻了快兩倍。

最終 bitmask3 版本 56ms,採用以下的方案:

  1. 減少型別轉換 movz 的出現,用補數來抽換判斷。
    意即 (a&mask) == 0 || (a&mask) == mask) 將只會有 (a&mask) == mask,需要 (a&mask) == 0 的判斷,則先 a = ~a 再進行 (a&mask) == mask
  2. 利用編譯的常數展開,減少二維陣列取址時的一次乘法。
    意即 g[x][y] 取址使用時,會動用到一次乘法和一次加法,對於每一個角點偵測,動用到 16 次的乘法運算。若矩陣大小事先已知,那麼對於某一行的角點,g[x] 可以用一次乘法計算,接著該行所有角點偵測,只會剩下 16 次的加法。
  3. 建表太慢,用 __builtin_popcount() 提供剪枝。

當然,以上變態至極的作法,倒不如樸素解直接開 g++ -O3 或者是 g++ -Ofast 來的省事,速度慢一點也是沒問題的對吧。

1
#pragma GCC optimize ("O3")

樸素解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <bits/stdc++.h>
using namespace std;
const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
char g[2048][2048], ret[2048][2048];
int FAST(int x, int y) {
char c[4] = {g[x+dx[0]][y+dy[0]], g[x+dx[4]][y+dy[4]],
g[x+dx[8]][y+dy[8]], g[x+dx[12]][y+dy[12]]};
if (c[0] == c[1] && c[1] == c[2] ||
c[1] == c[2] && c[2] == c[3] ||
c[2] == c[3] && c[3] == c[0] ||
c[3] == c[0] && c[0] == c[1]) {
int cc = -1, p = 1;
for (int it = 1, i = 1, j = 0; it < 16; it++, i++, i = i >= 16 ? 0 : i) {
if (g[x+dx[i]][y+dy[i]] == g[x+dx[j]][y+dy[j]])
j ++, j = j >= 16 ? 0 : j, p++;
else {
if (cc == -1)
cc = p;
j = i, p = 1;
}
if (p >= 12)
return 1;
}
if (g[x+dx[0]][y+dy[0]] == g[x+dx[15]][y+dy[15]] && p+cc >= 12)
return 1;
}
return 0;
}
int main() {
int N, M, cases = 0;
while (scanf("%d %d", &N, &M) == 2) {
while (getchar() != '\n');
for (int i = 0; i < N; i++)
fgets(g[i], 2000, stdin);
for (int i = 0; i < N; i++) {
for (int j = 0; j < M; j++) {
ret[i][j] = '0';
if (i-3 >= 0 && j-3 >= 0 && i+3 < N && j+3 < M)
ret[i][j] = FAST(i, j) + '0';
}
}
printf("Case #%d:\n", ++cases);
for (int i = 0; i < N; i++) {
ret[i][M] = '\0';
puts(ret[i]);
}
}
return 0;
}

bitmask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>
using namespace std;
const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
const int MAXH = 2048, MAXW = 2048;
const int MAXN = MAXH * MAXW;
char g[MAXH][MAXW], ret[MAXH*MAXW];
int n, m;
#define T(x, y, z) ((g[x+dx[z]][y+dy[z]])<<z)
#define UNLOOPX(i) (val&corner[i]) == 0 || (val&corner[i]) == corner[i] || \
(val&corner[i+1]) == 0 || (val&corner[i+1]) == corner[i+1] || \
(val&corner[i+2]) == 0 || (val&corner[i+2]) == corner[i+2] || \
(val&corner[i+3]) == 0 || (val&corner[i+3]) == corner[i+3]
#define UNLOOPXALL UNLOOPX(0) || UNLOOPX(4) || UNLOOPX(8) || UNLOOPX(12)
#define UNLOOPY(i) T(x, y, i) | T(x, y, i+1) | T(x, y, i+2) | T(x, y, i+3)
#define UNLOOPYALL UNLOOPY(0) | UNLOOPY(4) | UNLOOPY(8) | UNLOOPY(12)
typedef unsigned short int UINT16;
UINT16 corner[16] = {};
UINT16 rotate_left(UINT16 x, UINT16 n) {
return (x << n) | (x >> (16-n));
}
inline int FAST(int x, int y) {
UINT16 val = UNLOOPYALL;
return UNLOOPXALL;
}
int main() {
for (int i = 0, j = (1<<12)-1; i < 16; i++, j = rotate_left(j, 1))
corner[i] = j;
int cases = 0;
while (scanf("%d %d", &n, &m) == 2) {
while (getchar() != '\n');
for (int i = 0; i < n; i++) {
fgets(g[i], MAXW, stdin);
for (int j = 0; j < m; j++)
g[i][j] -= '0';
}
int bn = n-3, bm = m-3;
char *p = ret;
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
for (int i = 3; i < bn; i++) {
*p = '0', p++;
*p = '0', p++;
*p = '0', p++;
for (int j = 3; j < bm; j++)
*p = FAST(i, j) + '0', p++;
*p = '0', p++;
*p = '0', p++;
*p = '0', p++;
*p = '\n', p++;
}
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
*p = '\0', p++;
printf("Case #%d:\n", ++cases);
puts(ret);
}
return 0;
}

bitmask2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
using namespace std;
//const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
//const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
const int MAXH = 2048, MAXW = 2048;
const int MAXN = MAXH * MAXW;
char g[MAXH][MAXW], ret[MAXH*MAXW];
int n, m;
#define dx_00 -3
#define dx_01 -3
#define dx_02 -2
#define dx_03 -1
#define dx_40 0
#define dx_41 1
#define dx_42 2
#define dx_43 3
#define dx_80 3
#define dx_81 3
#define dx_82 2
#define dx_83 1
#define dx_120 0
#define dx_121 -1
#define dx_122 -2
#define dx_123 -3
#define dy_00 0
#define dy_01 1
#define dy_02 2
#define dy_03 3
#define dy_40 3
#define dy_41 3
#define dy_42 2
#define dy_43 1
#define dy_80 0
#define dy_81 -1
#define dy_82 -2
#define dy_83 -3
#define dy_120 -3
#define dy_121 -3
#define dy_122 -2
#define dy_123 -1
#define T(x, y, z, w) ((g[x + dx_##z##w][y + dy_##z##w])<<(z+w))
#define UNLOOPX(i) ((val&cor[i]) == cor[i] || (val&cor[i+1]) == cor[i+1] || \
(val&cor[i+2]) == cor[i+2] || (val&cor[i+3]) == cor[i+3])
#define UNLOOPXALL UNLOOPX(0) || UNLOOPX(4) || UNLOOPX(8) || UNLOOPX(12)
#define UNLOOPY(i) T(x, y, i, 0) | T(x, y, i, 1) | T(x, y, i, 2) | T(x, y, i, 3)
#define UNLOOPYALL UNLOOPY(0) | UNLOOPY(4) | UNLOOPY(8) | UNLOOPY(12)
typedef unsigned short int UINT16;
UINT16 cor[16] = {};
UINT16 rotate_left(UINT16 x, UINT16 n) {
return (x << n) | (x >> (16-n));
}
int f[1<<16];
int main() {
UINT16 val;
for (int i = 0, j = (1<<12)-1; i < 16; i++, j = rotate_left(j, 1))
cor[i] = j;
for (int i = 0; i < 1<<16; i++) {
val = i;
f[i] = UNLOOPXALL ? 1 : (val = ~val, UNLOOPXALL);
}
int cases = 0;
char c;
while (scanf("%d %d", &n, &m) == 2) {
while (getchar() != '\n');
for (int i = 0; i < n; i++) {
fgets(g[i], 2000, stdin);
for (int j = 0; j < m; j++)
g[i][j] -= '0';
}
int bn = n-3, bm = m-3;
char *p = ret;
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
for (int x = 3; x < bn; x++) {
*p = '0', p++, *p = '0', p++, *p = '0', p++;
#define UNLOOP { \
val = UNLOOPYALL; \
*p = f[val] | '0'; \
p++, y++; \
}
#define UNLOOP4 UNLOOP UNLOOP UNLOOP UNLOOP
#define UNLOOP8 {UNLOOP4 UNLOOP4}
int y = 3;
for (; y+8 < bm; )
UNLOOP8;
for (; y < bm; )
UNLOOP;
*p = '0', p++, *p = '0', p++, *p = '0', p++;
*p = '\n', p++;
}
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
*p = '\0', p++;
printf("Case #%d:\n", ++cases);
puts(ret);
}
return 0;
}

bitmask3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
using namespace std;
//const int dx[] = {-3, -3, -2, -1, 0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3};
//const int dy[] = {0, 1, 2, 3, 3, 3, 2, 1, 0, -1, -2, -3, -3, -3, -2, -1};
#define MAXH 2048
#define MAXW 2048
char g[MAXH][MAXW], ret[MAXH*MAXW];
char *ptr_g;
int n, m;
#define AB00 -3*MAXW
#define AB01 -3*MAXW+1
#define AB02 -2*MAXW+2
#define AB03 -MAXW+3
#define AB40 3
#define AB41 MAXW+3
#define AB42 2*MAXW+2
#define AB43 3*MAXW+1
#define AB80 3*MAXW
#define AB81 3*MAXW-1
#define AB82 2*MAXW-2
#define AB83 MAXW-3
#define AB120 -3
#define AB121 -MAXW-3
#define AB122 -2*MAXW-2
#define AB123 -3*MAXW-1
#define T(x, y, z, w) (*(ptr_g + AB##z##w)<<(z+w))
#define UNLOOPX(i) ((val&cor[i]) == cor[i] || (val&cor[i+1]) == cor[i+1] || \
(val&cor[i+2]) == cor[i+2] || (val&cor[i+3]) == cor[i+3])
#define UNLOOPXALL UNLOOPX(0) || UNLOOPX(4) || UNLOOPX(8) || UNLOOPX(12)
#define UNLOOPY(i) T(x, y, i, 0) | T(x, y, i, 1) | T(x, y, i, 2) | T(x, y, i, 3)
#define UNLOOPYALL UNLOOPY(0) | UNLOOPY(4) | UNLOOPY(8) | UNLOOPY(12)
typedef unsigned short int UINT16;
UINT16 cor[16] = {};
UINT16 rotate_left(UINT16 x, UINT16 n) {
return (x << n) | (x >> (16-n));
}
char f[1<<16];
int main() {
UINT16 val;
for (int i = 0, j = (1<<12)-1; i < 16; i++, j = rotate_left(j, 1))
cor[i] = j;
for (int i = 0, one; i < 1<<16; i++) {
val = i, one = __builtin_popcount(val);
if (one < 12 && one > 4)
f[i] = 0;
else
f[i] = UNLOOPXALL ? 1 : (val = ~val, UNLOOPXALL);
f[i] |= '0';
}
int cases = 0;
char c;
while (scanf("%d %d", &n, &m) == 2) {
while (getchar() != '\n');
for (int i = 0; i < n; i++) {
fgets(g[i], 2000, stdin);
for (int j = 0; j < m; j++)
g[i][j] -= '0';
}
int bn = n-3, bm = m-3;
char *p = ret;
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
for (int x = 3, y; x < bn; x++) {
*p = '0', p++, *p = '0', p++, *p = '0', p++;
ptr_g = g[x]+3;
#define UNLOOP { \
val = UNLOOPYALL; \
*p = f[val]; \
p++, y++, ptr_g++; \
}
#define UNLOOP4 UNLOOP UNLOOP UNLOOP UNLOOP
#define UNLOOP8 {UNLOOP4 UNLOOP4}
for (y = 3; y+8 < bm; )
UNLOOP8;
for (; y < bm; )
UNLOOP;
*p = '0', p++, *p = '0', p++, *p = '0', p++;
*p = '\n', p++;
}
for (int i = 0; i < 3; i++) {
memset(p, '0', m), p += m;
*p = '\n', p++;
}
*p = '\0', p++;
printf("Case #%d:\n", ++cases);
puts(ret);
}
return 0;
}