b498. 史蒂芙的單詞統計

contents

  1. 1. Problem
    1. 1.1. 背景
    2. 1.2. 題目描述
  2. 2. Sample Input
  3. 3. Sample Output
  4. 4. Solution
    1. 4.1. DA trie
    2. 4.2. vector + base

Problem

背景

史蒂芙手上有本小字典,這字典專門針對某種類型的需求而收集,字典中有非常多特定單詞,為了要辨識語言是否屬於哪一類,判斷方法就是統計一句話中出現多少次字典的單詞,再按照比例數量去偵測。

題目描述

給予 $N$ 個單詞的字典,每個單詞只由大小寫字母和數字構成,接著有 $Q$ 個詢問,每個詢問為一個字串,字串只由大小寫字母和數字構成。對於每個詢問加總每一種單詞出現在字串的個數。

例如經典遊戲《魂斗羅》的秘笈 UUDDLRLRBA,若字典中只有兩個單詞 LR 和 BA,由於 LR 出現 2 次,BA 出現 1 次,統計結果為 1+2 = 3。同理,BANANANA,若字典中只有 ANA 和 BA,則 ANA 出現 3 次,BA 出現一次,統計結果為 3+1 = 4。

這個工作需要極致的效率,就麻煩各位。由於史蒂芙經費有限,只能給予 128MB 的空間。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
2
01
01
1
01001
3
LR
BA
ANA
2
UUDDLRLRBA
BANANANA

Sample Output

1
2
3
4
5
Case #1:
2
Case #2:
3
4

Solution

史蒂芙系列迎來第七題,這次考驗的是 Trie 和 Aho-Corasick automation 的空間優化,明顯地在 Trie 中每一個節點會帶有 |alpha set| 個 pointer,因此空間使用量非常浪費,尤其後綴的機率極低重疊。

因此,雖然只有大小寫英文和數字的字串,$O(64)$ 空間仍然很大,那麼開平衡樹是個解決方法,但千萬不能使用內建的 std::map,因為實作的 RB tree 帶有的 CONTAINER OVERHEAD 直接佔有大量空間。

  • 解法一: Double-Array Trie + Aho-Corasick automation,DA Trie 原則上在弄 perfect hash,一旦衝突就要捨棄,接著一代子節點跟著搬家,因此宣告一個諾大的內存池,利用雙向鏈表維護,讓他們盡可能找到安居之地。
  • 解法二:原生作法,但使用 vector<Node*> link 來維護,保證 link 的首尾都有實體子節點,中間若發生留空就隨他去,目標省下後綴空間,那後綴運氣好都是在 size = 1 的情況紀錄。
  • 解法三:安妥妥地寫平衡樹。

看到解法二就知道解法一白做工,雖然標榜 DA Trie 在插入能力很低效,修改一下找 perfect hash 的方法,不支持高效率壓縮,多一點垃圾也沒關係,那麼弄成作業系統那樣的檔案配置,看要用 C-LOOK 還是 SCAN 亂幹。總之,每次從頭開始刷會刷到天昏地暗。

  • 解法一:AC (2.5s, 76.8MB)
  • 解法二:AC (2.5s, 53.3MB)
  • 解法三:釣魚中

