批改娘 10027. Fast Sudoku

contents

  1. 1. 背景
  2. 2. 題目描述
  3. 3. 輸入格式
  4. 4. 輸出格式
  5. 5. 範例輸入
  6. 6. 範例輸出
  7. 7. 範例輸入 2
  8. 8. 範例輸出 2
  9. 9. Solution
    1. 9.1. IDA
    2. 9.2. Task
    3. 9.3. DLX

背景

不過,還想做得更好

但這個世界卻要求神通廣大、無所不知,那不聰明的我又該何去何從 …

題目描述

給一張 $9 \times 9$ 數獨,必須滿足每行每列的 1 到 9 的數字不重複,同時在九個 $3 \times 3$ 方格內的 1 到 9 數字也不重複。請計算可填入的方法數。

輸入格式

測資只有一筆,共有九行,每一行上有九個整數,第 $i$ 行上的第 $j$ 個整數表示 $\text{grid}[i][j]$ 填入的數字,若 $\text{grid}[i][j] = 0$ 表示尚未填入。

輸出格式

對於每一組測資輸出一行一個整數,表示數獨共有幾種填法。

範例輸入

1
2
3
4
5
6
7
8
9
0 6 0 0 0 4 0 5 0
0 0 8 3 0 5 6 0 0
2 0 0 0 0 0 0 0 1
8 0 0 4 0 7 0 0 6
0 0 6 0 0 0 3 0 0
7 0 0 9 0 1 0 0 4
5 0 0 0 0 0 0 0 2
0 0 7 2 0 6 9 0 0
0 4 0 5 0 8 0 7 0

範例輸出

1
2

範例輸入 2

1
2
3
4
5
6
7
8
9
0 0 0 0 0 0 0 0 0
0 0 3 6 0 0 0 0 0
0 7 0 0 9 0 2 0 0
0 5 0 0 0 7 0 0 0
0 0 0 0 4 5 7 0 0
0 0 0 1 0 0 0 3 0
0 0 1 0 0 0 0 6 8
0 0 8 5 0 0 0 1 0
0 9 0 0 0 0 4 0 0

範例輸出 2

1
292

Solution

牽涉到 load balance 問題,大部分都要先廣度搜尋一次,把狀態展開後,再利用 dynamic scheduling 進行分配工作,效果會好上許多。為了解決狀態展開而後不浪費記憶體空間,IDA* 會是一種好的選擇。當然在 OpenMP 3.0 以上有提供 task 來幫忙做到類似的事情,但搜索展開會變成寫死,估計上會比較困難。

  • IDA: Accepted (946 ms, 10752 KB)
  • OpenMP task: Accepted (1321 ms, 2560 KB)
  • DLX: Accepted (1208 ms, 56940 KB)

當然 DLX 在建表處理會變得很討厭,overhead 相較於其他的算法都高出許多,因此狀態展開不可以太多,因為他本身就搜得非常快,撰寫這一類搜索問題時,特別小心 #pragma omp private() 的使用,複製操作原則上都是根據 sizeof() 決定複製空間大小。因此若複製目標為函數參數,很容易只有複製到指標。

IDA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <limits.h>
#include <assert.h>
#define MAXN 65536
int C[81][2], N;
int R[9][3] = {}, P[MAXN*9][9][3], Pidx;
int ida_dep;
void IDA(int idx, int state[][3]) {
if (idx == ida_dep) {
memcpy(P[Pidx], state, sizeof(P[0])), Pidx++;
return ;
}
int x = C[idx][0], y = C[idx][1];
int c = x/3*3 + y/3;
for (int i = 0; i < 9; i++) {
int mask = 1<<i;
if ((state[x][0]&mask) || (state[y][1]&mask) ||
(state[c][2]&mask))
continue;
state[x][0] |= mask;
state[y][1] |= mask;
state[c][2] |= mask;
IDA(idx+1, state);
state[x][0] ^= mask;
state[y][1] ^= mask;
state[c][2] ^= mask;
}
}
int dfs(int idx, int state[][3]) {
if (idx == N)
return 1;
int x = C[idx][0], y = C[idx][1];
int c = x/3*3 + y/3, sum = 0;
for (int i = 0; i < 9; i++) {
int mask = 1<<i;
if ((state[x][0]&mask) || (state[y][1]&mask) ||
(state[c][2]&mask))
continue;
state[x][0] |= mask;
state[y][1] |= mask;
state[c][2] |= mask;
sum += dfs(idx+1, state);
state[c][2] ^= mask;
state[y][1] ^= mask;
state[x][0] ^= mask;
}
return sum;
}
int incomplete_bfs() {
for (ida_dep = 1; ida_dep <= N; ida_dep++) {
Pidx = 0;
IDA(0, R);
if (Pidx >= MAXN)
break;
}
if (ida_dep == N+1)
return Pidx;
int ret = 0;
#pragma omp parallel for schedule(guided) reduction(+: ret)
for (int i = 0; i < Pidx; i++)
ret += dfs(ida_dep, P[i]);
return ret;
}
int main() {
int g[9][9], n = 9;
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", &g[i][j]);
N = 0;
memset(R, 0, sizeof(R));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (g[i][j]) {
g[i][j]--;
R[i][0] |= 1<<g[i][j];
R[j][1] |= 1<<g[i][j];
R[i/3*3+j/3][2] |= 1<<g[i][j];
} else {
C[N][0] = i, C[N][1] = j;
N++;
}
}
}
int ret = incomplete_bfs();
printf("%d\n", ret);
return 0;
}

