b452. 傻傻地幫人數錢錢

Problem

給一個彩色影像上面有數個相同大小的硬幣,硬幣之間不會重疊,但會有部分碰觸,請問影像中有多少個硬幣。

Sample Input

1
2
3
4
5
6
7
8
9
10
10 9
159 138 133 150 129 124 153 132 127 154 133 128 151 132 126 151 132 126 152 133 129 152 133 129 143 125 125 136 118 118
147 126 121 163 142 137 146 125 120 86 65 60 60 41 35 57 38 32 68 49 45 111 92 88 143 125 123 138 120 118
145 125 118 152 132 125 63 44 37 53 34 27 54 35 28 55 36 29 56 39 32 50 32 28 105 87 83 132 114 110
150 130 123 106 86 79 53 34 27 57 38 31 53 34 27 64 45 38 55 38 31 56 39 32 63 45 41 131 113 109
151 132 125 87 68 61 52 33 26 56 37 30 58 41 33 49 32 24 57 40 33 53 36 29 43 28 23 131 116 111
136 117 111 103 84 78 56 37 31 59 40 34 54 37 30 57 40 33 47 30 23 53 35 31 54 39 34 141 126 121
142 124 120 140 122 118 52 34 30 54 36 32 51 33 29 53 35 31 53 38 33 46 31 28 82 67 64 136 121 118
145 127 127 125 107 107 116 98 98 54 36 36 53 35 33 54 36 34 46 30 30 65 49 49 138 122 122 145 129 129
135 119 120 137 121 122 137 121 122 133 117 118 103 87 88 95 79 80 111 95 96 138 122 123 132 118 118 132 118 118

Sample Output

1
1

Solution

若圓形彼此之間不相交,可以用二值化 + 灌水法 (flood fill) + 分團大小檢測。現在圓形有相連可能,出題者給我附前測代碼啊,咱們兩個比一下速度 … 總之步驟是

  1. 根據亮度二值化
  2. 搭配索貝爾運算 (sobel) 選定閥值後找到邊緣
  3. 接著窮舉圓半徑,使用霍夫轉換將邊緣點推到同一個圓心
  4. 掃描一個 $7 \times 7$ 的矩形,內部點個數要出現 56 個,同時要滿足 49 格至少出現 36 格或者其中一格出現大於 4 次。

關於霍夫轉換,其中$x_c, \; y_c$ 表示推向圓心的座標,而 $r$ 表示窮舉的圓半徑,$gx$ 表示 x 方向的 sobel 差分,同理 $gy$,而 $g = \sqrt{gx^2 + gy^2}$

$x_c = x - r \times (gx / g)$ $y_c = y - r \times (gy / g)$

當然閥值判定都還不到位,手動測試好幾個版本,OpenCV 的寫法值得去 trace 一下。代碼僅供玩玩,不代表在其他情況也能使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#include <stdio.h>
#include <math.h>
#include <string.h>
#include <stdlib.h>
#include <limits.h>
#include <vector>
#include <algorithm>
using namespace std;
class IMAGE {
public:
struct Pixel {
int r, g, b;
Pixel(int x = 0, int y = 0, int z = 0):
r(x), g(y), b(z) {}
void read() {
scanf("%d %d %d", &r, &g, &b);
}
Pixel operator-(const Pixel &x) const {
return Pixel(r-x.r, g-x.g, b-x.b);
}
Pixel operator+(const Pixel &x) const {
return Pixel(r+x.r, g+x.g, b+x.b);
}
Pixel operator*(const double x) const {
return Pixel(r*x, g*x, b*x);
}
Pixel operator/(const double x) const {
return Pixel(r/x, g/x, b/x);
}
bool operator==(const Pixel &x) const {
return r == x.r && g == x.g && b == x.b;
}
void print() {
printf("%3d", length());
}
int sum() {
return r + g + b;
}
int length() {
return abs(r) + abs(g) + abs(b);
}
int dist(Pixel x) {
return abs((r + g + b) - (x.r + x.g + x.b));
}
};
int W, H;
static const int MAXN = 256;
Pixel data[MAXN][MAXN];
void read() {
scanf("%d %d", &W, &H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].read();
}
void print() {
printf("%d %d\n", W, H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].print(), printf("%c", j == W-1 ? '\n' : ' ');
}
inline Pixel getPixel(int x, int y) {
if (x >= 0 && y >= 0 && x < H && y < W)
return data[x][y];
if (y < 0) return data[min(max(x, 0), H-1)][0];
if (y >= W) return data[min(max(x, 0), H-1)][W-1];
if (x < 0) return data[0][min(max(y, 0), W-1)];
if (x >= H) return data[H-1][min(max(y, 0), W-1)];
return Pixel(0, 0, 0);
}
void sobel(int i, int j, int &gx, int &gy) {
static int dx[] = {-1, -1, -1, 0, 0, 0, 1, 1, 1};
static int dy[] = {-1, 0, 1, -1, 0, 1, -1, 0, 1};
static int yw[] = {-1, 0, 1, -2, 0, 2, -1, 0, 1};
static int xw[] = {-1, -2, -1, 0, 0, 0, 1, 2, 1};
Pixel Dx(0, 0, 0), Dy(0, 0, 0);
for (int k = 0; k < 9; k++) {
if (xw[k])
Dx = Dx + getPixel(i+dx[k], j+dy[k]) * xw[k];
if (yw[k])
Dy = Dy + getPixel(i+dx[k], j+dy[k]) * yw[k];
}
gx = Dx.sum(), gy = Dy.sum();
}
int used[MAXN][MAXN];
int gx[MAXN][MAXN], gy[MAXN][MAXN];
double gxy[MAXN][MAXN];
int isValid(int x, int y) {
return x >= 0 && y >= 0 && x < H && y < W;
}
int hough_circle() {
int mxb = INT_MIN, mnb = INT_MAX;
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
int b = data[i][j].length();
mxb = max(mxb, b), mnb = min(mnb, b);
}
}
if (mxb - mnb < 300)
return 0;
int threshold = 250;
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
if (data[i][j].length() >= threshold)
data[i][j] = Pixel(1, 1, 1);
else
data[i][j] = Pixel(0, 0, 0);
}
}
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
sobel(i, j, gx[i][j], gy[i][j]), gxy[i][j] = hypot(gx[i][j], gy[i][j]);
}
}
int ret = 0;
for (double r = 4; r <= min(H, W)/2; r += 1) {
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
used[i][j] = 0;
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
if (gxy[i][j] > 4) {
int xc, yc;
xc = round(i - r * (gx[i][j] / gxy[i][j]));
yc = round(j - r * (gy[i][j] / gxy[i][j]));
if (isValid(xc, yc))
used[xc][yc]++;
}
}
}
int coins = 0;
const int C = 3;
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
if (used[i][j] < 3)
continue;
int sum = 0, mx = 0, has = 0;
for (int p = -C; p <= C; p++) {
for (int q = -C; q <= C; q++) {
if (isValid(i+p, j+q))
sum += used[i+p][j+q], mx = max(mx, used[i+p][j+q]), has += used[i+p][j+q] > 0;
}
}
if (sum > 56 && (has > 36 || mx > 4)) {
coins++;
int cx = i, cy = j;
for (int p = -r-1; p <= r+1; p++) {
for (int q = -r-1; q <= r+1; q++) {
if (isValid(cx+p, cy+q))
used[cx+p][cy+q] = 0;
}
}
}
}
}
if (coins < ret - 2)
break;
ret = max(ret, coins);
}
return ret;
}
} image;
int main() {
image.read();
printf("%d\n", image.hough_circle());
return 0;
}
Read More +

b448. 哈哈鏡

Problem

進行圖片變形,效果類似哈哈鏡的作用。

對圖片水平中線進行座標 y' = sqrt(y) 變換,將靠近中線的像素盡可能地拉往中間,形成延伸的變形效果。

Sample Input

1
2
1 1
1 2 3

Sample Output

1
2
1 1
1 2 3

Solution

計算公式

  • 中線之上 (H/2) - pow(i-H/2, 2)/ (H/2)
  • 中線之下 pow(i-H/2, 2)/ (H/2) + H/2

分別套用後,逆轉換到原圖上,進行最近鄰居查找,。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <bits/stdc++.h>
using namespace std;
class IMAGE {
public:
struct Pixel {
int r, g, b;
Pixel(int x = 0, int y = 0, int z = 0):
r(x), g(y), b(z) {}
void read() {
scanf("%d %d %d", &r, &g, &b);
}
Pixel operator-(const Pixel &x) const {
return Pixel(r-x.r, g-x.g, b-x.b);
}
Pixel operator+(const Pixel &x) const {
return Pixel(r+x.r, g+x.g, b+x.b);
}
Pixel operator*(const double x) const {
return Pixel(r*x, g*x, b*x);
}
Pixel operator/(const double x) const {
return Pixel(r/x, g/x, b/x);
}
bool operator==(const Pixel &x) const {
return r == x.r && g == x.g && b == x.b;
}
void print() {
printf("%d %d %d", r, g, b);
}
int sum() {
return r + g + b;
}
int length() {
return abs(r) + abs(g) + abs(b);
}
int dist(Pixel x) {
return abs((r + g + b) - (x.r + x.g + x.b));
}
};
int W, H;
static const int MAXN = 256;
Pixel data[MAXN][MAXN], tmp[MAXN][MAXN];
void read() {
scanf("%d %d", &W, &H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].read();
}
void print() {
printf("%d %d\n", W, H);
for (int i = 0; i < H; i++)
for (int j = 0; j < W; j++)
data[i][j].print(), printf("%c", j == W-1 ? '\n' : ' ');
}
int isValid(int x, int y) {
return x >= 0 && y >= 0 && x < H && y < W;
}
void distorting_mirror() {
int rW = W, rH = H;
double ch = H / 2.0;
for (int i = 0; i < rH; i++) {
double x, y;
if (i < ch)
x = ch - pow(i-ch, 2)/ch;
else
x = pow(i-ch, 2)/ch + ch;
for (int j = 0; j < rW; j++) {
y = j;
int lx, rx, ly, ry;
lx = floor(x), rx = ceil(x);
ly = floor(y), ry = ceil(y);
int px[] = {lx, lx, rx, rx};
int py[] = {ly, ry, ly, ry};
int c = -1;
double mndist = 1e+30;
for (int k = 0; k < 4; k++) {
if (!isValid(px[k], py[k]))
continue;
double d = (x-px[k])*(x-px[k])+(y-py[k])*(y-py[k]);
if (c == -1 || mndist > d)
c = k, mndist = d;
}
assert (c >= 0);
tmp[i][j] = data[px[c]][py[c]];
}
}
W = rW, H = rH;
for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
data[i][j] = tmp[i][j];
}
}
}
} image;
int main() {
image.read();
image.distorting_mirror();
image.print();
return 0;
}
Read More +

HDU 5307 - He is Flying

Problem

題目連結,加速以下的程序計算,下方程式需要 $O(N^3)$,若用前綴維護總和也需要 $O(N^2)$

1
2
3
4
5
6
for l = 0 to n-1
for r = l+1 to n-1
sum = 0
for k = l to r
sum += A[k]
ret[sum] += r-l+1

最後輸出所有 ret[sum] 的對應結果。

Sample Input

1
2
3
4
5
2
3
1 2 3
4
0 1 2 3

Sample Output

1
2
3
4
5
6
7
8
9
10
11
12
13
14
0
1
1
3
0
2
3
1
3
1
6
0
2
7

Solution

由於是一個很樸素的計算,為了加速運算不套點資料結構和算法是不行的。看到對於每一個結果都要輸出,因此可以想到快速傅立葉 FFT 的旋積計算,接下來就要思考如何構造多項式 (向量)。

假設前 $i$ 個數字的前綴和$s_i$,為了要計數反應在係數,而索引值要反應在項數,因此得到兩個 $x$ 多項式相乘,若要統計區間 $[l, r]$ 的總和,則反應在$(i - j) x^{s_i} \times x^{- s_j} = (i-j) x^{s_i - s_j}$。但這樣的計算無法一次完成,因此要拆成兩次計算,分別得到$i x^{s_i - s_j}$$-j x^{s_i - s_j}$

明顯地前者構造$(\sum i x^s_i) \times (\sum x^{-s_j})$,後者構造$(\sum x^s_i) \times (\sum -j x^{-s_j})$,利用快速傅立葉 $O(n \log n)$ 計算多項式相乘,隨後相扣即可。

特別注意到總和 0 要特別判斷,因為構造法無法計算。此外這題非常講究精準度,可以利用 NTT/FNT 全部都在整數運算,又或者使用 FFT 在 double 形態下完成,特別小心 FFT 通常會利用角度疊加 (合角公式) 來加速運算,但不幸地這裡會遇到精準度誤差,必須採用 cos, sin 全建表。其他人容易遇到要用 long double 取代 double 計算是因為這種寫法的問題。