DA trie

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#include <bits/stdc++.h>
using namespace std;
class DATrie {
public:
static const int MAXN = 5000000;
static const int MAXC = 63;
struct Node {
int check, base, fail, val;
} A[MAXN + MAXC];
int node_size, mem_size, emp_size;
//
int son_pos[MAXC], find_pos;
inline int toIndex(char c) {
if (isdigit(c)) return c - '0' + 1;
if (isupper(c)) return c - 'A' + 10 + 1;
if (islower(c)) return c - 'a' + 26 + 10 + 1;
assert(false);
}
inline bool isEMPTY(int u) {
return u < MAXN && A[u].check < 0 && A[u].base < 0;
}
void init() {
for (int i = 1; i < MAXN; i++)
A[i].check = -(i+1), A[i].base = -(i-1);
for (int i = MAXN; i < MAXN + MAXC; i++)
A[i].check = INT_MAX;
A[MAXN-1].check = -1, A[1].base = -(MAXN-1);
A[0].check = -1, A[0].base = 0;
node_size = mem_size = emp_size = 0, find_pos = 1;
}
inline void rm_node(int x) {
if (find_pos == x) find_pos = abs(A[x].check);
A[-A[x].base].check = A[x].check;
A[-A[x].check].base = A[x].base;
node_size++;
mem_size = max(mem_size, x);
}
inline void ad_node(int x) {
A[x].check = MAXN, A[x].base = MAXN;
emp_size++;
}
bool insert(const char *s) {
int st = 0, to, c;
int flag = 0;
for (int i = 0; s[i]; i++) {
c = toIndex(s[i]);
to = abs(A[st].base) + c;
if (st == abs(A[to].check)) {
st = to;
} else if (isEMPTY(to)) {
rm_node(to);
A[to].check = st, A[to].base = to;
st = to;
} else {
int son_sz = 0;
int pos = find_empty(st, c, son_sz);
relocate(st, pos, son_sz-1);
i--;
}
}
// if (A[st].base > 0) words++;
A[st].base = -abs(A[st].base);
return 1;
}
int find(const char *s) {
int st = 0, to, c;
for (int i = 0; s[i]; i++) {
c = toIndex(s[i]);
to = abs(A[st].base) + c;
if (st == abs(A[to].check))
st = to;
else
return 0;
}
return A[st].base < 0;
}
int find_empty(int st, int c, int &sz) {
sz = 0;
int bs = abs(A[st].base);
for (int i = 1, j = bs+1; i < MAXC; i++, j++) {
if (abs(A[j].check) == st)
son_pos[sz++] = i;
}
son_pos[sz++] = c;
int mn_pos = min(son_pos[0], c) - 1;
for (; find_pos && (find_pos < bs || find_pos < mn_pos); find_pos = abs(A[find_pos].check));
for (; find_pos; find_pos = abs(A[find_pos].check)) {
int ok = 1;
for (int i = 0; i < sz && ok; i++)
ok &= isEMPTY(find_pos + son_pos[i] - mn_pos);
if (ok)
return find_pos - mn_pos;
}
printf("Memory Leak -- %d\n", find_pos);
exit(0);
return -1;
}
void relocate(int st, int to, int sz) { // move ::st -> ::to
for (int i = sz-1; i >= 0; i--) {
int a = abs(A[st].base) + son_pos[i]; // old
int b = to + son_pos[i]; // new
rm_node(b);
A[b].check = st, A[b].base = A[a].base;
int vs = abs(A[a].base);
for (int j = 1, k = vs+1; j < MAXC; j++, k++) {
if (abs(A[k].check) == a)
A[k].check = b;
}
ad_node(a);
}
A[st].base = (A[st].base < 0 ? -1 : 1) * to;
}
void build() { // AC automation
queue<int> Q;
int u, p, to, pto;
Q.push(0), A[0].fail = -1;
while (!Q.empty()) {
u = Q.front(), Q.pop();
for (int i = 1; i < MAXC; i++) {
to = abs(A[u].base) + i;
if (u != abs(A[to].check))
continue;
Q.push(to);
p = A[u].fail;
while (p != -1) {
pto = abs(A[p].base) + i;
if (p != abs(A[pto].check))
p = A[p].fail;
else
break;
}
if (p == -1)
A[to].fail = 0;
else
A[to].fail = abs(A[p].base) + i;
A[to].val = A[A[to].fail].val + (A[to].base < 0);
}
}
}
int query(const char *s) {
int st = 0, c, to;
int matched = 0;
for (int i = 0; s[i]; i++) {
c = toIndex(s[i]);
do {
to = abs(A[st].base) + c;
if (st != abs(A[to].check) && st != 0)
st = A[st].fail;
else
break;
} while (true);
to = abs(A[st].base) + c;
if (st != abs(A[to].check))
st = 0;
else
st = to;
matched += A[st].val;
}
return matched;
}
} tree;
char s[1048576];
int main() {
int n, m, cases = 0;
while (scanf("%d", &n) == 1) {
printf("Case #%d:\n", ++cases);
tree.init();
for (int i = 0; i < n; i++) {
scanf("%s", s);
tree.insert(s);
}
tree.build();
scanf("%d", &m);
for (int i = 0; i < m; i++) {
scanf("%s", s);
int t = tree.query(s);
printf("%d\n", t);
}
}
return 0;
}

