批改娘 10111. Longest Common Subsequence II (OpenMP)

contents

  1. 1. 題目描述
  2. 2. 輸入格式
  3. 3. 輸出格式
  4. 4. 範例輸入
  5. 5. 範例輸出
  6. 6. Solution
    1. 6.1. fine-grain
    2. 6.2. coarse-grain

題目描述

給兩個字串 $X, \; Y$,在兩個字串中都有出現且最長的子序列 (subsequence),就是最長共同子字串

輸入格式

有多組測資,每組測資有兩行字串 $X, \; Y$,$X, \; Y$ 只由 A T C G 四個字母構成。

  • $1 \le |X|, |Y| \le 60000$

輸出格式

針對每一組測資,輸出一行 $X, \; Y$ 的最長共同子字串長度以及任何一組最長共同子字串。

範例輸入

1
2
3
4
TCA
GTA
TGGAC
TATCT

範例輸出

1
2
3
4
2
TA
3
TAC

Solution

雖然我們都知道數據小時,由於有 $O(N^2)$ 大小的表可以協助回溯找解,但當 $N$ 非常大時,就無法儲存這張表。因此,可以採用分治算法找到這一張表,也就是所謂的 Hirschberg’s algorithm,相關的 demo 在此。

時間複雜度為 $O(N \log N)$,空間複雜度為 $O(N)$,平行部分可以採用 fine-grain 設計,那麼就會有不斷建立 thread 的 overhead,但更容易處於負載平衡 (workload balance)。另一種則是在遞迴分治下,拆成好幾個 thread 完成各自部分,這樣比較容易觸發 cache 的性質。

從實作中,fine-grain 比 coarse-grain 好上許多。

