HDU 5307 - He is Flying

contents

  1. 1. Problem
  2. 2. Sample Input
  3. 3. Sample Output
  4. 4. Solution
    1. 4.1. FFT
    2. 4.2. NTT
    3. 4.3. NTT CRT

Problem

題目連結,加速以下的程序計算,下方程式需要 $O(N^3)$,若用前綴維護總和也需要 $O(N^2)$

1
2
3
4
5
6
for l = 0 to n-1
for r = l+1 to n-1
sum = 0
for k = l to r
sum += A[k]
ret[sum] += r-l+1

最後輸出所有 ret[sum] 的對應結果。

Sample Input

1
2
3
4
5
2
3
1 2 3
4
0 1 2 3

Sample Output

1
2
3
4
5
6
7
8
9
10
11
12
13
14
0
1
1
3
0
2
3
1
3
1
6
0
2
7

Solution

由於是一個很樸素的計算,為了加速運算不套點資料結構和算法是不行的。看到對於每一個結果都要輸出,因此可以想到快速傅立葉 FFT 的旋積計算,接下來就要思考如何構造多項式 (向量)。

假設前 $i$ 個數字的前綴和$s_i$,為了要計數反應在係數,而索引值要反應在項數,因此得到兩個 $x$ 多項式相乘,若要統計區間 $[l, r]$ 的總和,則反應在$(i - j) x^{s_i} \times x^{- s_j} = (i-j) x^{s_i - s_j}$。但這樣的計算無法一次完成,因此要拆成兩次計算,分別得到$i x^{s_i - s_j}$$-j x^{s_i - s_j}$

明顯地前者構造$(\sum i x^s_i) \times (\sum x^{-s_j})$,後者構造$(\sum x^s_i) \times (\sum -j x^{-s_j})$,利用快速傅立葉 $O(n \log n)$ 計算多項式相乘,隨後相扣即可。

特別注意到總和 0 要特別判斷,因為構造法無法計算。此外這題非常講究精準度,可以利用 NTT/FNT 全部都在整數運算,又或者使用 FFT 在 double 形態下完成,特別小心 FFT 通常會利用角度疊加 (合角公式) 來加速運算,但不幸地這裡會遇到精準度誤差,必須採用 cos, sin 全建表。其他人容易遇到要用 long double 取代 double 計算是因為這種寫法的問題。

FFT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <vector>
#include <algorithm>
#include <complex>
using namespace std;
template<typename T> class TOOL_FFT {
public:
typedef unsigned int UINT32;
#define MAXN 262144
complex<T> p[2][MAXN];
int pre_n;
T PI;
TOOL_FFT() {
pre_n = 0;
PI = acos(-1);
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0; ; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void FFT(bool InverseTransform, vector<complex<T> >& In, vector<complex<T> >& Out) {
// simultaneous data copy and bit-reversal ordering into outputs
int NumSamples = In.size();
int NumBits = NumberOfBitsNeeded(NumSamples);
for (int i = 0; i < NumSamples; ++i) {
Out[FastReverseBits(i, NumBits)] = In[i];
}
// the FFT process
for (register int i = 1; i <= NumBits; i++) {
int BlockSize = 1<<i, BlockEnd = BlockSize>>1, BlockCnt = NumSamples/BlockSize;
for (register int j = 0; j < NumSamples; j += BlockSize) {
complex<T> *t = p[InverseTransform];
for (register int k = 0; k < BlockEnd; k++, t += BlockCnt) {
complex<T> a = (*t) * Out[k+j+BlockEnd];
Out[k+j+BlockEnd] = Out[k+j] - a;
Out[k+j] += a;
}
}
}
// normalize if inverse transform
if (InverseTransform) {
for (int i = 0; i < NumSamples; ++i) {
Out[i] /= NumSamples;
}
}
}
void prework(int n) {
if (pre_n == n)
return ;
pre_n = n;
p[0][0] = complex<T>(1, 0);
p[1][0] = complex<T>(1, 0);
for (register int i = 1; i < n; i++) {
p[0][i] = complex<T>(cos(2*i*PI / n ) , sin(2*i*PI / n ));
p[1][i] = complex<T>(cos(2*i*PI / n ) , -sin(2*i*PI / n ));
}
}
vector<T> convolution(complex<T> *a, complex<T> *b, int n) {
prework(n);
vector< complex<T> > s(a, a+n), d1(n), d2(n), y(n);
vector<T> ret(n);
FFT(false, s, d1);
s[0] = b[0];
for (int i = 1, j = n-1; i < n; ++i, --j)
s[i] = b[j];
FFT(false, s, d2);
for (int i = 0; i < n; ++i) {
y[i] = d1[i] * d2[i];
}
FFT(true, y, s);
for (int i = 0; i < n; ++i) {
ret[i] = s[i].real();
}
return ret;
}
};
TOOL_FFT<double> tool;
complex<double> a[MAXN], b[MAXN];
vector<double> c;
int A[MAXN], sum[MAXN];
long long ret[MAXN], pr[262144];
int main() {
pr[0] = 0;
for (int i = 1; i < MAXN; i++)
pr[i] = pr[i-1] + (long long)i*(i+1)/2;
int testcase, n, m, s;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &A[i]);
sum[0] = 0;
for (int i = 1; i <= n; i++)
sum[i] = sum[i-1] + A[i];
s = sum[n];
memset(ret, 0, sizeof(ret[0]) * (s+1));
for (m = 1; m <= (s<<1); m <<= 1);
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] += complex<double>(i, 0);
b[sum[i-1]] += complex<double>(1, 0);
}
c = tool.convolution(a, b, m);
for (int i = 1; i < m; i++)
ret[i] += round(c[i]);
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] += complex<double>(1, 0);
b[sum[i-1]] += complex<double>(i-1, 0);
}
c = tool.convolution(a, b, m);
for (int i = 1; i <= s; i++)
ret[i] -= round(c[i]);
for (int i = 1, z = 0; i <= n+1; i++) {
if (i == n+1 || A[i] != 0)
ret[0] += pr[z], z = 0;
else
z++;
}
for (int i = 0; i <= s; i++) {
printf("%lld\n", ret[i]);
}
}
return 0;
}