FFT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <vector>
#include <algorithm>
#include <complex>
using namespace std;
template<typename T> class TOOL_FFT {
public:
typedef unsigned int UINT32;
#define MAXN 262144
complex<T> p[2][MAXN];
int pre_n;
T PI;
TOOL_FFT() {
pre_n = 0;
PI = acos(-1);
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0; ; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void FFT(bool InverseTransform, vector<complex<T> >& In, vector<complex<T> >& Out) {
// simultaneous data copy and bit-reversal ordering into outputs
int NumSamples = In.size();
int NumBits = NumberOfBitsNeeded(NumSamples);
for (int i = 0; i < NumSamples; ++i) {
Out[FastReverseBits(i, NumBits)] = In[i];
}
// the FFT process
for (register int i = 1; i <= NumBits; i++) {
int BlockSize = 1<<i, BlockEnd = BlockSize>>1, BlockCnt = NumSamples/BlockSize;
for (register int j = 0; j < NumSamples; j += BlockSize) {
complex<T> *t = p[InverseTransform];
for (register int k = 0; k < BlockEnd; k++, t += BlockCnt) {
complex<T> a = (*t) * Out[k+j+BlockEnd];
Out[k+j+BlockEnd] = Out[k+j] - a;
Out[k+j] += a;
}
}
}
// normalize if inverse transform
if (InverseTransform) {
for (int i = 0; i < NumSamples; ++i) {
Out[i] /= NumSamples;
}
}
}
void prework(int n) {
if (pre_n == n)
return ;
pre_n = n;
p[0][0] = complex<T>(1, 0);
p[1][0] = complex<T>(1, 0);
for (register int i = 1; i < n; i++) {
p[0][i] = complex<T>(cos(2*i*PI / n ) , sin(2*i*PI / n ));
p[1][i] = complex<T>(cos(2*i*PI / n ) , -sin(2*i*PI / n ));
}
}
vector<T> convolution(complex<T> *a, complex<T> *b, int n) {
prework(n);
vector< complex<T> > s(a, a+n), d1(n), d2(n), y(n);
vector<T> ret(n);
FFT(false, s, d1);
s[0] = b[0];
for (int i = 1, j = n-1; i < n; ++i, --j)
s[i] = b[j];
FFT(false, s, d2);
for (int i = 0; i < n; ++i) {
y[i] = d1[i] * d2[i];
}
FFT(true, y, s);
for (int i = 0; i < n; ++i) {
ret[i] = s[i].real();
}
return ret;
}
};
TOOL_FFT<double> tool;
complex<double> a[MAXN], b[MAXN];
vector<double> c;
int A[MAXN], sum[MAXN];
long long ret[MAXN], pr[262144];
int main() {
pr[0] = 0;
for (int i = 1; i < MAXN; i++)
pr[i] = pr[i-1] + (long long)i*(i+1)/2;
int testcase, n, m, s;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &A[i]);
sum[0] = 0;
for (int i = 1; i <= n; i++)
sum[i] = sum[i-1] + A[i];
s = sum[n];
memset(ret, 0, sizeof(ret[0]) * (s+1));
for (m = 1; m <= (s<<1); m <<= 1);
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] += complex<double>(i, 0);
b[sum[i-1]] += complex<double>(1, 0);
}
c = tool.convolution(a, b, m);
for (int i = 1; i < m; i++)
ret[i] += round(c[i]);
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] += complex<double>(1, 0);
b[sum[i-1]] += complex<double>(i-1, 0);
}
c = tool.convolution(a, b, m);
for (int i = 1; i <= s; i++)
ret[i] -= round(c[i]);
for (int i = 1, z = 0; i <= n+1; i++) {
if (i == n+1 || A[i] != 0)
ret[0] += pr[z], z = 0;
else
z++;
}
for (int i = 0; i <= s; i++) {
printf("%lld\n", ret[i]);
}
}
return 0;
}

NTT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <vector>
#include <algorithm>
using namespace std;
typedef unsigned int UINT32;
typedef long long INT64;
class TOOL_NTT {
public:
#define MAXN 262144
const INT64 P = 50000000001507329LL; // prime m = kn+1
const INT64 G = 3;
INT64 wn[20];
INT64 s[MAXN], d1[MAXN], d2[MAXN], y[MAXN];
TOOL_NTT() {
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
INT64 mod_mul(INT64 a, INT64 b, INT64 mod) {
return (a*b - (long long)(a/(long double)mod*b+1e-3)*mod+mod)%mod;
// INT64 ret = 0;
// for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
// if (b&1) {
// ret += a;
// if (ret >= mod)
// ret -= mod;
// }
// }
// return ret;
}
INT64 mod_pow(INT64 n, INT64 e, INT64 m) {
INT64 x = 1;
for (n = n >= m ? n%m : n; e; e >>= 1) {
if (e&1)
x = mod_mul(x, n, m);
n = mod_mul(n, n, m);
}
return x;
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void NTT(int on, INT64 *In, INT64 *Out, int n) {
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; ++i)
Out[FastReverseBits(i, NumBits)] = In[i];
for(int h = 2, id = 1; h <= n; h <<= 1, id++) {
for(int j = 0; j < n; j += h) {
INT64 w = 1, u, t;
int block = h/2, blockEnd = j + h/2;
for(int k = j; k < blockEnd; k++) {
u = Out[k], t = mod_mul(w, Out[k+block], P);
Out[k] = u + t;
Out[k + block] = u - t + P;
if (Out[k] >= P) Out[k] -= P;
if (Out[k+block] >= P) Out[k+block] -= P;
w = mod_mul(w, wn[id], P);
}
}
}
if (on == 1) {
for (int i = 1; i < n/2; i++)
swap(Out[i], Out[n-i]);
INT64 invn = mod_pow(n, P-2, P);
for (int i = 0; i < n; i++)
Out[i] = mod_mul(Out[i], invn, P);
}
}
void convolution(INT64 *a, INT64 *b, int n, INT64 *c) {
NTT(0, a, d1, n);
s[0] = b[0];
for (int i = 1; i < n; ++i)
s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++)
s[i] = mod_mul(d1[i], d2[i], P);
NTT(1, s, c, n);
}
} tool;
INT64 a[MAXN], b[MAXN], c[MAXN];
int A[MAXN], sum[MAXN];
long long ret[MAXN], pr[262144];
int main() {
pr[0] = 0;
for (int i = 1; i < MAXN; i++)
pr[i] = pr[i-1] + (long long)i*(i+1)/2;
int testcase, n, m, s;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &A[i]);
sum[0] = 0;
for (int i = 1; i <= n; i++)
sum[i] = sum[i-1] + A[i];
s = sum[n];
memset(ret, 0, sizeof(ret[0]) * (s+1));
for (m = 1; m <= (s<<1); m <<= 1);
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] += i;
b[sum[i-1]] ++;
}
tool.convolution(a, b, m, c);
for (int i = 1; i < m; i++)
ret[i] += c[i];
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] ++;
b[sum[i-1]] += i-1;
}
tool.convolution(a, b, m, c);
for (int i = 1; i <= s; i++)
ret[i] -= c[i];
for (int i = 1, z = 0; i <= n+1; i++) {
if (i == n+1 || A[i] != 0)
ret[0] += pr[z], z = 0;
else
z++;
}
for (int i = 0; i <= s; i++) {
printf("%lld\n", ret[i]);
}
}
return 0;
}

NTT CRT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <vector>
#include <algorithm>
using namespace std;
typedef uint_fast32_t UINT32;
typedef long long INT64;
typedef uint_fast32_t INT32;
class TOOL_NTT {
public:
#define MAXN 262144
// INT64 P = 50000000001507329LL; // prime m = kn+1
// INT64 G = 3;
INT32 P = 3, G = 2;
INT32 wn[20];
INT32 s[MAXN], d1[MAXN], d2[MAXN], c1[MAXN], c2[MAXN];
const INT32 P1 = 998244353; // P1 = 2^23 * 7 * 17 + 1
const INT32 G1 = 3;
const INT32 P2 = 995622913; // P2 = 2^19 *3*3*211 + 1
const INT32 G2 = 5;
const INT64 M1 = 397550359381069386LL;
const INT64 M2 = 596324591238590904LL;
const INT64 MM = 993874950619660289LL; // MM = P1*P2
TOOL_NTT() {
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
void reset(INT32 p, INT32 g) {
P = p, G = g;
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
INT64 mod_mul(INT64 a, INT64 b, INT64 mod) {
return (a*b - (long long)(a/(long double)mod*b+1e-3)*mod+mod)%mod;
// INT64 ret = 0;
// for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
// if (b&1) {
// ret += a;
// if (ret >= mod)
// ret -= mod;
// }
// }
// return ret;
}
INT64 mod_pow(INT64 n, INT64 e, INT64 m) {
INT64 x = 1;
for (n = n >= m ? n%m : n; e; e >>= 1) {
if (e&1)
x = mod_mul(x, n, m);
n = mod_mul(n, n, m);
}
return x;
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void NTT(int on, INT32 *In, INT32 *Out, int n) {
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; ++i)
Out[FastReverseBits(i, NumBits)] = In[i];
for (int h = 2, id = 1; h <= n; h <<= 1, id++) {
for (int j = 0; j < n; j += h) {
INT32 w = 1, u, t;
int block = h/2, blockEnd = j + h/2;
for (int k = j; k < blockEnd; k++) {
u = Out[k], t = (INT64)w*Out[k+block]%P;
Out[k] = (u + t)%P;
Out[k+block] = (u - t + P)%P;
w = (INT64)w * wn[id]%P;
}
}
}
if (on == 1) {
for (int i = 1; i < n/2; i++)
swap(Out[i], Out[n-i]);
INT32 invn = mod_pow(n, P-2, P);
for (int i = 0; i < n; i++)
Out[i] = (INT64)Out[i]*invn%P;
}
}
INT64 crt(INT32 a, INT32 b) {
return (mod_mul(a, M1, MM) + mod_mul(b, M2, MM))%MM;
}
void convolution(INT32 *a, INT32 *b, int n, INT64 *c) {
reset(P1, G1);
NTT(0, a, d1, n);
s[0] = b[0]; for (int i = 1; i < n; ++i) s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++) s[i] = (INT64)d1[i] * d2[i]%P;
NTT(1, s, c1, n);
reset(P2, G2);
NTT(0, a, d1, n);
s[0] = b[0]; for (int i = 1; i < n; ++i) s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++) s[i] = (INT64)d1[i] * d2[i]%P;
NTT(1, s, c2, n);
for (int i = 0; i < n; i++)
c[i] = crt(c1[i], c2[i]);
}
} tool;
INT32 a[262144], b[262144];
INT64 c[262144];
int A[MAXN], sum[MAXN];
long long ret[MAXN], pr[262144];
int main() {
pr[0] = 0;
for (int i = 1; i < MAXN; i++)
pr[i] = pr[i-1] + (long long)i*(i+1)/2;
int testcase, n, m, s;
scanf("%d", &testcase);
while (testcase--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &A[i]);
sum[0] = 0;
for (int i = 1; i <= n; i++)
sum[i] = sum[i-1] + A[i];
s = sum[n];
memset(ret, 0, sizeof(ret[0]) * (s+1));
for (m = 1; m <= (s<<1); m <<= 1);
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] += i;
b[sum[i-1]] ++;
}
tool.convolution(a, b, m, c);
for (int i = 1; i < m; i++)
ret[i] += c[i];
memset(a, 0, sizeof(a[0])*m);
memset(b, 0, sizeof(b[0])*m);
for (int i = 1; i <= n; i++) {
a[sum[i]] ++;
b[sum[i-1]] += i-1;
}
tool.convolution(a, b, m, c);
for (int i = 1; i <= s; i++)
ret[i] -= c[i];
for (int i = 1, z = 0; i <= n+1; i++) {
if (i == n+1 || A[i] != 0)
ret[0] += pr[z], z = 0;
else
z++;
}
for (int i = 0; i <= s; i++) {
printf("%lld\n", ret[i]);
}
}
return 0;
}
Read More +

a994. 10325 - The Lottery

Problem

模擬計算,刪除組合數字的倍數,請問剩下多少個數字可選。

Sample Input

1
2
3
4
5
6
10 2
2 3
20 2
2 4
100 3
3 5 7

Sample Output

1
2
3
3
10
45

Solution

這一題是非常容易的題目,利用排容原理可以在 $O(2^m)$ 的時間內完成,所以複雜度沒有太大的改善,若使用 bitmask 的方式撰寫,複雜度會落在 $O(m \times 2^m)$,中間會有大量的 gcd(a, b) 計算,歐基里德輾轉相除法的常數並不大,時間複雜度 $O(\log n)$

為了加速運算,可以利用組合來完成,利用選用組合 1011,可以得到 lcm(1011) = lcm(lcm(1010), A[0]) 完成,因此只會有 $O(2^m)$gcd() 的成本,整整少了常數 m。

因此需要使用 lowbit = i&(-i) 的位元運算技巧,同時為了得到 2 的冪次的次方數,建議使用內置函數 __builtin 系列,編譯器最佳化計算。而 gcd 使用,也使用內建的 __gcd(a, b) 但內置函數通常只會設置在 unsigned int 意即 32-bits 無號整數,要防止運算 overflow。

