批改娘 10116. Fast Dynamic Programming Computing I (OpenMP)

contents

  1. 1. 題目描述
    1. 1.1. sequence.c
  2. 2. 輸入格式
  3. 3. 輸出格式
  4. 4. 範例輸入
  5. 5. 範例輸出
  6. 6. Solution
    1. 6.1. 方法一
    2. 6.2. 方法二
    3. 6.3. 基礎平行
    4. 6.4. 進階平行
    5. 6.5. 忘卻快取平行

題目描述

給定一序列矩陣,期望求出相乘這些矩陣的最有效方法的乘法次數。

sequence.c

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <stdio.h>
#define MAXN 4096
#define INF (1LL<<60)
int N;
long long dp[MAXN*MAXN], SZ[MAXN+5];
int main() {
while (scanf("%d", &N) == 1) {
for (int i = 0; i <= N; i++)
scanf("%lld", &SZ[i]);
for (int i = 1; i <= N; i++) {
for (int j = 0; j < N-i; j++) {
int l = j, r = j+i;
long long local = INF;
for (int k = l; k < r; k++) {
long long t = dp[l*N+k] + dp[(k+1)*N+r] + SZ[l] * SZ[k+1] * SZ[r+1];
if (t < local)
local = t;
}
dp[l*N+r] = local;
}
}
printf("%lld\n", dp[0*N+N-1]);
}
return 0;
}

輸入格式

有多組測資,每組第一行會有一個整數 $N$ 表示矩陣鏈上有 $N$ 個矩陣,第二行上會有 $N+1$ 個整數 $Z_i$,表示矩陣鏈的每一個行列大小,例如當 $N = 3$ 時,輸入 10 30 5 60 表示矩陣 $A_{10, 30} B_{30, 5} C_{5, 60}$ 相乘。

  • $1 \le N \le 4096$
  • $1 \le Z_i \le 4096$

輸出格式

對於每組測資輸出一行一個整數 $M$ 為最少乘法次數。

範例輸入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
2
2 2 2
3
10 30 5 60
3
1 5 20 1
3
5 10 20 35
6
30 35 15 5 10 20 25

範例輸出

1
2
3
4
5
8
4500
105
4500
15125

Solution

假設有 $p$ 個處理單元,矩陣大小為 $n \times n$,分析一般平行運算時間 $T_p$ 如下所示:

$$\begin{align*} T_p &= \sum\limits_{k=1}^{n} \left\lceil \frac{k}{p} \right\rceil \times (n - k) \end{align*}$$

針對地 $k$ 次迭代,將第二層迴圈平行化,每一個執行緒處理 $\left\lceil \frac{k}{p} \right\rceil$ 個狀態計算,每個狀態繼續需要 $n-k$ 次計算。

$$\begin{align*} \sum\limits_{k=1}^{n} \left\lceil \frac{k}{p} \right\rceil \times k &= 1 \times (1 + 2 + 3 + \cdots + p) + \\ & \qquad 2 \times ((p+1) + (p+2) + (p+3) + \cdots + 2p) + \cdots + \\ & \qquad (n \bmod{p}) \times \left\lceil \frac{n}{p}\right\rceil (\lfloor n/p \rfloor \times p + 1 + \cdots + n) \\ &= \sum_{k=1}^{\lfloor n/p \rfloor} k \cdot \frac{p (2k+1) p}{2} + (n \bmod{p}) \times \left\lceil \frac{n}{p}\right\rceil \frac{(n \bmod{p})(n + \left\lfloor \frac{n}{p}\right\rfloor p + 1)}{2} \\ &= p^2 \left[ \frac{\left\lfloor n/p \right\rfloor(\left\lfloor n/p \right\rfloor+1)(2 \left\lfloor n/p \right\rfloor + 1)}{6} + \frac{\left\lfloor n/p \right\rfloor(\left\lfloor n/p \right\rfloor+1)}{4} \right] + (n \bmod{p})^2 \left\lceil n/p \right\rceil \frac{n + \left\lfloor n/p \right\rfloor p + 1}{2} \\ \sum\limits_{k=1}^{n} \left\lceil \frac{k}{p} \right\rceil \times n &= n \cdot p \cdot \frac{\left\lfloor n/p \right\rfloor (\left\lfloor n/p \right\rfloor + 1)}{2} + n \cdot (n - p \cdot \left\lfloor n/p \right\rfloor) \left\lceil n/p \right\rceil \end{align*}$$

總結一下 $T_p = \sum\nolimits_{k=1}^{n} \left\lceil \frac{k}{p} \right\rceil \times n - \sum\nolimits_{k=1}^{n} \left\lceil \frac{k}{p} \right\rceil \times k = O(n^3 / p)$

