# 分治 ## 3/28 社課 --- ### 什麼是分治? **分**:把大問題分成小問題 **治**:先把小問題解決,再合併回大問題 **分而治之** ---- ### 再分得細一點 - 分割問題 - 解決小問題 - 合併問題 這三點就是分治的主要核心! --- ## Merge sort 給定一個長度為 $N$ 的序列 $a$ 請你排序他 ---- ## Merge sort 又回到了排序問題 這次用分治解決看看 ---- ### 分 **把大問題分成小問題** 大問題:排序 $[1, N]$ 的序列 小問題:排序 $[1, \frac{N}{2}]$ 和 $[\frac{N}{2}+1, N]$ 的序列 ---- ### 治 **先解決小問題,再合併回大問題** 遞迴到最後都會是長度為 $1$ 的序列 也就是這些序列都已經排序好了(解決最小問題) ---- ### 怎麼合併 問題變成 **有兩個已經排好的序列 要合併成一個排序好的序列** ---- 先在兩個序列的開頭各放一個指針 這個指針越向右移,數值就會越大 ---- 每次比較兩個指針所指的數值 將比較小的那個數放到新的(已排序好的)序列 並且移動那個數值比較小的指針 ---- {**1**, 3, 4, 6, 7} {**2**, 4, 5, 8, 9} {} ---- {1, **3**, 4, 6, 7} {**2**, 4, 5, 8, 9} {1} ---- {1, **3**, 4, 6, 7} {2, **4**, 5, 8, 9} {1, 2} ---- {1, 3, **4**, 6, 7} {2, **4**, 5, 8, 9} {1, 2, 3} ---- {1, 3, 4, **6** ,7} {2, **4**, 5, 8, 9} {1, 2, 3, 4} ---- {1, 3, 4, **6**, 7} {2, 4, **5**, 8, 9} {1, 2, 3, 4, 4} ---- {1, 3, 4, **6**, 7} {2, 4, 5, **8**, 9} {1, 2, 3, 4, 4, 5} ---- {1, 3, 4, 6, **7**} {2, 4, 5, **8**, 9} {1, 2, 3, 4, 4, 5, 6} ---- {1, 3, 4, 6, 7} {2, 4, 5, **8**, 9} {1, 2, 3, 4, 4, 5, 6, 7} ---- {1, 3, 4, 6, 7} {2, 4, 5, 8, **9**} {1, 2, 3, 4, 4, 5, 6, 7, 8} ---- {1, 3, 4, 6, 7} {2, 4, 5, 8, 9} {1, 2, 3, 4, 4, 5, 6, 7, 8, 9} ---- 只會操作序列長度的次數 $O(N)$ 每次將序列分成兩半 共有 $O(\log{N})$ 層 複雜度 $O(N\log{N})$! ---- ### Merge sort 程式碼 ```cpp= #include<bits/stdc++.h> #define fastio ios::sync_with_stdio(0); cin.tie(0); using namespace std; const int MAXN=2e5+5; int n; int arr[MAXN], tmp[MAXN]; // tmp: 暫存陣列 void merge_sort(int l, int r){ // 左閉右開區間 if(l>=r-1) return; // 長度小於等於 1 int m=(l+r)/2; // 把陣列分成兩半 merge_sort(l, m); merge_sort(m, r); // 分割成小問題 int i=l, j=m, k=l; // i: 左區間指針,j: 右區間指針,k: 合併區間指針 while(i<m&&j<r){ if(arr[i]<=arr[j]) tmp[k++]=arr[i++]; else tmp[k++]=arr[j++]; } while(i<m) tmp[k++]=arr[i++]; // 左區間還有剩 while(j<r) tmp[k++]=arr[j++]; // 右區間還有剩 for(int p=l; p<r; p++) arr[p]=tmp[p]; // 把排序好的陣列存回原陣列 } int main(){ fastio cin >> n; for(int i=0; i<n; i++) cin >> arr[i]; merge_sort(0, n); for(int i=0; i<n; i++) cout << arr[i] << ' '; } ``` ---- ## [逆序數對](https://tioj.ck.tp.edu.tw/problems/1080) 在一個序列 $a$ 中 若存在 $i<j$ 而 $a_i>a_j$ 則稱 $(i, j)$ 為一個**逆序數對** 給定長度為 $N$ 的序列 $a$ 請求出序列 $a$ 中有多少逆序數對 ---- 一樣大問題先分成小問題 把序列切一半 數量會是 左區間的數量 + 右區間的數量 \+ 橫跨左右區間的數量 ---- 假設左右兩半都求出來了 要怎麼合併? ---- 如果左右的序列都是亂的 那我們只能一個一個排 $O(N^2)$ BAD :( ---- 按照 Merge sort 的思想 我們邊排序,邊求數量 ---- 因為左右兩邊都已經排序好了 左指針到左區間尾部一定是**遞增**的 也就是說 左指針數值如果比右指針數值大 那從左指針一直到左區間尾部一定都比右指針大 因此我們可以在 Merge sort 的過程中 順便用 $O(1)$ 的時間找此時的逆序數對數量 ---- 又是一個 $O(N\log{N})$ ```cpp= #include<bits/stdc++.h> #define fastio ios::sync_with_stdio(0); cin.tie(0); using namespace std; const int MAXN=2e5+5; int n; int arr[MAXN], tmp[MAXN]; long long merge_sort(int l, int r){ if(l>=r-1) return 0; // 長度小於等於 1,不會有逆序數對 int m=(l+r)/2; long long cnt=merge_sort(l, m)+merge_sort(m, r); // 分割成小問題,先把兩半各自逆序數對求出來 // 現在左右半都排序好了 int i=l, j=m, k=l; while(i<m&&j<r){ if(arr[i]<=arr[j]) tmp[k++]=arr[i++]; else{ cnt+=m-i; // 左指針數值比右指針數值大,左指針以後的所有數都比右指針大 tmp[k++]=arr[j++]; } } while(i<m) tmp[k++]=arr[i++]; while(j<r) tmp[k++]=arr[j++]; for(int p=l; p<r; p++) arr[p]=tmp[p]; return cnt; // 回傳此區間逆序數對數量 } int main(){ fastio cin >> n; for(int i=0; i<n; i++) cin >> arr[i]; cout << merge_sort(0, n) << '\n'; } ``` --- ## 主定理 ---- 像是 Merge sort 這種比較直觀的 可以簡單估複雜度 但如果稍微複雜一點就不容易了 ---- 但主定理(Master Theorm) 可以輕鬆直觀的解決 ---- 假設遞迴分治的複雜度 $T(n)=aT(\frac{n}{b})+f(n)$ 另有一常數 $\epsilon>0$ - 若 $f(n)=O(n^{\log_{b}{a}-\epsilon})$,$T(n)=O(n^{\log_{b}{a}})$ - 若 $f(n)=\Theta(n^{\log_{b}{a}}\log^{\epsilon}{n})$,則$T(n)=\Theta(f(n)\log{n})$ - 若 $f(n)=\Omega(n^{\log_{b}{a}+\epsilon})$,$T(n)=\Theta(f(n))$ ---- 簡單來說 $f(n)$ 和 $O(n^{\log_{b}{a}})$ 比大小 如果**多項式**一樣大就在後面加個 $\log{n}$ 否則就取大的 其中 $a$ 是遞迴子問題的數量 $\frac{n}{b}$ 是遞迴子問題的大小 $f(n)$ 是除了遞迴所需的複雜度 (每個問題自己的複雜度) [證明](https://www.cs.cornell.edu/courses/cs3110/2012fa/recitations/mm-proof.pdf) ---- 來估估看 Merge sort 的複雜度 每次分成 $2$ 個子問題 每個子問題大小為 $\frac{n}{2}$ 問題本身所需複雜度 $O(n)$ ---- $T(n)=2T(\frac{n}{2})+O(n)$ $O(n^{\log_{2}{}2})=O(n)$ 所以在後面加個 $\log{n}$ 因此複雜度為 $O(n\log{n})$ --- ## 平面最近點對 ---- ### [平面最近點對](https://cses.fi/problemset/task/2194/) 給定平面上 $N$ 個點 $(x_i, y_i)$ 求最近的兩個點距離平方是多少 距離平方 $=(x_i-x_j)^2+(y_i-y_j)^2$ $2\le N\le 2\times 10^5$ $-10^9\le x_i, y_i\le 10^9$ ---- 同樣的 如果直接枚舉的話 會是 $O(N^2)$ ---- #### 再來考慮分治 分:把平面切成左右兩半 治:已知左右兩半的最短距離 找出橫跨兩邊的最短距離 ---- 假設已知左右兩半的最短距離 左邊是 $d_l$,右邊是 $d_r$ 目前的最短距離就是 $d=\min(d_l, d_r)$ 現在要找的就是 有沒有橫跨左右的點對距離小於 $d$ ---- 最直接的想法 找到中間的那個點 讓他在左右($x$ 座標)範圍 $d$ 之內 看看是否有點對距離小於 $d$ ---- 有可能所有點 $x$ 座標都在這個距離 $d$ 的範圍內 複雜度回歸 $O(N^2)$ ---- 可以[證明](https://www.cs.cmu.edu/afs/cs/academic/class/15451-s20/www/lectures/lec23-closest-pair.pdf) 對於一個點 $y$ 座標 不超過 $8$ 個點會落在 $y\pm d$ 之內 不超過 $5$ 個點[證明](https://sprout.tw/algo2023/ppt_pdf/week06/divide_and_conquer_inclass(tp).pdf) 不超過 $3$ 個點[證明](https://web.ntnu.edu.tw/~algo/Point2.html) ---- 先找出中間點 $x$ 座標左右範圍 $d$ 之內的所有點 再把這些點對 $y$ 座標排序 $O(N\log{N})$ 對於每個點,找他上面的 $8$ 個點 就完成合併了! ---- 分析複雜度 $T(n)=2T(\frac{n}{2})+O(n\log{n})$ $T(n)=O(n\log^2{n})$ ---- $O(n\log^2{n})$ ```cpp= #include<bits/stdc++.h> #define fastio ios::sync_with_stdio(0); cin.tie(0); #define F first #define S second #define ll long long using namespace std; const ll MAX=9e18; int n; pair<ll, ll> p[200005]; ll DQ(int l, int r){ if(l>=r-1) return MAX; // 一個點沒有距離 int m=(l+r)/2; ll cur=min(DQ(l, m), DQ(m, r)); // 先遞迴 vector<int> v; for(int i=l; i<r; i++){ // 找到離中間點 x 座標差距 d 以內的數 if((p[i].F-p[m].F)*(p[i].F-p[m].F)<=cur){ v.push_back(i); } } sort(v.begin(), v.end(), [&](int a, int b){return p[a].S<p[b].S;}); // 按照 y 座標排序 for(int i=0; i<v.size(); i++){ for(int j=i+1; j<min((int)v.size(), i+5); j++){ // 往上找 5 個點(最近點在裡面) ll x=p[v[i]].F-p[v[j]].F, y=p[v[i]].S-p[v[j]].S; cur=min(cur, x*x+y*y); } } return cur; } int main(){ fastio cin >> n; for(int i=0; i<n; i++) cin >> p[i].F >> p[i].S; sort(p, p+n); // 先對 x 座標排序 cout << DQ(0, n) << '\n'; } ``` ---- 還能更快嗎? ---- 那個 $O(n\log{n})$ 是來自排序 $y$ 座標的 可以用 Merge sort 的思想 先用 $O(n)$ 的時間合併兩個已排序的陣列(按 $y$) 這樣在找 $x$ 座標距離小於 $d$ 的點時 就已經自動幫我們把 $y$ 座標排序了 $T(n)=2T(\frac{n}{2})+O(n)$ $T(n)=O(n\log{n})$ ---- 介紹一個好用的合併函式 `std::merge(a+l, a+r, b+l, b+r, c, cmp);` 可以將已排序的 $a$ 的 $[l, r)$ 和 $b$ 的 $[l, r)$ $O(n)$ 合併到 $c$(也是已排序的) ---- $O(n\log{n})$ ```cpp= #include<bits/stdc++.h> #define fastio ios::sync_with_stdio(0); cin.tie(0); #define F first #define S second #define ll long long #define pll pair<ll, ll> using namespace std; const ll MAX=9e18; int n; pll p[200005]; ll DQ(int l, int r){ if(l>=r-1) return MAX; // 一個點沒有距離 int m=(l+r)/2; ll cur=min(DQ(l, m), DQ(m, r)); // 先遞迴 pll tmp[r-l+5]; merge(p+l, p+m, p+m, p+r, tmp, [&](pll a, pll b){return a.S<b.S;}); // 先照 y 排序 for(int i=l; i<r; i++) p[i]=tmp[i-l]; vector<int> v; for(int i=l; i<r; i++){ // 找到離中間點 x 座標差距 d 以內的數 if((p[i].F-p[m].F)*(p[i].F-p[m].F)<=cur){ v.push_back(i); // 這時的 v 已經照 y 排序了 } } for(int i=0; i<v.size(); i++){ for(int j=i+1; j<min((int)v.size(), i+5); j++){ // 往上找 5 個點(最近點在裡面) ll x=p[v[i]].F-p[v[j]].F, y=p[v[i]].S-p[v[j]].S; cur=min(cur, x*x+y*y); } } return cur; } int main(){ fastio cin >> n; for(int i=0; i<n; i++) cin >> p[i].F >> p[i].S; sort(p, p+n); // 先對 x 座標排序 cout << DQ(0, n) << '\n'; } ``` --- ## 小結 分治需要注意的點 就是看看問題能不能分割 並且遞迴有邊界 接下來就剩下思考怎麼把多個已解決的小問題 合併成現在的答案了! ---- ## 感謝大家
{"title":"03/29 C++社課","contributors":"[{\"id\":\"1a0296c8-ce58-4742-acda-22c02ae81a74\",\"add\":8707,\"del\":974}]"}
    108 views