LGV引理, NTT模板

LGV引理

给出网格图上两个起点$a_1,a_2$和两个终点$b_1,b_2$,每次只能往右或往上走一步,问$a_1\rightarrow b_1$和$a_2\rightarrow b_2$的所有路径方案中,两条路径没有相交点的方案数

用$e(a,b)$表示从a走到b的方案总数,则$ans=e(a_1,b_1)e(a_2,b_2)-e(a_1,b_2)e(a_2,b_1)$

理解:利用容斥,从$e(a_1,b_1)e(a_2,b_2)$中减去那些存在相交的方案。已知$e(a_1,b_2)e(a_2,b_1)$中每种方案的路径一定相交,对于每个交点,两条路径交点要么在交点处交换位置,要么像是“在交点处反弹了”一样没有交换位置(如下图)

image-20210815210826601

对于存在n个交点的路径,根据每个交点是否交换位置,有$2^n$种方案,若根据交点奇偶性来区分,则奇数和偶数的方案是相同的。且$e(a_1,b_1)e(a_2,b_2)$中,交换位置的交点一定为偶数个,$e(a_1,b_2)e(a_2,b_1)$中一定为奇数个,因此$e(a_1,b_2)e(a_2,b_1)$就等于$e(a_1,b_1)e(a_2,b_2)$中存在相交的方案数。

若写成行列式则为

感性理解完两组点的情况之后直接扩展到n个点的情况,已知有n组点$(a_1,\cdots,a_n)$和$(b_1,\cdots,b_n)$,则$a_1\rightarrow b_1,\cdots,a_n\rightarrow b_n$路径不相交的方案为$|A|$,其中$A_{(i,j)}=e(a_i,b_j)$

其实还可以扩展到一些非路径计数的问题上,即$e(a,b)$可以理解为$\sum_\limits{P:a\rightarrow b}\prod\limits_{e\in P}\omega_e$,即所有路径边权乘积之和,路径计数是将所有路径边权置为1后的结果

NTT模板

#include <bits/stdc++.h>
#define debug(x) cerr << #x << " = " << x << ' '
#define debugl(x) cerr << #x << " = " << x << '\n'
#define all(x) (x).begin(), (x).end()
#define ms(x, y) memset(x, y, sizeof(x))
#define rep(x,y,z) for (int x=y;x<=z;x++)
#define per(x,y,z) for (int x=y;x>=z;x--)
#define lowbit(x) ((x) & (-x))
#define x1 xx1
#define y1 yy1
#define x2 xx2
#define y2 yy2

using namespace std;
typedef long long ll;
typedef long long LL;
typedef pair<int, int> pii;

const int N = 3e6 + 5, G = 3, P = (119 << 23) + 1, offset = 1e6 + 5, MOD = 998244353;

int add(int a, int b) { return a + b >= MOD ? a + b % MOD : a + b; }
int sub(int a, int b) { return a - b < 0 ? a - b + MOD : a - b; }
int mul(int a, int b) { return 1ll * a * b % MOD; }
int bin(int x, int p) {
    int res = 1;
    for (int b = x; p; p >>= 1, b = mul(b, b))
        if (p & 1) res = mul(res, b);
    return res;
}
int n, A[N], B[N], fac, ans;

int L, R[N], lim;
void NTT(int* a,int f){
    for (int i = 0; i < lim; i++)
        if (i < R[i]) swap(a[i],a[R[i]]);
    for (int i = 1; i < lim; i <<= 1){
        int gn = bin(G,(P - 1) / (i << 1));
        for (int j = 0; j < lim; j += (i << 1)){
            int g = 1;
            for (int k = 0; k < i; k++,g = 1ll * g * gn % P){
                int x = a[j + k],y = 1ll * g * a[j + k + i] % P;
                a[j + k] = (x + y) % P; a[j + k + i] = (x - y + P) % P;
            }
        }
    }
    if (f == 1) return;
    int nv = bin(lim, P - 2); reverse(a + 1,a + lim);
    for (int i = 0; i < lim; i++) a[i] = 1ll * a[i] * nv % P;
}

int main() {
    scanf("%d", &n);
    ans = 1;
    fac = 1;
    for (int i = 1; i <= n; i++) fac = mul(fac, i);
    fac = bin(fac, MOD - 2);
    for (int i = n; i >= 1; i--) {
        ans = mul(ans, fac);
        fac = mul(fac, i);
    }
    for (int i = 1; i <= n; i++) {
        int x; scanf("%d", &x);
        ans = mul(ans, x + 1);
        A[x] = 1;
        B[offset - x] = 1;
    }
    for (lim = 1; lim <= 2000100; lim <<= 1) L++;
    for (int i = 0; i < lim; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
    NTT(A, 1); NTT(B, 1);
    for (int i = 0; i < lim; i++) A[i] = mul(A[i], B[i]);
    NTT(A, -1);
    for (int i = offset + 1; i < lim; i++) {
        ans = mul(ans, bin(i - offset, A[i]));
    }
    printf("%d\n", ans);
    return 0;
}