針對前半段 $\sum\limits_{k=1}^{n} \left\lceil \frac{k}{p} \right\rceil$ 有以下兩種推法可供參考。

方法一

$$\begin{align*} \sum\limits_{k=1}^{n} \left\lceil \frac{k}{p} \right\rceil &= p \cdot 1 + p \cdot 2 + p \cdot 3 + \cdots + p \cdot \left\lfloor n/p \right\rfloor + (n \bmod{p}) \cdot \left\lceil n/p \right\rceil \phantom{0\over0}\\ &= \sum\limits_{k=1}^{\left\lfloor n/p \right\rfloor} k \cdot p + (n - p \cdot \left\lfloor n/p \right\rfloor) \left\lceil n/p \right\rceil \\ &= p \cdot \frac{\left\lfloor n/p \right\rfloor (\left\lfloor n/p \right\rfloor + 1)}{2} + (n - p \cdot \left\lfloor n/p \right\rfloor) \left\lceil n/p \right\rceil && \blacksquare \end{align*}$$

方法二

套用定理如下:(解釋: $n$ 個人分成 $p$,每組人數最多差一。)

$$\begin{align} n = \left\lceil \frac{n}{p} \right\rceil + \left\lceil \frac{n-1}{p} \right\rceil + \cdots + \left\lceil \frac{n-p+1}{p} \right\rceil \end{align}$$

套用上式從 $k = n$ 往回推。

$$\begin{align*} \sum\limits_{k=1}^{n} \left\lceil \frac{k}{p} \right\rceil &= \left\lceil \frac{n}{p} \right\rceil + \left\lceil \frac{n-1}{p} \right\rceil + \cdots + \left\lceil \frac{n-p+1}{p} \right\rceil + \sum_{k=1}^{n-p} \left\lceil \frac{k}{p} \right\rceil \\ &= n + (n - p) + (n - 2p) + \cdots + (n - \lfloor n/p \rfloor \cdot p) + \sum_{k=1}^{n \bmod{p}} \left\lceil \frac{k}{p} \right\rceil \\ &= \sum\limits_{k=0}^{\left\lfloor n/p \right\rfloor - 1} (n - k p) + (n \bmod{p}) \\ &= n \left\lfloor n / p \right\rfloor - \frac{\left\lfloor n/p \right\rfloor (\left\lfloor n/p \right\rfloor - 1)}{2} \times p + (n \bmod{p}) && \blacksquare \end{align*}$$

基礎平行

  • Accepted (50883 ms, 132224 KB)

平行地方呼之欲出,針對第二層迴圈直接平行化即可,下述代碼會帶入一點累贅,其原因在於做了各種實驗所致,但不影響正確性。只做了減少 thread 建立時的 overhead 的處理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <stdio.h>
#include <omp.h>
#define MAXN 4096
#define INF (1LL<<60)
int N;
long long dp[MAXN*MAXN], SZ[MAXN+5];
int main() {
while (scanf("%d", &N) == 1) {
for (int i = 0; i <= N; i++)
scanf("%lld", &SZ[i]);
for (int i = 0; i < N; i++)
dp[i*N+i] = 0;
#pragma omp parallel firstprivate(N)
{
for (int i = 1; i <= N; i++) {
#pragma omp for
for (int j = 0; j < N-i; j++) {
int l = j, r = j+i;
long long local = INF;
dp[l*N+r] = INF;
for (int k = l; k < r; k++) {
long long t = dp[l*N+k] + dp[(k+1)*N+r] + SZ[l] * SZ[k+1] * SZ[r+1];
if (t < local)
local = t;
}
if (local < dp[l*N+r])
dp[l*N+r] = local;
}
}
}
printf("%lld\n", dp[0*N+N-1]);
}
return 0;
}

進階平行

  • Accepted (9890 ms, 136320 KB)

從一般平行實驗結果中,發現明顯地加速倍率不對,明明使用 24 個核心,跑起來不到兩倍加速的原因到底在哪?是的,在第三層存取順序需要 dp[l][k]dp[k+1][r] 隨著 $k$ 變大,在前者存取順序是連續的,而在後者存取順序每一次跳躍長度為 $N$,導致 dp[k+1][r] 常常會發生 cache miss,一旦發生 cache miss,需要數百的 cycle 等待資料被帶進記憶體中。

