b500. 子字串集合

contents

  1. 1. Problem
  2. 2. Sample Input
  3. 3. Sample Output
  4. 4. Solution
    1. 4.1. 後綴自動機
    2. 4.2. set ver 1
    3. 4.3. set ver 2
    4. 4.4. trie
    5. 4.5. DA trie

Problem

給予一個字串 $S$,求出所有子字串的集合,將集合內除了空字串以外的字串照字典順序輸出。

Sample Input

1
SLEEP

Sample Output

1
2
3
4
5
6
7
8
9
10
11
12
13
14
E
EE
EEP
EP
L
LE
LEE
LEEP
P
S
SL
SLE
SLEE
SLEEP

Solution

新手上路練習題,沒打算考很難的操作,即便如此信任也破產,看來出成變態題的機會高一點導致釣不到人來寫。

直觀作法是窮舉所有子字串,去掉重複即可,時間複雜度 $O(N^3)$,由於 $N = 500$ 也要考慮輸出的成本,所以不可能出太大。儘管如此,來比較算法之間的空間和時間用量。

首先,最常見到的 set 作法,若直接儲存字串空間用量 $O(N^3)$,為了避免空間太多,可以自己寫一個 compare function 來完成,因此 set 只要記錄子字串的起始位置和長度即可,時間複雜度 $O(N^3)$,空間降到 $O(N^2)$

接著,進入到後期,常用到字典樹 trie,若把所有後綴插入到字典中,接著在 trie 走訪輸出所有結果即可,時間複雜度 $O(N^2)$,空間 $O(N^2 \times 26)$。可以使用 double-array trie 捨棄掉一點時間,降低節點的空間使用量。

最後,比較強悍的後綴自動機,在劉汝佳的書上主要是用 DAWG (directed acyclic word graph) 來描述 suffix automaton,後綴自動機可以在線構造,時間和空間複雜度都是 $O(N)$,後綴自動機可以接受 $S$ 所有的後綴,關於建造時間和狀態總數的證明可以參考 《Suffix Automaton 杭州外国语学校 陈立杰》 的簡報。

若不想這麼詳細的證明狀態總數和時間複雜度,可以從 AC 自動機的建構概念來思考,一樣有 fail 指針,來維護當一個後綴失去匹配,移除前綴要移動到的狀態。特別的是,每一次增加最多兩個節點,最後一次增加的節點為 accept state (後綴自動機只有一個或兩個 accept state),其中一個節點是解決當前字串的後綴長度 1 的轉移。

根據這一題的需求,走訪一次後綴自動機就能印出所有子字串。

  • set 解法一 AC (0.2s, 25.2MB)
  • set 解法二 AC (0.4s, 3.8MB)
  • trie AC (0.1s, 12.1MB)
  • Double-array trie AC (0.2s, 7.7MB)
  • 後綴自動機 suffix automaton AC (64ms, 240KB)

後綴自動機

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <bits/stdc++.h>
class SuffixAutomaton {
public:
static const int MAXN = 500<<1;
static const int MAXC = 26;
struct Node {
Node *next[MAXC], *pre;
int step;
Node() {
pre = NULL, step = 0;
memset(next, 0, sizeof(next));
}
} _mem[MAXN];
int size;
Node *root, *tail;
void init() {
size = 0;
root = tail = newNode();
}
Node* newNode() {
Node *p = &_mem[size++];
*p = Node();
return p;
}
int toIndex(char c) { return c - 'A'; }
char toChar(int c) { return c + 'A'; }
void add(char c, int len) {
c = toIndex(c);
Node *p, *q, *np, *nq;
p = tail, np = newNode();
np->step = len;
for (; p && p->next[c] == NULL; p = p->pre)
p->next[c] = np;
tail = np;
if (p == NULL) {
np->pre = root;
} else {
if (p->next[c]->step == p->step+1) {
np->pre = p->next[c];
} else {
q = p->next[c], nq = newNode();
*nq = *q;
nq->step = p->step + 1;
q->pre = np->pre = nq;
for (; p && p->next[c] == q; p = p->pre)
p->next[c] = nq;
}
}
}
void build(const char *s) {
init();
for (int i = 0; s[i]; i++)
add(s[i], i+1);
}
void dfs(Node *u, int idx, char path[]) {
for (int i = 0; i < MAXC; i++) {
if (u->next[i]) {
path[idx] = toChar(i);
path[idx+1] = '\0';
puts(path);
dfs(u->next[i], idx+1, path);
}
}
}
void print() {
char s[1024];
dfs(root, 0, s);
}
} SAM;
int main() {
char s[1024];
while (scanf("%s", s) == 1) {
SAM.build(s);
SAM.print();
}
return 0;
}

