组合

Comb

我们知道计算组合数的公式:

Cnk=n!k!(nk)!C_n^k = \cfrac {n!} {k!(n-k)!}

在数很大的情况下需要取模 MOD=1e9+7MOD = 1e9 + 7 ,而取模的操作只包含了 + - *,并不能使用 /,故需要用到乘法逆元。

乘法逆元

baba1 (mod m)\cfrac b a \equiv b \cdot a^{-1} ~(mod ~m)

即在取模运算中,只要找到一个数 qq,使 aq1 (mod m) a \cdot q \equiv 1 ~(mod ~m) ,则有 babq (mod m)\cfrac b a \equiv b \cdot q ~(mod ~m) 。计算乘法逆元需要用到费马小定理

费马小定理:若 pp 是质数,且 a mod p0a ~mod~ p \ne 0, 则有

ap11 (mod p)a ^ {p-1} \equiv 1 ~ (mod ~p)

把模数 mm 带入 pp 得:

am11 (mod m)am1a1a1 (mod m)am2aa1a1 (mod m)a1am2 (mod m)\begin{align*} a ^ {m-1} &\equiv 1 ~(mod ~m) \\ \lrArr a^{m-1} a^{-1} &\equiv a ^{-1} ~(mod ~m) \\ \lrArr a^{m-2} aa^{-1} &\equiv a^{-1} ~(mod ~m) \\ \lrArr a^{-1} &\equiv a ^{m-2} ~(mod ~m) \end{align*}

故只需要计算 am2a ^{m-2} 即可(用快速幂)

TypeScript
export function useComb(size: number) {
    const MOD = 1e9 + 7
    const $ = BigInt

    const fac: number[] = new Array(size + 1).fill(0) // 阶乘
    const inv: number[] = new Array(size + 1).fill(0) // fac[i]的乘法逆元
    fac[0] = inv[0] = 1

    function fastPow(a: number, n: number): number {
        const $ = BigInt
        const MOD = $(1e9 + 7)
        let b = $(a) % MOD
        let res = 1n
        while (n) {
            if (n % 2 === 1) res = (res * b) % MOD

            b = (b * b) % MOD
            n = Math.trunc(n / 2)
        }
        return Number(res % MOD)
    }

    for (let i = 1; i <= size; i++) {
        fac[i] = (fac[i - 1] * i) % MOD
        inv[i] = fastPow(fac[i], MOD - 2)
    }

    function comb(n: number, k: number): number {
        return Number(
            ((($(fac[n]) * $(inv[k])) % $(MOD)) * $(inv[n - k])) % $(MOD),
        )
    }

    return comb
}
JavaScript
export function useComb(size) {
    const MOD = 1e9 + 7
    const $ = BigInt
    const fac = new Array(size + 1).fill(0)
    const inv = new Array(size + 1).fill(0)
    fac[0] = inv[0] = 1
    function fastPow(a, n) {
        const $ = BigInt
        const MOD = $(1e9 + 7)
        let b = $(a) % MOD
        let res = 1n
        while (n) {
            if (n % 2 === 1) res = (res * b) % MOD
            b = (b * b) % MOD
            n = Math.trunc(n / 2)
        }
        return Number(res % MOD)
    }
    for (let i = 1; i <= size; i++) {
        fac[i] = (fac[i - 1] * i) % MOD
        inv[i] = fastPow(fac[i], MOD - 2)
    }
    function comb(n, k) {
        return Number(
            ((($(fac[n]) * $(inv[k])) % $(MOD)) * $(inv[n - k])) % $(MOD),
        )
    }
    return comb
}
Rust
pub struct Comb {
    fac: Vec<i64>,
    inv: Vec<i64>,
    modulo: i64,
    capacity: usize,
} 
 
impl Comb {
    pub fn with_capacity(capacity: usize) -> Self {
        let modulo = (1e9) as i64 + 7;
        let mut fac = vec![0 as i64; capacity + 1];
        let mut inv = vec![0 as i64; capacity + 1];
        fac[0] = 1;
        inv[0] = 1;
        for i in 1..=capacity {
            fac[i] = (fac[i - 1] * i as i64) % modulo;
            inv[i] = Self::fast_pow(fac[i], modulo - 2, modulo);
        }
        Self {
            fac,
            inv,
            modulo,
            capacity,
        }
    }

    pub fn new() -> Self {
        let fac = vec![1];
        let inv = vec![1];
        let modulo = (1e9) as i64 + 7;
        Self {
            fac,
            inv,
            modulo,
            capacity: 0,
        }
    }

    fn fast_pow(a: i64, mut n: i64, m: i64) -> i64 {
        let mut b = a % m;
        let mut res = 1;
        while n > 0 {
            if n & 1 == 1 {
                res = (res * b) % m;
            }
            b = (b * b) % m;
            n = n / 2;
        }
        return res % m;
    }

    pub fn comb(&mut self, n: usize, k: usize) -> i64 {
        assert!(n >= k);

        while self.capacity < n {
            let i = self.capacity + 1;
            self.fac.push((self.fac[i - 1] * i as i64) % self.modulo);
            self.inv
                .push(Self::fast_pow(self.fac[i], self.modulo - 2, self.modulo));
            self.capacity += 1;
        }

        return (self.fac[n] * self.inv[k] % self.modulo) * self.inv[n - k] % self.modulo;
    }
}
Java

class Comb {

    long[] fac;
    long[] inv;
    int MOD = (int) 1e9 + 7;
    int size;

    public Comb(int size) {
        this.size = size;
        fac = new long[size + 1];
        inv = new long[size + 1];
        fac[0] = inv[0] = 1;
        for (int i = 1; i <= size; i++) {
            fac[i] = fac[i - 1] * i % MOD;
            // 快速幂计算乘法逆元
            inv[i] = fastPow(fac[i], MOD - 2);
        }
    }

    public long comb(int n, int k) {
        if (n > size || k > size) {
            throw new Error("超出范围");
        }
        return fac[n] * inv[k] % MOD * inv[n - k] % MOD;
    }

    private long fastPow(long a, long b) {
        long base = a % MOD, res = 1;
        while (b != 0) {
            if ((b & 1) == 1) {
                res = res * base % MOD;
            }
            base = base * base % MOD;
            b >>>= 1;
        }
        return res % MOD;
    }
}
Python
class Comb:
    def __init__(self, size: int):
        self.MOD = 10**9 + 7
        self.fac = [1] * (size + 1)
        self.inv = [1] * (size + 1)
        self._init(size)

    def _fast_pow(self, a: int, n: int) -> int:
        res = 1
        base = a % self.MOD
        while n > 0:
            if n & 1:
                res = (res * base) % self.MOD
            base = (base * base) % self.MOD
            n >>= 1
        return res

    def _init(self, size: int):
        for i in range(1, size + 1):
            self.fac[i] = self.fac[i - 1] * i % self.MOD
            self.inv[i] = self._fast_pow(self.fac[i], self.MOD - 2)

    def comb(self, n: int, k: int) -> int:
        if k < 0 or k > n:
            return 0
        return (self.fac[n] * self.inv[k] % self.MOD) * self.inv[n - k] % self.MOD