Task

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <limits.h>
#include <assert.h>
#define MAXN 65536
int C[81][2], N;
int R[9][3] = {};
int threadcount = 0;
#pragma omp threadprivate(threadcount)
void dfs_serial(int idx, int state[][3]) {
if (idx == N) {
threadcount++;
return ;
}
int x = C[idx][0], y = C[idx][1];
int c = x/3*3 + y/3;
for (int i = 0; i < 9; i++)
{
int mask = 1<<i;
if ((state[x][0]&mask) || (state[y][1]&mask) ||
(state[c][2]&mask))
continue;
state[x][0] |= mask;
state[y][1] |= mask;
state[c][2] |= mask;
dfs_serial(idx+1, state);
state[c][2] ^= mask;
state[y][1] ^= mask;
state[x][0] ^= mask;
}
}
void dfs(int idx, int state[][3]) {
if (idx == N) {
threadcount++;
return ;
}
int x = C[idx][0], y = C[idx][1];
int c = x/3*3 + y/3;
for (int i = 0; i < 9; i++)
#pragma omp task untied
{
int mask = 1<<i;
if ((state[x][0]&mask) || (state[y][1]&mask) ||
(state[c][2]&mask)) {
} else {
int S[9][3] = {};
memcpy(S, state, sizeof(S));
// int (*S)[3] = malloc(sizeof(int)*9*3);
// memcpy(S, state, sizeof(int)*9*3);
S[x][0] |= mask;
S[y][1] |= mask;
S[c][2] |= mask;
if (idx >= 6) {
dfs_serial(idx+1, S);
} else {
dfs(idx+1, S);
}
}
}
#pragma omp taskwait
}
int incomplete_bfs() {
int ret = 0;
#pragma omp parallel
{
#pragma omp single
{
dfs(0, R);
}
#pragma omp critical
ret += threadcount;
}
return ret;
}
int main() {
int g[9][9], n = 9;
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", &g[i][j]);
N = 0;
memset(R, 0, sizeof(R));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (g[i][j]) {
g[i][j]--;
R[i][0] |= 1<<g[i][j];
R[j][1] |= 1<<g[i][j];
R[i/3*3+j/3][2] |= 1<<g[i][j];
} else {
C[N][0] = i, C[N][1] = j;
N++;
}
}
}
int ret = incomplete_bfs();
printf("%d\n", ret);
return 0;
}