set ver 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include <bits/stdc++.h>
using namespace std;
int main() {
char s[512];
while (scanf("%s", s) == 1) {
set<string> S;
int n = strlen(s);
for (int i = 0; i < n; i++) {
for (int j = i; j < n; j++) {
char c = s[j+1];
s[j+1] = '\0';
S.insert(s+i);
s[j+1] = c;
}
}
for (auto &x : S)
puts(x.c_str());
}
return 0;
}

set ver 2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <bits/stdc++.h>
using namespace std;
char s[512];
struct cmp {
bool operator() (const pair<int, int> &a, const pair<int, int> &b) const {
for (int i = 0; i < a.second && i < b.second; i++) {
if (s[a.first+i] != s[b.first+i])
return s[a.first+i] < s[b.first+i];
}
return a.second < b.second;
}
};
int main() {
while (scanf("%s", s) == 1) {
set< pair<int, int>, cmp > S;
int n = strlen(s);
for (int i = 0; i < n; i++) {
for (int j = i; j < n; j++) {
S.insert({i, j-i+1});
}
}
for (auto &x : S) {
int base = x.first, len = x.second;
char c = s[base+len];
s[base+len] = '\0';
puts(s + base);
s[base+len] = c;
}
}
return 0;
}

trie

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <bits/stdc++.h>
using namespace std;
class Trie {
public:
static const int MAXN = 130000;
static const int MAXC = 26;
struct Node {
Node *next[MAXC];
void init() {
memset(next, 0, sizeof(next));
}
} nodes[MAXN];
Node *root;
int size;
Node* newNode() {
assert(size < MAXN);
Node *p = &nodes[size++];
p->init();
return p;
}
void init() {
size = 0;
root = newNode();
}
inline int toIndex(char c) {
return c - 'A';
}
inline int toChar(char c) {
return c + 'A';
}
void insert(const char str[]) {
Node *p = root;
for (int i = 0, idx; str[i]; i++) {
idx = toIndex(str[i]);
if (p->next[idx] == NULL)
p->next[idx] = newNode();
p = p->next[idx];
}
}
void dfs(Node *u, int idx, char path[]) {
for (int i = 0; i < MAXC; i++) {
if (u->next[i]) {
path[idx] = toChar(i);
path[idx+1] = '\0';
puts(path);
dfs(u->next[i], idx+1, path);
}
}
}
void print() {
char s[1024];
dfs(root, 0, s);
}
} tree;
char s[1024];
int main() {
while (scanf("%s", s) == 1) {
int n = strlen(s);
tree.init();
for (int i = 0; i < n; i++) {
for (int j = i; j < n; j++) {
char c = s[j+1];
s[j+1] = '\0';
tree.insert(s+i);
s[j+1] = c;
}
}
tree.print();
}
return 0;
}

