批改娘 10093. Fast Matrix Chain Multiplication (OpenMP)

contents

  1. 1. 題目描述
    1. 1.1. sample.c
  2. 2. 輸入格式
  3. 3. 輸出格式
  4. 4. 範例輸入
  5. 5. 範例輸出
  6. 6. 備註
  7. 7. Solution

題目描述

計算矩陣鏈乘積 $A_{r_1, c_1} B_{r_2, c_2} \cdots$ 的值。

sample.c

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// generate matrix, row-major
uint32_t* rand_gen(uint32_t seed, int R, int C) {
uint32_t *m = (uint32_t *) malloc(sizeof(uint32_t) * R*C);
uint32_t x = 2, n = R*C;
for (int i = 0; i < R; i++) {
for (int j = 0; j < C; j++) {
x = (x * x + seed + i + j)%n;
m[i*C + j] = x;
}
}
return m;
}
uint32_t hash(uint32_t x) {
return (x * 2654435761LU);
}
// output
uint32_t signature(uint32_t *A, int r, int c) {
uint32_t h = 0;
for (int i = 0; i < r; i++) {
for (int j = 0; j < c; j++)
h = hash(h + A[i*c + j]);
}
return h;
}

輸入格式

有多組測資,每組第一行會有一個整數 $N$ 表示矩陣鏈上有 $N$ 個矩陣,第二行上會有 $N+1$ 個整數 $Z_i$,表示矩陣鏈的每一個行列大小,例如當 $N = 3$ 時,輸入 10 30 5 60 表示矩陣 $A_{10, 30} B_{30, 5} C_{5, 60}$ 相乘。第三行會有 $N$ 個整數,第 $i$ 個整數 $S_i$ 為第 $i$ 個矩陣生成種子。

  • $1 \le N \le 100$
  • $1 \le Z_i \le 1000$
  • $0 \le S_i \le 32767$

輸出格式

對於每組測資輸出一行,將最後的矩陣結果輸出雜湊值。

範例輸入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2
2 2 2
2 5
3
10 30 5 60
0 0 0
3
1 5 20 1
0 0 0
3
5 10 20 35
0 0 0
6
30 35 15 5 10 20 25
0 0 0 0 0 0

範例輸出

1
2
3
4
5
573770929
1762797124
1738984832
354147713
3544048495

備註

輸出請用 printf("%u", answer);,計算 Dynamic Programming 時,請使用 64-bit 型態紀錄,因為最慘情況下會超過 32-bit 所能容納的範圍。

Solution

充分地運用當初在演算法學到的,計算矩陣鍊乘積的最少乘法數,接著再針對優化後的乘法順序進行平行。平行可以從單純矩陣乘法,又或者針對可以同時進行矩陣乘法操作開始。甚至可以套用編譯器學到的最少暫存器算法,想辦法從少量的空間換取好的快取效果。

下述程式只針對矩陣乘法計算平行,而非兩個乘法同時進行,其一原因在於很難保證 load balance。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <assert.h>
#define MAXN 128
#define LOOP_UNROLL 8
#define INF (1LL<<60)
int N, SZ[MAXN], SEED[MAXN];
long long dp[MAXN][MAXN] = {};
int argdp[MAXN][MAXN];
uint32_t* rand_gen(uint32_t c, int R, int C) {
uint32_t *m = (uint32_t *) malloc(sizeof(uint32_t) * R*C);
assert(m != NULL);
uint32_t x = 2, n = R*C;
for (int i = 0; i < R; i++) {
for (int j = 0; j < C; j++) {
x = (x * x + c + i + j)%n;
m[i*C + j] = x;
}
}
return m;
}
uint32_t* multiplyAndDel(uint32_t *A, uint32_t *B, int r, int rc, int c) {
uint32_t *C = (uint32_t *) malloc(sizeof(uint32_t) * r * c);
uint32_t *tB = (uint32_t *) malloc(sizeof(uint32_t) * rc * c);
assert(C != NULL);
assert(tB != NULL);
for (int i = 0; i < rc; i++) {
for (int j = 0; j < c; j++)
tB[j*rc + i] = B[i*c + j];
}
free(B);
#pragma omp parallel for
for (int i = r-1; i >= 0; i--) {
for (int j = c-1; j >= 0; j--) {
register uint32_t sum = 0;
uint32_t *a = &A[i*rc], *b = &tB[j*rc];
int k = rc;
switch (k % LOOP_UNROLL) {
case 0: do { sum += *a * *b, a++, b++;
case 7: sum += *a * *b, a++, b++;
case 6: sum += *a * *b, a++, b++;
case 5: sum += *a * *b, a++, b++;
case 4: sum += *a * *b, a++, b++;
case 3: sum += *a * *b, a++, b++;
case 2: sum += *a * *b, a++, b++;
case 1: sum += *a * *b, a++, b++;
} while ((k -= LOOP_UNROLL) > 0);
}
C[i*c + j] = sum;
}
}
free(A), free(tB);
return C;
}
uint32_t hash(uint32_t x) {
return (x * 2654435761LU);
}
uint32_t signatureAndDel(uint32_t *A, int r, int c) {
uint32_t h = 0;
for (int i = 0; i < r; i++) {
for (int j = 0; j < c; j++)
h = hash(h + A[i*c + j]);
}
free(A);
return h;
}
uint32_t* dfs(int l, int r, int *mR, int *mC) {
if (l == r) {
*mR = SZ[l], *mC = SZ[l+1];
return rand_gen(SEED[l], *mR, *mC);
}
int split = argdp[l][r];
int r1, r2, c1, c2;
uint32_t *A, *B;
A = dfs(l, split, &r1, &c1);
B = dfs(split+1, r, &r2, &c2);
assert(c1 == r2);
*mR = r1, *mC = c2;
return multiplyAndDel(A, B, r1, c1, c2);
}
int main() {
while (scanf("%d", &N) == 1) {
for (int i = 0; i <= N; i++)
scanf("%d", &SZ[i]);
for (int i = 0; i < N; i++)
scanf("%d", &SEED[i]);
memset(dp, 0, sizeof(dp));
for (int i = 1; i <= N; i++) {
for (int j = 0; j+i < N; j++) {
int l = j, r = j+i;
dp[l][r] = INF;
for (int k = l; k < r; k++) {
long long t = dp[l][k] + dp[k+1][r] + (long long) SZ[l] * SZ[k+1] * SZ[r+1];
if (t < dp[l][r])
dp[l][r] = t, argdp[l][r] = k;
}
}
}
int retR, retC;
uint32_t *retM;
uint32_t hval;
retM = dfs(0, N-1, &retR, &retC);
hval = signatureAndDel(retM, retR, retC);
printf("%u\n", hval);
long long test = 0;
for (int i = 1; i < N; i++) {
test += SZ[0] * SZ[i] * SZ[i+1];
}
fprintf(stderr, "best %lld, origin %lld, %lf\n", dp[0][N-1], test, dp[0][N-1]*1.f / test);
}
return 0;
}