DLX

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <limits.h>
#include <assert.h>
#define MAXN 65536
#define MAXR 1024
typedef unsigned char byte;
typedef struct {
int left, right, up, down;
int ch;
} Node;
// DLX header
typedef struct {
int (* const sudoku) (byte [][9]);
} DLX_namespace;
// DLX body begin
void DLX_remove(int c, Node DL[], int s[]) {
DL[DL[c].right].left = DL[c].left;
DL[DL[c].left].right = DL[c].right;
for (int i = DL[c].down; i != c; i = DL[i].down) {
for (int j = DL[i].right; j != i; j = DL[j].right) {
DL[DL[j].down].up = DL[j].up;
DL[DL[j].up].down = DL[j].down;
s[DL[j].ch]--;
}
}
}
void DLX_resume(int c, Node DL[], int s[]) {
for (int i = DL[c].down; i != c; i = DL[i].down) {
for (int j = DL[i].left; j != i; j = DL[j].left) {
DL[DL[j].down].up = j;
DL[DL[j].up].down = j;
s[DL[j].ch]++;
}
}
DL[DL[c].right].left = c;
DL[DL[c].left].right = c;
}
int DLX_dfs(int k, Node DL[], int s[], int head) {
if (DL[head].right == head)
return 1;
int t = INT_MAX, c;
for (int i = DL[head].right; i != head; i = DL[i].right) {
if (s[i] < t) {
t = s[i], c = i;
}
}
int ans = 0;
DLX_remove(c, DL, s);
for (int i = DL[c].down; i != c; i = DL[i].down) {
for (int j = DL[i].right; j != i; j = DL[j].right)
DLX_remove(DL[j].ch, DL, s);
ans += DLX_dfs(k+1, DL, s, head);
for (int j = DL[i].left; j != i; j = DL[j].left)
DLX_resume(DL[j].ch, DL, s);
}
DLX_resume(c, DL, s);
return ans;
}
int DLX_newnode(int up, int down, int left, int right, Node DL[], int *size) {
DL[*size].up = up, DL[*size].down = down;
DL[*size].left = left, DL[*size].right = right;
DL[up].down = DL[down].up = DL[left].right = DL[right].left = *size;
assert(*size < MAXN);
return (*size)++;
}
void DLX_newrow(int n, int Row[], Node DL[], int s[], int *size) {
int a, r, row = -1, k;
for (a = 0; a < n; a++) {
r = Row[a];
DL[*size].ch = r, s[r]++;
if (row == -1) {
row = DLX_newnode(DL[DL[r].ch].up, DL[r].ch, *size, *size, DL, size);
} else {
k = DLX_newnode(DL[DL[r].ch].up, DL[r].ch, DL[row].left, row, DL, size);
}
}
}
void DLX_init(int m, Node DL[], int s[], int *size, int *head) {
*size = 0;
*head = DLX_newnode(0, 0, 0, 0, DL, size);
for (int i = 1; i <= m; i++) {
DLX_newnode(i, i, DL[*head].left, *head, DL, size);
DL[i].ch = i, s[i] = 0;
}
}
int DLX_sudoku(byte g[][9]) {
Node DL[MAXN + MAXR];
int s[MAXR], head, size;
int row[10];
int used[4][10][10] = {}, isValid = 1;
int n = 9, tn = 3;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (g[i][j]) {
int x = g[i][j];
int y = i/3*3 + j/3;
if (used[1][i][x]) isValid = 0;
if (used[2][j][x]) isValid = 0;
if (used[3][y][x]) isValid = 0;
used[0][i][j] = 1;
used[1][i][x] = used[2][j][x] = used[3][y][x] = 1;
}
}
}
if (!isValid)
return 0;
int OFF[4] = {};
int label[4][10][10] = {};
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (used[0][i][j] == 0)
label[0][i][j] = ++OFF[0];
}
}
for (int k = 1; k < 4; k++) {
OFF[k] = OFF[k-1];
for (int i = 0; i < n; i++) {
for (int j = 1; j <= n; j++) {
if (used[k][i][j] == 0) {
label[k][i][j] = ++OFF[k];
}
}
}
}
DLX_init(OFF[3], DL, s, &size, &head);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (g[i][j]) continue;
for (int k = 1; k <= n; k++) {
int x = k;
int y = i/3*3 + j/3;
if (used[1][i][x]) continue;
if (used[2][j][x]) continue;
if (used[3][y][x]) continue;
row[0] = label[0][i][j];
row[1] = label[1][i][k];
row[2] = label[2][j][k];
row[3] = label[3][y][k];
DLX_newrow(4, row, DL, s, &size);
}
}
}
return DLX_dfs(0, DL, s, head);
}
DLX_namespace const DLX = {DLX_sudoku};
// DLX body end
// parallel utils
typedef struct {
byte b[9][9]; // board
bool r[9][9], c[9][9], g[9][9]; // row, column, grid
} Board;
static inline int gridID(int x, int y) {
return x / 3 * 3 + y / 3;
}
void board_init(Board *b) {
memset(b, 0, sizeof(Board));
}
bool board_test(const Board *b, int x, int y, int v) {
return b->b[x][y] == 0 && b->r[x][v-1] == 0 && b->c[y][v-1] == 0 && b->g[gridID(x, y)][v-1] == 0;
}
bool board_fill(Board *b, int x, int y, int v) {
b->b[x][y] = v;
b->r[x][v-1] = b->c[y][v-1] = b->g[gridID(x, y)][v-1] = 1;
}
int incomplete_bfs(int g[][9]) {
Board origin;
int n = 9;
int A[81][2], m = 0;
board_init(&origin);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (g[i][j]) {
board_fill(&origin, i, j, g[i][j]);
} else {
A[m][0] = i, A[m][1] = j, m++;
}
}
}
#define MAXQMEM 8192
static Board Q[2][MAXQMEM*9];
int Qcnt[2] = {}, Qflag = 0;
Q[Qflag][Qcnt[Qflag]] = origin, Qcnt[Qflag]++;
for (int it = 0; it < m; it++) {
const int x = A[it][0], y = A[it][1];
Qcnt[!Qflag] = 0;
#pragma omp parallel for
for (int i = 0; i < Qcnt[Qflag]; i++) {
for (int j = 1; j <= 9; j++) {
if (!board_test(&Q[Qflag][i], x, y, j))
continue;
int pIdx;
#pragma omp critical
{
pIdx = Qcnt[!Qflag]++;
}
Q[!Qflag][pIdx] = Q[Qflag][i];
board_fill(&Q[!Qflag][pIdx], x, y, j);
}
}
if (Qcnt[!Qflag] >= MAXQMEM) {
int ret = 0;
#pragma omp parallel for reduction(+:ret)
for (int i = 0; i < Qcnt[Qflag]; i++) {
ret += DLX.sudoku(Q[Qflag][i].b);
}
return ret;
}
Qflag = !Qflag;
}
return Qcnt[Qflag];
}
int main() {
int g[9][9], n = 9;
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", &g[i][j]);
int ret = incomplete_bfs(g);
printf("%d\n", ret);
return 0;
}