有人會提案使用建表,這樣查找只需要 $O(1)$,但記憶體置換會造成負擔,有兩個 $O(2^{15})$ 的 cache page,而且還要前置建表的計算成本。根據實驗測試,建表的方案速度會比較慢。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <bits/stdc++.h>
using namespace std;
int main() {
int n, m, A[16];
unsigned int dp[1<<15], a, b;
while (scanf("%d %d", &n, &m) == 2) {
for (int i = 1; i <= m; i++)
scanf("%d", &A[i]);
int ret = n;
dp[0] = 1;
for (int i = 1; i < (1<<m); i++) {
long long val;
a = dp[i-(i&(-i))], b = A[__builtin_ffs(i&(-i))];
if (a > n) {dp[i] = n+1; continue;}
val = b / __gcd(a, b) * (long long) a;
if (val <= n) {
dp[i] = val;
if (__builtin_popcount(i)&1)
ret -= n / val;
else
ret += n / val;
} else {
dp[i] = n+1;
}
}
printf("%d\n", ret);
}
return 0;
}
Read More +

b454. 困難版 輸出 RGB 數值

Problem

給一張 $45 \times 45$ 的經典 Lena 彩色影像,共有 $2025$ 個像素,程式碼長度上限 10K,避開打表的長度限制,輸出一模一樣的圖形。

Sample Input

1
NO INPUT

Sample Output

1
2
45 45
225 134 119 223 130 108 220 126 105 (後面省略)

Solution

Zerojudge b454. 請輸出這張圖片的 RGB 數值 改編,強迫使用 base64 的方案,使用一般的 16 進制輸出編碼會超過限制。16 進制下,共計需要 $6075$0 - 255 的整數,共計需要用 $12150$ 個可視字元。

根據前一題的實驗,雖然霍夫曼編碼會比較短,但是還要附加解壓縮的代碼一起上傳,除非寫短碼否則很容易虧本。而 lz77 是不錯的壓縮方案,但用在影像中很容易虧本,因為重複的片段並不高。因此最後選擇直接使用 base64,則可視字元數量可以降到 10K 以下,接下來就比誰的短碼能力好。

base64 只用到 0-9a-zA-Z+/= 這 64 個字元,為了在一般 C/C++ 的字串宣告語法,跳逸字元如 \\\t\n … 等必須用兩個字元表示,通常都是在 ASCII [0, 31] 為跳逸字元。蔡神 asas 直接用連續片段,因此會有一些跳逸字元,儘管跳逸字元占用兩個以上的字符,會多好幾個字元,但就不必編寫解碼程序,不用去寫繁複的映射,代碼居然比較短。

產生器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
#include <stdio.h>
#include <stdlib.h>
#include <algorithm>
#include <string.h>
#include <math.h>
#include <map>
#include <vector>
#include <set>
using namespace std;
void lz77_encode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
if (n <= 16) {
for (int i = 0; i < n; i++)
out[m++] = in[i];
return ;
}
memcpy(out, in, 16);
m += 16;
while (n > 16) {
int mx = 0, offset, i, j;
for (i = 0; i < 16; i++) {
for (j = 0; i+j < 16 && j+16 < n && in[i+j] == in[j+16]; j++);
if (j > mx)
mx = j, offset = i;
}
if (mx <= 1) {
out[m++] = 0;
out[m++] = in[16];
in++, n--;
} else {
out[m++] = (offset<<4)|(mx-1);
in += mx, n -= mx;
}
}
}
void lz77_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
if (n <= 16) {
for (int i = 0; i < n; i++)
out[m++] = in[i];
return ;
}
memcpy(out, in, 16);
in += 16, n -= 16, m += 16;
int offset, mx;
while (n > 0) {
if (*in) {
offset = (*in)>>4, mx = ((*in)&0xf)+1;
memcpy(out + m, out + m - 16 + offset, mx);
m += mx;
in++, n -= 1;
} else {
out[m++] = in[1];
in += 2, n -= 2;
}
}
}
void base64_encode(unsigned char *in, int n, unsigned char *out, int &m) {
static char encode_table[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3',
'4', '5', '6', '7', '8', '9', '+', '/'};
static int mod_table[] = {0, 2, 1};
m = 4 * ((n + 2) / 3);
for (int i = 0, j = 0; i < n; ) {
unsigned int a, b, c, d;
a = i < n ? in[i++] : 0;
b = i < n ? in[i++] : 0;
c = i < n ? in[i++] : 0;
d = (a<<16)|(b<<8)|c;
out[j++] = encode_table[(d>>18)&0x3f];
out[j++] = encode_table[(d>>12)&0x3f];
out[j++] = encode_table[(d>>6)&0x3f];
out[j++] = encode_table[d&0x3f];
}
for (int i = 0; i < mod_table[n%3]; i++)
out[m - 1 - i] = '=';
}
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int decode_table[256];
for (int i = 'A'; i <= 'Z'; i++) decode_table[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) decode_table[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) decode_table[i] = i - '0' + 52;
decode_table['+'] = 62, decode_table['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : decode_table[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
void huffman_encode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
struct Table {
int f, bit, bn;
} tables[512];
struct Node {
int f, tid;
Node *l, *r;
Node(int a = 0, int b = 0, Node *ls = NULL, Node *rs = NULL):
f(a), tid(b), l(ls), r(rs) {}
bool operator<(const Node &x) const {
return f < x.f;
}
} nodes[512];
struct Stack {
Node *node;
int bit, bn;
Stack(Node *a = NULL, int b = 0, int c = 0):
node(a), bit(b), bn(c) {}
} stacks[1024];
int size = 0, leaf;
multiset< pair<int, Node*> > S;
memset(tables, 0, sizeof(tables));
for (int i = 0; i < n; i++)
tables[in[i]].f++;
for (int i = 0; i < 256; i++) {
if (tables[i].f) {
nodes[size] = Node(tables[i].f, i);
S.insert({nodes[size].f, &nodes[size]});
size++;
}
}
leaf = size;
while (S.size() >= 2) {
pair<int, Node*> u, v;
u = *S.begin(), S.erase(S.begin());
v = *S.begin(), S.erase(S.begin());
int f = u.second->f + v.second->f;
nodes[size] = Node(f, -1, u.second, v.second);
S.insert({nodes[size].f, &nodes[size]});
size++;
}
int stkIdx = 0;
stacks[stkIdx++] = Stack(&nodes[size-1], 0, 0);
while (stkIdx) {
Stack u = stacks[--stkIdx], v;
if (u.bn >= 31) {
fprintf(stderr, "huffman error: bit length exceeded\n");
exit(0);
}
if (u.node->l == NULL) {
tables[u.node->tid].bit = u.bit;
tables[u.node->tid].bn = u.bn;
} else {
v = Stack(u.node->l, u.bit | (1<<u.bn), u.bn+1);
u.node = u.node->r, u.bn++;
stacks[stkIdx++] = u;
stacks[stkIdx++] = v;
}
}
int bits_cnt = 0;
for (int i = 0; i < 256; i++)
bits_cnt += tables[i].f * tables[i].bn;
fprintf(stderr, "huffman: %d bit length\n", bits_cnt);
memcpy(out+m, &bits_cnt, 4), m += 4;
memcpy(out+m, &leaf, 1), m += 1;
for (int i = 0; i < leaf; i++) {
Table t = tables[nodes[i].tid];
memcpy(out+m, &nodes[i].tid, 1), m += 1;
memcpy(out+m, &t.bn, 1), m += 1;
memcpy(out+m, &t.bit, (t.bn + 7)/8), m += (t.bn + 7)/8;
}
int cnt = 0, mask = 0;
for (int i = 0; i < n; i++) {
int bit = tables[in[i]].bit;
int bn = tables[in[i]].bn;
while (bn) { // LBS
int j = min(bn, 8 - cnt);
mask |= (bit&((1<<j)-1))<<cnt;
cnt += j, bn -= j, bit >>= j;
if (cnt == 8) {
memcpy(out+m, &mask, 1), m += 1;
cnt = 0, mask = 0;
}
}
}
if (cnt)
memcpy(out+m, &mask, 1), m += 1;
fprintf(stderr, "huffman encode length %d\n", m);
}
void huffman_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
map<int, int> R[64];
int bits_cnt = 0, leaf = 0;
memcpy(&bits_cnt, in, 4), in += 4, n -= 4;
memcpy(&leaf, in, 1), in += 1, n -= 1;
fprintf(stderr, "huffman: %d bit length\n", bits_cnt);
fprintf(stderr, "huffman: %d leaves\n", leaf);
for (int i = 0; i < leaf; i++) {
int tid = 0, bn = 0, bit = 0;
memcpy(&tid, in, 1), in += 1, n -= 1;
memcpy(&bn, in, 1), in += 1, n -= 1;
memcpy(&bit, in, (bn+7)/8), in += (bn+7)/8, n -= (bn+7)/8;
R[bn][bit] = tid;
}
int mask = 0, cnt = 0;
for (int i = 0; i < bits_cnt; i++) {
mask |= ((in[(i>>3)]>>(i&7))&1)<<cnt;
cnt++;
if (R[cnt].count(mask)) {
out[m++] = R[cnt][mask];
cnt = 0, mask = 0;
}
}
}
int main() {
int n, m;
while (scanf("%d %d", &n, &m) == 2) {
unsigned char in[32767], in2[32767], *pin;
pin = in;
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
int r, g, b;
scanf("%d %d %d", &r, &g, &b);
*pin = r, pin++;
*pin = g, pin++;
*pin = b, pin++;
}
}
int lz77n, b64n, elz77n, tn, hufn, ehufn;
unsigned char buf[4][32767] = {}, test[32767] = {};
// Case 1
// base64_encode(in, n*m*3, buf[1], b64n);
// Case 2
// huffman_encode(in, n*m, buf[0], hufn);
// base64_encode(buf[0], hufn, buf[1], b64n);
// base64_decode(buf[1], b64n, buf[2], ehufn);
// huffman_decode(buf[2], ehufn, test, tn);
// Case 3
// lz77_encode(in, n*m, buf[0], lz77n);
// base64_encode(buf[0], lz77n, buf[1], b64n);
// base64_decode(buf[1], b64n, buf[2], elz77n);
// lz77_decode(buf[2], elz77n, test, tn);
// Case 4
// huffman_encode(in, n*m, buf[0], hufn);
// lz77_encode(buf[0], hufn, buf[2], lz77n);
// base64_encode(buf[2], lz77n, buf[1], b64n);
// Case 5
// lz77_encode(in, n*m, buf[0], lz77n);
// huffman_encode(buf[0], hufn, buf[2], hufn);
// base64_encode(buf[2], lz77n, buf[1], b64n);
// Case 6
for (int i = 0; i < n*m*3; i += 3) {
in2[i] = in[i] - (i >= 3 ? in[i-3] : 0);
in2[i+1] = in[i+1] - (i >= 4 ? in[i-4] : 0);
in2[i+2] = in[i+2] - (i >= 5 ? in[i-5] : 0);
}
huffman_encode(in2, n*m*3, buf[0], hufn);
base64_encode(buf[0], hufn, buf[1], b64n);
for (int i = 0; i < b64n; i++)
printf("%c", buf[1][i]);
puts("");
}
return 0;
}

base64

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
using namespace std;
unsigned char data[] = "4YZ334Js3H5p5o1z8JZ0t01cpEBWsklUsE ... ignore";
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int R[256];
for (int i = 'A'; i <= 'Z'; i++) R[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) R[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) R[i] = i - '0' + 52;
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : R[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
int main() {
unsigned char d[32767] = {};
int n;
base64_decode(data, sizeof(data), d, n);
printf("%d %d\n", 45, 45);
for (int i = 0; i < 45; i++) {
for (int j = 0; j < 45; j++) {
int v = (i*45+j)*3;
printf("%d %d %d%c", d[v], d[v+1], d[v+2], " \n"[j == 44]);
}
}
return 0;
}

base64 短碼

利用巨集展開重複的指令

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include <bits/stdc++.h>
typedef unsigned char UINT8;
UINT8 data[] = "4YZ334Js3H5p5o1z8JZ0t01cpEBWsklUsE ... ignore";
UINT8 d[32767] = {};
void base64_decode(UINT8 *in, int n, UINT8 *out, int &m) {
#define MK(st, ed, b) for (int i = st; i <= ed; i++) R[i] = i-st+b;
#define MG(k) b |= (in[i] == '=' ? 0 : R[in[i]])<<(k*6), i++;
#define MP(k) if (j < m) out[j++] = (b>>(k*8))&0xff;
int R[256];
MK('A', 'Z', 0); MK('a', 'z', 26); MK('0', '9', 52);
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
int b = 0;
MG(3); MG(2); MG(1); MG(0);
MP(2); MP(1); MP(0);
}
}
int main() {
int n, v;
base64_decode(data, sizeof(data), d, n);
printf("%d %d\n", 45, 45);
for (int i = 0; i < 45; i++) {
for (int j = 0; j < 45; j++) {
v = (i*45+j)*3;
printf("%d %d %d%c", d[v], d[v+1], d[v+2], " \n"[j == 44]);
}
}
return 0;
}
Read More +

a458. Beats of the Angel 加強版

Problem

給一張圖,給定起點 S、終點 T,把任意邊的權重放大兩倍,請問最多能拉長最短道路長度為何。

意即修改一條邊找替代道路,使得替代道路越長越好。

Sample Input

1
2
3
4
5
6
7
8
9
5 7
2 1 5
1 3 1
3 2 8
3 5 7
3 4 3
2 4 7
4 5 2
1 5

Sample Output

1
2

Solution

終於向 liouzhou101 問到解法,類似 BZOJ 2725 Violet 6 故乡的梦,過了三年多,才把 Angel Beats 加強版寫出來,終於完成了,代碼是如此地迷人,一堆掃描線算法,找瓶頸也用掃描線代替感覺萌萌哒。

