树上问题的构图方式

树上问题,往往伴随着求LCA(a,b)LCA(a, b),在构图的时候
需要构造出如下数据结构

  • fa(u,d)fa(u, d),用于倍增,表示从uu 开始往根节点走2d2^d 的步长能够到的节点
  • dep(u)dep(u) 表示uu 这个节点的深度,对于(u,v), dep(v)=dep(u)+1(u, v), \ dep(v) = dep(u) + 1
1
2
3
4
5
6
7
8
9
10
11
12
vector<int> G[maxn];
const int LOG = 25;
int fa[maxn][LOG], dep[maxn];

int main() {
while (m--) {
int u, v;
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u)
}
}

dfs 遍历

1
2
3
4
5
6
7
8
9
10
void dfs(int u, int pa) {
for (auto v : G[u]) {
if (v == pa) continue;
fa[v][0] = u, dep[v] = dep[u] + 1;
dfs(v, u);
}
}

int main:
dfs(1, -1)

bfs遍历

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void bfs(int s) {
dep[s] = 1;
queue<int> q; q.push(s);
while (q.size()) {
int x = q.front(); q.pop();
for (auto y : G[x]) {
if (dep[y]) continue;
dep[y] = dep[x] + 1, fa[y][0] = x;
q.push(y);
}
}
}
int main() {
memset(dep, 0, sizeof dep);
memset(fa, 0, sizeof fa);
bfs(1);
}

倍增法求 LCA

  • 先根据 dfs 或者 bfs 求出来的fa[],dep[]\bold{fa}[\cdots], \bold{dep}[\cdots] 数组,打出ST\textbf{ST}

     k[1,logn]: fa(u,k)=fa(fa(u,k1),k1)\forall \ k \in [1, \log n]: \ fa(u, k) = fa(fa_{(u, k-1)}, k-1)