DA trie

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include <bits/stdc++.h>
using namespace std;
class DATrie {
public:
static const int MAXN = 500000;
static const int MAXC = 27;
struct Node {
int check, base, fail, val;
} A[MAXN + MAXC];
int node_size, mem_size, emp_size;
//
int son_pos[MAXC], find_pos;
inline int toIndex(char c) {
return c - 'A' + 1;
}
inline int toChar(char c) {
return c + 'A' - 1;
}
inline bool isEMPTY(int u) {
return u < MAXN && A[u].check < 0 && A[u].base < 0;
}
void init() {
for (int i = 1; i < MAXN; i++)
A[i].check = -(i+1), A[i].base = -(i-1);
for (int i = MAXN; i < MAXN + MAXC; i++)
A[i].check = INT_MAX;
A[MAXN-1].check = -1, A[1].base = -(MAXN-1);
A[0].check = -1, A[0].base = 0;
node_size = mem_size = emp_size = 0, find_pos = 1;
}
inline void rm_node(int x) {
if (find_pos == x) find_pos = abs(A[x].check);
A[-A[x].base].check = A[x].check;
A[-A[x].check].base = A[x].base;
node_size++;
mem_size = max(mem_size, x);
}
inline void ad_node(int x) {
A[x].check = MAXN, A[x].base = MAXN;
emp_size++;
}
bool insert(const char *s) {
int st = 0, to, c;
int flag = 0;
for (int i = 0; s[i]; i++) {
c = toIndex(s[i]);
to = abs(A[st].base) + c;
if (st == abs(A[to].check)) {
st = to;
} else if (isEMPTY(to)) {
rm_node(to);
A[to].check = st, A[to].base = to;
st = to;
} else {
int son_sz = 0;
int pos = find_empty(st, c, son_sz);
relocate(st, pos, son_sz-1);
i--;
}
}
A[st].base = -abs(A[st].base);
return 1;
}
int find(const char *s) {
int st = 0, to, c;
for (int i = 0; s[i]; i++) {
c = toIndex(s[i]);
to = abs(A[st].base) + c;
if (st == abs(A[to].check))
st = to;
else
return 0;
}
return A[st].base < 0;
}
int find_empty(int st, int c, int &sz) {
sz = 0;
int bs = abs(A[st].base);
for (int i = 1, j = bs+1; i < MAXC; i++, j++) {
if (abs(A[j].check) == st)
son_pos[sz++] = i;
}
son_pos[sz++] = c;
int mn_pos = min(son_pos[0], c) - 1;
for (; find_pos && (find_pos < bs || find_pos < mn_pos); find_pos = abs(A[find_pos].check));
for (; find_pos; find_pos = abs(A[find_pos].check)) {
int ok = 1;
for (int i = 0; i < sz && ok; i++)
ok &= isEMPTY(find_pos + son_pos[i] - mn_pos);
if (ok)
return find_pos - mn_pos;
}
printf("Memory Leak -- %d\n", find_pos);
exit(0);
return -1;
}
void relocate(int st, int to, int sz) { // move ::st -> ::to
for (int i = sz-1; i >= 0; i--) {
int a = abs(A[st].base) + son_pos[i]; // old
int b = to + son_pos[i]; // new
rm_node(b);
A[b].check = st, A[b].base = A[a].base;
int vs = abs(A[a].base);
for (int j = 1, k = vs+1; j < MAXC; j++, k++) {
if (abs(A[k].check) == a)
A[k].check = b;
}
ad_node(a);
}
A[st].base = (A[st].base < 0 ? -1 : 1) * to;
}
void dfs(int u, int idx, char path[]) {
for (int i = 1; i < MAXC; i++) {
int to = abs(A[u].base) + i;
if (u == abs(A[to].check)) {
path[idx] = toChar(i);
path[idx+1] = '\0';
puts(path);
dfs(to, idx+1, path);
}
}
}
void print() {
char s[1024];
dfs(0, 0, s);
}
} tree;
char s[1024];
int main() {
while (scanf("%s", s) == 1) {
int n = strlen(s);
tree.init();
for (int i = 0; i < n; i++) {
for (int j = i; j < n; j++) {
char c = s[j+1];
s[j+1] = '\0';
tree.insert(s+i);
s[j+1] = c;
}
}
tree.print();
}
return 0;
}