移除掉一條邊,找到 S-T 最短路徑的替代道路 $O(E \log E)$ 離線處理。

  1. 先在 DAG 上找瓶頸,只有瓶頸發生變化才受影響。
  2. 計算 DAG 上,經過瓶頸的最小次數。
  3. 接著,對瓶頸由近至遠進行掃描,維護一個最小堆 multiset,維護替代道路的可行性,保持掃描線左右兩方各自通過的瓶頸數接起來不會通過瓶頸。
1
2
3
4
5
B -- C H -- I
/ \ / \
S - A D -- G ------ J - T
\ /
E -- F

變數說明

  • Ds(u) 表示從起點 S 到 u 的最短路徑
  • De(u) 表示從 u 到終點 T 的最短路徑
  • Bs(u) 表示從起點 S 到 u 保持在最短路徑上,經過最少的瓶頸數
  • Be(u) 表示從 u 到終點 T 保持在最短路徑上,經過最少的瓶頸數

例如先找到 shortest-path DAG 圖如上,則 D - GJ - T 都是瓶頸。瓶頸有兩種找法,第一種是類似最大流的最小割,第二種是轉換成區間,例如 e(u, v) 則可以得到區間 [Ds(u), Ds(v)],然後做掃描線算法,得到哪個時候只有一個區間,則那個邊就是瓶頸。

得到所有瓶頸後,依序掃描離 S 近到遠的瓶頸,維護一個 mutliset<int>,當窮舉瓶頸 e(u, v),則替代道路為 min(Ds(a) + w(a, b) + De(b)),滿足 Bs(a) <= Bs(u)Be(b) <= Be(v)。條件 Bs(a) <= Bs(u)Be(b) <= Be(v) 是為了保證不經過這個瓶頸。

隨著掃描線移動,Be(b) 遞減,Bs(a) 遞增,因此先排序 Bs[]Be[],每當掃描到一個瓶頸,鬆弛新的路徑,同時移除掉最小堆中 Ds(a) + w(a, b) + De(b)Be(b) > Be(v) 的所有元素。

英文名稱 The selfish-edges Shortest-Path problem

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <bits/stdc++.h>
using namespace std;
// http://zrt.github.io/2014/08/13/bzoj2725/
const int MAXV = 200005;
const int MAXE = 200005<<1;
const long long INF = 1LL<<60;
struct Edge {
int to, eid;
long long w;
Edge *next;
};
Edge edge[MAXE], *adj[MAXV];
int e = 0;
long long Ds[MAXV], De[MAXV];
int Bs[MAXV], Be[MAXV], keye[MAXE];
void addEdge(int x, int y, long long v) {
edge[e].to = y, edge[e].w = v, edge[e].eid = e;
edge[e].next = adj[x], adj[x] = &edge[e++];
edge[e].to = x, edge[e].w = v, edge[e].eid = e;
edge[e].next = adj[y], adj[y] = &edge[e++];
}
void dijkstra(int st, long long dist[], int n) {
typedef pair<long long, int> PLL;
for (int i = 0; i <= n; i++)
dist[i] = INF;
set<PLL> pQ;
PLL u;
pQ.insert(PLL(0, st)), dist[st] = 0;
while (!pQ.empty()) {
u = *pQ.begin(), pQ.erase(pQ.begin());
for (Edge *p = adj[u.second]; p; p = p->next) {
if (dist[p->to] > dist[u.second] + p->w) {
if (dist[p->to] != INF)
pQ.erase(pQ.find(PLL(dist[p->to], p->to)));
dist[p->to] = dist[u.second] + p->w;
pQ.insert(PLL(dist[p->to], p->to));
}
}
}
}
int bottleneck(int st, int ed, long long Ds[], long long De[], int n) {
typedef pair<long long, pair<long long, int> > PLL;
vector<PLL> A;
// short path DAG st-ed
long long d = Ds[ed];
for (int i = 1; i <= n; i++) {
for (Edge *p = adj[i]; p; p = p->next) {
if (Ds[p->to] == Ds[i] + p->w && Ds[p->to] + De[p->to] == d) {
A.push_back(PLL(Ds[i], {Ds[p->to], p->eid}));
}
}
}
// bottleneck edge st-ed
sort(A.begin(), A.end());
priority_queue<long long, vector<long long>, std::greater<long long>> pQ;
int cnt = 0;
for (int i = 0; i < e; i++)
keye[i] = 0;
for (int i = 0; i < A.size(); ) {
long long l = A[i].first;
while (!pQ.empty() && pQ.top() <= l)
pQ.pop();
while (i < A.size() && A[i].first <= l)
pQ.push(A[i].second.first), i++;
if (pQ.size() == 1) {
keye[A[i-1].second.second] = 1;
keye[A[i-1].second.second^1] = -1;
cnt++;
}
}
return cnt;
}
void bfs(int st, int dist[], long long Ds[], int n, int f) {
typedef pair<long long, int> PLL;
// work for short-path DAG, weight: key edge
vector<PLL> A;
for (int i = 1; i <= n; i++)
A.push_back(PLL(Ds[i], i)), dist[i] = 0x3f3f3f3f;
dist[st] = 0;
sort(A.begin(), A.end());
for (int i = 0; i < n; i++) {
int x = A[i].second;
for (Edge *p = adj[x]; p; p = p->next) {
if (Ds[p->to] == Ds[x] + p->w)
dist[p->to] = min(dist[p->to], dist[x] + (keye[p->eid] == f));
}
}
}
long long solve(int st, int ed, long long Ds[], long long De[], int n) {
typedef pair<long long, int> PLL;
typedef pair<int, int> PII;
typedef multiset<long long>::iterator MIT;
bfs(st, Bs, Ds, n, 1);
bfs(ed, Be, De, n, -1);
vector<PLL> A;
vector<PII> B, C;
for (int i = 1; i <= n; i++) {
A.push_back(PLL(Ds[i], i));
B.push_back(PII(Bs[i], i));
C.push_back(PII(Be[i], i));
}
sort(A.begin(), A.end());
sort(B.begin(), B.end());
sort(C.begin(), C.end());
long long d = Ds[ed];
int Bidx = 0, Cidx = n-1;
multiset<long long> S;
vector< vector<MIT> > RM(n+1, vector<MIT>());
long long ret = 0;
for (int i = 0; i < n; i++) {
int x = A[i].second;
for (Edge *p = adj[x]; p; p = p->next) {
if (Ds[p->to] == Ds[x] + p->w && Ds[p->to] + De[p->to] == d) {
if (keye[p->eid]) {
int bb = Bs[x], cut = Be[p->to];
// relax
for (; Bidx < B.size() && B[Bidx].first <= bb; Bidx++) {
int u = B[Bidx].second;
for (Edge *q = adj[u]; q; q = q->next) {
if (Be[q->to] <= cut && p != q) {
MIT it = S.insert(Ds[u] + q->w + De[q->to]);
RM[q->to].push_back(it);
}
}
}
// remove
for (; Cidx >= 0 && C[Cidx].first > cut; Cidx--) {
int u = C[Cidx].second;
for (auto e : RM[u])
S.erase(e);
}
long long replace_path = INF;
if (!S.empty())
replace_path = *S.begin();
replace_path = min(replace_path, d + p->w); // this edge double weight
ret = max(ret, replace_path - d);
}
}
}
}
return ret;
}
int main() {
int n, m, x, y;
int st, ed;
long long v;
while (scanf("%d %d", &n, &m) == 2) {
for (int i = 1; i <= n; i++)
adj[i] = NULL;
for (int i = 0; i < m; i++) {
scanf("%d %d %lld", &x, &y, &v);
addEdge(x, y, v);
}
scanf("%d %d", &st, &ed);
dijkstra(st, Ds, n);
dijkstra(ed, De, n);
int cnt = bottleneck(st, ed, Ds, De, n);
if (cnt == 0)
puts("0");
else
printf("%lld\n", solve(st, ed, Ds, De, n));
}
return 0;
}
Read More +

b454. 請輸出這張圖片的 RGB 數值

Problem

給一張 $64 \times 64$ 的經典 Lena 影像,共有 $4096$ 個像素,程式碼長度上限 10K,避開打表的長度限制,輸出一模一樣的圖形。

Sample Input

1
NO INPUT

Sample Output

1
2
64 64
153 153 153 151 151 151 148 148 148 (後面省略)

Solution

前處理

首先遇到的問題是將灰階圖片取出,變成一般常看到的數字格式,直接用 python 來完成,必要時需要安裝 install Python Imaging Library PIL,這是前處理,也可以用其他的 OpenCV 來完成。

1
$ python img2txt.py b454.png >in.txt
img2txt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import sys
import Image
import codecs
sys.stdout = codecs.getwriter('utf8')(sys.stdout)
for infile in sys.argv[1:]:
try:
im = Image.open(infile)
gray = im.convert('LA')
width = gray.size[0]
height = gray.size[1]
pix = gray.load()
print height, width
for i in range(height):
for j in range(width):
print pix[j, i][0],
print
except IOError:
pass

建表壓縮

由於 4096 個像素,若使用十六進制表示每一個 0 - 255 像素,只需要 2 個字元 00 - FF,程式碼長度大約會落在 8192 bytes,若使用十進制,需要使用 3 個字元表示,那麼大約是在 13KB,所以至少要用十六進制表示法。

還有更好的方式,如一般常在網頁上瀏覽的 base64,它充分利用 64 個可視字元進行編碼,相較於十六進制只使用 16 個可視字元來比較會更加地短。利用傳統的霍夫曼編碼 (huffman coding) 來壓縮,結果會短個 10% 到 20%,當每種像素次數分布相當懸殊時,效果就越好,但上傳時還要附加解壓縮的代碼,所以沒辦法短太多,還不如直接 base64。由於是圖片,漸層效果比較普遍,改用與前一個相鄰像素的差值來轉換,得到的分布會比較極端,帶入 huffman 的效果就會不錯。

最後,還有一種比較容易實作的 lz77 壓縮,比較類似最長平台,每一次的訊息為 (起始位置, 重疊長度, 補尾字元),設定一個 window 長度,然後 sliding 滑動,但是起始位置、長度需要選好,代碼中嘗試用 window size = 16,則起始位置和重疊長度可以用 8-bits 表示,越大不見得越好,太小也不是好事,但不管怎麼做,由於影像的重複 pattern 並不多,沒有像一般數學性質的數列來得強,壓縮實驗不管怎樣都大於 10KB。

關於實作細節,霍夫曼編碼儲存格式是 壓縮位元長度 bits length + 字典表 + 壓縮資料,對於字典表的儲存有很多種,由於是 complete binary tree (只會有兩個、零個子節點),可以用一個 0 / 1 前序走訪來完成,這會造成解壓縮代碼長度就會有點虧本,所以在代碼中直接使用一般的打表,所以要保證每一個壓縮完的最大 bit-length 小於 32 來方便型態 int32 操作。

方法 代碼長度 (bytes)
base64 6485
huffman+base64 7965
diff+huffman+base64 7717
lz77+base64 > 10K

