批改娘 20020. Dot Product

contents

  1. 1. 題目描述
    1. 1.1. main.c
  2. 2. 輸入格式
  3. 3. 輸出格式
  4. 4. 範例輸入
  5. 5. 範例輸出
  6. 6. 編譯參數
  7. 7. 參考資料
  8. 8. Solution
    1. 8.1. AVX
    2. 8.2. SSE

題目描述

請嘗試使用 SIMD 技術 AVX/SSE/MMX 來加速以下的純數值計算。

main.c

最低需求加速兩倍快

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <stdio.h>
#include <assert.h>
#include <inttypes.h>
#include <stdint.h>

static inline uint32_t rotate_left(uint32_t x, uint32_t n) {
return (x << n) | (x >> (32-n));
}
static inline uint32_t encrypt(uint32_t m, uint32_t key) {
return (rotate_left(m, key&31) + key)^key;
}

static uint32_t f(int N, int off, uint32_t key1, uint32_t key2) {
uint32_t sum = 0;
for (int i = 0, j = off; i < N; i++, j++)
sum += encrypt(j, key1) * encrypt(j, key2), i++, j++;
return sum;
}
int main() {
int N;
uint32_t key1, key2;
while (scanf("%d %" PRIu32 " %" PRIu32, &N, &key1, &key2) == 3) {
uint32_t sum = f(N, 0, key1, key2);
printf("%" PRIu32 "\n", sum);
}
return 0;
}

輸入格式

有多組測資,每組一行包含三個整數 $N, \; \text{key1}, \; \text{key2}$,表示向量長度 $N$、向量 $\vec{A}$ 由亂數種子 $\text{key1}$ 產生、向量 $\vec{B}$ 由亂數種子 $\text{key2}$ 產生。

  • $1 \le N \le 16777216$

輸出格式

對於每組測資輸出一行整數,為 $\vec{A} \cdot \vec{B}$ 的 unsigned 32-bit integer 結果。

範例輸入

1
2
16777216 1 2
16777216 3 5

範例輸出

1
2
2885681152
2147483648

編譯參數

1
gcc -std=c99 -O3 -march=native main.c -lm

參考資料

Solution

對於數值計算時,SIMD 能充分地加速程序,每一個元素皆經過一連串的數學函數計算,那麼把一連串的數學式拆分,化成最簡的邏輯計算,並且找出常數向量存放到暫存器中。如在影像處理的程序,即使沒有 GPU 幫忙,搭配 SIMD 也是個不錯的選擇,可以加速 2~4 倍之多。

為了凸顯效能差異,題目設計時必須在計算單一元素結果上複雜些,防止大部分的加速效果是來自於減少 branch 操作 (loop unrolling)。

AVX

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <stdio.h>
#include <assert.h>
#include <inttypes.h>
#include <stdint.h>
#include <x86intrin.h>

static inline uint32_t rotate_left(uint32_t x, uint32_t n) {
return (x << n) | (x >> (32-n));
}
static inline uint32_t encrypt(uint32_t m, uint32_t key) {
return (rotate_left(m, key&31) + key)^key;
}

