淺談多重背包問題 (Multiple Knapsack Problem) 優化那些事

contents

  1. 1. 輸入格式
  2. 2. 輸出格式
  3. 3. 範例輸入 1
  4. 4. 範例輸出 1
  5. 5. Solution
    1. 5.1. 優化初夜
    2. 5.2. 優化二夜
    3. 5.3. 優化三夜
    4. 5.4. 優化四夜

收錄於 批改娘 20008. Fast Multiple Knapsack Problem

之所以出了這一題,源自於實驗室另一名同學跑實驗太久,進而撰寫優化程序。聽說原本跑了十分鐘的實驗,改善後提升了到一分鐘跑完。

輸入格式

每組測資第一行包含兩個正整數,分別代表背包大小 $M$ ($\leq 10^6$) 和物品個數 $N$ ($\leq 1000$),下一行開始每行包含兩個正整數,分別代表物品價值 $P_i$ ($\leq 10^3$)、物品重量 $W_i$ ($ \leq 10^5$) 以及物品最多可以挑 $C_i$ 個 ($\le 100$)。

輸出格式

對於每組測資,請輸出最大收益。

範例輸入 1

1
2
3
4
5
6
7
8
50 7
66 31 1
232 10 4
49 20 1
54 19 1
426 4 3
589 3 10
10 6 4

範例輸出 1

1
7178

Solution

不管是 0/1 背包或者多重背包,兩者都屬於 bounded knapsack problem 問題。即便如此,優化上仍有些許的不同,請讓我緩緩道來。

在此之前,您必須先理解上一篇《淺談背包問題 (0/1 Knapsack Problem) 優化那些事》的部分,不然會造成閱讀上的困難。

多重背包有一個二進制優化,也就是當物品限制最多拿 $C$ 個時,我們可以利用二進制組合的方式,轉換到 0/1 背包問題,因此我們會得到新的 $N \log C$ 個物品跑一次 0/1 背包,因此複雜度落在 $O(N \log C \; W)$

然而,從公式定義上,在好幾年前的論文中,使用斜率優化降到 $O(N \; W)$,推倒過程如下,

$j = k \; w_i + r, \; 0 \le r \le w_i - 1$ $$\begin{align*} dp[i][j] &= \max\left\{dp[i-1][j], dp[i-1][j-w_i]+p_i, \cdots, dp[i-1][j-c_i \; w_i] + c_i \; p_i\right\} \\ &= \max\left\{dp[i-1][k \; w_i + r], dp[i-1][(k-1) \; w_i + r] + p_i, \cdots , dp[i-1][(k-c_i) \; w_i + r] + c_i p_i\right\} \\ &= \max\left\{dp[i-1][k \; w_i + r] - k \; p_i, dp[i-1][(k-1) \; w_i + r] + (k-1) \; p_i, \cdots , dp[i-1][(k-c_i) \; w_i + r] - (k-c_i) p_i\right\} + k \; p_i\\ \end{align*}$$

隨著式子的轉移,我們發現每一個取值將不依賴相對位置,只跟自身的位置有關,那麼可以使用單調堆 (monotone queue) 運行 $O(1)$ 的 sliding windows 查找極值。最後,將相同餘數分堆處理,單調堆中最多存在 $O(c)$ 個元素。

優化初夜

如果只使用二進制優化,套上我們的 0/1 優化方案,將有大幅度地提升。