產生器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
#include <bits/stdc++.h>
using namespace std;
void lz77_encode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
if (n <= 16) {
for (int i = 0; i < n; i++)
out[m++] = in[i];
return ;
}
memcpy(out, in, 16);
m += 16;
while (n > 16) {
int mx = 0, offset, i, j;
for (i = 0; i < 16; i++) {
for (j = 0; i+j < 16 && j+16 < n && in[i+j] == in[j+16]; j++);
if (j > mx)
mx = j, offset = i;
}
if (mx <= 1) {
out[m++] = 0;
out[m++] = in[16];
in++, n--;
} else {
out[m++] = (offset<<4)|(mx-1);
in += mx, n -= mx;
}
}
}
void lz77_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
if (n <= 16) {
for (int i = 0; i < n; i++)
out[m++] = in[i];
return ;
}
memcpy(out, in, 16);
in += 16, n -= 16, m += 16;
int offset, mx;
while (n > 0) {
if (*in) {
offset = (*in)>>4, mx = ((*in)&0xf)+1;
memcpy(out + m, out + m - 16 + offset, mx);
m += mx;
in++, n -= 1;
} else {
out[m++] = in[1];
in += 2, n -= 2;
}
}
}
void base64_encode(unsigned char *in, int n, unsigned char *out, int &m) {
static char encode_table[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3',
'4', '5', '6', '7', '8', '9', '+', '/'};
static int mod_table[] = {0, 2, 1};
m = 4 * ((n + 2) / 3);
for (int i = 0, j = 0; i < n; ) {
unsigned int a, b, c, d;
a = i < n ? in[i++] : 0;
b = i < n ? in[i++] : 0;
c = i < n ? in[i++] : 0;
d = (a<<16)|(b<<8)|c;
out[j++] = encode_table[(d>>18)&0x3f];
out[j++] = encode_table[(d>>12)&0x3f];
out[j++] = encode_table[(d>>6)&0x3f];
out[j++] = encode_table[d&0x3f];
}
for (int i = 0; i < mod_table[n%3]; i++)
out[m - 1 - i] = '=';
}
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int decode_table[256];
for (int i = 'A'; i <= 'Z'; i++) decode_table[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) decode_table[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) decode_table[i] = i - '0' + 52;
decode_table['+'] = 62, decode_table['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : decode_table[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
void huffman_encode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
struct Table {
int f, bit, bn;
} tables[512];
struct Node {
int f, tid;
Node *l, *r;
Node(int a = 0, int b = 0, Node *ls = NULL, Node *rs = NULL):
f(a), tid(b), l(ls), r(rs) {}
bool operator<(const Node &x) const {
return f < x.f;
}
} nodes[512];
struct Stack {
Node *node;
int bit, bn;
Stack(Node *a = NULL, int b = 0, int c = 0):
node(a), bit(b), bn(c) {}
} stacks[1024];
int size = 0, leaf;
multiset< pair<int, Node*> > S;
memset(tables, 0, sizeof(tables));
for (int i = 0; i < n; i++)
tables[in[i]].f++;
for (int i = 0; i < 256; i++) {
if (tables[i].f) {
nodes[size] = Node(tables[i].f, i);
S.insert({nodes[size].f, &nodes[size]});
size++;
}
}
leaf = size;
while (S.size() >= 2) {
pair<int, Node*> u, v;
u = *S.begin(), S.erase(S.begin());
v = *S.begin(), S.erase(S.begin());
int f = u.second->f + v.second->f;
nodes[size] = Node(f, -1, u.second, v.second);
S.insert({nodes[size].f, &nodes[size]});
size++;
}
int stkIdx = 0;
stacks[stkIdx++] = Stack(&nodes[size-1], 0, 0);
while (stkIdx) {
Stack u = stacks[--stkIdx], v;
if (u.bn >= 31) {
fprintf(stderr, "huffman error: bit length exceeded\n");
exit(0);
}
if (u.node->l == NULL) {
tables[u.node->tid].bit = u.bit;
tables[u.node->tid].bn = u.bn;
} else {
v = Stack(u.node->l, u.bit | (1<<u.bn), u.bn+1);
u.node = u.node->r, u.bn++;
stacks[stkIdx++] = u;
stacks[stkIdx++] = v;
}
}
int bits_cnt = 0;
for (int i = 0; i < 256; i++)
bits_cnt += tables[i].f * tables[i].bn;
fprintf(stderr, "huffman: %d bit length\n", bits_cnt);
memcpy(out+m, &bits_cnt, 4), m += 4;
memcpy(out+m, &leaf, 1), m += 1;
for (int i = 0; i < leaf; i++) {
Table t = tables[nodes[i].tid];
memcpy(out+m, &nodes[i].tid, 1), m += 1;
memcpy(out+m, &t.bn, 1), m += 1;
memcpy(out+m, &t.bit, (t.bn + 7)/8), m += (t.bn + 7)/8;
}
int cnt = 0, mask = 0;
for (int i = 0; i < n; i++) {
int bit = tables[in[i]].bit;
int bn = tables[in[i]].bn;
while (bn) { // LBS
int j = min(bn, 8 - cnt);
mask |= (bit&((1<<j)-1))<<cnt;
cnt += j, bn -= j, bit >>= j;
if (cnt == 8) {
memcpy(out+m, &mask, 1), m += 1;
cnt = 0, mask = 0;
}
}
}
if (cnt)
memcpy(out+m, &mask, 1), m += 1;
fprintf(stderr, "huffman encode length %d\n", m);
}
void huffman_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
map<int, int> R[64];
int bits_cnt = 0, leaf = 0;
memcpy(&bits_cnt, in, 4), in += 4, n -= 4;
memcpy(&leaf, in, 1), in += 1, n -= 1;
fprintf(stderr, "huffman: %d bit length\n", bits_cnt);
fprintf(stderr, "huffman: %d leaves\n", leaf);
for (int i = 0; i < leaf; i++) {
int tid = 0, bn = 0, bit = 0;
memcpy(&tid, in, 1), in += 1, n -= 1;
memcpy(&bn, in, 1), in += 1, n -= 1;
memcpy(&bit, in, (bn+7)/8), in += (bn+7)/8, n -= (bn+7)/8;
R[bn][bit] = tid;
}
int mask = 0, cnt = 0;
for (int i = 0; i < bits_cnt; i++) {
mask |= ((in[(i>>3)]>>(i&7))&1)<<cnt;
cnt++;
if (R[cnt].count(mask)) {
out[m++] = R[cnt][mask];
cnt = 0, mask = 0;
}
}
}
int main() {
int n, m;
int data[64][64];
while (scanf("%d %d", &n, &m) == 2) {
unsigned char in[32767], in2[32767];
int stat[256] = {};
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
scanf("%d", &data[i][j]);
in[i*m+j] = data[i][j];
stat[data[i][j]]++;
}
}
int lz77n, b64n, elz77n, tn, hufn, ehufn;
unsigned char buf[4][32767] = {}, test[32767] = {};
// Case 1
base64_encode(in, n*m, buf[1], b64n);
// Case 2
// huffman_encode(in, n*m, buf[0], hufn);
// base64_encode(buf[0], hufn, buf[1], b64n);
// base64_decode(buf[1], b64n, buf[2], ehufn);
// huffman_decode(buf[2], ehufn, test, tn);
// Case 3
// lz77_encode(in, n*m, buf[0], lz77n);
// base64_encode(buf[0], lz77n, buf[1], b64n);
// base64_decode(buf[1], b64n, buf[2], elz77n);
// lz77_decode(buf[2], elz77n, test, tn);
// Case 4
// huffman_encode(in, n*m, buf[0], hufn);
// lz77_encode(buf[0], hufn, buf[2], lz77n);
// base64_encode(buf[2], lz77n, buf[1], b64n);
// Case 5
// lz77_encode(in, n*m, buf[0], lz77n);
// huffman_encode(buf[0], hufn, buf[2], hufn);
// base64_encode(buf[2], lz77n, buf[1], b64n);
// Case 6
// for (int i = 0; i < n*m; i++)
// in2[i] = in[i] - (i ? in[i-1] : 0);
// huffman_encode(in2, n*m, buf[0], hufn);
// base64_encode(buf[0], hufn, buf[1], b64n);
for (int i = 0; i < b64n; i++)
printf("%c", buf[1][i]);
puts("");
}
return 0;
}

base64

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <bits/stdc++.h>
using namespace std;
unsigned char data[] = "mZeUkpampXpOX2FhYGlxd ... ignore";
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int R[256];
for (int i = 'A'; i <= 'Z'; i++) R[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) R[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) R[i] = i - '0' + 52;
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : R[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
int main() {
unsigned char d[32767] = {};
int n;
base64_decode(data, sizeof(data), d, n);
printf("%d %d\n", 64, 64);
for (int i = 0; i < 64; i++) {
for (int j = 0; j < 64; j++) {
printf("%d %d %d%c", d[i*64+j], d[i*64+j], d[i*64+j], j == 63 ? '\n' : ' ');
}
}
return 0;
}

huffman+base64

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <bits/stdc++.h>
using namespace std;
unsigned char data[] = "RHkAAOoAC4kCAguWBwYMFggIDBYA ... ignore";
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int R[256];
for (int i = 'A'; i <= 'Z'; i++) R[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) R[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) R[i] = i - '0' + 52;
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : R[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
void huffman_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
map<int, int> R[64];
int bits_cnt = 0, leaf = 0;
memcpy(&bits_cnt, in, 4), in += 4, n -= 4;
memcpy(&leaf, in, 1), in += 1, n -= 1;
for (int i = 0; i < leaf; i++) {
int tid = 0, bn = 0, bit = 0;
memcpy(&tid, in, 1), in += 1, n -= 1;
memcpy(&bn, in, 1), in += 1, n -= 1;
memcpy(&bit, in, (bn+7)/8), in += (bn+7)/8, n -= (bn+7)/8;
R[bn][bit] = tid;
}
int mask = 0, cnt = 0;
for (int i = 0; i < bits_cnt; i++) {
mask |= ((in[(i>>3)]>>(i&7))&1)<<cnt;
cnt++;
if (R[cnt].count(mask)) {
out[m++] = R[cnt][mask];
cnt = 0, mask = 0;
}
}
}
int main() {
unsigned char d[2][32767] = {};
int n, m;
base64_decode(data, sizeof(data), d[0], n);
huffman_decode(d[0], n, d[1], m);
printf("%d %d\n", 64, 64);
for (int i = 0; i < 64; i++) {
for (int j = 0; j < 64; j++) {
printf("%d %d %d%c", d[1][i*64+j], d[1][i*64+j], d[1][i*64+j], j == 63 ? '\n' : ' ');
}
}
return 0;
}

diff+huffman+base64

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <bits/stdc++.h>
using namespace std;
unsigned char data[] = "nG0AAPkABQIBBRICBRoDBRMEBQ0FB ... ignore";
void base64_decode(unsigned char *in, int n, unsigned char *out, int &m) {
static int R[256];
for (int i = 'A'; i <= 'Z'; i++) R[i] = i - 'A';
for (int i = 'a'; i <= 'z'; i++) R[i] = i - 'a' + 26;
for (int i = '0'; i <= '9'; i++) R[i] = i - '0' + 52;
R['+'] = 62, R['/'] = 63;
m = n/4*3;
if (in[n-1] == '=') m--;
if (in[n-2] == '=') m--;
for (int i = 0, j = 0; i < n; ) {
unsigned int a, val = 0;
for (int k = 3; k >= 0; i++, k--) {
a = in[i] == '=' ? 0 : R[in[i]];
val |= a<<(k*6);
}
if (j < m) out[j++] = (val>>16)&0xff;
if (j < m) out[j++] = (val>>8)&0xff;
if (j < m) out[j++] = (val>>0)&0xff;
}
}
void huffman_decode(unsigned char *in, int n, unsigned char *out, int &m) {
m = 0;
map<int, int> R[64];
int bits_cnt = 0, leaf = 0;
memcpy(&bits_cnt, in, 4), in += 4, n -= 4;
memcpy(&leaf, in, 1), in += 1, n -= 1;
for (int i = 0; i < leaf; i++) {
int tid = 0, bn = 0, bit = 0;
memcpy(&tid, in, 1), in += 1, n -= 1;
memcpy(&bn, in, 1), in += 1, n -= 1;
memcpy(&bit, in, (bn+7)/8), in += (bn+7)/8, n -= (bn+7)/8;
R[bn][bit] = tid;
}
int mask = 0, cnt = 0;
for (int i = 0; i < bits_cnt; i++) {
mask |= ((in[(i>>3)]>>(i&7))&1)<<cnt;
cnt++;
if (R[cnt].count(mask)) {
out[m++] = R[cnt][mask];
cnt = 0, mask = 0;
}
}
}
int main() {
unsigned char d[2][32767] = {};
int n, m;
base64_decode(data, sizeof(data), d[0], n);
huffman_decode(d[0], n, d[1], m);
for (int i = 1; i < m; i++)
d[1][i] = d[1][i] + d[1][i-1];
printf("%d %d\n", 64, 64);
for (int i = 0; i < 64; i++) {
for (int j = 0; j < 64; j++) {
printf("%d %d %d%c", d[1][i*64+j], d[1][i*64+j], d[1][i*64+j], j == 63 ? '\n' : ' ');
}
}
return 0;
}
Read More +

b451. 圖片匹配

Problem

背景

圖片匹配和字串匹配有一點不同,字串匹配通常要求其子字串與搜尋字串完全相符,而圖片匹配則用相似度為依據,當圖片大、複雜且具有干擾,或者需要匹配數量非常多,更先進的領域會利用特徵擷取,用機率統計的方式來篩選可能的匹配數量,篩選過後才進行圖片的細節匹配。

題目描述

給予兩個圖片 $A, B$,圖片格式為灰階影像,每個像素 $\mathit{pixel}(x, y)$ 採用 8-bits 表示,範圍為 $\mathit{pixel}(x, y) \in [0, 255]$

舉一個例子,有一個 $3 \times 3$ 的圖片 $A$ 和一個 $2 \times 2$ 的圖片 $B$,用矩陣表示如下:

$$A := \begin{bmatrix} a1 & a2 & a3 \\ a4 & a5 & a6 \\ a7 & a8 & a9 \end{bmatrix} ,\; B := \begin{bmatrix} b1 & b2\\ b3 & b4\\ \end{bmatrix}$$

假設左上角座標 $(1, 1)$$a1$ 的位置、$(1, 2)$$a2$ 的位置。

  • 把影像 $B$ 左上角對齊 $A$$(1, 1)$ 位置,其差異程度 $\mathit{diff}(A, B) = (a1 - b1)^2 + (a2 - b2)^2 + (a4 - b3)^2 + (a5 - b4)^2$
  • 相同地,對齊 $(2, 1)$ 位置,其差異程度 $\mathit{diff}(A, B) = (a4 - b1)^2 + (a5 - b2)^2 + (a7 - b3)^2 + (a8 - b4)^2$

比較時,整張 $B$ 都要落在 $A$ 中。現在要找到一個對齊位置 $(x, y)$,使得 $\mathit{diff}(A, B)$ 最小。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
4 4 2 2
0 3 3 0
0 3 3 0
0 0 5 5
0 0 5 5
5 5
5 5
3 5 1 2
2 3 0 4 5
0 0 7 0 0
0 0 7 0 0
3 4

Sample Output

1
2
3 3
1 1

Solution

參考資料

相同樸素 FFT 題目 ZOJ 1637 - Fast Image Match

小記

雖然不是很明白 FFT,參考上面的資料,了解 FFT 的旋積 (convolution) 可以在 $O(n \log n)$ 完成就好,接著套模板。

從差異公式中得到$\mathit{diff}(A, B) = \sum (a_i - b_i)^2 = \sum a_{i}^{2} - \sum 2 a_i b_i + \sum b_{i}^{2}$$\sum a_{i}^{2}$$\sum b_{i}^{2}$ 都是獨立的,麻煩的是在於$\sum 2 a_i b_i$ 向量內積,若樸素去計算會在 $O(H^4)$ 完成,套用 FFT 旋積計算,得到 $O(H^2 \log H)$。FFT 有一個缺點,在浮點數的複數域下運行,計算時會失去精準度,要四捨五入到整數。儘管如此,由於不用像 NTT 那樣有很多模運算,速度是最快的。

數論變換 (NTT) / 快速數論變換 (FNT),採用費馬數數論變換,取代複數根的疊加,利用原根的性質來完成。NTT/FNT 處理整數域內積時不存在誤差,所有計算皆在整數,但是要取模變得非常慢。

為了加速計算,丟 CRT 下去降一半的 bits,乘法速度只能提升兩倍,但要做兩次計算,速度稍微快一點點而已。在實作時,特別要注意到,CRT 運作時,挑選兩個質數 $P1, \; P2$ 分別計算 FNT,最後 CRT 逆推回去。

其他實作細節

  1. Fast Reverse Bit 使用位元運算取代建表、迴圈方案。除非多測資。
  2. std::complex<double>struct complex 取代,加速 method interface 拷貝。
  3. sin() cos() 不考慮建表,採用乘法和加法疊加,這會損失一點點精度。若固定測資考慮內存池去搞。

在 O1 下編輯跟 O3 一樣快,由於第二點的貢獻,速度直接從快兩倍。0.3s -> 90ms。

FFT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <bits/stdc++.h>
using namespace std;
template<typename T> class TOOL_FFT {
public:
typedef unsigned int UINT32;
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void FFT(bool InverseTransform, vector<complex<T> >& In, vector<complex<T> >& Out) {
// simultaneous data copy and bit-reversal ordering into outputs
int NumSamples = In.size();
int NumBits = NumberOfBitsNeeded(NumSamples);
for (int i = 0; i < NumSamples; ++i) {
Out[FastReverseBits(i, NumBits)] = In[i];
}
// the FFT process
T angle_numerator = acos(-1.) * (InverseTransform ? -2 : 2);
for (int BlockEnd = 1, BlockSize = 2; BlockSize <= NumSamples; BlockSize <<= 1) {
T delta_angle = angle_numerator / BlockSize;
T sin1 = sin(-delta_angle);
T cos1 = cos(-delta_angle);
T sin2 = sin(-delta_angle * 2);
T cos2 = cos(-delta_angle * 2);
for (int i = 0; i < NumSamples; i += BlockSize) {
complex<T> a1(cos1, sin1), a2(cos2, sin2), a0;
for (int j = i, n = 0; n < BlockEnd; ++j, ++n) {
a0 = complex<T>(2 * cos1 * a1.real() - a2.real(), 2 * cos1 * a1.imag() - a2.imag());
a2 = a1;
a1 = a0;
complex<T> a = a0 * Out[j + BlockEnd];
Out[j + BlockEnd] = Out[j] - a;
Out[j] += a;
}
}
BlockEnd = BlockSize;
}
// normalize if inverse transform
if (InverseTransform) {
for (int i = 0; i < NumSamples; ++i) {
Out[i] /= NumSamples;
}
}
}
vector<T> convolution(T *a, T *b, int n) {
vector<std::complex<T>> s(n), d1(n), d2(n), y(n);
vector<T> ret(n);
for (int i = 0; i < n; ++i) {
s[i] = complex<T>(a[i], 0);
}
FFT(false, s, d1);
s[0] = complex<T>(b[0], 0);
for (int i = 1; i < n; ++i) {
s[i] = complex<T>(b[n - i], 0);
}
FFT(false, s, d2);
for (int i = 0; i < n; ++i) {
y[i] = d1[i] * d2[i];
}
FFT(true, y, s);
for (int i = 0; i < n; ++i) {
ret[i] = s[i].real();
}
return ret;
}
};
TOOL_FFT<double> tool;
double a[262144], b[262144];
long long sum[512][512];
long long getArea(int lx, int ly, int rx, int ry) {
long long ret = sum[rx][ry];
if(lx-1 >= 0)
ret -= sum[lx-1][ry];
if(ly-1 >= 0)
ret -= sum[rx][ly-1];
if(lx-1 >= 0 && ly-1 >= 0)
ret += sum[lx-1][ly-1];
return ret;
}
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
int main() {
int m, n, p, q, x, N, M, S;
while (ReadInt(&m)) {
ReadInt(&n), ReadInt(&p), ReadInt(&q);
N = max(m, p), M = max(n, q);
S = 1;
for (; S < N*M; S <<= 1);
memset(a, 0, sizeof(a[0]) * S);
memset(b, 0, sizeof(b[0]) * S);
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
ReadInt(&x);
a[i*M+j] = x;
}
}
for (int i = 0; i < p; i++) {
for (int j = 0; j < q; j++) {
ReadInt(&x);
b[i*M+j] = x;
}
}
for (int i = 0; i < m; i++) {
long long s = 0;
for (int j = 0; j < n; j++) {
x = a[i*M+j];
s += x*x;
sum[i][j] = (i > 0 ? sum[i-1][j] : 0) + s;
}
}
vector<double> r = tool.convolution(a, b, S);
int qx = m - p, qy = n - q, bX = 0, bY = 0;
long long diff = LONG_MAX;
for (int i = 0; i <= qx; i++) {
for (int j = 0; j <= qy; j++) {
long long v = round(getArea(i, j, i+p-1, j+q-1) - 2*r[i*M + j]);
if (v < diff) {
diff = v, bX = i, bY = j;
}
}
}
printf("%d %d\n", bX+1, bY+1);
}
return 0;
}

