-
백준 이항 계수 시리즈 풀이 (이항 계수 2, 이항 계수3, 이항 계수와 쿼리)Programming/PS 2024. 2. 27. 00:53
백준 11051번 이항 계수 2, 11401번 이항 계수 3, 13977번 이항 계수와 쿼리 문제에 대한 C언어 해설입니다.
페르마의 소정리로 모듈러 곱셈 역원을 구하고, 거듭제곱을 최적화하여 큰 수의 이항 계수를 효율적으로 계산하는 방법을 단계적으로 설명합니다.11051 이항 계수 2
백준 11051번 이항 계수 2
자료 구조 말고 다른 유형도 풀어볼까 하다가 마주친 문제문제를 풀기 위해서 이항 계수의 정의에 따라 다음을 계산해야 한다.
${n \choose k} = \frac{n!}{r!(n-r)!}$문제는 입력 결과값이 엄청 크다는 것이다.
n = 1000, k = 500 일때 결과값만 해도 $10^{299}$ 를 넘는다.
C의 unsigned long long 값, $2^{64}- 1$ 을 아득히 초월하는 수치이다.다행히 문제는 결과를 10007로 나눈 값을 구하라고 한다. 근데 이걸 어떻게 이용해야 할까?
숫자 a와 b가 있다고 하자. 둘을 곱한 ab값이 너무 커서 대신 ab를 p로 나눈 나머지를 구해야하는 상황이다.
이는 a를 p로 나눈 나머지와 b를 p로 나눈 나머지를 곱한뒤, 다시 p로 나눈 나머지를 구하면 된다.아래 식을 보면 직관적으로 이해가 가능하다.
$$ a = a_{q}p +a_{r}, \quad b = b_{q}p + b_{r} $$
$$ ab = (a_{q}b_{q}+a_{r}+b_{r})p + a_{r}b_{r}$$문제는 나눗셈이다. 위와 같은 계산이 성립하지 않는다.
다행히 마침 학교에서 수강중이던 선형대수학 수업에서 과제로 증명했던 내용이 도움이 되었다.0 이 아닌 정수 a와 소수 p에 대해 다음이 성립한다.
$$ a^{p-1} \equiv 1 \pmod{p}$$
이를 페르마의 소정리 라고하는데, 이를 통해 다음을 알 수 있다.
$$a \times a^{p-2} \equiv 1 \pmod{p}$$
$$a^{p-2} \equiv a^{-1} \pmod{p}$$따라서 b 나누기 a 를 10007로 나눈 나머지는 $ b \times a^{10007-2} $ 를 10007로 나눈 나머지와 같다.
곱셈은 위에서 설명한 방법으로 계산할 수 있으니, 해결이 되었다.코드
#include <stdio.h> #define PRIME_MODE 10007 int main(void) { int n, k; scanf("%d %d", &n, &k); int numerator = 1; for (int i = n; i > n - k; i--) { numerator *= i; numerator %= PRIME_MODE; } int denominator = 1; for (int i = 1; i <= k; i++) { denominator *= i; denominator %= PRIME_MODE; } // Get inverse in Z_PRIME_MODE // Fermat's little theorem int inverse_denominator = 1; for (int i = 0; i < PRIME_MODE - 2; i++) { inverse_denominator *= denominator; inverse_denominator %= PRIME_MODE; } printf("%d\n", (numerator * inverse_denominator) % PRIME_MODE); }
11401 이항 계수 3
이항 계수 2와 거의 동일하다. 대신 이번엔 나머지가 1000000007이다.
위 문제와 같은 방법으로 역원을 구하려면 곱셈을 1000000005 번 해야한다.아무 이유없이 난이도가 올라간 게 아니었다. 정직하게 이만큼 곱하다간 시간 초과가 나게 된다.
찾아보다가 거듭제곱을 빠르게 하는 다음 테크닉을 알게 되었다.
Fast modular exponential
원리는 다음과 같다.a의 21승을 구한다고 가정해보자,
21을 이진수로 나타내면 10101이 된다.
$$a^{21} = a^{10101_2}$$
$$a^{10101_{2}} = a^{10000_{2}} \times a^{100_{2}} \times a^{1_{2}}$$
$$a^{21} = 1 \cdot a^{16} \times 0 \cdot a^{8}\times 1 \cdot a^{4}\times 0 \cdot a^{2} \times 1 \cdot a^{1}$$a의 21승은 a의 16승, 4승, 1승을 곱한 값이다. 이걸 응용하면 곱셈 횟수와 p로 나누는 연산을 획기적으로 단축할 수 있다.
구현
long mode_pow(long base, long e, long mode) { long res = 1; while (e > 0) { // Multiply res by base if e is odd res = res * (1 + (e & 1) * (base - 1)); res %= mode; e >>= 1; base *= base; base %= mode; } return res; }
e & 1은 지수 e가 홀수면 (1로 끝나면) 1, 아니면 0이 된다.
결과 값에 0이 곱해지는 것을 방지하기 위해 + 1 하나를 따로 빼주었다.이후 각 iteration 마다 지수를 2로 나누고 (오른쪽으로 shift) 밑을 제곱 해준다.
코드
#include <stdio.h> #define PRIME_MODE 1000000007 // Fast exponentiation in Z_PRIME_MODE long mode_pow(long base, long e, long mode); int main(void) { long n, k; scanf("%ld %ld", &n, &k); if (k > n / 2) { k = n - k; } long numerator = 1; for (long i = n; i > n - k; i--) { numerator *= i; numerator %= PRIME_MODE; } long denominator = 1; for (long i = 1; i <= k; i++) { denominator *= i; denominator %= PRIME_MODE; } long inverse_denominator = mode_pow(denominator, PRIME_MODE - 2, PRIME_MODE); printf("%ld\n", (numerator * inverse_denominator) % PRIME_MODE); } long mode_pow(long base, long e, long mode) { long res = 1; while (e > 0) { // Multiply res by base if e is odd res = res * (1 + (e & 1) * (base - 1)); res %= mode; e >>= 1; base *= base; base %= mode; } return res; }
추가적으로 k와 n-k 중 더 작은 값을 k로 두고 계산하는 로직을 추가하였다.
13977 이항 계수와 쿼리
이번엔 입력 횟수가 1개가 아니다.
팩토리얼 계산을 매번 다시 하는건 비효율적이니 이전에 계산한 값이 있으면 활용할 수 있도록 하였다.코드
#include <stdio.h> #define PRIME_MODE 1000000007 #define MAX_NUM 4000000 #define NOT_COMPUTED 0 long factorials[MAX_NUM + 1]; // Fast exponentiation in Z_PRIME_MODE long mode_pow(long base, long e, long mode); int main(void) { factorials[0] = 1; factorials[1] = 1; int num_cases; scanf("%d", &num_cases); for (int i = 0; i < num_cases; i++) { long n, k; scanf("%ld %ld", &n, &k); if (factorials[n] == NOT_COMPUTED) { int curr = 1; while (factorials[curr] != NOT_COMPUTED) { curr++; } for (; curr <= n; curr++) { factorials[curr] = factorials[curr - 1] * curr; factorials[curr] %= PRIME_MODE; } } long numerator = factorials[n]; long denominator = (factorials[k] * factorials[n - k]) % PRIME_MODE; long inverse_denominator = mode_pow(denominator, PRIME_MODE - 2, PRIME_MODE); printf("%ld\n", (numerator * inverse_denominator) % PRIME_MODE); } } long mode_pow(long base, long e, long mode) { long res = 1; while (e > 0) { // Multiply res by base if e is odd res = res * (1 + (e & 1) * (base - 1)); res %= mode; e >>= 1; base *= base; base %= mode; } return res; }
이후 다른 사람들 한 거 보니 그냥 처음부터 미리 죄다 계산하는 편이 효율적인 것 같다.
// Precompute factorials[0] = 1; for (int i = 1; i <= MAX_NUM; i++) { factorials[i] = factorials[i - 1] * i; factorials[i] %= PRIME_MODE; }
끝.