ADA 2020 Fall P3. ADA Party

contents

  1. 1. Problem
  2. 2. Sample Input
  3. 3. Sample Output
  4. 4. Solution

Algorithm Design and Analysis (NTU CSIE, Fall 2020)

Problem

$N$ 個堆,每個堆有 $a_i$ 個糖果,現在邀請 $K$ 個人,現在問有多少種挑選區間的方法,滿足扣掉最大堆和最小堆後,區間內的糖果總數可以被 $K$ 整除。

Sample Input

1
2
10 2
6 9 3 4 5 6 1 7 8 3

Sample Output

1
25

Solution

由於沒辦法參與課程,就測測自己產的測試資料,正確性有待確認。

分治處理可行解的組合,每一次剖半計算,統計跨區間的答案個數。

討論項目分別為

  • 最大值、最小值嚴格都在左側
  • 最大值、最小值嚴格都在右側
  • 最大值在左側、最小值在右側
  • 最大值在右側、最小值在左側

最後兩項會有交集部分,則扣除 在左側的最大最小值接等於右側的最大最小值。對於每一項回答,搭配單調運行的滑動窗口解決。

時間複雜度 $\mathcal{O}(n \log n)$、空間複雜度 $\mathcal{O}(n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <bits/stdc++.h>
using namespace std;
// Algorithm Design and Analysis (NTU CSIE, Fall 2020)
// Problem 3. ADA Party
const int MAXN = 500005;
const int32_t MIN = LONG_MIN;
const int32_t MAX = LONG_MAX;
int32_t a[MAXN];
int32_t lsum[MAXN], rsum[MAXN];
int32_t lmin[MAXN], lmax[MAXN];
int32_t rmin[MAXN], rmax[MAXN];
int cases = 0;
int mark[MAXN];
int counter[MAXN];
int n, k;
void inc(int32_t val) {
val = (val%k+k)%k;
if (mark[val] != cases)
mark[val] = cases, counter[val] = 0;
counter[val]++;
}
void dec(int32_t val) {
val = (val%k+k)%k;
if (mark[val] != cases)
mark[val] = cases, counter[val] = 0;
counter[val]--;
}
int get(int32_t val) {
val = (val%k+k)%k;
if (mark[val] != cases)
mark[val] = cases, counter[val] = 0;
return counter[val];
}
int64_t common(int l, int m, int r) {
int64_t ret = 0;
cases++; // max and min is same on both end
for (int i = m, j = m+1, jl = m+1; i >= l; i--) {
while (j <= r && (rmax[j] <= lmax[i] && rmin[j] >= lmin[i])) {
inc(rsum[j]-rmin[j]);
j++;
}
while (jl < j && (rmin[jl] > lmin[i] || rmax[jl] < lmax[i])) {
dec(rsum[jl]-rmin[jl]);
jl++;
}
if (j > m+1 && lmin[i] == rmin[j-1] && lmax[i] == rmax[j-1])
ret += get(k-(lsum[i]-lmax[i]));
}
return ret;
}
int64_t divide(int l, int r) {
if (l >= r)
return 0;
int m = (l+r)/2;
int32_t sum = 0;
int32_t mn = MAX, mx = MIN;
for (int i = m; i >= l; i--) {
sum += a[i], mn = min(mn, a[i]), mx = max(mx, a[i]);
if (sum >= k) sum %= k;
lsum[i] = sum, lmin[i] = mn, lmax[i] = mx;
}
sum = 0, mn = MAX, mx = MIN;
for (int i = m+1; i <= r; i++) {
sum += a[i], mn = min(mn, a[i]), mx = max(mx, a[i]);
if (sum >= k) sum %= k;
rsum[i] = sum, rmin[i] = mn, rmax[i] = mx;
}
int64_t c1 = 0, c2 = 0, c3 = 0, c4 = 0;
cases++; // min max on the left
for (int i = m, j = m+1; i >= l; i--) {
while (j <= r && lmin[i] < a[j] && a[j] < lmax[i]) {
inc(rsum[j]);
j++;
}
if (i < m)
c1 += get(k-(lsum[i]-lmin[i]-lmax[i]));
}
cases++; // min max on the right
for (int i = m+1, j = m; i <= r; i++) {
while (j >= l && rmin[i] < a[j] && a[j] < rmax[i]) {
inc(lsum[j]);
j--;
}
if (i > m+1)
c2 += get(k-(rsum[i]-rmin[i]-rmax[i]));
}
cases++; // min on the left, max on the right
for (int i = m, j = m+1, jl = m+1; i >= l; i--) {
while (j <= r && rmin[j] >= lmin[i]) {
inc(rsum[j]-rmax[j]);
j++;
}
while (jl < j && rmax[jl] < lmax[i]) {
dec(rsum[jl]-rmax[jl]);
jl++;
}
c3 += get(k-(lsum[i]-lmin[i]));
}
cases++; // min on the right, max on the left
for (int i = m+1, j = m, jl = m; i <= r; i++) {
while (j >= l && lmin[j] >= rmin[i]) {
inc(lsum[j]-lmax[j]);
j--;
}
while (jl > j && lmax[jl] < rmax[i]) {
dec(lsum[jl]-lmax[jl]);
jl--;
}
c4 += get(k-(rsum[i]-rmin[i]));
}
int64_t local = c1 + c2 + c3 + c4 - common(l, m, r);
return local + divide(l, m) + divide(m+1, r);
}
int main() {
while (scanf("%d %d", &n, &k) == 2) {
for (int i = 0; i < n; i++)
scanf("%d", &a[i]);
memset(counter, 0, sizeof(counter[0])*k);
int64_t ret = divide(0, n-1);
printf("%lld\n", ret);
}
return 0;
}