# 分治
## 3/28 社課
---
### 什麼是分治?
**分**:把大問題分成小問題
**治**:先把小問題解決,再合併回大問題
**分而治之**
----
### 再分得細一點
- 分割問題
- 解決小問題
- 合併問題
這三點就是分治的主要核心!
---
## Merge sort
給定一個長度為 $N$ 的序列 $a$
請你排序他
----
## Merge sort
又回到了排序問題
這次用分治解決看看
----
### 分
**把大問題分成小問題**
大問題:排序 $[1, N]$ 的序列
小問題:排序 $[1, \frac{N}{2}]$ 和 $[\frac{N}{2}+1, N]$ 的序列
----
### 治
**先解決小問題,再合併回大問題**
遞迴到最後都會是長度為 $1$ 的序列
也就是這些序列都已經排序好了(解決最小問題)
----
### 怎麼合併
問題變成
**有兩個已經排好的序列
要合併成一個排序好的序列**
----
先在兩個序列的開頭各放一個指針
這個指針越向右移,數值就會越大
----
每次比較兩個指針所指的數值
將比較小的那個數放到新的(已排序好的)序列
並且移動那個數值比較小的指針
----
{**1**, 3, 4, 6, 7}
{**2**, 4, 5, 8, 9}
{}
----
{1, **3**, 4, 6, 7}
{**2**, 4, 5, 8, 9}
{1}
----
{1, **3**, 4, 6, 7}
{2, **4**, 5, 8, 9}
{1, 2}
----
{1, 3, **4**, 6, 7}
{2, **4**, 5, 8, 9}
{1, 2, 3}
----
{1, 3, 4, **6** ,7}
{2, **4**, 5, 8, 9}
{1, 2, 3, 4}
----
{1, 3, 4, **6**, 7}
{2, 4, **5**, 8, 9}
{1, 2, 3, 4, 4}
----
{1, 3, 4, **6**, 7}
{2, 4, 5, **8**, 9}
{1, 2, 3, 4, 4, 5}
----
{1, 3, 4, 6, **7**}
{2, 4, 5, **8**, 9}
{1, 2, 3, 4, 4, 5, 6}
----
{1, 3, 4, 6, 7}
{2, 4, 5, **8**, 9}
{1, 2, 3, 4, 4, 5, 6, 7}
----
{1, 3, 4, 6, 7}
{2, 4, 5, 8, **9**}
{1, 2, 3, 4, 4, 5, 6, 7, 8}
----
{1, 3, 4, 6, 7}
{2, 4, 5, 8, 9}
{1, 2, 3, 4, 4, 5, 6, 7, 8, 9}
----
只會操作序列長度的次數 $O(N)$
每次將序列分成兩半
共有 $O(\log{N})$ 層
複雜度 $O(N\log{N})$!
----
### Merge sort 程式碼
```cpp=
#include<bits/stdc++.h>
#define fastio ios::sync_with_stdio(0); cin.tie(0);
using namespace std;
const int MAXN=2e5+5;
int n;
int arr[MAXN], tmp[MAXN]; // tmp: 暫存陣列
void merge_sort(int l, int r){ // 左閉右開區間
if(l>=r-1) return; // 長度小於等於 1
int m=(l+r)/2; // 把陣列分成兩半
merge_sort(l, m);
merge_sort(m, r); // 分割成小問題
int i=l, j=m, k=l; // i: 左區間指針,j: 右區間指針,k: 合併區間指針
while(i<m&&j<r){
if(arr[i]<=arr[j]) tmp[k++]=arr[i++];
else tmp[k++]=arr[j++];
}
while(i<m) tmp[k++]=arr[i++]; // 左區間還有剩
while(j<r) tmp[k++]=arr[j++]; // 右區間還有剩
for(int p=l; p<r; p++) arr[p]=tmp[p]; // 把排序好的陣列存回原陣列
}
int main(){
fastio
cin >> n;
for(int i=0; i<n; i++) cin >> arr[i];
merge_sort(0, n);
for(int i=0; i<n; i++) cout << arr[i] << ' ';
}
```
----
## [逆序數對](https://tioj.ck.tp.edu.tw/problems/1080)
在一個序列 $a$ 中
若存在 $i<j$ 而 $a_i>a_j$
則稱 $(i, j)$ 為一個**逆序數對**
給定長度為 $N$ 的序列 $a$
請求出序列 $a$ 中有多少逆序數對
----
一樣大問題先分成小問題
把序列切一半
數量會是
左區間的數量 + 右區間的數量
\+ 橫跨左右區間的數量
----
假設左右兩半都求出來了
要怎麼合併?
----
如果左右的序列都是亂的
那我們只能一個一個排
$O(N^2)$
BAD :(
----
按照 Merge sort 的思想
我們邊排序,邊求數量
----
因為左右兩邊都已經排序好了
左指針到左區間尾部一定是**遞增**的
也就是說
左指針數值如果比右指針數值大
那從左指針一直到左區間尾部一定都比右指針大
因此我們可以在 Merge sort 的過程中
順便用 $O(1)$ 的時間找此時的逆序數對數量
----
又是一個 $O(N\log{N})$
```cpp=
#include<bits/stdc++.h>
#define fastio ios::sync_with_stdio(0); cin.tie(0);
using namespace std;
const int MAXN=2e5+5;
int n;
int arr[MAXN], tmp[MAXN];
long long merge_sort(int l, int r){
if(l>=r-1) return 0; // 長度小於等於 1,不會有逆序數對
int m=(l+r)/2;
long long cnt=merge_sort(l, m)+merge_sort(m, r); // 分割成小問題,先把兩半各自逆序數對求出來
// 現在左右半都排序好了
int i=l, j=m, k=l;
while(i<m&&j<r){
if(arr[i]<=arr[j]) tmp[k++]=arr[i++];
else{
cnt+=m-i; // 左指針數值比右指針數值大,左指針以後的所有數都比右指針大
tmp[k++]=arr[j++];
}
}
while(i<m) tmp[k++]=arr[i++];
while(j<r) tmp[k++]=arr[j++];
for(int p=l; p<r; p++) arr[p]=tmp[p];
return cnt; // 回傳此區間逆序數對數量
}
int main(){
fastio
cin >> n;
for(int i=0; i<n; i++) cin >> arr[i];
cout << merge_sort(0, n) << '\n';
}
```
---
## 主定理
----
像是 Merge sort 這種比較直觀的
可以簡單估複雜度
但如果稍微複雜一點就不容易了
----
但主定理(Master Theorm)
可以輕鬆直觀的解決
----
假設遞迴分治的複雜度
$T(n)=aT(\frac{n}{b})+f(n)$
另有一常數 $\epsilon>0$
- 若 $f(n)=O(n^{\log_{b}{a}-\epsilon})$,$T(n)=O(n^{\log_{b}{a}})$
- 若 $f(n)=\Theta(n^{\log_{b}{a}}\log^{\epsilon}{n})$,則$T(n)=\Theta(f(n)\log{n})$
- 若 $f(n)=\Omega(n^{\log_{b}{a}+\epsilon})$,$T(n)=\Theta(f(n))$
----
簡單來說
$f(n)$ 和 $O(n^{\log_{b}{a}})$ 比大小
如果**多項式**一樣大就在後面加個 $\log{n}$
否則就取大的
其中 $a$ 是遞迴子問題的數量
$\frac{n}{b}$ 是遞迴子問題的大小
$f(n)$ 是除了遞迴所需的複雜度
(每個問題自己的複雜度)
[證明](https://www.cs.cornell.edu/courses/cs3110/2012fa/recitations/mm-proof.pdf)
----
來估估看 Merge sort 的複雜度
每次分成 $2$ 個子問題
每個子問題大小為 $\frac{n}{2}$
問題本身所需複雜度 $O(n)$
----
$T(n)=2T(\frac{n}{2})+O(n)$
$O(n^{\log_{2}{}2})=O(n)$
所以在後面加個 $\log{n}$
因此複雜度為 $O(n\log{n})$
---
## 平面最近點對
----
### [平面最近點對](https://cses.fi/problemset/task/2194/)
給定平面上 $N$ 個點 $(x_i, y_i)$
求最近的兩個點距離平方是多少
距離平方 $=(x_i-x_j)^2+(y_i-y_j)^2$
$2\le N\le 2\times 10^5$
$-10^9\le x_i, y_i\le 10^9$
----
同樣的
如果直接枚舉的話
會是 $O(N^2)$
----
#### 再來考慮分治
分:把平面切成左右兩半
治:已知左右兩半的最短距離
找出橫跨兩邊的最短距離
----
假設已知左右兩半的最短距離
左邊是 $d_l$,右邊是 $d_r$
目前的最短距離就是 $d=\min(d_l, d_r)$
現在要找的就是
有沒有橫跨左右的點對距離小於 $d$
----
最直接的想法
找到中間的那個點
讓他在左右($x$ 座標)範圍 $d$ 之內
看看是否有點對距離小於 $d$
----
有可能所有點 $x$ 座標都在這個距離 $d$ 的範圍內
複雜度回歸 $O(N^2)$
----
可以[證明](https://www.cs.cmu.edu/afs/cs/academic/class/15451-s20/www/lectures/lec23-closest-pair.pdf)
對於一個點 $y$ 座標
不超過 $8$ 個點會落在 $y\pm d$ 之內
不超過 $5$ 個點[證明](https://sprout.tw/algo2023/ppt_pdf/week06/divide_and_conquer_inclass(tp).pdf)
不超過 $3$ 個點[證明](https://web.ntnu.edu.tw/~algo/Point2.html)
----
先找出中間點 $x$ 座標左右範圍 $d$ 之內的所有點
再把這些點對 $y$ 座標排序 $O(N\log{N})$
對於每個點,找他上面的 $8$ 個點
就完成合併了!
----
分析複雜度
$T(n)=2T(\frac{n}{2})+O(n\log{n})$
$T(n)=O(n\log^2{n})$
----
$O(n\log^2{n})$
```cpp=
#include<bits/stdc++.h>
#define fastio ios::sync_with_stdio(0); cin.tie(0);
#define F first
#define S second
#define ll long long
using namespace std;
const ll MAX=9e18;
int n;
pair<ll, ll> p[200005];
ll DQ(int l, int r){
if(l>=r-1) return MAX; // 一個點沒有距離
int m=(l+r)/2;
ll cur=min(DQ(l, m), DQ(m, r)); // 先遞迴
vector<int> v;
for(int i=l; i<r; i++){ // 找到離中間點 x 座標差距 d 以內的數
if((p[i].F-p[m].F)*(p[i].F-p[m].F)<=cur){
v.push_back(i);
}
}
sort(v.begin(), v.end(), [&](int a, int b){return p[a].S<p[b].S;}); // 按照 y 座標排序
for(int i=0; i<v.size(); i++){
for(int j=i+1; j<min((int)v.size(), i+5); j++){ // 往上找 5 個點(最近點在裡面)
ll x=p[v[i]].F-p[v[j]].F, y=p[v[i]].S-p[v[j]].S;
cur=min(cur, x*x+y*y);
}
}
return cur;
}
int main(){
fastio
cin >> n;
for(int i=0; i<n; i++) cin >> p[i].F >> p[i].S;
sort(p, p+n); // 先對 x 座標排序
cout << DQ(0, n) << '\n';
}
```
----
還能更快嗎?
----
那個 $O(n\log{n})$ 是來自排序 $y$ 座標的
可以用 Merge sort 的思想
先用 $O(n)$ 的時間合併兩個已排序的陣列(按 $y$)
這樣在找 $x$ 座標距離小於 $d$ 的點時
就已經自動幫我們把 $y$ 座標排序了
$T(n)=2T(\frac{n}{2})+O(n)$
$T(n)=O(n\log{n})$
----
介紹一個好用的合併函式
`std::merge(a+l, a+r, b+l, b+r, c, cmp);`
可以將已排序的 $a$ 的 $[l, r)$ 和 $b$ 的 $[l, r)$
$O(n)$ 合併到 $c$(也是已排序的)
----
$O(n\log{n})$
```cpp=
#include<bits/stdc++.h>
#define fastio ios::sync_with_stdio(0); cin.tie(0);
#define F first
#define S second
#define ll long long
#define pll pair<ll, ll>
using namespace std;
const ll MAX=9e18;
int n;
pll p[200005];
ll DQ(int l, int r){
if(l>=r-1) return MAX; // 一個點沒有距離
int m=(l+r)/2;
ll cur=min(DQ(l, m), DQ(m, r)); // 先遞迴
pll tmp[r-l+5];
merge(p+l, p+m, p+m, p+r, tmp, [&](pll a, pll b){return a.S<b.S;}); // 先照 y 排序
for(int i=l; i<r; i++) p[i]=tmp[i-l];
vector<int> v;
for(int i=l; i<r; i++){ // 找到離中間點 x 座標差距 d 以內的數
if((p[i].F-p[m].F)*(p[i].F-p[m].F)<=cur){
v.push_back(i); // 這時的 v 已經照 y 排序了
}
}
for(int i=0; i<v.size(); i++){
for(int j=i+1; j<min((int)v.size(), i+5); j++){ // 往上找 5 個點(最近點在裡面)
ll x=p[v[i]].F-p[v[j]].F, y=p[v[i]].S-p[v[j]].S;
cur=min(cur, x*x+y*y);
}
}
return cur;
}
int main(){
fastio
cin >> n;
for(int i=0; i<n; i++) cin >> p[i].F >> p[i].S;
sort(p, p+n); // 先對 x 座標排序
cout << DQ(0, n) << '\n';
}
```
---
## 小結
分治需要注意的點
就是看看問題能不能分割
並且遞迴有邊界
接下來就剩下思考怎麼把多個已解決的小問題
合併成現在的答案了!
----
## 感謝大家
{"title":"03/29 C++社課","contributors":"[{\"id\":\"1a0296c8-ce58-4742-acda-22c02ae81a74\",\"add\":8707,\"del\":974}]"}