由於矩陣鍊乘積計算時只用了矩陣的上三角,那複製一份到下三角矩陣去吧!dp[l][r] = dp[r][l],如此一來在第三層迴圈發生 cache miss 的機會就大幅度下降。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <stdio.h>
#include <omp.h>
#define MAXN 4096
#define INF (1LL<<60)
int N, SZ[MAXN+5];
long long dp[MAXN*MAXN];
int main() {
while (scanf("%d", &N) == 1) {
for (int i = 0; i <= N; i++)
scanf("%d", &SZ[i]);
for (int i = 0; i < N; i++)
dp[i*N+i] = 0;
#pragma omp parallel firstprivate(N)
{
for (int i = 1; i <= N; i++) {
#pragma omp for
for (int j = 0; j < N-i; j++) {
int l = j, r = j+i;
long long local = INF, base = 1LL * SZ[l] * SZ[r+1];
long long *dp1 = dp + l*N, *dp2 = dp + r*N;
for (int k = l; k < r; k++) {
long long t = dp1[k] + dp2[k+1] + SZ[k+1] * base;
if (t < local)
local = t;
}
dp1[r] = dp2[l] = local;
}
}
}
printf("%lld\n", dp[0*N+N-1]);
}
return 0;
}

忘卻快取平行

  • Accepted (4264 ms, 134400 KB)

  • 參考論文 High-perofrmance Energy-efficient Recursive Dynamic Programming with Matrix-multiplication-like Flexible Kernels

最主要的精神 cache-oblivious algorithm 設計,利用大方陣切割成數個小方陣,矩陣跟矩陣之間在足夠小的情況下進行合併計算。算法概述如下:

cache-oblivious algorithm

每一個函數參數中的藍色區塊答案算出,而其他參數則是要合併的左矩陣和下矩陣,直到計算大小足夠小時,直接跑序列版本的程式,此時所有陣列可以全部帶進快取,這時候計算變得更加有利。再加上很多層的快取,每一個 CPU 的 cache 重複使用率相當高。