FFT 編譯優化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <bits/stdc++.h>
using namespace std;
template<typename T> class TOOL_FFT {
public:
struct complex {
T x, y;
complex(T x = 0, T y = 0):
x(x), y(y) {}
complex operator+(const complex &A) {
return complex(x+A.x,y+A.y);
}
complex operator-(const complex &A) {
return complex(x-A.x,y-A.y);
}
complex operator*(const complex &A) {
return complex(x*A.x-y*A.y,x*A.y+y*A.x);
}
};
typedef unsigned int UINT32;
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void FFT(bool InverseTransform, vector<complex>& In, vector<complex>& Out) {
// simultaneous data copy and bit-reversal ordering into outputs
int NumSamples = In.size();
int NumBits = NumberOfBitsNeeded(NumSamples);
for (int i = 0; i < NumSamples; ++i) {
Out[FastReverseBits(i, NumBits)] = In[i];
}
// the FFT process
T angle_numerator = acos(-1.) * (InverseTransform ? -2 : 2);
for (int BlockEnd = 1, BlockSize = 2; BlockSize <= NumSamples; BlockSize <<= 1) {
T delta_angle = angle_numerator / BlockSize;
T sin1 = sin(-delta_angle);
T cos1 = cos(-delta_angle);
T sin2 = sin(-delta_angle * 2);
T cos2 = cos(-delta_angle * 2);
for (int i = 0; i < NumSamples; i += BlockSize) {
complex a1(cos1, sin1), a2(cos2, sin2), a0, a;
int j, n;
for (j = i, n = 0; n+8 < BlockEnd; ) {
#define UNLOOP {\
a0 = complex(2 * cos1 * a1.x - a2.x, 2 * cos1 * a1.y - a2.y); \
a2 = a1, a1 = a0; \
a = a0 * Out[j + BlockEnd]; \
Out[j + BlockEnd] = Out[j] - a; \
Out[j] = Out[j] + a; \
++j, ++n; }
#define UNLOOP8 {UNLOOP UNLOOP UNLOOP UNLOOP UNLOOP UNLOOP UNLOOP UNLOOP}
UNLOOP8;
}
for (; n < BlockEnd; )
UNLOOP;
}
BlockEnd = BlockSize;
}
// normalize if inverse transform
if (InverseTransform) {
for (int i = 0; i < NumSamples; ++i) {
Out[i] = Out[i].x / NumSamples;
}
}
}
void convolution(T *a, T *b, int n, T *c) {
vector<complex> s(n), d1(n), d2(n), y(n);
for (int i = 0; i < n; ++i)
s[i] = complex(a[i], 0);
FFT(false, s, d1);
s[0] = complex(b[0], 0);
for (int i = 1; i < n; ++i)
s[i] = complex(b[n - i], 0);
FFT(false, s, d2);
for (int i = 0; i < n; ++i)
y[i] = d1[i] * d2[i];
FFT(true, y, s);
for (int i = 0; i < n; ++i)
c[i] = s[i].x;
}
};
TOOL_FFT<double> tool;
double a[262144], b[262144], c[262144];
long long sum[512][512];
inline long long getArea(int lx, int ly, int rx, int ry) {
long long ret = sum[rx][ry];
if(lx) ret -= sum[lx-1][ry];
if(ly) ret -= sum[rx][ly-1];
if(lx && ly) ret += sum[lx-1][ly-1];
return ret;
}
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
int main() {
int m, n, p, q, x, N, M, S;
while (ReadInt(&m)) {
ReadInt(&n), ReadInt(&p), ReadInt(&q);
N = max(m, p), M = max(n, q);
S = 1;
for (; S < N*M; S <<= 1);
memset(a, 0, sizeof(a[0]) * S);
memset(b, 0, sizeof(b[0]) * S);
for (int i = 0; i < m; i++) {
long long s = 0;
for (int j = 0; j < n; j++) {
ReadInt(&x);
a[i*M+j] = x;
s += x*x;
sum[i][j] = (i > 0 ? sum[i-1][j] : 0) + s;
}
}
for (int i = 0; i < p; i++) {
for (int j = 0; j < q; j++) {
ReadInt(&x);
b[i*M+j] = x;
}
}
tool.convolution(a, b, S, c);
int qx = m - p, qy = n - q, bX = 0, bY = 0;
long long diff = LONG_MAX;
for (int i = 0; i <= qx; i++) {
for (int j = 0; j <= qy; j++) {
long long v = getArea(i, j, i+p-1, j+q-1) - 2*c[i*M + j] + 0.5;
if (v < diff)
diff = v, bX = i, bY = j;
}
}
fprintf(stderr, "%lld\n", diff);
printf("%d %d\n", bX+1, bY+1);
}
return 0;
}

NTT/FNT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <bits/stdc++.h>
using namespace std;
typedef unsigned int UINT32;
typedef long long INT64;
class TOOL_NTT {
public:
#define MAXN 262144
const INT64 P = 50000000001507329LL; // prime m = kn+1
const INT64 G = 3;
INT64 wn[20];
INT64 s[MAXN], d1[MAXN], d2[MAXN], y[MAXN];
TOOL_NTT() {
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
INT64 mod_mul(INT64 a, INT64 b, INT64 mod) {
long long y = (long long)((double)a*b/mod+0.5); // fast for P < 2^58
long long r = (a*b-y*mod)%mod;
return r < 0 ? r + mod : r;
// INT64 ret = 0;
// for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
// if (b&1) {
// ret += a;
// if (ret >= mod)
// ret -= mod;
// }
// }
// return ret;
}
INT64 mod_pow(INT64 n, INT64 e, INT64 m) {
INT64 x = 1;
for (n = n >= m ? n%m : n; e; e >>= 1) {
if (e&1)
x = mod_mul(x, n, m);
n = mod_mul(n, n, m);
}
return x;
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void NTT(int on, INT64 *In, INT64 *Out, int n) {
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; ++i)
Out[FastReverseBits(i, NumBits)] = In[i];
for(int h = 2, id = 1; h <= n; h <<= 1, id++) {
for(int j = 0; j < n; j += h) {
INT64 w = 1, u, t;
int block = h/2, blockEnd = j + h/2;
for(int k = j; k < blockEnd; k++) {
u = Out[k], t = mod_mul(w, Out[k+block], P);
Out[k] = u + t;
Out[k+block] = u - t + P;
if (Out[k] >= P) Out[k] -= P;
if (Out[k+block] >= P) Out[k+block] -= P;
w = mod_mul(w, wn[id], P);
}
}
}
if (on == 1) {
for (int i = 1; i < n/2; i++)
swap(Out[i], Out[n-i]);
INT64 invn = mod_pow(n, P-2, P);
for (int i = 0; i < n; i++)
Out[i] = mod_mul(Out[i], invn, P);
}
}
void convolution(INT64 *a, INT64 *b, int n, INT64 *c) {
NTT(0, a, d1, n);
s[0] = b[0];
for (int i = 1; i < n; ++i)
s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++)
s[i] = mod_mul(d1[i], d2[i], P);
NTT(1, s, c, n);
}
} tool;
INT64 a[262144], b[262144], c[262144];
long long sum[512][512];
inline long long getArea(int lx, int ly, int rx, int ry) {
long long ret = sum[rx][ry];
if(lx) ret -= sum[lx-1][ry];
if(ly) ret -= sum[rx][ly-1];
if(lx && ly) ret += sum[lx-1][ly-1];
return ret;
}
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
int main() {
int m, n, p, q, x, N, M, S;
while (ReadInt(&m)) {
ReadInt(&n), ReadInt(&p), ReadInt(&q);
N = max(m, p), M = max(n, q);
S = 1;
for (; S < N*M; S <<= 1);
memset(a, 0, sizeof(a[0]) * S);
memset(b, 0, sizeof(b[0]) * S);
for (int i = 0; i < m; i++) {
long long s = 0;
for (int j = 0; j < n; j++) {
ReadInt(&x);
a[i*M+j] = x;
s += x*x;
sum[i][j] = (i > 0 ? sum[i-1][j] : 0) + s;
}
}
for (int i = 0; i < p; i++) {
for (int j = 0; j < q; j++) {
ReadInt(&x);
b[i*M+j] = x;
}
}
tool.convolution(a, b, S, c);
int qx = m - p, qy = n - q, bX = 0, bY = 0;
long long diff = LONG_MAX;
for (int i = 0; i <= qx; i++) {
for (int j = 0; j <= qy; j++) {
long long v = getArea(i, j, i+p-1, j+q-1) - 2*c[i*M + j];
if (v < diff)
diff = v, bX = i, bY = j;
}
}
fprintf(stderr, "diff = %lld\n", diff);
printf("%d %d\n", bX+1, bY+1);
}
return 0;
}