static uint32_t SSE(int N, int off, uint32_t key1, uint32_t key2) {
uint32_t sum = 0;
for (int i = (N>>3)<<3; i < N; i++)
sum += encrypt(i+off, key1) * encrypt(i+off, key2);
__m256i s_i = _mm256_set_epi32(off, off+1, off+2, off+3, off+4, off+5, off+6, off+7);
__m256i s_4 = _mm256_set_epi32(8, 8, 8, 8, 8, 8, 8, 8);
__m256i s_k1 = _mm256_set_epi32(key1, key1, key1, key1, key1, key1, key1, key1);
__m256i s_k2 = _mm256_set_epi32(key2, key2, key2, key2, key2, key2, key2, key2);
uint32_t modk1 = key1&31;
uint32_t modk2 = key2&31;
uint32_t cmodk1 = 32 - modk1;
uint32_t cmodk2 = 32 - modk2;
__m256i s_ret = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, 0);
N >>= 3;
for (int it = 0; it < N; it++) {
__m256i r_1 = _mm256_or_si256(_mm256_slli_epi32(s_i, modk1), _mm256_srli_epi32(s_i, cmodk1));
r_1 = _mm256_xor_si256(_mm256_add_epi32(r_1, s_k1), s_k1);
__m256i r_2 = _mm256_or_si256(_mm256_slli_epi32(s_i, modk2), _mm256_srli_epi32(s_i, cmodk2));
r_2 = _mm256_xor_si256(_mm256_add_epi32(r_2, s_k2), s_k2);
__m256i r_m = _mm256_mullo_epi32(r_1, r_2);
s_ret = _mm256_add_epi32(s_ret, r_m);
s_i = _mm256_add_epi32(s_i, s_4);
}
{
static int32_t tmp[8] __attribute__ ((aligned (32)));
_mm256_store_si256((__m256i*) &tmp[0], s_ret);
sum += tmp[0] + tmp[1] + tmp[2] + tmp[3];
sum += tmp[4] + tmp[5] + tmp[6] + tmp[7];
}
return sum;
}
int main() {
int N;
uint32_t key1, key2;
while (scanf("%d %" PRIu32 " %" PRIu32, &N, &key1, &key2) == 3) {
uint32_t sum = SSE(N, 0, key1, key2);
printf("%" PRIu32 "\n", sum);

}
return 0;
}

SSE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <stdio.h>
#include <assert.h>
#include <inttypes.h>
#include <stdint.h>
#include <x86intrin.h>

static inline uint32_t rotate_left(uint32_t x, uint32_t n) {
return (x << n) | (x >> (32-n));
}
static inline uint32_t encrypt(uint32_t m, uint32_t key) {
return (rotate_left(m, key&31) + key)^key;
}

static uint32_t SSE(int N, int off, uint32_t key1, uint32_t key2) {
uint32_t sum = 0;
for (int i = N/4*4; i < N; i++)
sum += encrypt(i+off, key1) * encrypt(i+off, key2);
__m128i s_i = _mm_set_epi32(off, off+1, off+2, off+3);
__m128i s_4 = _mm_set_epi32(4, 4, 4, 4);
__m128i s_k1 = _mm_set_epi32(key1, key1, key1, key1);
__m128i s_k2 = _mm_set_epi32(key2, key2, key2, key2);
uint32_t modk1 = key1&31;
uint32_t modk2 = key2&31;
uint32_t cmodk1 = 32 - modk1;
uint32_t cmodk2 = 32 - modk2;
__m128i s_ret = _mm_set_epi32(0, 0, 0, 0);
N >>= 2;
for (int it = 0; it < N; it++) {
__m128i r_1 = _mm_or_si128(_mm_slli_epi32(s_i, modk1), _mm_srli_epi32(s_i, cmodk1));
r_1 = _mm_xor_si128(_mm_add_epi32(r_1, s_k1), s_k1);
__m128i r_2 = _mm_or_si128(_mm_slli_epi32(s_i, modk2), _mm_srli_epi32(s_i, cmodk2));
r_2 = _mm_xor_si128(_mm_add_epi32(r_2, s_k2), s_k2);
__m128i r_m = _mm_mullo_epi32(r_1, r_2);
s_ret = _mm_add_epi32(s_ret, r_m);
s_i = _mm_add_epi32(s_i, s_4);
}
{
static int32_t tmp[4] __attribute__ ((aligned (16)));
_mm_store_si128((__m128i*) &tmp[0], s_ret);
sum += tmp[0] + tmp[1] + tmp[2] + tmp[3];
}
return sum;
}
int main() {
int N;
uint32_t key1, key2;
while (scanf("%d %" PRIu32 " %" PRIu32, &N, &key1, &key2) == 3) {
uint32_t sum = SSE(N, 0, key1, key2);
printf("%" PRIu32 "\n", sum);

}
return 0;
}