点分治(树分治)专题
简介
如果处理“所有经过某一个顶点的链对答案的贡献”的时间复杂度为$O(n)$或者$O(nlogn)$,那么运用点分治的思想可以把问题规模降为$O(nlogn)$或$O(nlog^2n)$,而非暴力枚举顶点计算答案的$O(n^2)$。
所以说,点分治是一种在树上统计合法链个数的思想。显而易见的,对于当前顶点$x$,任意一条链要么经过$x$,要么不经过$x$。于是我们只需要计算那些经过$x$的链,剩下不经过$x$的链则统统放在子树中递归计算,此为分治。
那么,怎么分治?按照常规的递推思路来考虑,处理完当前顶点后,递归进入所有与顶点相邻的顶点并分别处理它们?不行,因为很容易可以找到一条退化成链的树(见下图),每一层递归只能使问题规模减少1(去掉一个顶点),却要递归到第n层,显然,时间复杂度是$O(n^2)$的。
比起点分治,这似乎更像树形dp,但树形dp的前提是每一个点的状态都可以借助其子结点$O(1)$处理(在例题中可以看到一个既能用点分治,又能用树形dp来解决的例子)。但如果每一条链的状态都不能够合并,无法用dp来降低复杂度呢?实际上,点分治可以说就是为了解决一些无法保存状态的树形dp而发明的一种“优雅的暴力”。
点分治维护其复杂度的关键在于“找重心”。依然从最难处理的链来考虑问题,如果每次递归进入一颗子树时,先找子树的重心,再以重心来分割子树,可以发现每一次分割都使问题的规模减少为原来的一半,那么经过$O(logn)$次分割,问题规模一定被减少到1,也就是每个顶点都已被考虑到。如果用递归树来刻画这一过程(见下图),可以发现树高为$log(n)$,每一层的时间之和都为$T(n)$,于是,我们成功将$O(n^2)$的复杂度,通过找重心降低到$O(nlogn)$。
分治
点分治的题目中,“分治”部分几乎是一成不变的,无非是找到重心后,以重心为根重新计算子树大小,并利用计算结果进一步去找子树的重心。简而言之就是两个函数Getroot和Getsz而已。
int Getroot(int x, int f) {
int sum = 0, mx = 0, tmp;
for (auto T : G[x]) {
int to = T.to;
if (vis[to] || to == f) continue;
tmp = Getroot(to, x);
sum += tmp;
mx = max(mx, tmp);
}
sum++;
mx = max(mx, tot - sum);
if (mx < MN) {
MN = mx;
rt = x;
}
return sum;
}
void Getsz(int x, int f) {
sz[x] = 0;
for (auto T : G[x]) {
int to = T.to;
if (vis[to] || to == f) continue;
Getsz(to, x);
sz[x] += sz[to];
}
sz[x]++;
}
计算贡献
但这只是第一步。点分治中最重要,最灵活,也最需要投入思考的点,其实是一开始提到的如何计算“所有经过某一个顶点的链对答案的贡献”。
以这道最基本的模板题来说。相对而言比较经典的计算过程是这样的。
1、一次性收集所有链信息,存放在一个vector或数组中。
void Getdis(int x, int len, int f) {
D.push_back(len);
for (auto T : G[x]) {
int to = T.to, v = T.v;
if (vis[to] || to == f) continue;
Getdis(to, len + v, x);
}
}
2、利用数据结构(这里用的是桶)计算链对答案的贡献
void calc(int x, int len, int type) {
D.clear();
Getdis(x, len, -1);
for (int y : D) {
if (y > M) continue;
for (int i = 1; i <= m; i++) {
if (q[i] - y < 0) continue;
cnt[i] += type * bucket[q[i] - y];
}
bucket[y]++;
}
for (int y : D) {
if (y > M) continue;
bucket[y]--;
}
}
3、先计算子树内所有链的贡献(cal(x, 0, 1)),再利用容斥原理,去除不合法的链(calc(to, v, -1))
void solve(int x) {
vis[x] = 1;
calc(x, 0, 1);
for (auto T : G[x]) {
int to = T.to, v = T.v;
if (vis[to]) continue;
calc(to, v, -1);
tot = sz[to], MN = INF;
Getroot(to, -1);
Getsz(rt, -1);
solve(rt);
}
}
关于第三点或许需要再举个例子说明一下。
若当前以A为根收集链信息,一共可以得到五条链:
A
A->B
A->B->D
A->B->E
A->C
而用于计算答案的链,实际上是从收集到的链中任选两条,组合成一条新链再进行计算的。比如选择A->B和A->C,实际上是B->A->C这条链,选择A和A->B实际上就是A->B这条链,这两条都是合法的。但如果选择的是A->B->D和A->B->E时,选择的其实是D->B->A->B->E,但D到E的简单路径是D->B->E,也就是这条链不存在,这是不合法的。实际上,同时经过A和B的两条链组合之后总是不合法的,而cal(to, v, -1)的意义就在于选出这样的不合法的链,并把它们对答案的贡献消除。
完整的代码贴在这里
#include <bits/stdc++.h>
#define debug(x) cerr << #x << " : " << x << endl
using namespace std;
typedef long long LL;
const int N = 1e4 + 5, M = 1e7 + 5, P = 1e2 + 5, INF = 0x3f3f3f3f;
struct Edge {
int to, v;
};
int n, m, k, q[P], cnt[P], bucket[M];
vector<int> D;
/**variables of tree divide*/
int tot, sz[N], MN, rt;
bool vis[N];
vector<Edge> G[N];
int Getroot(int x, int f) {
int sum = 0, mx = 0, tmp;
for (auto T : G[x]) {
int to = T.to;
if (vis[to] || to == f) continue;
tmp = Getroot(to, x);
sum += tmp;
mx = max(mx, tmp);
}
sum++;
mx = max(mx, tot - sum);
if (mx < MN) {
MN = mx;
rt = x;
}
return sum;
}
void Getsz(int x, int f) {
sz[x] = 0;
for (auto T : G[x]) {
int to = T.to;
if (vis[to] || to == f) continue;
Getsz(to, x);
sz[x] += sz[to];
}
sz[x]++;
}
void Getdis(int x, int len, int f) {
D.push_back(len);
for (auto T : G[x]) {
int to = T.to, v = T.v;
if (vis[to] || to == f) continue;
Getdis(to, len + v, x);
}
}
void calc(int x, int len, int type) {
D.clear();
Getdis(x, len, -1);
for (int y : D) {
if (y > M) continue;
for (int i = 1; i <= m; i++) {
if (q[i] - y < 0) continue;
cnt[i] += type * bucket[q[i] - y];
}
bucket[y]++;
}
for (int y : D) {
if (y > M) continue;
bucket[y]--;
}
}
void solve(int x) {
vis[x] = 1;
calc(x, 0, 1);
for (auto T : G[x]) {
int to = T.to, v = T.v;
if (vis[to]) continue;
calc(to, v, -1);
tot = sz[to], MN = INF;
Getroot(to, -1);
Getsz(rt, -1);
solve(rt);
}
}
int main() {
cin >> n >> m;
for (int i = 1, u, v, w; i <= n - 1; i++) {
scanf("%d %d %d", &u, &v, &w);
G[u].push_back({v, w});
G[v].push_back({u, w});
}
for (int i = 1; i <= m; i++) scanf("%d", &q[i]);
tot = n, MN = INF;
Getroot(1, -1);
Getsz(rt, -1);
solve(rt);
for (int i = 1; i <= m; i++) {
printf("%s\n", cnt[i] ? "AYE" : "NAY");
}
return 0;
}
至此,一个完整的点分治就结束了。
但这并不是唯一的计算贡献的方法。同样以这道题目来说,我们其实不需要一次性收集所有的链,而是可以先收集某个子树中的链,根据桶的信息修改答案,把收集到的链存入桶中,再去下一个子树收集链信息。如此一来就可以避免计算不合法的链,也就不需要利用容斥来消除它们。
void calc(int x) {
judge[0] = 1;
for (auto T : G[x]) {
int to = T.to;
if (vis[to]) continue;
D.clear();
Getdis(to, T.v, x);
for (int i : D) {
for (int j = 1; j <= m; j++)
if (query[j] >= i) ans[j] |= judge[query[j] - i];
}
for (int i : D) {
if (i < 10000010 && !judge[i]) {
C.push_back(i);
judge[i] = 1;
}
}
}
for (int i : C) judge[i] = 0;
C.clear();
}
我个人认为点分治只能说是一种思想,而不能称为一种算法,就在于每道题计算贡献的方法都不一样。因此,考虑是否要运用点分治解决问题,最关键的就是能否找到一个计算贡献的方法。许多计算方法都要套一个树状数组或者线段树之类的结构,这也都会对最终的时间复杂度产生影响。
例题
最简单的点分治,$O(n)$收集信息,$O(1)$修改答案即可。
#include <bits/stdc++.h>
#define debug(x) cerr << #x << " : " << x << endl
using namespace std;
typedef long long LL;
const int N = 2e5 + 5, INF = 0x3f3f3f3f;
struct Edge {
int to, v;
};
vector<Edge> G[N];
bool vis[N];
int n, tot, sz[N], MN, rt, a[N];
LL ans[3], cnt[3];
int add(int x, int y) {return x + y >= 3 ? x + y - 3 : x + y;}
int Getroot(int x, int f) {
int sum = 0, mx = 0, tmp;
for (auto T : G[x]) {
int to = T.to;
if (vis[to] || to == f) continue;
tmp = Getroot(to, x);
sum += tmp;
mx = max(mx, tmp);
}
sum++;
mx = max(mx, tot - sum);
if (mx < MN) {
MN = mx;
rt = x;
}
return sum;
}
void Getsz(int x, int f) {
sz[x] = 0;
for (auto T : G[x]) {
int to = T.to;
if (vis[to] || to == f) continue;
Getsz(to, x);
sz[x] += sz[to];
}
sz[x]++;
}
void Getdis(int x, int len, int f) {
cnt[len]++;
for (auto T : G[x]) {
int to = T.to, v = T.v;
if (vis[to] || to == f) continue;
Getdis(to, add(len, v), x);
}
}
void calc(int x, int len, int type) {
memset(cnt, 0, sizeof(cnt));
Getdis(x, len, -1);
for (int i = 0; i <= 2; i++) {
for (int j = 0; j <=2; j++) {
ans[add(i, j)] += cnt[i] * cnt[j] * type;
}
}
}
void solve(int x) {
vis[x] = 1;
calc(x, 0, 1);
for (auto T : G[x]) {
int to = T.to, v = T.v;
if (vis[to]) continue;
calc(to, v, -1);
tot = sz[to], MN = INF;
Getroot(to, -1);
Getsz(rt, -1);
solve(rt);
}
}
int main() {
cin >> n;
for (int i = 1, u, v, w; i <= n - 1; i++) {
scanf("%d %d %d", &u, &v, &w);
w %= 3;
G[u].push_back({v, w});
G[v].push_back({u, w});
}
tot = n, MN = INF;
Getroot(1, -1);
Getsz(rt, -1);
solve(rt);
LL gcd = __gcd(ans[0], ans[0] + ans[1] + ans[2]);
printf("%lld/%lld", ans[0] / gcd, (ans[0] + ans[1] + ans[2]) / gcd);
return 0;
}
可能因为要处理的信息太简单了,用树形dp一样可以过这道题,复杂度还更低。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 2e4 + 5;
int n;
LL dp[N][3], sum[N][3], ans[3];
struct Edge {
int to, v;
};
vector<Edge> G[N];
int add(int x, int y) {return x + y >= 3 ? x + y - 3 : x + y;}
int sub(int x, int y) {return x - y < 0 ? x - y + 3 : x - y;}
void dfs(int x, int f) {
dp[x][0]++;
for (auto T : G[x]) {
int to = T.to, v = T.v;
if (to == f) continue;
dfs(to, x);
for (int i = 0; i <= 2; i++) {
for (int j = 0; j <= 2; j++) {
ans[add(i, j)] += dp[to][sub(i, v)] * dp[x][j];
}
}
for (int i = 0; i <= 2; i++) {
dp[x][i] += dp[to][sub(i, v)];
}
}
}
int main() {
cin >> n;
for (int i = 1, u, v, w; i <= n - 1; i++) {
scanf("%d %d %d", &u, &v, &w);
w %= 3;
G[u].push_back({v,w});
G[v].push_back({u,w});
}
dfs(1, -1);
ans[0] *= 2;
ans[1] *= 2;
ans[2] *= 2;
ans[0] += n;
LL gcd = __gcd(ans[0], ans[0] + ans[1] + ans[2]);
printf("%lld/%lld", ans[0] / gcd, (ans[0] + ans[1] + ans[2]) / gcd);
return 0;
}
利用树状数组统计答案,非常经典。
#include <bits/stdc++.h>
using namespace std;
const int N = 4e4 + 5, M = 2e4 + 5, INF = 0x3f3f3f3f;
struct Edge {
int to, v;
};
struct BIT {
int t[N];
int lowbit(int x) {return x &(-x);}
void add(int x, int v) {
x += 1;
while(x < N) {
t[x] += v;
x += lowbit(x);
}
}
int ask(int x) {
x += 1;
int res = 0;
while(x) {
res += t[x];
x -= lowbit(x);
}
return res;
}
}bit;
vector<Edge> G[N];
vector<int> D;
bool vis[N];
int n, k, tot, sz[N], MN, rt, ans;
int Getroot(int x, int f) {
int sum = 0, mx = 0, tmp;
for (auto T : G[x]) {
int to = T.to;
if (vis[to] || to == f) continue;
tmp = Getroot(to, x);
sum += tmp;
mx = max(mx, tmp);
}
sum++;
mx = max(mx, tot - sum);
if (mx < MN) {
MN = mx;
rt = x;
}
return sum;
}
void Getsz(int x, int f) {
sz[x] = 0;
for (Edge T : G[x]) {
int to = T.to;
if (vis[to] || to == f) continue;
Getsz(to, x);
sz[x] += sz[to];
}
sz[x]++;
}
void Getdis(int x,int len, int f) {
D.push_back(len);
for (Edge T : G[x]) {
int to = T.to, v = T.v;
if (vis[to] || to == f) continue;
Getdis(to, len + v, x);
}
}
int calc(int x, int len) {
int res = 0;
Getdis(x, len, -1);
D.push_back(0);
for (int y : D) {
if (y > k) continue;
res += bit.ask(k - y);
bit.add(y, 1);
}
for (int y : D) {
if (y > k) continue;
bit.add(y, -1);
}
D.clear();
return res;
}
void solve(int x) {
vis[x] = 1;
ans += calc(x, 0);
for (Edge T : G[x]) {
int to = T.to, v = T.v;
if (vis[to]) continue;
ans -= calc(to, v);
tot = sz[to], MN = INF;
Getroot(to, -1);
Getsz(rt, -1);
solve(rt);
}
}
int main() {
cin >> n;
for (int i = 1, u, v, w; i <= n - 1; i++) {
scanf("%d %d %d", &u, &v, &w);
G[u].push_back({v, w});
G[v].push_back({u, w});
}
cin >> k;
tot = n, MN = INF;
Getroot(1, -1);
Getsz(rt, -1);
solve(rt);
printf("%d\n", ans - n);
return 0;
}
需要利用排序、离散化、树状数组来计算贡献,可能还要卡卡常,比较清奇。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 2e5 + 5, INF = 0x3f3f3f3f;
struct BIT {
int t[N];
int lowbit(int x) {return x &(-x);}
void add(int x, int v) {
x++;
while(x < N) {
t[x] += v;
x += lowbit(x);
}
}
int ask(int x) {
x++;
int res = 0;
while(x) {
res += t[x];
x -= lowbit(x);
}
return res;
}
}bit;
struct Edge{
LL len;
int mx;
};
vector<int> G[N], dc, add;
vector<Edge> D;
unordered_map<LL, int> pos;
bool vis[N];
int T, n, tot, sz[N], MN, rt, a[N];
LL ans;
int Getroot(int x, int f) {
int sum = 0, mx = 0, tmp;
for (int to : G[x]) {
if (vis[to] || to == f) continue;
tmp = Getroot(to, x);
sum += tmp;
mx = max(mx, tmp);
}
sum++;
mx = max(mx, tot - sum);
if (mx < MN) {
MN = mx;
rt = x;
}
return sum;
}
void Getsz(int x, int f) {
sz[x] = 0;
for (int to : G[x]) {
if (vis[to] || to == f) continue;
Getsz(to, x);
sz[x] += sz[to];
}
sz[x]++;
}
void Getdis(int x,LL len, int mx, int f) {
D.push_back({len, mx});
for (int to : G[x]) {
if (vis[to] || to == f) continue;
Getdis(to, len + a[to], max(mx, a[to]), x);
}
}
LL calc(int x, LL len, int mx, int top) {
dc.clear(); pos.clear(); D.clear(); add.clear();
if (top == 0) top = a[x];
LL res = 0;
Getdis(x, len, mx, -1);
sort(D.begin(), D.end(), [&](Edge u, Edge v) {
return u.mx > v.mx;
});
for (auto y : D) {
dc.push_back(y.len - top);
}
sort(dc.begin(), dc.end());
dc.erase(unique(dc.begin(), dc.end()), dc.end());
for (auto y : D) {
int idx = lower_bound(dc.begin(), dc.end(), y.len - top) - dc.begin();
res += bit.ask(idx);
idx = lower_bound(dc.begin(), dc.end(), y.mx * 2 - y.len + 1) - dc.begin();
if (idx >= 0) {
bit.add(idx, 1);
add.push_back(idx);
}
}
for (auto y : add) bit.add(y, -1);
return res;
}
void solve(int x) {
vis[x] = 1;
ans += calc(x, a[x], a[x], 0);
for (int to : G[x]) {
if (vis[to]) continue;
ans -= calc(to, a[x] + a[to], max(a[x], a[to]), a[x]);
tot = sz[to], MN = INF;
Getroot(to, -1);
Getsz(rt, -1);
solve(rt);
}
}
int main() {
cin >> T;
while(T--) {
cin >> n;
for (int i = 1; i <= n; i++) {
vis[i] = 0;
scanf("%d", &a[i]);
G[i].clear();
}
ans = 0;
for (int i = 1, u, v, w; i <= n - 1; i++) {
scanf("%d %d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
tot = n, MN = INF;
Getroot(1, -1);
Getsz(rt, -1);
solve(rt);
printf("%lld\n", ans);
}
return 0;
}