--- tags: 進階班 --- # DSU & 最小生成樹 (MST) ## DSU 並查集 全名 `disjoint set union` 它是一種資料結構,就跟線段樹一樣 ### 例子 假設今天有個班上有 $n$ 人,且一開始沒有任何好友關係, 而題目有兩種操作: 1. 讓點 `a, b` 成為好友 2. 若秉持「好友的好友也是好友」的規則,求任意兩人 `p, q` 是不是好友? 示意圖: ![](https://i.imgur.com/I8mFz6c.png) 圖中用線相連的兩人就代表互為好友 ### 題目作法 #### DFS 只要從點 `p` 經過一些線走到點 `q`,那答案就是 `yes`,反之則為 `no` 而修改的情況就很直觀地在圖中加一條線就可以了 #### 用陣列存「父節點」 這題目如果有好友關係就能連線,有連線就可以畫出關係「圖」 $\Rightarrow$ 可設定父、子節點 假設我們使用 `anc[i]` 來存取自己的父節點,而如果自己是根節點則 `anc[i] = i` 而一開始沒有任何好友關係,所以所有人的 `anc[i] = i` ```cpp= vector<int> anc; anc.assign(n + 1, 1);//編號 1-base iota(all(anc), 0); //等價於 for(int i = 0; i <= n; i++) anc[i] = i; ``` 找出一個人的根節點可以 $O(n)$ 搜 ```cpp=4 int find(int x) {return anc[x] == x ? x : find(anc[x]);} ``` 假設 `find(p) == find(q)`,那麼兩人就是好友。 接下來是「建立關係」的部分。 建立關係時有兩種情況: 1. 本來就是好友了 (因為好友的好友是好友的規則) 2. 建立關係後才是好友 如果是 1. 就不用任何修改,2. 才需要,因此需要判斷: ```cpp=5 void merge(int p, int q) { if (find(p) == find(q)) return; //這行要有合併的操作 } ``` 而合併的操作則是「讓一個人目前的根節點的父節點 `anc[find(p)]` 變成另外一人的根節點 `find(q)`」 ```cpp=7 void merge(int p, int q) { if (find(p) == find(q)) return; anc[find(p)] = find(q); } ``` 這個操作也是 $O(n)$ 如果出現下列極端情況: ![](https://i.imgur.com/NdEhHqB.png) 求 $5$ 和 $8$ 是不是好朋友? 時間複雜度是 $O(n)$,詢問 $t$ 次就是 $O(nt)$,很慢。 所以就需要 `DSU` 來幫忙解決這類問題。 ### DSU 理論 & 實作 基本上剛剛的陣列存父節點就是 `DSU`,不過可以有很多優化 看關係圖可以發現基本上就是問兩個點是否屬於同個點集,**無關乎自己的父節點是誰** 那麼剛剛的 `anc[i]` 可以試著把「存取根節點」的速度優化 具體來看: ```cpp=4 int find(int x) {return anc[x] == x ? x : anc[x] = find(anc[x]);} ``` 在找尋 `anc[x]` 的父節點 `find(anc[x])` 時,如果順便把 `anc[x]` 修改掉, 那麼後續的查找就壓掉一次遞迴了!(如果是長鏈型的關係圖,效果更驚人) 這樣會使查找複雜度變為 $O(lg\;n)$ (內含複雜數學證明) 接下來是 `merge()` 的部分,由於是兩個點集的合併,點集當然有大小之分, 如果拿小點集去合併大點集就很耗時間, 所以可以多建一個 `size[i]` 來存 「認 `i` 當根節點的人有幾個」 然後用點集大的去合併點集小的 而為了打字方便 (~~10th 進階教學打字有夠慢~~),所以通常都打 `sz[i]` :poop: ```cpp= void merge(int p, int q) { if (find(p) == find(q)) return; if (sz[find(p)] < sz[find(q)]){ sz[find(q)] += sz[find(p)]; anc[find(p)] = find(q); } else sz[find(p)] += sz[find(q)], anc[find(q)] = find(p); } ``` 這樣複雜度也是 $O(lg\;n)$ 但如果合併、查找的優化都有做的話,複雜度會變成 $O(\alpha(n))$ (反阿卡曼函數) ~~它跟巨人沒有任何關係~~ 而 `DSU` 的其中一個應用就是最小生成樹 `Minimum Spanning Tree` ## 什麼是最小生成樹 ### 它是一棵樹 樹的特性是:每個點會有一些邊延伸到其他點,但不能有任何的環 (即一個點經過若干條邊後可以回到自己) 如以下就是一棵樹: ![](https://i.imgur.com/HFxub1p.png) 樹的特性是:假設有 $n$ 個節點,那麼必定只有 $n - 1$ 個邊。 ### 這棵樹的權重最小 一個圖的邊是可以有權重的,我們可以把這種圖稱為「帶權圖」 如以下就是一個帶權圖: ![](https://i.imgur.com/mjG2jvP.png) 而一棵最小生成樹,就是在一個連通圖中選取一些邊,使得這些邊與節點結合之後可以成為一棵樹,且每個邊的權重和最小。 以下是上圖的最小生成樹: ![](https://i.imgur.com/JOTfqfu.png) 邊權和是 $9$,且它是一棵樹! ## 怎麼求最小生成樹 ### Kruskal's algorithm 1. 將每個點分成不同點集 (以下以不同顏色表示不同點集) 2. 把每個邊依大小排序 3. 由最小的邊開始,把兩個點連起來,並**將兩個點整合為同個點集** 4. 如果一個邊連通的兩點是同個點集的,那麼就不須相連 (不然生成出的圖就不是樹了) 以圖來看: ![](https://i.imgur.com/96mtiLN.png) ![](https://i.imgur.com/zjQB2yc.png) ![](https://i.imgur.com/iopikoS.png) ![](https://i.imgur.com/ienX9c6.png) ![](https://i.imgur.com/c5AcwCp.png) 那麼剩下操作就比較直觀,建立一個 `struct road{}` 存下各路的左端、右端及權重 接下來執行上述 $4$ 點的步驟,直到所有點皆為同個點集即可。 :::spoiler `code` ```cpp= #include<bits/stdc++.h> #define LL long long #define all(x) x.begin(), x.end() using namespace std; vector<int> sz, anc; struct road { int v; int l, r; bool operator<(road rd) { return v < rd.v; } }; int find(int x) {return anc[x] == x ? x : anc[x] = find(anc[x]);} bool merge(int l, int r) { if (find(l) == find(r)) return false; if (sz[find(l)] < sz[find(r)]){ sz[find(r)] += sz[find(l)]; anc[find(l)] = find(r); } else sz[find(l)] += sz[find(r)], anc[find(r)] = find(l); return true; } int main() { int n, m, ans; //n 個點,m 條邊 ans = 0; cin >> n >> m; vector<road> rd(m); sz.assign(n + 1, 1); anc.assign(n + 1, 1); iota(all(anc), 0); //一開始將所有人的祖先設為自己 for (auto & i : rd) cin >> i.l >> i.r >> i.v; sort(all(rd)); for (int i = 1, x = 0; x < m && i < n; x++) { if (merge(rd[x].l, rd[x].r)) i++, ans += rd[x].v; } cout << ans << '\n'; return 0; } ``` ::: `Time Complexity:` $O(mlgm + m(\alpha(n)))$ 可以看到 `merge()` 函數傳回值變 `bool` 了 因為這樣可以順便判斷兩點是否屬於同個點集~ ### Prim's algorithm 剛剛的 `Kruskal's algorithm` 是以邊的大小出發,而 `Prim's algorithm` 則是以點出發。 步驟如下: 1. 選定一個起始點 `i` (通常就是第 $1$ 個點) 2. 將點 `i` 有連到的邊存下來,選出最小的邊並連起來,得到另一端點 `j` 3. 將 `j` 設為 `i`,重複第 2. 點直到把整個圖接通 (如果最小邊遇到的點已經被接過就不要連) 以圖來看: ![](https://i.imgur.com/EAjOfJZ.png) ![](https://i.imgur.com/dWq0whO.png) ![](https://i.imgur.com/Vf74VGW.png) ![](https://i.imgur.com/J5Enm5Z.png) ![](https://i.imgur.com/Aqh3u0Z.png) 從樹的定義可以知道,第 2. 點要重複 $n - 1$ 次,且我們可以想辦法讓「選邊」這個動作更快, 因此可以使用 `priority_queue`,「讓查詢最小值」這動作從 $O(m)\rightarrow O(lg\;m)$ :::spoiler `code` ```cpp= #include<bits/stdc++.h> #define LL long long using namespace std; int main(){ int t, n, m, l, r, val; cin >> n >> m; vector<vector<pair<int, int>>> v(n + 1, vector<pair<int, int>>()); while(m--){ cin >> l >> r >> val; v[l].push_back(make_pair(val, r)); v[r].push_back(make_pair(val, l)); } priority_queue<pair<int, int>, vector<pair<int, int>>, greater<pair<int, int>>> pq; int sum = 0, tmp = 1; bool chk[n + 1]; memset(chk, 0, sizeof(chk)); for(int i = 1; i < n;i++){ chk[tmp] = 1; for(int j = 0; j < v[tmp].size(); j++){ pq.push(v[tmp][j]); } while(chk[pq.top().second] && !pq.empty()) pq.pop(); sum += pq.top().first, tmp = pq.top().second, pq.pop(); } cout << sum << '\n'; return 0; } ``` ::: `Time Complexity:` $O(nlgm)$ ## 題目練習:DanDanJudge [a604. "Country" Road (Easy Version)](http://203.64.191.163/ShowProblem?problemid=a604) 由於最後要輸出「上屬和下屬城市是誰」, 用 `Prim's algorithm` 會讓 `priority_queue` 的地方很麻煩且冗長,很有可能出 `bug`, 因此可以選用 `Kruskal's algorithm`,並從第 $1$ 點開始建 `MST`。 因為「上屬不能成為下屬」,所以建完 `MST` 後, 可以使用 `bool used[]` 判斷是否為上屬 (即已經遍歷過的點) :::spoiler `code` ```cpp= #include<bits/stdc++.h> #define LL long long #define f first #define s second #define all(x) x.begin(), x.end() using namespace std; vector<int> sz, anc; struct road { LL v; int l, r; bool operator<(road rd) {return v < rd.v;} }; int find(int x) {return anc[x] == x ? x : anc[x] = find(anc[x]);} bool merge(int l, int r) { if (find(l) == find(r)) return false; if (sz[find(l)] < sz[find(r)]) sz[find(r)] += sz[find(l)], anc[find(l)] = find(r); else sz[find(l)] += sz[find(r)], anc[find(r)] = find(l); return true; } int main() { cin.tie(nullptr), ios_base::sync_with_stdio(false); int n, m; cin >> n >> m; vector<vector<int>> v(n + 1, vector<int> ()); vector<road> rd(m); sz.assign(n + 1, 1); anc.assign(n + 1, 1); iota(all(anc), 0); for (auto & i : rd) cin >> i.l >> i.r >> i.v; sort(all(rd)); for (int i = 1, x = 0; x < m && i < n; x++) { if (merge(rd[x].l, rd[x].r)){ v[rd[x].l].push_back(rd[x].r); v[rd[x].r].push_back(rd[x].l); i++; } } bool used[n + 1]; memset(used, 0, sizeof(used)), used[1] = 1; for (int x = 1; x <= n; x++) { cout << x << ' '; for (size_t y = 0; y < v[x].size(); y++) if (!used[v[x][y]]) cout << v[x][y] << ' ', used[v[x][y]] = 1; cout << '\n'; } return 0; } ``` :::