NTT/FNT CRT 加速

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#include <bits/stdc++.h>
using namespace std;
typedef uint_fast32_t UINT32;
typedef long long INT64;
typedef uint_fast32_t INT32;
class TOOL_NTT {
public:
#define MAXN 262144
// INT64 P = 50000000001507329LL; // prime m = kn+1
// INT64 G = 3;
INT32 P = 3, G = 2;
INT32 wn[20];
INT32 s[MAXN], d1[MAXN], d2[MAXN], c1[MAXN], c2[MAXN];
const INT32 P1 = 998244353; // P1 = 2^23 * 7 * 17 + 1
const INT32 G1 = 3;
const INT32 P2 = 995622913; // P2 = 2^19 *3*3*211 + 1
const INT32 G2 = 5;
const INT64 M1 = 397550359381069386LL;
const INT64 M2 = 596324591238590904LL;
const INT64 MM = 993874950619660289LL; // MM = P1*P2
TOOL_NTT() {
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
void reset(INT32 p, INT32 g) {
P = p, G = g;
for (int i = 0; i < 20; i++)
wn[i] = mod_pow(G, (P-1) / (1<<i), P);
}
INT64 mod_mul(INT64 a, INT64 b, INT64 mod) {
long long y = (long long)((double)a*b/mod+0.5); // fast for P < 2^58
long long r = (a*b-y*mod)%mod;
return r < 0 ? r + mod : r;
// INT64 ret = 0;
// for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
// if (b&1) {
// ret += a;
// if (ret >= mod)
// ret -= mod;
// }
// }
// return ret;
}
INT64 mod_pow(INT64 n, INT64 e, INT64 m) {
INT64 x = 1;
for (n = n >= m ? n%m : n; e; e >>= 1) {
if (e&1)
x = mod_mul(x, n, m);
n = mod_mul(n, n, m);
}
return x;
}
int NumberOfBitsNeeded(int PowerOfTwo) {
for (int i = 0;; ++i) {
if (PowerOfTwo & (1 << i)) {
return i;
}
}
}
inline UINT32 FastReverseBits(UINT32 a, int NumBits) {
a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
return a >> (32 - NumBits);
}
void NTT(int on, INT32 *In, INT32 *Out, int n) {
int NumBits = NumberOfBitsNeeded(n);
for (int i = 0; i < n; ++i)
Out[FastReverseBits(i, NumBits)] = In[i];
for (int h = 2, id = 1; h <= n; h <<= 1, id++) {
for (int j = 0; j < n; j += h) {
INT32 w = 1, u, t;
int block = h/2, blockEnd = j + h/2;
for (int k = j; k < blockEnd; k++) {
u = Out[k], t = (INT64)w*Out[k+block]%P;
Out[k] = (u + t)%P;
Out[k+block] = (u - t + P)%P;
w = (INT64)w * wn[id]%P;
}
}
}
if (on == 1) {
for (int i = 1; i < n/2; i++)
swap(Out[i], Out[n-i]);
INT32 invn = mod_pow(n, P-2, P);
for (int i = 0; i < n; i++)
Out[i] = (INT64)Out[i]*invn%P;
}
}
INT64 crt(INT32 a, INT32 b) {
return (mod_mul(a, M1, MM) + mod_mul(b, M2, MM))%MM;
}
void convolution(INT32 *a, INT32 *b, int n, INT64 *c) {
reset(P1, G1);
NTT(0, a, d1, n);
s[0] = b[0]; for (int i = 1; i < n; ++i) s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++) s[i] = (INT64)d1[i] * d2[i]%P;
NTT(1, s, c1, n);
reset(P2, G2);
NTT(0, a, d1, n);
s[0] = b[0]; for (int i = 1; i < n; ++i) s[i] = b[n-i];
NTT(0, s, d2, n);
for (int i = 0; i < n; i++) s[i] = (INT64)d1[i] * d2[i]%P;
NTT(1, s, c2, n);
for (int i = 0; i < n; i++)
c[i] = crt(c1[i], c2[i]);
}
} tool;
INT32 a[262144], b[262144];
INT64 c[262144];
long long sum[512][512];
inline long long getArea(int lx, int ly, int rx, int ry) {
long long ret = sum[rx][ry];
if(lx) ret -= sum[lx-1][ry];
if(ly) ret -= sum[rx][ly-1];
if(lx && ly) ret += sum[lx-1][ly-1];
return ret;
}
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
int main() {
int m, n, p, q, x, N, M, S;
while (ReadInt(&m)) {
ReadInt(&n), ReadInt(&p), ReadInt(&q);
N = max(m, p), M = max(n, q);
S = 1;
for (; S < N*M; S <<= 1);
memset(a, 0, sizeof(a[0]) * S);
memset(b, 0, sizeof(b[0]) * S);
for (int i = 0; i < m; i++) {
long long s = 0;
for (int j = 0; j < n; j++) {
ReadInt(&x);
a[i*M+j] = x;
s += x*x;
sum[i][j] = (i > 0 ? sum[i-1][j] : 0) + s;
}
}
for (int i = 0; i < p; i++) {
for (int j = 0; j < q; j++) {
ReadInt(&x);
b[i*M+j] = x;
}
}
tool.convolution(a, b, S, c);
int qx = m - p, qy = n - q, bX = 0, bY = 0;
long long diff = LONG_MAX;
for (int i = 0; i <= qx; i++) {
for (int j = 0; j <= qy; j++) {
long long v = getArea(i, j, i+p-1, j+q-1) - 2*c[i*M + j];
if (v < diff)
diff = v, bX = i, bY = j;
}
}
fprintf(stderr, "diff = %lld\n", diff);
printf("%d %d\n", bX+1, bY+1);
}
return 0;
}
Read More +

b325. 人格分裂

Problem

某 M 現在正在平面座標上的原點 $(0, 0)$,現在四周被擺放了很多很多鏡子,某 M 可以藉由鏡子與他的人格小夥伴對話,請問那些鏡子可以見到小夥伴。

鏡子可以當作一個線段,線段之間不會交任何一點,只要能見到該鏡子中一小段區域就算可見到。

備註:不考慮反射看到,保證鏡子不會通過原點。

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
1
5 -5 5 5
4
4 2 5 -2
2 4 6 1
5 5 8 1
3 -4 7 -1
7
-1 2 3 1
2 4 5 -1
-3 -1 1 -2
-1 -4 3 -2
-2 -4 1 -5
-4 1 -1 4
-3 4 -4 3
1
1 1 2 2
2
-1 3 -2 -2
-2 -1 -3 -1

Sample Output

1
2
3
4
5
1
1 1 0 1
1 1 1 1 0 1 0
0
1 0

Solution

對於每一個線段,可以化作極角座標上的一個角度區間$[\theta_{start}, \theta_{end}]$,做一次極角排序,維護從原點射出的射線,找到該射線交到的所有角度區間,意即維護射線和線段交的最近距離,用一個平衡樹 set<Seg> 維護。由於線段之間不會相交,平衡樹靠遠近當作權重比較,遠近關係是單調的,故不影響插入和刪除。若發生遠近問題可採用 multiset<Seg> 來進行。

由於每一個線段可能拆分好幾個可視線段,而大多數的線段全是不可視的,可以考慮分治去處理,但我的實作效果並不好,其原因在於並沒有維護極角的 skyline,而是保留整個線段。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <bits/stdc++.h>
using namespace std;
#define eps 1e-12
struct Pt {
double x, y;
Pt(double a = 0, double b = 0):
x(a), y(b) {}
Pt operator-(const Pt &a) const {
return Pt(x - a.x, y - a.y);
}
Pt operator+(const Pt &a) const {
return Pt(x + a.x, y + a.y);
}
Pt operator*(const double a) const {
return Pt(x * a, y * a);
}
bool operator<(const Pt &a) const {
if (fabs(x - a.x) > eps)
return x < a.x;
if (fabs(y - a.y) > eps)
return y < a.y;
return false;
}
double dist2(Pt a) {
return (x - a.x)*(x - a.x)+(y - a.y)*(y - a.y);
}
};
double dot(Pt a, Pt b) {
return a.x * b.x + a.y * b.y;
}
double cross(Pt o, Pt a, Pt b) {
return (a.x-o.x)*(b.y-o.y)-(a.y-o.y)*(b.x-o.x);
}
double cross2(Pt a, Pt b) {
return a.x * b.y - a.y * b.x;
}
int between(Pt a, Pt b, Pt c) {
return dot(c - a, b - a) >= -eps && dot(c - b, a - b) >= -eps;
}
int onSeg(Pt a, Pt b, Pt c) {
return between(a, b, c) && fabs(cross(a, b, c)) < eps;
}
Pt getIntersect(Pt as, Pt ae, Pt bs, Pt be) {
Pt u = as - bs;
double t = cross2(be - bs, u)/cross2(ae - as, be - bs);
return as + (ae - as) * t;
}
struct Seg {
Pt s, e;
int id;
Seg(Pt a = Pt(), Pt b = Pt(), int c = 0):
s(a), e(b), id(c) {}
};
bool polar_cmp(const Pt& p1, const Pt& p2) {
if (p1.y == 0 && p2.y == 0 && p1.x * p2.x <= 0) return p1.x > p2.x;
if (p1.y == 0 && p1.x >= 0 && p2.y != 0) return true;
if (p2.y == 0 && p2.x >= 0 && p1.y != 0) return false;
if (p1.y * p2.y < 0) return p1.y > p2.y;
double c = cross2(p1, p2);
return c > 0 || (c == 0 && fabs(p1.x) < fabs(p2.x));
}
bool polar_cmp2(pair<Pt, int> x, pair<Pt, int> y) {
return polar_cmp(x.first, y.first);
}
int cmpZero(double x) {
if (fabs(x) < eps) return 0;
return x < 0 ? -1 : 1;
}
struct CMP {
static Pt ray_s, ray_e;
bool operator()(const Seg &x, const Seg &y) {
Pt v1 = getIntersect(ray_s, ray_e, x.s, x.e);
Pt v2 = getIntersect(ray_s, ray_e, y.s, y.e);
return cmpZero(ray_s.dist2(v1) - ray_s.dist2(v2)) < 0;
}
static bool ray2seg(Seg x) {
if (cmpZero(cross(ray_s, ray_e, x.s))*cmpZero(cross(ray_s, ray_e, x.e)) < 0) {
return cmpZero(cross(x.s, ray_s, ray_s+ray_e))*cmpZero(cross(x.s, ray_s, x.e)) >= 0 &&
cmpZero(cross(x.e, ray_s, ray_s+ray_e))*cmpZero(cross(x.e, ray_s, x.s)) >= 0;
}
return false;
}
};
Pt CMP::ray_s, CMP::ray_e;
const int MAXN = 32768;
int visual[MAXN];
Seg segs[MAXN];
int main() {
int N, sx, sy, ex, ey;
while (scanf("%d", &N) == 1) {
vector< pair<Pt, int> > A;
set<Seg, CMP> S;
for (int i = 1; i <= N; i++) {
scanf("%d %d %d %d", &sx, &sy, &ex, &ey);
A.push_back(make_pair(Pt(sx, sy), i));
A.push_back(make_pair(Pt(ex, ey), -i));
segs[i] = Seg(Pt(sx, sy), Pt(ex, ey), i);
visual[i] = 0;
}
sort(A.begin(), A.end(), polar_cmp2);
CMP::ray_s = Pt(0, 0), CMP::ray_e = A[0].first;
for (int i = 1; i <= N; i++) {
if (CMP::ray2seg(segs[i]))
S.insert(segs[i]);
}
for (int i = 0; i < A.size(); ) {
CMP::ray_e = A[i].first;
while (i < A.size() && cmpZero(cross(CMP::ray_s, CMP::ray_e, A[i].first)) == 0) {
int clockwise, id = abs(A[i].second);
if (A[i].second > 0)
clockwise = cmpZero(cross(CMP::ray_s, segs[id].s, segs[id].e));
else
clockwise = cmpZero(cross(CMP::ray_s, segs[id].e, segs[id].s));
if (clockwise) {
if (clockwise > 0)
S.insert(segs[id]);
else
S.erase(segs[id]);
}
i++;
}
if (S.size() > 0)
visual[S.begin()->id] = 1;
}
for (int i = 1; i <= N; i++)
printf("%d%c", visual[i], i == N ? '\n' : ' ');
}
return 0;
}

