基礎樹論

基本概念

樹就是一個無向無環聯通圖。

基本名詞

  • 根 (Root):有根樹的起始節點(通常指最「上面」的節點)。
  • 度(Degree):連接某個節點的邊數。
  • 葉子(Leaf):任何度數為
    1
    的節點(若同時為根節點,則視題目而定)。
  • 父節點(Parent):當前節點往上(根)走碰到的第一個節點,根節點沒有父節點。
  • 祖先(Ancestors):所有在該節點以上的點(含)。
  • 子樹(Subtree):所有視當前節點為祖先的節點集合。
  • 深度 (Depth):節點到根的邊數距離。
  • 高度 (Height):節點到最深葉子的最大邊數距離。

定義

  • 若且唯若有
    N
    個節點和
    N1
    條邊時,該圖為樹。
  • 若且唯若每個節點中間洽有一簡單路徑時,該圖為樹。
  • 若且唯若圖為聯通且沒有環,該圖為樹。

要注意若沒有規定,任何節點都可以成為根,此時樹為無根樹

存圖&遍歷

最常用的是鄰接陣列(adjacency list)

int n; vector<vector<int>> adj(n + 1); for (int i = 0; i < n - 1; i++) { int u, v; cin >> u >> v; adj[u].push_back(v); adj[v].push_back(u); // 無向邊要建雙向 }

DFS

遞迴或 stack,用來計算子樹資訊、深度等:

void dfs(int u, int p) { // process node u for (int v : adj[u]) { if (v == p) continue; // 防止往回走 dfs(v, u); } } // main dfs(root, -1); // 通常根節點是 1 可以直接 dfs(1, -1);
lambda 寫法
auto dfs = [&](auto self, int u, int p) -> void { // process node u for (int v : adj[u]) { if (v == p) continue; self(self, v, u); } }; // 記得加分號 // main dfs(dfs, root, -1);

BFS

用 queue,適合計算最短路或節點距離:

vector<int> dist(n + 1, -1); queue<int> q; dist[start] = 0; q.push(start); while (!q.empty()) { int u = q.front(); q.pop(); for (int v : adj[u]) { if (dist[v] == -1) { dist[v] = dist[u] + 1; q.push(v); } } }

以上計算出所有節點相對於 start 的距離。

子樹大小

題目:CSES - Subordinates

給定每個人的上司,計算出每個人的下屬有幾個。

將所有人的上司 下屬關係用一條邊連著,可以構成一顆樹。用 DFS 遞迴在樹上 DP:

void dfs(int u, int p) { cnt[u] = 1; for (int v : adj[u]) { if (v == p) continue; dfs(v, u); // 要先處理完子節點再更新狀態 cnt[u] += cnt[v]; } }

上面程式計算出每個節點的子樹大小(cnt[u] 代表節點

u 的子樹大小)。
最後再把每個人的子樹大小減去一就是下屬的數量。

AC Code
// complexity: O(n) #include <iostream> #include <vector> using namespace std; int main() { int n; cin >> n; vector<vector<int>> adj(n + 1); for (int i = 2; i <= n; i++) { int fa; cin >> fa; adj[fa].push_back(i); adj[i].push_back(fa); } vector<int> cnt(n + 1); auto dfs = [&](auto self, int u, int p) -> void { cnt[u] = 1; for (int v : adj[u]) if (v != p) { self(self, v, u); cnt[u] += cnt[v]; } }; dfs(dfs, 1, -1); for (int i = 1; i <= n; i++) { cout << cnt[i] - 1 << " \n"[i == n]; } }

樹直徑

題目:CSES - Tree Diameter

找出樹直徑。

樹上任意兩點之間的最長簡單路徑長稱為樹的直徑(可能有多條相同長度的路徑),可以用兩次 DFS || 樹 DP 在

O(n) 時間內求出。

兩次 DFS

先隨便找出一個點

a,找出距離點
a
最遠的點
b
,再找距離點
b
最遠的點
c
。則
dist(b,c)
為樹直徑。證明

AC Code
// complexity: O(n) #include <iostream> #include <vector> using namespace std; int main() { int n; cin >> n; vector<vector<int>> adj(n + 1); for (int i = 1; i < n; i++) { int a, b; cin >> a >> b; adj[a].push_back(b); adj[b].push_back(a); } int s = 1, dia = 0; vector<int> depth(n + 1); auto dfs = [&](auto self, int u, int p) -> void { if (depth[u] > dia) { dia = depth[u]; s = u; } for (int v : adj[u]) if (v != p) { depth[v] = depth[u] + 1; self(self, v, u); } }; dfs(dfs, 1, -1); depth[s] = 0; dfs(dfs, s, -1); cout << dia << '\n'; }

樹 DP

隨便選擇一個點作為根,計算出每一個點往下的最大深度與次大深度,合併之後為通過當前點的最大路徑,將所有點取最大值即為樹直徑。

AC Code
// complexity: O(n) #include <iostream> #include <vector> using namespace std; int main() { int n; cin >> n; vector<vector<int>> adj(n + 1); for (int i = 1; i < n; i++) { int a, b; cin >> a >> b; adj[a].push_back(b); adj[b].push_back(a); } int dia = 0; auto dfs = [&](auto self, int u, int p) -> int { int mx = 0, mx2 = 0; for (int v : adj[u]) if (v != p) { int h = self(self, v, u) + 1; if (h > mx) { mx2 = mx; mx = h; } else if (h > mx2) { mx2 = h; } } dia = max(dia, mx + mx2); return mx; }; dfs(dfs, 1, -1); cout << dia << '\n'; }

全源最長路

題目:CSES - Tree Distances I

找出對於每個節點最遠點的距離。

對於每一個節點,可以發現離他最遠的點一定是往他的子節點們或是父節點其中一條路徑之內。
一樣用 DFS 遞迴下去處理,對於每一個節點

u,記錄往子節點 path[u] 可以達成最長長度 maxDist[u],和第二長的長度 secMax[u]
接著重頭開始跑另外一次遞迴,如果當前節點
u
maxDist[u] 不是往欲處理節點
v
,則可以得知 maxDist[v] 一定小於等於 maxDist[u] + 1(否則在計算 maxDist[u] 時會選擇 maxDist[v] + 1)。
而如果 maxDist[u] 是往
v
,檢查是否要更新 maxDist[v]secMax[v]secMax[u] + 1
最後處理完,maxDist[k] 即為節點
k
到任何節點最遠距離。

AC Code
// complexity: O(n) #include <iostream> #include <vector> using namespace std; int main() { int n; cin >> n; vector<vector<int>> adj(n + 1); for (int i = 1; i < n; i++) { int a, b; cin >> a >> b; adj[a].push_back(b); adj[b].push_back(a); } vector<int> maxDist(n + 1), secMax(n + 1), path(n + 1); auto dfs = [&](auto self, int u, int p) -> void { for (int v : adj[u]) if (v != p) { self(self, v, u); if (maxDist[v] + 1 > maxDist[u]) { secMax[u] = maxDist[u]; maxDist[u] = maxDist[v] + 1; path[u] = v; } else if (maxDist[v] + 1 > secMax[u]) { secMax[u] = maxDist[v] + 1; } } }; dfs(dfs, 1, -1); auto dfs2 = [&](auto self, int u, int p) -> void { for (int v : adj[u]) if (v != p) { if (path[u] == v) { // maxDist[u] is the path toward v if (secMax[u] + 1 > maxDist[v]) { secMax[v] = maxDist[v]; maxDist[v] = secMax[u] + 1; path[v] = u; } else if (secMax[u] + 1 > secMax[v]) { secMax[v] = secMax[u] + 1; } } else { // maxDist[u] + 1 >= maxDist[v] in this case // cause if not, then maxDist[u] should calculated as maxDist[v] + 1 in first dfs secMax[v] = maxDist[v]; maxDist[v] = maxDist[u] + 1; path[v] = u; } self(self, v, u); } }; dfs2(dfs2, 1, -1); for (int i = 1; i <= n; i++) { cout << maxDist[i] << " \n"[i == n]; } }

換根 DP

題目:CSES - Tree Distances II

對於每一個節點,輸出他到其他所有點的距離和。

我們可以在

O(n) 時間內找到一個點到所有其他點的距離和,但如果要計算所有的點,
O(n2)
會炸。
其實可以發現如果計算出了一個人的答案,則他的子節點答案可以
O(1)
時間內用父節點求出。
記錄 sz[v] 為節點
v
的子樹大小,若要將當前節點
u
往子節點
v
推答案,所有在
u
之上的點相對於
v
的距離會加一,而相對於所有
v
以下的則會加一,列出轉移式
dp[v]=dp[u]+(nsz[v])sz[v]
,就能在
O(n)
時間得出所有點的答案。

AC Code
// complexity: O(n) #include <iostream> #include <vector> using namespace std; int main() { int n; cin >> n; vector<vector<int>> adj(n + 1); for (int i = 1; i < n; i++) { int a, b; cin >> a >> b; adj[a].push_back(b); adj[b].push_back(a); } vector<long long> ans(n + 1), sz(n + 1); // O(n) 算出節點 1 到所有點的距離和 auto dfs = [&](auto self, int u, int p, int dist) -> void { ans[1] += dist; sz[u] = 1; // 計算子樹大小 for (int v : adj[u]) if (v != p) { self(self, v, u, dist + 1); sz[u] += sz[v]; } }; dfs(dfs, 1, -1, 0); auto dfs2 = [&](auto self, int u, int p) -> void { for (int v : adj[u]) if (v != p) { ans[v] = ans[u] - sz[v] + (n - sz[v]); self(self, v, u); } }; dfs2(dfs2, 1, -1); for (int i = 1; i <= n; i++) { cout << ans[i] << " \n"[i == n]; } }

Binary Lifting

題目:CSES - Company Queries I

給定每個人的從屬關係和

Q 個詢問:
x
上面
k
層的的上司是誰?

如果每次查詢

O(n) 總複雜度
O(nq)
會炸掉。
可以利用之前 Sparse Table 類似的思想,記錄每一個人往上
2k
層的上司是誰,就可以在
O(logn)
時間內建表,
O(logn)
時間內查詢。

定義 succ[k][u] 為節點

u 往上走
2k
的人,可以得到轉移式
succ[k][u]=succ[k1][succ[k1][u]]
,即為往上走
2k1
再往上走
2k1
,等同於
2k1×2=2k

vector<vector<int>> succ(__lg(n) + 1, vector<int>(n + 1)); vector<int> depth(n + 1); auto dfs = [&](auto self, int u) -> void { for (int v : adj[u]) { succ[0][v] = u; depth[v] = depth[u] + 1; for (int i = 1; i <= __lg(depth[v]); i++) { succ[i][v] = succ[i - 1][succ[i - 1][v]]; } self(self, v); } };

查詢時也只需要把

k 拆分成
2
的冪次相加,往上跳到答案即可。

for (int i = __lg(k); i >= 0; i--) { if (k & (1 << i)) { x = succ[i][x]; } }
AC Code
// complexity: O(q log n) #include <iostream> #include <vector> using namespace std; int lg(int a) { return 31 - __builtin_clz(a); } int main() { int n, q; cin >> n >> q; vector<vector<int>> adj(n + 1); for (int i = 2; i <= n; i++) { int e; cin >> e; adj[e].push_back(i); } vector<vector<int>> succ(lg(n) + 1, vector<int>(n + 1)); vector<int> depth(n + 1); auto dfs = [&](auto self, int u) -> void { for (int v : adj[u]) { succ[0][v] = u; depth[v] = depth[u] + 1; for (int i = 1; i <= lg(depth[v]); i++) { succ[i][v] = succ[i - 1][succ[i - 1][v]]; } self(self, v); } }; dfs(dfs, 1); while (q--) { int x, k; cin >> x >> k; for (int i = lg(k); i >= 0; i--) { if (k & (1 << i)) { x = succ[i][x]; } } cout << (x == 0 ? -1 : x) << '\n'; } }

LCA

題目:CSES - Company Queries II

給定每個人的從屬關係和

Q 個詢問:距離
a
b
最近的共通祖先為何?

求最低共通祖先(Lowest Common Ancestor)通常有兩種寫法:倍增法和歐拉迴路。

倍增法

用剛剛講到的 Binary Lifting,先建好表和計算深度,給定兩個節點

a
b
後,先把比較深的節點提高,讓兩節點深度一樣,可以用 x = succ[lg(depth[x] - depth[y])][x] 達成(
depth[x]depth[y]
的情況,若反過來可以 swap 就好)。
接著只需要一直往上直到同個節點就好,為了實作方便可以先定位到 LCA 下面一格再往上就好。

auto lca = [&](int x, int y) -> int { if (depth[x] < depth[y]) swap(x, y); while (depth[x] > depth[y]) { x = succ[lg(depth[x] - depth[y])][x]; } if (x == y) return x; // 如果定位到同深度時已經找到 要先返回 for (int i = lg(depth[x]); i >= 0; i--) { if (succ[i][x] != succ[i][y]) { x = succ[i][x], y = succ[i][y]; } } return succ[0][x]; };
倍增法 AC Code
// complexity: O(q log n) #include <iostream> #include <vector> using namespace std; int lg(int a) { return 31 - __builtin_clz(a); } int main() { int n, q; cin >> n >> q; vector<vector<int>> adj(n + 1); for (int i = 2; i <= n; i++) { int fa; cin >> fa; adj[fa].push_back(i); adj[i].push_back(fa); } vector<vector<int>> succ(lg(n) + 1, vector<int>(n + 1)); vector<int> depth(n + 1); auto dfs = [&](auto self, int u, int p) -> void { depth[u] = depth[p] + 1; succ[0][u] = p; for (int i = 1; i <= lg(depth[u]); i++) { succ[i][u] = succ[i - 1][succ[i - 1][u]]; } for (int v : adj[u]) if (v != p) { self(self, v, u); } }; dfs(dfs, 1, 0); auto lca = [&](int x, int y) -> int { if (depth[x] < depth[y]) swap(x, y); while (depth[x] > depth[y]) { x = succ[lg(depth[x] - depth[y])][x]; } if (x == y) return x; for (int i = lg(depth[x]); i >= 0; i--) { if (succ[i][x] != succ[i][y]) { x = succ[i][x], y = succ[i][y]; } } return succ[0][x]; }; while (q--) { int a, b; cin >> a >> b; cout << lca(a, b) << '\n'; } }

歐拉迴路

核心思想是做一次 DFS,記錄每個點進出的時間,之後會變成一個一維的時間線,兩個點之中的某個時間點,其深度最小就等同於兩點的 LCA。記錄 timer[cnt] 是進/出時間為 cnt 的點,id[u] 為節點

u 進入的時間點。處理完之後可以在
O(nlogn)
時間內建好 Sparse Table,就可以在
O(1)
時間查詢 LCA。

歐拉迴路 AC Code
// complexity: O(n log n) #include <iostream> #include <vector> using namespace std; int lg(int a) { return 31 - __builtin_clz(a); } int main() { int n, q; cin >> n >> q; vector<vector<int>> adj(n + 1); for (int i = 2; i <= n; i++) { int e; cin >> e; adj[e].push_back(i); } int cnt = 1; vector<int> timer(2 * n + 1), id(n + 1), depth(n + 1); auto dfs = [&](auto self, int u) -> void { id[u] = cnt; timer[cnt++] = u; for (int v : adj[u]) { depth[v] = depth[u] + 1; self(self, v); timer[cnt++] = u; } }; dfs(dfs, 1); vector<vector<int>> st(lg(cnt) + 1, vector<int>(cnt + 1)); for (int i = 1; i <= cnt; i++) st[0][i] = i; for (int i = 1; i <= lg(cnt); i++) { int len = 1 << (i - 1); for (int j = 1; j + len <= cnt; j++) { int a = st[i - 1][j], b = st[i - 1][j + len]; if (depth[timer[a]] <= depth[timer[b]]) st[i][j] = a; else st[i][j] = b; } } while (q--) { int a, b; cin >> a >> b; int l = id[a], r = id[b]; if (l > r) swap(l, r); int k = lg(r - l + 1); a = st[k][l], b = st[k][r - (1 << k) + 1]; if (depth[timer[a]] <= depth[timer[b]]) { cout << timer[a] << '\n'; } else cout << timer[b] << '\n'; }

像歐拉迴路這種把樹「壓平」成一維的操作就是樹壓平。LCA 還有其他各種方式可以求,有一種更簡單的樹壓平求 LCA 方式可以參考這篇文章

很多題目都會需要求出 LCA,例如樹上兩點求最小距離:CSES - Distance Queries

樹壓平

至於為什麼要壓平,當然是套資料結構(剛剛求 LCA 時是套 Sparse Table),根據題目也可以套 BIT、線段樹之類的。
題目:CSES - Subtree Queries

給定一顆樹,有兩種操作:將節點

s 的值修改成
x
/ 查詢節點
s
的子樹值之和。

首先,這看起來就很 BIT,於是我們就把它壓平,記錄 st[u] 為節點

u 在 DFS 下出現的時間,en[u] 則為離開的時間。肉眼可見的發現了他的子樹時間序就全部在 st[u]en[u] 之間了,把
n
個時間點丟進 BIT 裡面(結束的時間點可以重複,所以只需要
n
個時間點),就可以在
O(logn)
時間內改值跟查詢了。

AC Code
// complexity: O(q log n) #include <iostream> #include <vector> using namespace std; using ll = long long; struct BIT { int n; vector<ll> arr, tree; BIT(int N) : n(N) { arr.resize(n + 1); tree.resize(n + 1); } void modify(int k, int x) { int dif = x - arr[k]; arr[k] = x; while (k <= n) { tree[k] += dif; k += k & -k; } } ll query(int k) { ll ret = 0; while (k) { ret += tree[k]; k -= k & -k; } return ret; } }; int main() { int n, q; cin >> n >> q; vector<int> v(n + 1); for (int i = 1; i <= n; i++) { cin >> v[i]; } vector<vector<int>> adj(n + 1); for (int i = 1; i < n; i++) { int a, b; cin >> a >> b; adj[a].push_back(b); adj[b].push_back(a); } int cnt = 0; vector<int> st(n + 1), en(n + 1); auto dfs = [&](auto self, int u, int p) -> void { st[u] = ++cnt; for (int v : adj[u]) if (v != p) { self(self, v, u); } en[u] = cnt; }; dfs(dfs, 1, -1); BIT tree(n); for (int i = 1; i <= n; i++) { tree.modify(st[i], v[i]); } while (q--) { int o, s; cin >> o >> s; if (o == 1) { // modify int x; cin >> x; tree.modify(st[s], x); } else { cout << tree.query(en[s]) - tree.query(st[s] - 1) << '\n'; } } }

壓平後也可以對樹上的一條鏈做查詢(如 SUM),入時間戳設正值出時間戳設負值。

題目

下面的都是好題,還有 CSES 的 Tree Algorithm,樹跟 DP 一樣就是要多刷題,盡量寫,不懂的可以問。

樹直徑

P2195 HXY造公园
P5536 【XR-3】核心城市
P1099 NOIP 2007 提高组 树网的核

LCA

P3938 斐波那契

樹 DP

P1352 没有上司的舞会
P1040 NOIP 2003 提高组 加分二叉树
P1122 最大子树和
P1273 有线电视网
P2014 CTSC1997 选课

參考資料

  1. Competitive Programmer's Handbook
    Laaksonen, A. (2018, July 3). Competitive Programmer's Handbook [PDF]. https://cses.fi/book/book.pdf