LGV引理, NTT模板
LGV引理
给出网格图上两个起点$a_1,a_2$和两个终点$b_1,b_2$,每次只能往右或往上走一步,问$a_1\rightarrow b_1$和$a_2\rightarrow b_2$的所有路径方案中,两条路径没有相交点的方案数
用$e(a,b)$表示从a走到b的方案总数,则$ans=e(a_1,b_1)e(a_2,b_2)-e(a_1,b_2)e(a_2,b_1)$
理解:利用容斥,从$e(a_1,b_1)e(a_2,b_2)$中减去那些存在相交的方案。已知$e(a_1,b_2)e(a_2,b_1)$中每种方案的路径一定相交,对于每个交点,两条路径交点要么在交点处交换位置,要么像是“在交点处反弹了”一样没有交换位置(如下图)
对于存在n个交点的路径,根据每个交点是否交换位置,有$2^n$种方案,若根据交点奇偶性来区分,则奇数和偶数的方案是相同的。且$e(a_1,b_1)e(a_2,b_2)$中,交换位置的交点一定为偶数个,$e(a_1,b_2)e(a_2,b_1)$中一定为奇数个,因此$e(a_1,b_2)e(a_2,b_1)$就等于$e(a_1,b_1)e(a_2,b_2)$中存在相交的方案数。
若写成行列式则为
感性理解完两组点的情况之后直接扩展到n个点的情况,已知有n组点$(a_1,\cdots,a_n)$和$(b_1,\cdots,b_n)$,则$a_1\rightarrow b_1,\cdots,a_n\rightarrow b_n$路径不相交的方案为$|A|$,其中$A_{(i,j)}=e(a_i,b_j)$
其实还可以扩展到一些非路径计数的问题上,即$e(a,b)$可以理解为$\sum_\limits{P:a\rightarrow b}\prod\limits_{e\in P}\omega_e$,即所有路径边权乘积之和,路径计数是将所有路径边权置为1后的结果
NTT模板
#include <bits/stdc++.h>
#define debug(x) cerr << #x << " = " << x << ' '
#define debugl(x) cerr << #x << " = " << x << '\n'
#define all(x) (x).begin(), (x).end()
#define ms(x, y) memset(x, y, sizeof(x))
#define rep(x,y,z) for (int x=y;x<=z;x++)
#define per(x,y,z) for (int x=y;x>=z;x--)
#define lowbit(x) ((x) & (-x))
#define x1 xx1
#define y1 yy1
#define x2 xx2
#define y2 yy2
using namespace std;
typedef long long ll;
typedef long long LL;
typedef pair<int, int> pii;
const int N = 3e6 + 5, G = 3, P = (119 << 23) + 1, offset = 1e6 + 5, MOD = 998244353;
int add(int a, int b) { return a + b >= MOD ? a + b % MOD : a + b; }
int sub(int a, int b) { return a - b < 0 ? a - b + MOD : a - b; }
int mul(int a, int b) { return 1ll * a * b % MOD; }
int bin(int x, int p) {
int res = 1;
for (int b = x; p; p >>= 1, b = mul(b, b))
if (p & 1) res = mul(res, b);
return res;
}
int n, A[N], B[N], fac, ans;
int L, R[N], lim;
void NTT(int* a,int f){
for (int i = 0; i < lim; i++)
if (i < R[i]) swap(a[i],a[R[i]]);
for (int i = 1; i < lim; i <<= 1){
int gn = bin(G,(P - 1) / (i << 1));
for (int j = 0; j < lim; j += (i << 1)){
int g = 1;
for (int k = 0; k < i; k++,g = 1ll * g * gn % P){
int x = a[j + k],y = 1ll * g * a[j + k + i] % P;
a[j + k] = (x + y) % P; a[j + k + i] = (x - y + P) % P;
}
}
}
if (f == 1) return;
int nv = bin(lim, P - 2); reverse(a + 1,a + lim);
for (int i = 0; i < lim; i++) a[i] = 1ll * a[i] * nv % P;
}
int main() {
scanf("%d", &n);
ans = 1;
fac = 1;
for (int i = 1; i <= n; i++) fac = mul(fac, i);
fac = bin(fac, MOD - 2);
for (int i = n; i >= 1; i--) {
ans = mul(ans, fac);
fac = mul(fac, i);
}
for (int i = 1; i <= n; i++) {
int x; scanf("%d", &x);
ans = mul(ans, x + 1);
A[x] = 1;
B[offset - x] = 1;
}
for (lim = 1; lim <= 2000100; lim <<= 1) L++;
for (int i = 0; i < lim; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(A, 1); NTT(B, 1);
for (int i = 0; i < lim; i++) A[i] = mul(A[i], B[i]);
NTT(A, -1);
for (int i = offset + 1; i < lim; i++) {
ans = mul(ans, bin(i - offset, A[i]));
}
printf("%d\n", ans);
return 0;
}