算法实现过程,对于询问LCA(x,y)\textbf{LCA}(x, y),不妨设dep(y)>dep(x)dep(y) > dep(x) (不满足的话交换x,yx, y

  • 对于yy,将其调整到和xx 同一深度
    yy 这个节点依次尝试往上走t[2logn,2logn1,20]t \in [2^{\log n}, 2^{\log n - 1}, \cdots 2^0]
    只要dep(fa(y,t))dep(x),yfa(y,t)dep(fa_{(y, t)}) \geqslant dep(x), \quad y \leftarrow fa(y, t) (让yy 往上走)
  • 此时x,yx, y 在同一深度,注意特判,x=yx = y 就说明我们找到了LCA=yLCA = y
    否则,对同一深度的x,yx, y,同步往上走,看是否相遇
    依旧是尝试t[2logn,2logn1,20]t \in [2^{\log n}, 2^{\log n - 1}, \cdots 2^0]
    只要fa(y,t)fa(x,t)fa(y, t) \neq fa(x, t),就同步向上调整xfa(x,t), yfa(y,t)x \leftarrow fa(x, t), \ y \leftarrow fa(y, t)
  • 最后他们一定只差一步相遇fa(x,0)fa(x, 0) 就是答案

How far away

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class Graph {
public:
int n;
int tot;
vector<int> head, ver, e, ne;

Graph() = default;
Graph(int _n) : n(_n) {
tot = 1;
head.resize(n), ver.resize(n << 1), e.resize(n << 1), ne.resize(n << 1);
}
void clear() {
tot = 1;
fill(head.begin(), head.end(), 0);
}

void add(int a, int b, int c) {
ver[++tot] = b, e[tot] = c, ne[tot] = head[a], head[a] = tot;
}

};

const int maxn = 40000 + 10;
const int LOG = 25;
int n, m;
Graph G(maxn);

int fa[maxn][LOG+1], dep[maxn];
void dfs(int u, int pa) {
for (int i = G.head[u]; i; i = G.ne[i]) {
int v = G.ver[i];
if (v == pa) continue;
dep[v] = dep[u] + G.e[i], fa[v][0] = u;
dfs(v, u);
}
}

void LCA() {
for (int t = 1; t <= LOG; t++) {
for (int u = 1; u <= n; u++)
fa[u][t] = fa[fa[u][t-1]][t-1];
}
}
int LCA(int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
for (int i = LOG; i >= 0; i--) {
if (dep[fa[y][i]] >= dep[x]) y = fa[y][i];
}
if (y == x) return y;
for (int i = LOG; i >= 0; i--) {
if (fa[y][i] != fa[x][i]) y = fa[y][i], x = fa[x][i];
}
return fa[x][0];
}

int main() {
freopen("input.txt", "r", stdin);
int T;
scanf("%d", &T);
while (T--) {
G.clear();
memset(fa, 0, sizeof fa);
memset(dep, 0, sizeof dep);

scanf("%d%d", &n, &m);
for (int i = 0; i < n-1; i++) {
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
G.add(x, y, z), G.add(y, x, z);
}

// dfs and solve
dep[1] = 1;
dfs(1, -1);
LCA();

while (m--) {
int u, v;
scanf("%d%d", &u, &v);
int w = LCA(u, v);
int res = dep[u] + dep[v] - 2*dep[w];
printf("%d\n", res);
}
}
}

其实代码可以更简洁,将LCA()\textbf{LCA}() 部分的代码
合并到dfs\textbf{dfs} 中,也就是说,我们一边dfs\textbf{dfs} 一边求出ST\textbf{ST}

1
2
3
4
5
6
7
8
9
10
void dfs(int u, int pa) {
for (int i = G.head[u]; i; i = G.ne[i]) {
int v = G.ver[i];
if (v == pa) continue;
dep[v] = dep[u] + G.e[i];
fa[v][0] = u;
for (int t = 1; t <= LOG; t++) fa[v][t] = fa[fa[v][t-1]][t-1];
dfs(v, u);
}
}

tarjan 离线求 LCA

tarjan-lca

  • 将所有查询全部读入,用三元组(x,y,id)(x, y, id) 来表示
    • dfn[]\textbf{dfn}[\cdots] 记录节点属于哪一类(未访问,已访问未回溯,已回溯)
    • fa[]\textbf{fa}[\cdots] 并查集来维护节点,每一次完成回溯时,将黑色节点并入灰色节点
  • 执行dfs\bold{dfs},在任意时刻,将节点分为三类
    • 对于已经完成回溯的节点,dfn[x]=2\textbf{dfn}[x] = 2
    • 正在访问,没有发生回溯的节点,dfn[x]=1\textbf{dfn}[x] = 1
    • 未访问的节点,dfn[x]=0\textbf{dfn}[x] = 0
  • 任意时刻,对于当前正在访问的节点xx,路径xLCArootx \to \text{LCA} \to \text{root}
    一定没有开始回溯,也就是说,路径上所有点的dfn=1\textbf{dfn} = 1
    此时检查所有和xx 相关联的询问yy如果yy 已经开始回溯,即dfn[y]=2\textbf{dfn}[y] = 2
    u=LCA(x,y)u = \text{LCA}(x, y) 一定是yy 往上走第一个遇到的dfn[u]=1\textbf{dfn}[u] = 1 的点
    • 如图中加入并查集优化,我们是不断把黑色节点vv 并入其父节点uu 中,直到父节点为灰色,令pa[v]=upa[v] = u
    • 所以yy 往上走第一个遇到的灰色节点,其实就是并查集的查询结果get(y)\textbf{get}(y)
  • tarjan LCA 是一个离线算法,并且不像倍增算法x,yx, y 弄到同一深度,再同时往上走,tarjan LCA 执行的时候,
    对于同一询问LCA(x,y)\text{LCA}(x, y)(x,y)(x, y) 并不对称
    也就是说,LCA(x,y)\text{LCA}(x, y) 可能和LCA(y,x)\text{LCA}(y, x) 不一定相等
    举个例子,比如(x,y), x(x, y), \ xyy 的子节点
    在执行dfs(x)dfs(x) 的时候,yy 为灰色,并不会执行并查集的合并,此时LCA(x,y)=fa(x)\text{LCA}(x, y) = fa(x),不等于yy
    只有dfs(y)dfs(y) 的时候,xx 为黑色,yy 为灰色,此时才会更新LCA(x,y)=y\text{LCA}(x, y) = y
    解决这个问题的办法其实很简单,就是将询问离线读入的时候,记录询问idid
    (x,y), (y,x)(x, y), \ (y, x) 映射到同一个idid,然后dfs\textbf{dfs} 的时候,更新res[id]res[id] 即可
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
const int maxn = 50000 + 10;
const int inf = 0x3f3f3f3f;
int n, m, fa[maxn], dfn[maxn], d[maxn], ans[maxn];
typedef pair<int, int> PII;
vector<PII> qry[maxn];

class Graph {
public:
int n;
int tot;
vector<int> head, ver, e, ne;

Graph() = default;
Graph(int _n) : n(_n) {
tot = 1;
head.resize(n), ver.resize(n<<1), e.resize(n<<1), ne.resize(n<<1);
}
void clear() {
tot = 1;
fill(head.begin(), head.end(), 0);
}

void add(int x, int y, int z) {
ver[++tot] = y; e[tot] = z; ne[tot] = head[x]; head[x] = tot;
}
};

Graph G(maxn);
inline int get(int x) {
return x == fa[x] ? x : fa[x] = get(fa[x]);
}

void tarjan(int x) {
dfn[x] = 1;
for (int i = G.head[x]; i; i = G.ne[i]) {
int y = G.ver[i];
if (dfn[y]) continue;
d[y] = d[x] + G.e[i];
tarjan(y);
fa[y] = x;
}
for (auto u : qry[x]) {
int y = u.first, id = u.second;
if (dfn[y] == 2) {
int lca = get(y);
ans[id] = min(ans[id], d[x] + d[y] - 2*d[lca]);
}
}
dfn[x] = 2;
}

int main() {
freopen("input.txt", "r", stdin);
ios::sync_with_stdio(false);
cin.tie(0);

int T;
cin >> T;
while (T--) {
cin >> n >> m;
G.clear();
memset(dfn, 0, sizeof dfn);
memset(d, 0, sizeof d);
for (int i = 1; i <= n; i++) fa[i] = i, qry[i].clear();

for (int i = 0; i < n-1; i++) {
int x, y, z;
cin >> x >> y >> z;
G.add(x, y, z), G.add(y, x, z);
}
for (int i = 1; i <= m; i++) {
int x, y;
cin >> x >> y;
if (x == y) {
ans[i] = 0;
continue;
}
qry[x].push_back({y, i}), qry[y].push_back({x, i});
ans[i] = inf;
}

// then tarjan
tarjan(1);
for (int i = 1; i <= m; i++) printf("%d\n", ans[i]);
}
}