vector + base

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <queue>
#include <map>
#include <assert.h>
#define MAXCHAR 62
#define MAXS (3000005)
#define MAXNODE 1200000
using namespace std;
class ACmachine {
public:
struct Node {
vector<Node*> link;
Node *fail;
int cnt, val, base;
Node() {
fail = NULL;
cnt = val = 0;
base = 128, link = vector<Node*>(0, NULL);
}
Node* next(int c) {
if (c - base < 0) return NULL;
if (c - base >= link.size()) return NULL;
return link[c - base];
}
void change(int c, Node *p) {
if (c >= base && c - base < link.size()) {
link[c-base] = p;
} else {
int nb = min(base, c), mx = max(base+(int)link.size()-1, c);
if (base == 128) mx = c;
vector<Node*> co(mx-nb+1, NULL);
for (int i = 0; i < link.size(); i++)
co[i+base-nb] = link[i];
link = co, base = nb;
link[c-base] = p;
}
}
} nodes[MAXNODE];
Node *root;
int size;
Node* getNode() {
assert(size < MAXNODE);
Node *p = &nodes[size++];
*p = Node();
return p;
}
void init() {
size = 0;
root = getNode();
}
inline int toIndex(char c) {
if (isdigit(c)) return c - '0';
if (isupper(c)) return c - 'A' + 10;
if (islower(c)) return c - 'a' + 26 + 10;
}
void insert(const char str[]) {
Node *p = root;
for (int i = 0, idx; str[i]; i++) {
idx = toIndex(str[i]);
if (p->next(idx) == NULL)
p->change(idx, getNode());
p = p->next(idx);
}
p->cnt = 1;
}
void build() { // AC automation
queue<Node*> Q;
Node *u, *p;
Q.push(root), root->fail = NULL;
while (!Q.empty()) {
u = Q.front(), Q.pop();
for (int i = 0; i < MAXCHAR; i++) {
if (u->next(i) == NULL)
continue;
Q.push(u->next(i));
p = u->fail;
while (p != NULL && p->next(i) == NULL)
p = p->fail;
if (p == NULL)
u->next(i)->fail = root;
else
u->next(i)->fail = p->next(i);
u->next(i)->val = u->next(i)->fail->val + u->next(i)->cnt;
}
}
}
int query(const char str[]) {
Node *u = root, *p;
int matched = 0;
for (int i = 0, idx; str[i]; i++) {
idx = toIndex(str[i]);
while (u->next(idx) == NULL && u != root)
u = u->fail;
u = u->next(idx);
u = (u == NULL) ? root : u;
p = u;
matched += p->val;
}
return matched;
}
void free() {
return ;
}
};
ACmachine disk;
char s[1048576];
int main() {
int n, m, cases = 0;
while (scanf("%d", &n) == 1) {
printf("Case #%d:\n", ++cases);
disk.init();
for (int i = 0; i < n; i++) {
scanf("%s", s);
disk.insert(s);
}
disk.build();
scanf("%d", &m);
for (int i = 0; i < m; i++) {
scanf("%s", s);
int t = disk.query(s);
printf("%d\n", t);
}
disk.free();
}
return 0;
}