一般的平行度最高只有 $N$,因為我們只針對其中一個迴圈平行,而在 cache-oblivious algorithm 中,平行度最高為 $N^{1.415}$。這些都只是理論分析,實作時為了考慮硬體架構是沒辦法達到理想狀態的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#define MAXN 4096
#define MAXD 6
#define MAXC 64
#define INF (1LL<<62)
int N;
long long dp[MAXN*MAXN], SZ[MAXN+5];
typedef struct MtxHead {
long long *A;
int bx, by, n;
} MtxHead;
MtxHead subMtx(MtxHead X, int i, int j) {
MtxHead T = X;
T.n >>= 1;
if (i == 2) T.bx += T.n;
if (j == 2) T.by += T.n;
return T;
}
static inline int min(int x, int y) {
return x < y ? x : y;
}
static inline int max(int x, int y) {
return x > y ? x : y;
}
static inline long long loop(MtxHead X, MtxHead U, MtxHead V, int tl, int tr, int l, int r) {
long long v = INF, t;
long long comSZ = SZ[l] * SZ[r+1];
for (int k = tl; k < tr; k++) {
t = U.A[l*N+k] + V.A[(k+1)+r*N] + SZ[k+1] * comSZ;
if (t < v) v = t;
}
return v;
}
void Cloop(MtxHead X, MtxHead U, MtxHead V) {
int n = X.n, l, r, tl, tr;
long long v, t;
for (int i = n-1; i >= 0; i--) {
for (int j = 0; j < n; j++) {
l = X.bx + i, r = X.by + j;
v = X.A[l*N+r];
tl = max(U.by, X.bx+i), tr = min(U.by+n, X.by+j);
t = loop(X, U, V, tl, tr, l, r);
if (t < v) v = t;
tl = max(V.bx, X.bx+i), tr = min(V.bx+n, X.by+j);
t = loop(X, U, V, tl, tr, l, r);
if (t < v) v = t;
X.A[l*N+r] = X.A[r*N+l] = v;
}
}
}
void Bloop(MtxHead X, MtxHead U, MtxHead V) {
int n = X.n, l, r, tl, tr;
long long v, t;
for (int i = n-1; i >= 0; i--) {
for (int j = 0; j < n; j++) {
l = X.bx + i, r = X.by + j;
v = X.A[l*N+r];
tl = max(U.by+i, X.bx+i), tr = min(U.by+n, X.by+j);
t = loop(X, U, V, tl, tr, l, r);
if (t < v) v = t;
tl = max(V.bx, X.bx+i), tr = min(V.bx+n, X.by+j);
t = loop(X, U, V, tl, tr, l, r);
if (t < v) v = t;
X.A[l*N+r] = X.A[r*N+l] = v;
}
}
}
void Aloop(MtxHead X) {
int n = X.n;
for (int i = 1; i <= n; i++) {
for (int j = 0; j+i < n; j++) {
int l = X.bx + j, r = X.by + j+i;
long long v = X.A[l*N+r], t;
long long comSZ = SZ[l] * SZ[r+1];
for (int k = l; k < r; k++) {
t = X.A[l*N+k] + X.A[(k+1)+r*N] + SZ[k+1] * comSZ;
if (t < v)
v = t;
}
X.A[l*N+r] = X.A[r*N+l] = v;
}
}
}
void Cpar(MtxHead X, MtxHead U, MtxHead V, int dep) {
if (X.n <= MAXC) {
Cloop(X, U, V);
return ;
}
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 1, 1), subMtx(U, 1, 1), subMtx(V, 1, 1), dep+1);
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 1, 2), subMtx(U, 1, 1), subMtx(V, 1, 2), dep+1);
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 2, 1), subMtx(U, 2, 1), subMtx(V, 1, 1), dep+1);
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 2, 2), subMtx(U, 2, 1), subMtx(V, 1, 2), dep+1);
#pragma omp taskwait
//
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 1, 1), subMtx(U, 1, 2), subMtx(V, 2, 1), dep+1);
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 1, 2), subMtx(U, 1, 2), subMtx(V, 2, 2), dep+1);
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 2, 1), subMtx(U, 2, 2), subMtx(V, 2, 1), dep+1);
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 2, 2), subMtx(U, 2, 2), subMtx(V, 2, 2), dep+1);
#pragma omp taskwait
}
void Bpar(MtxHead X, MtxHead U, MtxHead V, int dep) {
if (X.n <= MAXC) {
Bloop(X, U, V);
return ;
}
Bpar(subMtx(X, 2, 1), subMtx(U, 2, 2), subMtx(V, 1, 1), dep+1);
//
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 1, 1), subMtx(U, 1, 2), subMtx(X, 2, 1), dep+1);
#pragma omp task if (dep < MAXD)
Cpar(subMtx(X, 2, 2), subMtx(X, 2, 1), subMtx(V, 1, 2), dep+1);
#pragma omp taskwait
//
#pragma omp task if (dep < MAXD)
Bpar(subMtx(X, 1, 1), subMtx(U, 1, 1), subMtx(V, 1, 1), dep+1);
#pragma omp task if (dep < MAXD)
Bpar(subMtx(X, 2, 2), subMtx(U, 2, 2), subMtx(V, 2, 2), dep+1);
#pragma omp taskwait
//
Cpar(subMtx(X, 1, 2), subMtx(U, 1, 2), subMtx(X, 2, 2), dep+1);
Cpar(subMtx(X, 1, 2), subMtx(X, 1, 1), subMtx(V, 1, 2), dep+1);
Bpar(subMtx(X, 1, 2), subMtx(U, 1, 1), subMtx(V, 2, 2), dep+1);
}
void Apar(MtxHead X, int dep) {
if (X.n <= MAXC) {
Aloop(X);
return ;
}
#pragma omp task if (dep < MAXD)
Apar(subMtx(X, 1, 1), dep+1);
#pragma omp task if (dep < MAXD)
Apar(subMtx(X, 2, 2), dep+1);
#pragma omp taskwait
Bpar(subMtx(X, 1, 2), subMtx(X, 1, 1), subMtx(X, 2, 2), dep+1);
}
void Psmall(int N) {
for (int i = 0; i < N; i++)
dp[i*N+i] = 0;
#pragma omp parallel
{
for (int i = N-1; i > 0; i--) {
int comN = N-i;
#pragma omp for
for (int j = 0; j < i; j++) {
int l = j, r = comN+j;
long long local = INF;
long long *dp1 = dp+l*N, *dp2 = dp+r*N;
long long comSZ = SZ[l] * SZ[r+1];
for (int k = l; k < r; k++) {
long long t = dp1[k] + dp2[(k+1)] + comSZ * SZ[k+1];
if (t < local)
local = t;
}
dp[l*N+r] = dp[r*N+l] = local;
}
}
}
printf("%lld\n", dp[0*N+N-1]);
}
int main() {
while (scanf("%d", &N) == 1) {
for (int i = 0; i <= N; i++)
scanf("%lld", &SZ[i]);
if (N <= 2048) {
Psmall(N);
continue;
}
int ON = N;
while ((N&(-N)) != N)
N++;
for (int i = 0; i < N; i++) {
for (int j = i+1; j < N; j++)
dp[i*N+j] = INF;
dp[i*N+i] = 0;
}
MtxHead X;
X.n = N, X.bx = X.by = 0, X.A = dp;
#pragma omp parallel
{
#pragma omp single
Apar(X, 0);
}
printf("%lld\n", dp[0*N+ON-1]);
}
return 0;
}