加入 0/1 背包的優化策略,再套上最簡單的斜率優化算法,得到下面的程式。這裡很懶惰地,由於單調堆最多入隊 $W$ 次,不外乎地直接只用大小為 $W$ 的方式實作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>
using namespace std;
namespace {
static const int MAXW = 1000005;
static const int MAXN = 1005;
struct BB {
int w, v, c;
BB(int w = 0, int v = 0, int c = 0):
w(w), v(v), c(c) {}
bool operator<(const BB &x) const {
return w * c < x.w * x.c;
}
};
static int run(BB A[], int dp[], int W, int N) {
static int MQ[MAXW][2];
for (int i = 0, sum = 0; i < N; i++) {
int w = A[i].w, v = A[i].v, c = A[i].c;
sum = min(sum + w*c, W);
for (int j = 0; j < w; j++) {
int l = 0, r = 0;
MQ[l][0] = 0, MQ[l][1] = dp[j];
for (int k = 1; k*w+j <= sum; k++) {
if (k - MQ[l][0] > c)
l++;
int dpv = dp[k*w+j] - k*v;
while (l <= r && MQ[r][1] <= dpv)
r--;
r++;
MQ[r][0] = k, MQ[r][1] = dpv;
dp[k*w+j] = max(dp[k*w+j], MQ[l][1] + k*v);
}
}
}
}
static int knapsack(int C[][3], int N, int W) {
vector<BB> A;
for (int i = 0; i < N; i++) {
int w = C[i][0], v = C[i][1], c = C[i][2];
A.push_back(BB(w, v, c));
}
assert(N < MAXN);
static int dp1[MAXW+1], dp2[MAXW+1];
BB Ar[2][MAXN];
int ArN[2] = {};
memset(dp1, 0, sizeof(dp1[0])*(W+1));
memset(dp2, 0, sizeof(dp2[0])*(W+1));
sort(A.begin(), A.end());
int sum[2] = {};
for (int i = 0; i < N; i++) {
int ch = sum[1] < sum[0];
Ar[ch][ArN[ch]] = A[i];
ArN[ch]++;
sum[ch] = min(sum[ch] + A[i].w*A[i].c, W);
}
run(Ar[0], dp1, W, ArN[0]);
run(Ar[1], dp2, W, ArN[1]);
int ret = 0;
for (int i = 0, j = W, mx = 0; i <= W; i++, j--) {
mx = max(mx, dp2[i]);
ret = max(ret, dp1[j] + mx);
}
return ret;
}
}
int main() {
int W, N;
assert(scanf("%d %d", &W, &N) == 2);
int C[MAXN][3];
for (int i = 0; i < N; i++)
assert(scanf("%d %d %d", &C[i][1], &C[i][0], &C[i][2]) == 3);
printf("%d\n", knapsack(C, N, W));
return 0;
}

不幸地,相較於一般的斜率優化寫法,並沒有太大的改善。

優化二夜

運行 sliding windows 操作時,前 $c$ 次,是不會進行 pop_front() 操作的,因此把迴圈分兩堆處理,增加 branch predict。以及在乘數運算上,使用強度減少 (strength reduction) 的技術,將乘法換成加法。

只能些許地改善 5% 的效能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
static int run(BB A[], int dp[], int W, int N) {
static int MQ[MAXW][2];
for (int i = 0, sum = 0; i < N; i++) {
int w = A[i].w, v = A[i].v, c = A[i].c;
sum = min(sum + w*c, W);
for (int j = 0; j < w; j++) {
int l = 0, r = 0;
MQ[l][0] = 0, MQ[l][1] = dp[j];
for (int k = 1, tw = w+j, tv = v; tw <= sum && k <= c; k++, tw += w, tv += v) {
int dpv = dp[tw] - tv;
while (l <= r && MQ[r][1] <= dpv)
r--;
r++;
MQ[r][0] = k, MQ[r][1] = dpv;
dp[tw] = max(dp[tw], MQ[l][1] + tv);
}
for (int k = c+1, tw = (c+1)*w+j, tv = (c+1)*v; tw <= sum; k++, tw += w, tv += v) {
if (k - MQ[l][0] > c)
l++;
int dpv = dp[tw] - tv;
while (l <= r && MQ[r][1] <= dpv)
r--;
r++;
MQ[r][0] = k, MQ[r][1] = dpv;
dp[tw] = max(dp[tw], MQ[l][1] + tv);
}
}
}
}

優化三夜

後來發現,sliding windows 滑動時,我們常常看前看後,因此常常會發生 cache miss,因為他要跳躍一大段記憶體空間查找數值,所以可以考慮花點操作將極值放在 stack 上,視為一種 software cache 來加速,來減少 cache miss 的懲罰。

接著,在迴圈邊界比較時,我們可以算得更精準些,回到一般的 i++ 的 format pattern,讓編譯器幫我們做常見的迴圈優化。