NTT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <vector>
#include <algorithm>
using namespace std;
typedef unsigned int UINT32;
typedef long long INT64;
class TOOL_NTT {
public:
#define MAXN 262144
const INT64 P = 50000000001507329LL; // prime m = kn+1
const INT64 G = 3;
INT64 wn[20];
INT64 s[MAXN], d1[MAXN], d2[MAXN], y[MAXN];
TOOL_NTT() {
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
INT64 mod_mul(INT64 a, INT64 b, INT64 mod) {
return (a*b - (long long)(a/(long double)mod*b+1e-3)*mod+mod)%mod;
// INT64 ret = 0;
// for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
// if (b&1) {
// ret += a;
// if (ret >= mod)
// ret -= mod;
// }
// }
// return ret;
}
INT64 mod_pow(INT64 n, INT64 e, INT64 m) {
INT64 x = 1;
for (n = n >= m ? n%m : n; e; e >>= 1) {
if (e&1)
x = mod_mul(x, n, m);
n = mod_mul(n, n, m);
}
return x;
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void NTT(int on, INT64 *In, INT64 *Out, int n) {
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; ++i)
Out[FastReverseBits(i, NumBits)] = In[i];
for(int h = 2, id = 1; h <= n; h <<= 1, id++) {
for(int j = 0; j < n; j += h) {
INT64 w = 1, u, t;
int block = h/2, blockEnd = j + h/2;
for(int k = j; k < blockEnd; k++) {
u = Out[k], t = mod_mul(w, Out[k+block], P);
Out[k] = u + t;
Out[k + block] = u - t + P;
if (Out[k] >= P) Out[k] -= P;
if (Out[k+block] >= P) Out[k+block] -= P;
w = mod_mul(w, wn[id], P);
}
}
}
if (on == 1) {
for (int i = 1; i < n/2; i++)
swap(Out[i], Out[n-i]);
INT64 invn = mod_pow(n, P-2, P);
for (int i = 0; i < n; i++)
Out[i] = mod_mul(Out[i], invn, P);
}
}
void convolution(INT64 *a, INT64 *b, int n, INT64 *c) {
NTT(0, a, d1, n);
s[0] = b[0];
for (int i = 1; i < n; ++i)
s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++)
s[i] = mod_mul(d1[i], d2[i], P);
NTT(1, s, c, n);
}
} tool;
INT64 a[MAXN], b[MAXN], c[MAXN];
int A[MAXN], sum[MAXN];
long long ret[MAXN], pr[262144];
int main() {
pr[0] = 0;
for (int i = 1; i < MAXN; i++)
pr[i] = pr[i-1] + (long long)i*(i+1)/2;
int testcase, n, m, s;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &A[i]);
sum[0] = 0;
for (int i = 1; i <= n; i++)
sum[i] = sum[i-1] + A[i];
s = sum[n];
memset(ret, 0, sizeof(ret[0]) * (s+1));
for (m = 1; m <= (s<<1); m <<= 1);
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] += i;
b[sum[i-1]] ++;
}
tool.convolution(a, b, m, c);
for (int i = 1; i < m; i++)
ret[i] += c[i];
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] ++;
b[sum[i-1]] += i-1;
}
tool.convolution(a, b, m, c);
for (int i = 1; i <= s; i++)
ret[i] -= c[i];
for (int i = 1, z = 0; i <= n+1; i++) {
if (i == n+1 || A[i] != 0)
ret[0] += pr[z], z = 0;
else
z++;
}
for (int i = 0; i <= s; i++) {
printf("%lld\n", ret[i]);
}
}
return 0;
}

