树链剖分 学习笔记

树链剖分是个很强的 OIer

树链剖分是用来解决一系列树上问题的利器。

定义

  • 链:树上不拐弯的一条路径
  • 重儿子:子树大小最大的儿子
  • 轻儿子:其他儿子
  • 重链:由重儿子组成的链

结果和过程

结果

这是原树

剖分之后的树

这里,红色的边构成重链,蓝色的点为重儿子,绿色的点为轻儿子。

注意,这里的 7 号与 8 号节点子树大小相同,因此我们选择编号靠前的儿子为重儿子。

实现

一般来说,树链剖分通过两遍 dfs 来实现。

定义

1
2
3
4
5
siz[x] //子树x的大小
top[x] //x所在链的顶端
fat[x] //x的父亲
dep[x] //x的深度
son[x] //x的重儿子

第一遍 dfs ,我们先处理出每个节点的父亲深度子树大小重儿子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void dfs1(int u,int f) {
siz[u] = 1;
dep[u] = dep[f] + 1;
son[u] = 0;
fat[u] = f;
for (int i = hd[u];i;i= nxt[i]) {
int v = to[i];
if (v != f) {
dfs1(v,u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) {
son[u] = v;
}
}
}
}

第二遍 dfs,我们处理出每条链链顶的节点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void dfs2(int u,int f) {
if (son[f] == u) {
top[u] = top[f];
}
else {
top[u] = u;
}
if (son[u]) {
dfs2(son[u],u);
}
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f && v != son[u]) {
dfs2(v,u);
}
}
}

这样树链剖分的基本结构就写完辣!是不是很简单

应用

LCA

显然,我们可以发现,在一条重链上的两个节点的 LCA 显然就是深度更浅的那个节点。

所以我们可以先将两个节点跳到同一条链上,求出深度浅的那个节点即可,于是有了如下代码:

1
2
3
4
5
6
7
8
9
10
11
int lca(int a,int b) {
for (int ta = top[a],tb = top[b];ta != tb;) {
if (dep[ta] > dep[tb]) {
ta = top[a = fat[ta]];
}
else {
tb = top[b = fat[tb]];
}
}
return dep[a] < dep[b] ? a : b;
}

时间复杂度$O(\log{n})$

例题:板子题

随手套个板子写掉。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <cstdio>
#include <cctype>
#include <cstring>

const int N = 500010;
const int M = N << 1;

inline int read() {
char v = getchar();int x = 0,f = 1;
while (!isdigit(v)) {if (v == '-')f = -1;v = getchar();}
while (isdigit(v)) {x = x * 10 + v - 48;v = getchar();}
return x * f;
}

int hd[N],edg[M],nxt[M],to[M],n,m,tot,cnt;

inline void add(int u,int v,int w) {
to[++tot] = v;edg[tot] = w;nxt[tot] = hd[u];hd[u] = tot;
}

inline void addedge(int u,int v,int w) {
add(u,v,w);add(v,u,w);
}

int dfn[N],top[N],fat[N],siz[N],son[N],end[N],dep[N],s;

void dfs1(int s,int f) {
son[s] = 0;
dep[s] = dep[f] + 1;
siz[s] = 1;
fat[s] = f;
for (int i = hd[s];i;i = nxt[i]) {
if (to[i] != f) {
dfs1(to[i],s);
siz[s] += siz[to[i]];
if (siz[to[i]] > siz[son[s]]) {
son[s] = to[i];
}
}
}
}

void dfs2(int u,int f) {
dfn[u] = ++cnt;
if (u == son[f]) {
top[u] = top[f];
}
else {
top[u] = u;
}
if (son[u]) {
dfs2(son[u],u);
}
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f && v != son[u]) {
dfs2(v,u);
}
}
end[u] = cnt;
}

int lca(int a,int b) {
for (int ta = top[a],tb = top[b];ta != tb;) {
if (dep[ta] > dep[tb]) {
ta = top[a = fat[ta]];
}
else {
tb = top[b = fat[tb]];
}
}
return dep[a] < dep[b] ? a : b;
}

int main() {
n = read(),m = read(),s = read();
for (int i = 1;i < n;++i) {
int u = read(),v = read(),w = 1;
addedge(u,v,w);
}
dfs1(s,0);
dfs2(s,0);
for (int i = 1;i <= m;++i) {
int u = read(),v = read();
printf("%d\n",lca(u,v));
}
return 0;
}

链上修改,子树修改

这里就必须引出一个新东西了:dfs 序

指的是第几次 dfs 遍历到的这个节点

定义数组dfn[x]为节点 x 的 dfs 序。

我们可以在 dfs2 中顺手维护一下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void dfs2(int u,int f) {
dfn[u] = ++cnt;
if (son[f] == u) {
top[u] = top[f];
}
else {
top[u] = u;
}
if (son)
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f) {
dfs2(v,u);
}
}
end[cnt] = u;
}

这里给出每个节点加上 dfn 的图

我们可以发现,在每棵子树,每条上的dfn都是连续的!

这样看可能不太明朗,我们把它转化成区间来看(这里是子树的,链同理)

所以我们就可以利用数据结构维护一下节点了。

对子树修改就是 change(dfn[x],dfn[x]+siz[x]-1)

链上修改就是 change(dfn[x],dfn[y])