改善了 10% 效能

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
static int run(BB A[], int dp[], int W, int N) {
static int MQ[MAXW][2];
for (int i = 0, sum = 0; i < N; i++) {
int w = A[i].w, v = A[i].v, c = A[i].c;
sum = min(sum + w*c, W);
for (int j = 0; j < w; j++) {
int l = 0, r = 0;
MQ[l][0] = 0, MQ[l][1] = dp[j];
int cache_max = MQ[l][1], cache_idx = MQ[l][0];
int k_bound;
k_bound = min((sum-j)/w, c);
for (int k = 1, tw = w+j, tv = v; k <= k_bound; k++, tw += w, tv += v) {
// tw = k*w+j, tv = k*v;
int dpv = dp[tw] - tv;
while (l <= r && MQ[r][1] <= dpv)
r--;
r++;
MQ[r][0] = k, MQ[r][1] = dpv;
if (r == l) cache_max = dpv, cache_idx = k;
dp[tw] = max(dp[tw], cache_max + tv);
}
k_bound = (sum-j)/w;
for (int k = c+1, tw = (c+1)*w+j, tv = (c+1)*v; k <= k_bound; k++, tw += w, tv += v) {
int dpv = dp[tw] - tv;
while (l <= r && MQ[r][1] <= dpv)
r--;
r++;
MQ[r][0] = k, MQ[r][1] = dpv;
if (r == l)
cache_max = dpv, cache_idx = k;
else if (k - cache_idx > c)
l++, cache_idx = MQ[l][0], cache_max = MQ[l][1];
dp[tw] = max(dp[tw], cache_max + tv);
}
}
}
}

優化四夜

儘管上面使用的 software cache 的方式減少 cache miss,但 DP table 仍與數據結構的記憶體位置相當遙遠,為了使他們貼近,應使用環狀隊列的實作,空間從 $O(W)$ 將到 $O(N)$,實作時,將大小限制在 $2^k$,方便運行時使用 AND 運算取代耗時的模數運算。

由於限制個數分佈上,很容易造成貪心算法有解,因此先跑一次貪心,如果貪心沒辦法達到剛好大小,那麼再跑 DP 找解。DP 找解時,可以將物品嘗試進行二進制轉換,將等價物品合併,來觸發計算邊界的優化。完成的程序如下:

最終加速,改善 10%,期待你我的分享增進。根據鴿籠原理,cp 直種類不多時,可以高達 10 倍以上的加速。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#include <bits/stdc++.h>
using namespace std;
namespace {
static const int MAXW = 1000005;
static const int MAXN = 1005;
static const int MAXC = 1<<12;
struct BB {
int w, v, c;
BB(int w = 0, int v = 0, int c = 0):
w(w), v(v), c(c) {}
bool operator<(const BB &x) const {
return w * c < x.w * x.c;
}
};
static bool cmpByWeight(BB x, BB y) {
return x.w < y.w;
}
static int run(BB A[], int dp[], int W, int N) {
static int MQ[MAXC][2];
for (int i = 0, sum = 0; i < N; i++) {
int w = A[i].w, v = A[i].v, c = A[i].c;
assert(c < MAXC);
sum = min(sum + w*c, W);
if (c != 1) {
for (int j = 0; j < w; j++) {
int l = 0, r = 0;
MQ[r][0] = 0, MQ[r][1] = dp[j];
int cache_max = MQ[r][1], cache_idx = MQ[r][0];
int k_bound;
r = (r+1)&(MAXC-1);
k_bound = min((sum-j)/w, c);
for (int k = 1, tw = w+j, tv = v; k <= k_bound; k++, tw += w, tv += v) {
// tw = k*w+j, tv = k*v;
int dpv = dp[tw] - tv;
while (l != r && MQ[(r-1+MAXC)&(MAXC-1)][1] <= dpv)
r = (r-1+MAXC)&(MAXC-1);
MQ[r][0] = k, MQ[r][1] = dpv;
if (l == r) cache_max = dpv, cache_idx = k;
r = (r+1)&(MAXC-1);
dp[tw] = max(dp[tw], cache_max + tv);
}
k_bound = (sum-j)/w;
for (int k = c+1, tw = (c+1)*w+j, tv = (c+1)*v; k <= k_bound; k++, tw += w, tv += v) {
int dpv = dp[tw] - tv;
while (l != r && MQ[(r-1+MAXC)&(MAXC-1)][1] <= dpv)
r--;
MQ[r][0] = k, MQ[r][1] = dpv;
if (l == r) cache_max = dpv, cache_idx = k;
else if (k - cache_idx > c)
l = (l+1)&(MAXC-1), cache_idx = MQ[l][0], cache_max = MQ[l][1];
r = (r+1)&(MAXC-1);
dp[tw] = max(dp[tw], cache_max + tv);
}
}
} else if (c == 1) {
for (int j = sum; j >= w; j--)
dp[j] = max(dp[j], dp[j-w]+v);
}
}
}
static int greedy(int C[][3], int N, int W) {
struct GB {
int w, v, c;
GB(int w = 0, int v = 0, int c = 0):
w(w), v(v), c(c) {}
bool operator<(const GB &x) const {
if (v * x.w != x.v * w)
return v * x.w > x.v * w;
return c > x.c;
}
};
vector<GB> A;
for (int i = 0; i < N; i++) {
int w = C[i][0], v = C[i][1], c = C[i][2];
A.push_back(GB(w, v, c));
}
sort(A.begin(), A.end());
int ret = 0;
for (int i = 0; i < N; i++) {
int t = min(A[i].c, W/A[i].w);
if (t == 0)
return -1;
W -= t*A[i].w;
ret += t*A[i].v;
if (W == 0)
return ret;
}
return ret;
}
static int knapsack(int C[][3], int N, int W) {
// filter
{
int filter = greedy(C, N, W);
if (filter != -1)
return filter;
}
vector<BB> A;
for (int i = 0; i < N; i++) {
int w = C[i][0], v = C[i][1], c = C[i][2];
A.push_back(BB(w, v, c));
}
// reduce
{
sort(A.begin(), A.end(), cmpByWeight);
map<pair<int, int>, int> R;
for (int i = 0; i < N; i++)
R[make_pair(A[i].w, A[i].v)] = i;
for (int i = 0; i < N; i++) {
int c = A[i].c;
map<pair<int, int>, int>::iterator it;
for (int k = 1; k <= c; k <<= 1) {
int w = A[i].w * k, v = A[i].v * k;
it = R.find(make_pair(w, v));
if (it != R.end() && i != it->second) {
int j = it->second;
A[j].c ++;
A[i].c -= k;
}
c -= k;
}
if (c > 0) {
int w = A[i].w * c, v = A[i].v * c;
it = R.find(make_pair(w, v));
if (it != R.end() && i != it->second) {
int j = it->second;
A[j].c ++;
A[i].c -= c;
}
}
}
}
static int dp1[MAXW+1], dp2[MAXW+1];
BB Ar[2][MAXN];
int ArN[2] = {};
memset(dp1, 0, sizeof(dp1[0])*(W+1));
memset(dp2, 0, sizeof(dp2[0])*(W+1));
sort(A.begin(), A.end());
int sum[2] = {};
for (int i = 0; i < N; i++) {
int ch = sum[1] < sum[0];
Ar[ch][ArN[ch]] = A[i];
ArN[ch]++;
sum[ch] = min(sum[ch] + A[i].w*A[i].c, W);
}
run(Ar[0], dp1, W, ArN[0]);
run(Ar[1], dp2, W, ArN[1]);
int ret = 0;
for (int i = 0, j = W, mx = 0; i <= W; i++, j--) {
mx = max(mx, dp2[i]);
ret = max(ret, dp1[j] + mx);
}
return ret;
}
}
int main() {
int W, N;
assert(scanf("%d %d", &W, &N) == 2);
int C[MAXN][3];
for (int i = 0; i < N; i++)
assert(scanf("%d %d %d", &C[i][1], &C[i][0], &C[i][2]) == 3);
printf("%d\n", knapsack(C, N, W));
return 0;
}