NTT CRT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <vector>
#include <algorithm>
using namespace std;
typedef uint_fast32_t UINT32;
typedef long long INT64;
typedef uint_fast32_t INT32;
class TOOL_NTT {
public:
#define MAXN 262144
// INT64 P = 50000000001507329LL; // prime m = kn+1
// INT64 G = 3;
INT32 P = 3, G = 2;
INT32 wn[20];
INT32 s[MAXN], d1[MAXN], d2[MAXN], c1[MAXN], c2[MAXN];
const INT32 P1 = 998244353; // P1 = 2^23 * 7 * 17 + 1
const INT32 G1 = 3;
const INT32 P2 = 995622913; // P2 = 2^19 *3*3*211 + 1
const INT32 G2 = 5;
const INT64 M1 = 397550359381069386LL;
const INT64 M2 = 596324591238590904LL;
const INT64 MM = 993874950619660289LL; // MM = P1*P2
TOOL_NTT() {
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
void reset(INT32 p, INT32 g) {
P = p, G = g;
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
INT64 mod_mul(INT64 a, INT64 b, INT64 mod) {
return (a*b - (long long)(a/(long double)mod*b+1e-3)*mod+mod)%mod;
// INT64 ret = 0;
// for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
// if (b&1) {
// ret += a;
// if (ret >= mod)
// ret -= mod;
// }
// }
// return ret;
}
INT64 mod_pow(INT64 n, INT64 e, INT64 m) {
INT64 x = 1;
for (n = n >= m ? n%m : n; e; e >>= 1) {
if (e&1)
x = mod_mul(x, n, m);
n = mod_mul(n, n, m);
}
return x;
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void NTT(int on, INT32 *In, INT32 *Out, int n) {
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; ++i)
Out[FastReverseBits(i, NumBits)] = In[i];
for (int h = 2, id = 1; h <= n; h <<= 1, id++) {
for (int j = 0; j < n; j += h) {
INT32 w = 1, u, t;
int block = h/2, blockEnd = j + h/2;
for (int k = j; k < blockEnd; k++) {
u = Out[k], t = (INT64)w*Out[k+block]%P;
Out[k] = (u + t)%P;
Out[k+block] = (u - t + P)%P;
w = (INT64)w * wn[id]%P;
}
}
}
if (on == 1) {
for (int i = 1; i < n/2; i++)
swap(Out[i], Out[n-i]);
INT32 invn = mod_pow(n, P-2, P);
for (int i = 0; i < n; i++)
Out[i] = (INT64)Out[i]*invn%P;
}
}
INT64 crt(INT32 a, INT32 b) {
return (mod_mul(a, M1, MM) + mod_mul(b, M2, MM))%MM;
}
void convolution(INT32 *a, INT32 *b, int n, INT64 *c) {
reset(P1, G1);
NTT(0, a, d1, n);
s[0] = b[0]; for (int i = 1; i < n; ++i) s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++) s[i] = (INT64)d1[i] * d2[i]%P;
NTT(1, s, c1, n);
reset(P2, G2);
NTT(0, a, d1, n);
s[0] = b[0]; for (int i = 1; i < n; ++i) s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++) s[i] = (INT64)d1[i] * d2[i]%P;
NTT(1, s, c2, n);
for (int i = 0; i < n; i++)
c[i] = crt(c1[i], c2[i]);
}
} tool;
INT32 a[262144], b[262144];
INT64 c[262144];
int A[MAXN], sum[MAXN];
long long ret[MAXN], pr[262144];
int main() {
pr[0] = 0;
for (int i = 1; i < MAXN; i++)
pr[i] = pr[i-1] + (long long)i*(i+1)/2;
int testcase, n, m, s;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &A[i]);
sum[0] = 0;
for (int i = 1; i <= n; i++)
sum[i] = sum[i-1] + A[i];
s = sum[n];
memset(ret, 0, sizeof(ret[0]) * (s+1));
for (m = 1; m <= (s<<1); m <<= 1);
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] += i;
b[sum[i-1]] ++;
}
tool.convolution(a, b, m, c);
for (int i = 1; i < m; i++)
ret[i] += c[i];
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] ++;
b[sum[i-1]] += i-1;
}
tool.convolution(a, b, m, c);
for (int i = 1; i <= s; i++)
ret[i] -= c[i];
for (int i = 1, z = 0; i <= n+1; i++) {
if (i == n+1 || A[i] != 0)
ret[0] += pr[z], z = 0;
else
z++;
}
for (int i = 0; i <= s; i++) {
printf("%lld\n", ret[i]);
}
}
return 0;
}