这里的核心思想是:将树上问题转化成序列问题来处理

例题:[ZJOI2008]树的统计

用线段树维护最大值最小值。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
6#include <cstdio>
#include <cctype>
#include <iostream>
#include <string>
const int N = 100010;
const int M = N << 1;
const int INF = 0x3f3f3f3f;
using namespace std;
inline int read() {
char v = getchar();int x = 0,f = 1;
while (!isdigit(v)) {if (v == '-')f = -1;v = getchar();}
while (isdigit(v)) {x = x * 10 + v - 48;v = getchar();}
return x * f;
}

int to[M],hd[N],nxt[M],tot;
string s;

inline void add(int u,int v) {
to[++tot] = v;nxt[tot] = hd[u];hd[u] = tot;
}

inline void addedge(int u,int v) {
add(u,v);add(v,u);
}

int dfn[N],top[N],fat[N],siz[N],dep[N],son[N],rk[N],cnt,n;

void dfs1(int u,int f) {
fat[u] = f;
dep[u] = dep[f] + 1;
son[u] = 0;
siz[u] = 1;
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f) {
dfs1(v,u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) {
son[u] = v;
}
}
}
}

void dfs2(int u,int f) {
dfn[u] = ++cnt;rk[cnt] = u;
if (son[f] == u) {
top[u] = top[f];
}
else {
top[u] = u;
}
if (son[u]) {
dfs2(son[u],u);
}
for (int i = hd[u];i;i = nxt[i]) {
int v = to[i];
if (v != f && v != son[u]) {
dfs2(v,u);
}
}
}

int num[N];

struct node {
int l,r,big,sum;
}tree[N<<2];

inline void build(int p,int l,int r) {
tree[p].l = l;tree[p].r = r;
if (l == r) {
tree[p].big = tree[p].sum = num[rk[l]];
return ;
}
int mid = (l + r) >> 1;
build(p << 1,l,mid);
build(p << 1|1,mid+1,r);
tree[p].sum = tree[p<<1].sum + tree[p<<1|1].sum;
tree[p].big = max(tree[p<<1].big,tree[p<<1|1].big);
return ;
}

inline void modify(int p,int x,int v) {
if (tree[p].l == tree[p].r) {
tree[p].sum = tree[p].big = v;
return ;
}
int mid = (tree[p].l + tree[p].r) >> 1;
if (x <= mid) modify(p<<1,x,v);
if (x > mid) modify(p<<1|1,x,v);
tree[p].sum = tree[p<<1].sum + tree[p<<1|1].sum;
tree[p].big = max(tree[p<<1].big,tree[p<<1|1].big);
return ;
}

inline int querym(int p,int x,int y) {
if (tree[p].l >= x && tree[p].r <= y) {
return tree[p].big;
}
int mid = (tree[p].l + tree[p].r) >> 1,ans = -INF;
if (x <= mid) {
ans = max(ans,querym(p<<1,x,y));
}
if (y > mid) {
ans = max(ans,querym(p<<1|1,x,y));
}
return ans;
}

inline int querys(int p,int x,int y) {
if (tree[p].l >= x && tree[p].r <= y) {
return tree[p].sum;
}
int mid = (tree[p].l + tree[p].r) >> 1,ans = 0;
if (x <= mid) {
ans += querys(p<<1,x,y);
}
if (y > mid) {
ans += querys(p<<1|1,x,y);
}
return ans;
}

inline void change(int u,int t) {
modify(1,dfn[u],t);
}

inline int qmax(int u,int v) {
int ans = -INF;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) {
swap(u,v);
}
ans = max(ans,querym(1,dfn[top[u]],dfn[u]));
u = fat[top[u]];
}
if (dep[u] > dep[v]) {
swap(u,v);
}
ans = max(ans,querym(1,dfn[u],dfn[v]));
return ans;
}

inline int qsum(int u,int v) {
int ans = 0;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) {
swap(u,v);
}
ans += querys(1,dfn[top[u]],dfn[u]);
u = fat[top[u]];
}
if (dep[u] > dep[v]) {
swap(u,v);
}
ans += querys(1,dfn[u],dfn[v]);
return ans;
}

signed main() {
n = read();
for (int i = 1;i < n;++i) {
addedge(read(),read());
}
for (int i = 1;i <= n;++i) {
num[i] = read();
}
dfs1(1,0);dfs2(1,0);build(1,1,n);
int m = read();
for (int i = 1;i <= m;++i) {
cin >> s;
int u = read(),v = read();
if (s == "CHANGE") {
change(u,v);
}
if (s == "QSUM") {
printf("%d\n",qsum(u,v));
}
if (s == "QMAX") {
printf("%d\n",qmax(u,v));
}
}
return 0;
}

练习

Qtree 把边权变成点权,巧妙的做法

[HAOI2015]树上操作 练手,区间修改,区间查询

[JLOI2014]松鼠的新家 巧妙的做法,也可以用树上差分来写

[NOI2015]程序包管理器 稍微转换一下问题

后记

感谢 $\color{black}M\color{red}{inagami}$ 给我提供了极大的帮助,感谢 $\color{black}Y\color{red}{ouSiKi}$ 的教导

作者

Doubeecat

发布于

2019-08-25

更新于

2025-09-18

许可协议