附錄 DC

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include <bits/stdc++.h>
using namespace std;
#define eps 1e-9
struct Pt {
double x, y;
Pt(double a = 0, double b = 0):
x(a), y(b) {}
Pt operator-(const Pt &a) const {
return Pt(x - a.x, y - a.y);
}
Pt operator+(const Pt &a) const {
return Pt(x + a.x, y + a.y);
}
Pt operator*(const double a) const {
return Pt(x * a, y * a);
}
bool operator<(const Pt &a) const {
if (fabs(x - a.x) > eps)
return x < a.x;
if (fabs(y - a.y) > eps)
return y < a.y;
return false;
}
double dist2(Pt a) {
return (x - a.x)*(x - a.x)+(y - a.y)*(y - a.y);
}
};
double dot(Pt a, Pt b) {
return a.x * b.x + a.y * b.y;
}
double cross(Pt o, Pt a, Pt b) {
return (a.x-o.x)*(b.y-o.y)-(a.y-o.y)*(b.x-o.x);
}
double cross2(Pt a, Pt b) {
return a.x * b.y - a.y * b.x;
}
int between(Pt a, Pt b, Pt c) {
return dot(c - a, b - a) >= -eps && dot(c - b, a - b) >= -eps;
}
int onSeg(Pt a, Pt b, Pt c) {
return between(a, b, c) && fabs(cross(a, b, c)) < eps;
}
Pt getIntersect(Pt as, Pt ae, Pt bs, Pt be) {
Pt u = as - bs;
double t = cross2(be - bs, u)/cross2(ae - as, be - bs);
return as + (ae - as) * t;
}
struct Seg {
Pt s, e;
int id;
Seg(Pt a = Pt(), Pt b = Pt(), int c = 0):
s(a), e(b), id(c) {}
};
bool polar_cmp(const Pt& p1, const Pt& p2) {
if (p1.y == 0 && p2.y == 0 && p1.x * p2.x <= 0) return p1.x > p2.x;
if (p1.y == 0 && p1.x >= 0 && p2.y != 0) return true;
if (p2.y == 0 && p2.x >= 0 && p1.y != 0) return false;
if (p1.y * p2.y < 0) return p1.y > p2.y;
double c = cross2(p1, p2);
return c > 0 || (c == 0 && fabs(p1.x) < fabs(p2.x));
}
bool polar_cmp2(pair<Pt, int> x, pair<Pt, int> y) {
return polar_cmp(x.first, y.first);
}
int cmpZero(double x) {
if (fabs(x) < eps) return 0;
return x < 0 ? -1 : 1;
}
struct CMP {
static Pt ray_s, ray_e;
bool operator()(const Seg &x, const Seg &y) {
Pt v1 = getIntersect(ray_s, ray_e, x.s, x.e);
Pt v2 = getIntersect(ray_s, ray_e, y.s, y.e);
return cmpZero(ray_s.dist2(v1) - ray_s.dist2(v2)) < 0;
}
static bool ray2seg(Seg x) {
if (cmpZero(cross(ray_s, ray_e, x.s))*cmpZero(cross(ray_s, ray_e, x.e)) < 0) {
return cmpZero(cross(x.s, ray_s, ray_s+ray_e))*cmpZero(cross(x.s, ray_s, x.e)) >= 0 &&
cmpZero(cross(x.e, ray_s, ray_s+ray_e))*cmpZero(cross(x.e, ray_s, x.s)) >= 0;
}
return false;
}
};
Pt CMP::ray_s, CMP::ray_e;
const int MAXN = 32768;
int visual[MAXN];
Seg mirror[MAXN], sm[MAXN];
vector<Seg> computePolar(vector<Seg> segs) {
if (segs.size() == 0)
return vector<Seg>();
vector< pair<Pt, int> > A;
set<Seg, CMP> S;
for (int i = 0; i < segs.size(); i++) {
A.push_back(make_pair(segs[i].s, segs[i].id));
A.push_back(make_pair(segs[i].e, -segs[i].id));
visual[segs[i].id] = 0;
}
sort(A.begin(), A.end(), polar_cmp2);
CMP::ray_s = Pt(0, 0), CMP::ray_e = A[0].first;
for (int i = 0; i < segs.size(); i++) {
if (CMP::ray2seg(segs[i]))
S.insert(segs[i]);
}
for (int i = 0; i < A.size(); ) {
CMP::ray_e = A[i].first;
while (i < A.size() && cmpZero(cross(CMP::ray_s, CMP::ray_e, A[i].first)) == 0) {
int clockwise, id = abs(A[i].second);
if (A[i].second > 0)
clockwise = cmpZero(cross(CMP::ray_s, sm[id].s, sm[id].e));
else
clockwise = cmpZero(cross(CMP::ray_s, sm[id].e, sm[id].s));
if (clockwise) {
if (clockwise > 0)
S.insert(sm[id]);
else
S.erase(sm[id]);
}
i++;
}
if (S.size() > 0)
visual[S.begin()->id] = 1;
}
vector<Seg> ret;
for (int i = 0; i < segs.size(); i++) {
if (visual[segs[i].id])
ret.push_back(segs[i]);
}
return ret;
}
vector<Seg> dfs(int l, int r) {
vector<Seg> L, R;
if (l > r) return L;
if (l == r)
return computePolar(vector<Seg>(mirror+l, mirror+l+1));
int mid = (l+r)/2;
L = dfs(l, mid);
R = dfs(mid+1, r);
L.insert(L.end(), R.begin(), R.end());
return computePolar(L);
}
bool cmp(Seg a, Seg b) {
return polar_cmp(a.s, b.s);
}
int main() {
int N, sx, sy, ex, ey;
while (scanf("%d", &N) == 1) {
for (int i = 1; i <= N; i++) {
scanf("%d %d %d %d", &sx, &sy, &ex, &ey);
Pt s(sx, sy), e(ex, ey);
if (polar_cmp(s, e))
swap(s, e);
sm[i] = mirror[i] = Seg(s, e, i);
}
sort(mirror+1, mirror+N, cmp);
dfs(1, N);
for (int i = 1; i <= N; i++)
printf("%d%c", visual[i], i == N ? '\n' : ' ');
}
return 0;
}
Read More +

a739. 道路架設

Problem

給定一個有根樹,詢問兩個節點 $(u, v)$ 的距離,點數最多 $V = 2000000$

Sample Input

1
2
3
4
5
6
7
8
4
1 3
1 7
2 5
3
1 4
2 3
1 4

Sample Output

1
2
8
10

Solution

明顯地不可能使用 LCA-RMQ$O(Q \log V)$ 在線詢問,要套用 tarjan 算法中的 Offline-LCA,但 tarjan 算法靠的是遞迴遍歷,在 V = 20000000 很容易發生 stackoverflow,為了解決這一點採用非遞迴的方式實作。

這一題還有特別卡的地方,有向邊的儲存,之前 tarjan 都用在無向樹上,所以普遍都會儲存 $2 E$ 條邊,這一題必須只儲存 $E$ 條邊,常數卡得緊,不掛上優化輸入時間限制是在 TLE 邊緣。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#include <bits/stdc++.h>
using namespace std;
const int MAXV = 2000005;
const int MAXQ = 100005;
const int MAXE = 2000005;
class LCA {
public:
struct Edge {
int v;
Edge *next;
};
struct QEdge {
int qid, u;
QEdge *next;
};
Edge edge[MAXE], *adj[MAXV], *arc[MAXV];
QEdge qedge[MAXQ<<1], *qadj[MAXV], *qarc[MAXV];
int e, eq, n;
int parent[MAXV], weight[MAXV], visited[MAXV], LCA[MAXQ];
void init(int n) {
e = eq = 0, this->n = n;
for (int i = 0; i < n; i++)
adj[i] = NULL, qadj[i] = NULL;
}
void addDedge(int x, int y) {
edge[e].v = y, edge[e].next = adj[x], adj[x] = &edge[e++];
}
void addQuery(int x, int y, int qid) {
qedge[eq].qid = qid, qedge[eq].u = y, qedge[eq].next = qadj[x], qadj[x] = &qedge[eq++];
qedge[eq].qid = qid, qedge[eq].u = x, qedge[eq].next = qadj[y], qadj[y] = &qedge[eq++];
}
void offline(int root) {
tarjan(root);
}
private:
int findp(int x) {
return parent[x] == x ? x : (parent[x] = findp(parent[x]));
}
int joint(int x, int y) {
x = findp(x), y = findp(y);
if(x == y) return 0;
if(weight[x] > weight[y])
weight[x] += weight[y], parent[y] = x;
else
weight[y] += weight[x], parent[x] = y;
return 1;
}
struct Node {
int u, p, line;
Node(int a = 0, int b = 0, int c = 0):
u(a), p(b), line(c) {}
};
void tarjan(int root) {
for (int i = 0; i < n; i++)
arc[i] = adj[i], qarc[i] = qadj[i], visited[i] = 0;
stack<Node> stk;
Node u;
int x, y;
stk.push(Node(root, -1, 0));
parent[root] = root;
while (!stk.empty()) {
u = stk.top(), stk.pop();
if (u.line == 0) {
if (arc[u.u]) {
y = arc[u.u]->v, arc[u.u] = arc[u.u]->next;
stk.push(u);
parent[y] = y;
stk.push(Node(y, u.u, 0));
} else {
visited[u.u] = 1;
u.line++;
stk.push(u);
}
} else {
if (qarc[u.u]) {
x = qarc[u.u]->qid, y = qarc[u.u]->u, qarc[u.u] = qarc[u.u]->next;
stk.push(u);
if (visited[y])
LCA[x] = findp(y);
} else {
if (u.p != -1)
parent[findp(u.u)] = u.p;
}
}
}
}
} lca;
int A[MAXQ], B[MAXQ], dist[MAXV];
long long C[MAXV];
struct Node {
int u, line;
Node(int a = 0, int b = 0):
u(a), line(b) {}
};
void dfs(int root) {
for (int i = 0; i < lca.n; i++)
lca.arc[i] = lca.adj[i];
stack<Node> stk;
Node u;
int x, y;
dist[0] = 0, C[0] = 0;
stk.push(Node(root, 0));
while (!stk.empty()) {
u = stk.top(), stk.pop();
if (lca.arc[u.u] != NULL) {
y = lca.arc[u.u]->v;
lca.arc[u.u] = lca.arc[u.u]->next;
stk.push(u);
dist[y] = dist[u.u]+1;
C[y] += C[u.u];
stk.push(Node(y, 0));
}
}
}
namespace mLocalStream {
inline int readchar() {
const int N = 1048576;
static char buf[N];
static char *p = buf, *end = buf;
if(p == end) {
if((end = buf + fread(buf, 1, N, stdin)) == buf) return EOF;
p = buf;
}
return *p++;
}
inline int ReadInt(int *x) {
static char c, neg;
while((c = readchar()) < '-') {if(c == EOF) return 0;}
neg = (c == '-') ? -1 : 1;
*x = (neg == 1) ? c-'0' : 0;
while((c = readchar()) >= '0')
*x = (*x << 3) + (*x << 1) + c-'0';
*x *= neg;
return 1;
}
class Print {
public:
static const int N = 1048576;
char buf[N], *p, *end;
Print() {
p = buf, end = buf + N - 1;
}
void printInt(int x, char padding) {
static char stk[16];
int idx = 0;
stk[idx++] = padding;
if (!x)
stk[idx++] = '0';
while (x)
stk[idx++] = x%10 + '0', x /= 10;
while (idx) {
if (p == end) {
*p = '\0';
printf("%s", buf), p = buf;
}
*p = stk[--idx], p++;
}
}
static inline void online_printInt(int x) {
static char ch[16];
static int idx;
idx = 0;
if (x == 0) ch[++idx] = 0;
while (x > 0) ch[++idx] = x % 10, x /= 10;
while (idx)
putchar(ch[idx--]+48);
}
~Print() {
*p = '\0';
printf("%s", buf);
}
} bprint;
}
int main() {
int N, Q, P, c;
mLocalStream::ReadInt(&N);
lca.init(N);
for (int i = 1; i < N; i++) {
mLocalStream::ReadInt(&P);
mLocalStream::ReadInt(&c);
C[i] = c, P--;
lca.addDedge(P, i);
}
mLocalStream::ReadInt(&Q);
for (int i = 0; i < Q; i++) {
mLocalStream::ReadInt(A+i);
mLocalStream::ReadInt(B+i);
A[i]--, B[i]--;
lca.addQuery(A[i], B[i], i);
}
dfs(0);
lca.offline(0);
int lazy = 0;
long long d1, d2;
for (int i = 0; i < Q; i++) {
if (A[i] == B[i] || lca.LCA[i] != A[i])
lazy++;
else {
d1 = dist[A[i]] + dist[B[i]] - 2*dist[lca.LCA[i]];
d2 = C[A[i]] + C[B[i]] - 2*C[lca.LCA[i]];
printf("%lld\n", d1*lazy + d2);
}
}
return 0;
}
Read More +