fine-grain

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <omp.h>
#define MAXN 60005
#define MAXSEQ 500
#define CHARSET 4
#define max(x, y) (x) > (y) ? (x) : (y)
static __thread int c2i[128] = {['A'] = 0, ['C'] = 1, ['G'] = 2, ['T'] = 3};
int lcs_len(const char *A, int na, const char *B, int nb, int inv_flag, int dpf[]) {
static int P[CHARSET][MAXN];
int dp[2][MAXN] = {};
A--, B--;
if (!inv_flag) {
#pragma omp parallel for
for (int i = 0; i < CHARSET; i++) {
P[i][0] = nb+1;
for (int j = 1; j <= nb; j++)
P[i][j] = (B[j] == "ACGT"[i]) ? j-1 : P[i][j-1];
}
for (int i = 0; i < 2; i++)
dp[i][nb+1] = -1;
#pragma omp parallel
for (int i = 1; i <= na; i++) {
int *Pv = P[c2i[A[i]]];
int *dpIn = dp[i&1^1];
int *dpOut = dp[i&1];
#pragma omp for
for (int j = 1; j <= nb; j++) {
int t1 = dpIn[Pv[j]]+1;
int t2 = dpIn[j];
dpOut[j] = t1 > t2 ? t1 : t2;
}
}
memcpy(dpf, dp[na&1], sizeof(int)*(nb+1));
dpf[nb+1] = dpf[0] = 0;
return dpf[nb];
}
// inverse version
#pragma omp parallel for
for (int i = 0; i < CHARSET; i++) {
P[i][nb+1] = 0;
for (int j = nb; j >= 1; j--)
P[i][j] = (B[j] == "ACGT"[i]) ? j+1 : P[i][j+1];
}
for (int i = 0; i < 2; i++)
dp[i][0] = -1;
#pragma omp parallel
for (int i = na; i >= 1; i--) {
int *Pv = P[c2i[A[i]]];
int *dpIn = dp[i&1^1];
int *dpOut = dp[i&1];
#pragma omp for
for (int j = 1; j <= nb; j++) {
int t1 = dpIn[Pv[j]]+1;
int t2 = dpIn[j];
dpOut[j] = t1 > t2 ? t1 : t2;
}
}
memcpy(dpf, dp[1&1], sizeof(int)*(nb+1));
dpf[nb+1] = dpf[0] = 0;
return dpf[nb];
}
char* alloc_str(int sz) {
return (char *) calloc(sz, sizeof(char));
}
char* substr(const char *s, int pos, int len) {
char *t = alloc_str(len+1);
memcpy(t, s+pos, len);
return t;
}
char* cat(const char *sa, const char *sb) {
int na = strlen(sa), nb = strlen(sb);
char *t = alloc_str(na + nb + 1);
memcpy(t, sa, na);
memcpy(t+na, sb, nb);
return t;
}
char* find_lcs_seq(const char *A, int na, const char *B, int nb) {
static int P[CHARSET][MAXSEQ];
static char fa[MAXSEQ][MAXSEQ];
int dp[2][MAXSEQ] = {};
A--, B--;
for (int i = 0; i < CHARSET; i++) {
P[i][0] = nb+1;
for (int j = 1; j <= nb; j++)
P[i][j] = (B[j] == "ACGT"[i]) ? j-1 : P[i][j-1];
}
for (int i = 0; i < 2; i++)
dp[i][nb+1] = -1;
for (int i = 1; i <= na; i++) {
int *Pv = P[c2i[A[i]]];
int *dpIn = dp[i&1^1];
int *dpOut = dp[i&1];
for (int j = 1; j <= nb; j++) {
int t1 = dpIn[Pv[j]]+1;
int t2 = dpIn[j];
if (t1 > t2)
dpOut[j] = t1, fa[i][j] = 0;
else
dpOut[j] = t2, fa[i][j] = 1;
}
}
int sz = dp[na&1][nb];
char *ret = alloc_str(sz+1);
for (int i = na, j = nb; sz && i >= 1; i--) {
if (fa[i][j] == 0)
ret[--sz] = A[i], j = P[c2i[A[i]]][j];
}
return ret;
}
char* find_lcs(const char *a, int na, const char *b, int nb) {
if (na > nb) {
const char *c; int t;
c = a, a = b, b = c;
t = na, na = nb, nb = t;
}
if (na == 0)
return alloc_str(1);
if (na < MAXSEQ && nb < MAXSEQ)
return find_lcs_seq(a, na, b, nb);
int t1[MAXN];
int t2[MAXN];
int len = -1;
int half_len = na / 2;
char *la = substr(a, 0, half_len);
char *ra = substr(a, half_len, na - half_len);
lcs_len(la, half_len, b, nb, 0, t1);
lcs_len(ra, na - half_len, b, nb, 1, t2);
int split = -1;
for (int i = 0; i <= nb; i++) {
if (t1[i] + t2[i+1] > len)
split = i, len = t1[i] + t2[i+1];
}
if (len == 0)
return alloc_str(1);
assert(split != -1);
char *lb = substr(b, 0, split);
char *rb = substr(b, split, nb - split);
char *sl = t1[split] ? find_lcs(la, half_len, lb, split) : alloc_str(1);
char *sr = t2[split+1] ? find_lcs(ra, na - half_len, rb, nb - split) : alloc_str(1);
char *ret = cat(sl, sr);
free(la), free(ra);
free(lb), free(rb);
free(sl), free(sr);
return ret;
}
int main() {
static char A[MAXN], B[MAXN];
int dp[MAXN];
while (scanf("%s %s", A, B) == 2) {
int na = strlen(A);
int nb = strlen(B);
char *str = find_lcs(A, na, B, nb);
printf("%d\n", strlen(str));
printf("%s\n", str);
free(str);
}
return 0;
}

