b431. 中國餘數

contents

  1. 1. Problem
    1. 1.1. 背景
    2. 1.2. 問題描述
  2. 2. Sample Output
  3. 3. Solution
    1. 3.1. 番外
    2. 3.2. 一般解
    3. 3.3. 循環解

Problem

背景

中國餘數定理 (Chinese Remainder Theorem,簡稱 CRT) 經常是工程學裡面常用的一種轉換域,很多人不知道當初在大學離散數學中學這個做什麼,但是在不少計算的設計都會運用到 CRT。由於電腦 CPU 架構中的運算單位是 32-bits 或者 64-bits (也許在未來會更長),但值域高達 128-bits 或者 512-bits 以上模擬運算成了麻煩之處。

回顧中國餘數定理 CRT

$$(S): \left\{\begin{matrix} x \equiv a_1 \mod m_1 \\ x \equiv a_2 \mod m_2 \\ \vdots \\ x \equiv a_n \mod m_n \end{matrix}\right.$$
  • $m_1, m_2, \cdots , m_n$ 任兩數互質,意即 $\forall i \neq j, gcd(m_i, m_j) = 1$
  • 對於任意整數 $a_1, a_2, \cdots , a_n$ 方程組 $(S)$ 均有解,意即找得到一個 $x$ 的參數解。

構造法解 CRT

  1. $M = m_1 \times m_2 \times \cdots \times m_n = \prod_{i=1}^{n} m_i$
  2. $M_i = M / m_i$
  3. $t_i = M_i^{-1} \mod m_i<span>$,意即$</span><!-- Has MathJax -->t_i M_i \equiv 1 \mod m_i$
  4. 方程組 $(S)$ 的通解形式為: $x = a_1 t_1 M_1 + a_2 t_2 M_2 + \cdots + a_n t_n M_n + kM = kM + \sum_{i = 1}^{n} a_i t_i M_i, k \in \mathbb{Z}$
  5. 若限定 $0 \le x < M$,則 $x$ 只有一解。

很多人會納悶通解為什麼長那樣,原因很簡單,要滿足方程組每一條式子,勢必對於$a_i t_i M_i$ 要滿足$x \equiv a_i \mod m_i$ ,因此 $a_i t_i M_i \equiv a_i (t_i M_i) \mod m_i \equiv a_i \mod m_i$ 成立,但是$\forall i \neq j$,滿足$a_i t_i M_i \equiv a_i (t_i M_i) \mod m_j \equiv 0 \mod m_j$

問題描述

來個簡單運用,來計算簡單的 RSA 加解密,特化其中的數學運算。

$M \equiv C^d \mod n$ $n = p \times q$,其中 $p, q$ 是兩個不同的質數,已知 $C, d, p, q$,請求出 $M$

## Sample Input ##
1
2
88 7 17 11
11 23 17 11

Sample Output

1
2
11
88

Solution

RSA 可以預先處理

  • $c_p = q \times (q^{-1} \mod p)$
  • $c_q = p \times (p^{-1} \mod q)$

還原的算法則是 $M = Mp \times c_p + Mq \times c_q \mod N$

由於拆分後的 bit length 少一半,乘法速度快 4 倍,快速冪次方快 2 倍 (次方的 bit length 少一半),但是要算 2 次,最後共計快 4 倍。CPU 的乘法想必不會用快速傅立葉 FFT 來達到乘法速度為 $O(n \log n)$

特別小心「 bit length 少一半 」必須在 $gcd(C, p) = 1$ 時才成立,互質機率機率非常高,仍然有不成立的時候,這情況下速度不是加快 400%。請參照一般解。

例如 C = 27522, d = 17132, p = 2, q = 17293,若使用歐拉定理計算 $Mp = C^{d \mod (p-1)} \mod p = 1$ 事實上 $Mp = C^d \mod p = 0$。有一個特性尚未被利用,對於模質數 $p$ 而言,所有數 $x$ 在模 $p$ 下的循環長度 $L | (p-1)$,最後可以套用 $Mp = C^{d \mod (p-1) + (p-1)} \mod p$ 來完成。請參照循環解,如此一來就不必先判斷 $gcd(C, p) = 1$

番外

但是利用 CRT 計算容易受到硬體上攻擊,因為會造成 $p, q$ 在分解過程中獨立出現,當初利用 $N$ 很難被分解的特性來達到資訊安全,但是卻因為加速把 $p, q$ 存放道不同時刻的暫存器中。

其中一種攻擊,計算得到 $M = Mp \times q \times (q^{-1} \mod p) + Mq \times p \times (p^{-1} \mod q) \mod N$ 當擾亂後面的式子 (提供不穩定的計算結果)。得到 $M' = Mp \times q \times (q^{-1} \mod p) + (Mq + \Delta) \times p \times (p^{-1} \mod q) \mod N$

接著 $(M' - M) = (\Delta' \times p) \mod N$,若要求 $p$ 的方法為 $gcd(M' - M, N) = gcd(\Delta' \times p, N) = p$,輾轉相除的概念呼之欲出,原來 $p$ 會被這麼夾出,當得到兩個 $p, q$,RSA 算法就會被破解。

一般解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <bits/stdc++.h>
using namespace std;
long long mul(long long a, long long b, long long mod) {
long long ret = 0;
for (a %= mod, b %= mod; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod) ret -= mod;
}
}
return ret;
}
void exgcd(long long x, long long y, long long &g, long long &a, long long &b) {
if (y == 0)
g = x, a = 1, b = 0;
else
exgcd(y, x%y, g, b, a), b -= (x/y) * a;
}
long long llgcd(long long x, long long y) {
long long t;
while (x%y)
t = x, x = y, y = t%y;
return y;
}
long long inverse(long long x, long long p) {
long long g, b, r;
exgcd(x, p, g, r, b);
if (g < 0) r = -r;
return (r%p + p)%p;
}
long long mpow(long long x, long long y, long long mod) {
long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
int main() {
long long C, d, p, q;
long long N, M, Cp, Cq, Mp, Mq;
while (scanf("%lld %lld %lld %lld", &C, &d, &p, &q) == 4) {
N = p * q;
Mp = mpow(C%p, llgcd(C, p) == 1 ? d%(p-1) : d, p);
Mq = mpow(C%q, llgcd(C, q) == 1 ? d%(q-1) : d, q);
Cp = q*inverse(q, p)%N;
Cq = p*inverse(p, q)%N;
M = (mul(Mp, Cp, N) + mul(Mq, Cq, N))%N;
printf("%lld\n", M);
}
return 0;
}

循環解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long UINT64;
UINT64 mul(UINT64 a, UINT64 b, UINT64 mod) {
UINT64 ret = 0;
for (a = a >= mod ? a%mod : a, b = b >= mod ? b%mod : b; b != 0; b>>=1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod)
ret -= mod;
}
}
return ret;
}
void exgcd(long long x, long long y, long long &g, long long &a, long long &b) {
if (y == 0)
g = x, a = 1, b = 0;
else
exgcd(y, x%y, g, b, a), b -= (x/y) * a;
}
long long inverse(long long x, long long p) {
long long g, b, r;
exgcd(x, p, g, r, b);
if (g < 0) r = -r;
return (r%p + p)%p;
}
long long mpow(long long x, long long y, long long mod) {
long long ret = 1;
while (y) {
if (y&1)
ret = (ret * x)%mod;
y >>= 1, x = (x * x)%mod;
}
return ret % mod;
}
int main() {
long long C, d, p, q;
long long N, M, Cp, Cq, Mp, Mq;
while (scanf("%lld %lld %lld %lld", &C, &d, &p, &q) == 4) {
N = p * q;
Mp = mpow(C%p, d%(p-1) + (p-1), p);
Mq = mpow(C%q, d%(q-1) + (q-1), q);
Cp = mul(q, inverse(q, p), N);
Cq = mul(p, inverse(p, q), N);
M = (mul(Mp, Cp, N) + mul(Mq, Cq, N))%N;
printf("%lld\n", M);
}
return 0;
}