coarse-grain

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <omp.h>
#define MAXN 60005
#define CHARSET 4
#define max(x, y) (x) > (y) ? (x) : (y)
typedef unsigned short uint16;
int lcs_len_seq(const char *A, int na, const char *B, int nb, int dpf[]) {
uint16 dp[2][MAXN];
memset(dp[0], 0, sizeof(uint16)*(nb+1));
dp[1][0] = 0;
for (int i = 1; i <= na; i++) {
for (int j = 1; j <= nb; j++) {
if (A[i-1] == B[j-1])
dp[1][j] = dp[0][j-1]+1;
else
dp[1][j] = max(dp[1][j-1], dp[0][j]);
}
memcpy(dp[0], dp[1], sizeof(uint16)*(nb+1));
}
for (int i = 0; i <= nb; i++)
dpf[i] = dp[0][i];
return dpf[nb];
}
int lcs_len(const char *A, int na, const char *B, int nb, int dpf[]) {
if (nb < 256)
return lcs_len_seq(A, na, B, nb, dpf);
int c2i[128] = {['A'] = 0, ['C'] = 1, ['G'] = 2, ['T'] = 3};
char i2c[CHARSET] = {'A', 'C', 'G', 'T'};
int P[CHARSET][MAXN];
uint16 dp[2][MAXN];
A--, B--;
for (int i = 0; i < CHARSET; i++) {
P[i][0] = nb+1;
for (int j = 1; j <= nb; j++)
P[i][j] = (B[j] == i2c[i]) ? j-1 : P[i][j-1];
}
for (int i = 0; i < 2; i++) {
memset(dp[i], 0, sizeof(uint16)*(nb+1));
dp[i][nb+1] = -1;
}
for (int i = 1; i <= na; i++) {
int *Pv = P[c2i[A[i]]];
uint16 *dpIn = dp[i&1^1];
uint16 *dpOut = dp[i&1];
for (int j = 1; j <= nb; j++) {
uint16 t1 = dpIn[Pv[j]]+1;
uint16 t2 = dpIn[j];
dpOut[j] = t1 > t2 ? t1 : t2;
}
}
for (int i = 0; i <= nb; i++)
dpf[i] = dp[na&1][i];
return dpf[nb];
}
char* alloc_str(int sz) {
return (char *) calloc(sz, sizeof(char));
}
char* substr(const char *s, int pos, int len) {
char *t = alloc_str(len+1);
memcpy(t, s+pos, len);
return t;
}
char* cat(const char *sa, const char *sb) {
int na = strlen(sa), nb = strlen(sb);
char *t = alloc_str(na + nb + 1);
memcpy(t, sa, na);
memcpy(t+na, sb, nb);
return t;
}
char* reverse(const char *s, int len) {
char *t = substr(s, 0, len);
char *l = t, *r = t + len - 1;
char tmp;
while (l < r) {
tmp = *l, *l = *r, *r = tmp;
l++, r--;
}
return t;
}
char* find_lcs(const char *a, int na, const char *b, int nb, int dep) {
if (na > nb) {
const char *c; int t;
c = a, a = b, b = c;
t = na, na = nb, nb = t;
}
if (na == 0)
return alloc_str(1);
if (na == 1) {
for (int i = 0; i < nb; i++) {
if (a[0] == b[i])
return substr(a, 0, 1);
}
return alloc_str(1);
}
int t1[MAXN];
int t2[MAXN];
int len = -1;
int half_len = na / 2;
char *la = substr(a, 0, half_len);
char *ra = substr(a, half_len, na - half_len);
char *tb = reverse(b, nb);
char *ta = reverse(ra, na - half_len);
#pragma omp task untied shared(t1)
lcs_len(la, half_len, b, nb, t1);
#pragma omp task untied shared(t2)
lcs_len(ta, na - half_len, tb, nb, t2);
#pragma omp taskwait
int split = -1;
for (int i = 0; i <= nb; i++) {
if (t1[i] + t2[nb-i] > len)
split = i, len = t1[i] + t2[nb-i];
}
if (len == 0)
return alloc_str(1);
char *lb = substr(b, 0, split);
char *rb = substr(b, split, nb - split);
char *sl, *sr;
#pragma omp task untied shared(sl)
sl = find_lcs(la, half_len, lb, split, dep+1);
#pragma omp task untied shared(sr)
sr = find_lcs(ra, na - half_len, rb, nb - split, dep+1);
#pragma omp taskwait
char *ret = cat(sl, sr);
free(la), free(ra), free(ta);
free(lb), free(rb), free(tb);
free(sl), free(sr);
return ret;
}
int main() {
static char A[MAXN], B[MAXN];
int dp[MAXN];
while (scanf("%s %s", A, B) == 2) {
int na = strlen(A);
int nb = strlen(B);
char *str;
#pragma omp parallel
{
#pragma omp single
str = find_lcs(A, na, B, nb, 0);
}
printf("%d\n", strlen(str));
printf("%s\n", str);
free(str);
}
return 0;
}