# Divide and Conquer (分治法) ## 簡介 分治演算法(Divide-and-Conquer algorithm),又稱分而治之演算法,也有人稱為各個擊破法。分治是一種非常重要的演算法思維模式與策略,有很多重要的演算法都是根據分治的思維模式,例如快速排序法、合併排序法、快速傅立葉轉換(FFT)、矩陣乘法、整數乘法以及在一些在計算幾何的知名演算法都是分治的策略。相較於它的重要性,分治出現在題目中的比例卻不是那麼高。 主要原因有二: 1. 很多分治算法的問題都有別的解法的版本,樹狀圖的分治也往往歸類到動態規劃。 2. 一些重要的分治算法已經被納入在庫存函數中,例如排序,還有一些則因為太難而不適合出題。 即使如此,當碰到一個不曾見過的問題,分治往往是第一個思考的重點,因為分治的架構讓我們很容易找到比暴力要好的方法。 ## 核心想法 分治法主要有三個步驟,切割,解決,合併。 - 切割:把問題以「相同問題」切割成更多小的子問題。 - 解決:把切割完的子問題解決。 - 合併:把解決完的子問題合併起來成為問題的答案。 注意,把問題切割並不是將問題切成不同的問題,而是相同的問題,化成很多比較小的輸入資料。 ## 最大值與最小值問題 ### 問題 在一組數據中,找到其最大值與最小值 ### 思路 - 切割:把區間分成好幾個部分 - 解決:分別求出每個分割出來的最大最小值 - 合併:把每個區間的最大最小值回傳,並且比較誰最大(小) PS:這個問題太簡單了,直接解就可以了,這邊用分治只是為了更清楚分治的流程。 ### 實作 (這邊使用struct/pair讓大家多熟練一下) ```cpp= //struct 實作 #include <bits/stdc++.h> using namespace std; struct info { int mx; int mn; }; vector<int> ls; info solve (int left, int right) { info ans ; if (right - left <= 1) { ans.mx = max(ls[left],ls[right]); ans.mn = min(ls[left],ls[right]); return ans; } ans.mx = max(solve(left,(left+right)/2).mx , solve((left+right)/2 + 1,right).mx); ans.mn = min(solve(left,(left+right)/2).mn , solve((left+right)/2 + 1,right).mn); return ans; } int main() { int n; cin >> n; ls.resize(n); for (int i = 0 ; i < n ; i++) { cin >> ls[i]; } info ans = solve(0,n-1); cout << "max = " << ans.mx << " , min = " << ans.mn << '\n'; } ``` ```cpp= //pair實作 #include <bits/stdc++.h> using namespace std; vector<int> ls; //假設pair的第一個是最大值,第二個是最小值 pair<int,int> solve (int left, int right) { pair<int,int> ans ; if (right - left <= 1) { ans.first = max(ls[left],ls[right]); ans.second = min(ls[left],ls[right]); return ans; } ans.first = max(solve(left,(left+right)/2).first , solve((left+right)/2 + 1,right).first); ans.second = min(solve(left,(left+right)/2).second , solve((left+right)/2 + 1,right).second); return ans; } int main() { int n; cin >> n; ls.resize(n); for (int i = 0 ; i < n ; i++) { cin >> ls[i]; } pair<int,int> ans = solve(0,n-1); cout << "max = " << ans.first << " , min = " << ans.second << '\n'; } ``` ## 回顧合併排序法(待補) ### 實作 ```cpp= #include <bits/stdc++.h> using namespace std; vector<int> ls; vector<int> merge_sort (int left, int right) { if (right - left <= 1) { if (left == right) { return {ls[left]}; }else { if (ls[left] < ls[right]) return {ls[left],ls[right]}; else return {ls[right],ls[left]}; } } vector<int> m1 = merge_sort(left,(left+right)/2); vector<int> m2 = merge_sort((left+right)/2 + 1,right); vector<int> merge_list(right - left + 1); int l1 = 0 , l2 = 0, now = 0; while (l1 != m1.size() || l2 != m2.size()) { if (l1 == m1.size()) { merge_list[now++] = m2[l2++]; }else if (l2 == m2.size()) { merge_list[now++] = m1[l1++]; }else { if (m1[l1] < m2[l2]) merge_list[now++] = m1[l1++]; else merge_list[now++] = m2[l2++]; } } return merge_list; } int main() { ios::sync_with_stdio(false),cin.tie(0); int n ; cin >> n; ls.resize(n); for (int i = 0 ; i < n ; i++) cin >> ls[i]; ls = merge_sort(0,n-1); for (int i = 0 ; i < n ; i++) cout << ls[i] << ' '; cout << '\n'; return 0; } ``` ## 回顧快速排序法(待補) ### 實作 ```cpp= #include<bits/stdc++.h> using namespace std; vector<int> vec; void quicksort (int first, int end) { if (first >= end) return; int j = first+1; for (int i = first+1 ; i <= end ; i++) { if (vec[i]<vec[first]) { swap(vec[j],vec[i]); j++; } } swap(vec[first],vec[j-1]); int mid=j-1; quicksort(first,mid); quicksort(mid+1,end); } int main(){ ios::sync_with_stdio(false),cin.tie(0); vector<int> temp; int n; cin >> n; temp.resize(n); for (int i = 0 ; i < n ; i++) cin >> temp[i]; quicksort(0,n-1); for (int i = 0 ; i < n ; i++) cout << temp[i] << " "; } ``` ## 分治的複雜度 分治是一個遞迴的演算法,不像迴圈的複雜度可以用加總的方法或是乘法計算,遞迴的複雜度是由遞迴關係式(recurrence relation)所表達。計算複雜度必須解遞迴關係式。遞迴關係又稱為差分方程式(difference equation),解遞迴關係是個複雜的事情。 分治演算法的常見形式是將一個大小為 n 的問題切成 a 個大小為 b 的子問題此外必須做一些分割與合併的工作。假設大小為 n 的問題的複雜度是 T(n),而分割合併需要的時間是 f(n),我們可以得到以下遞迴關係: T(n) = a * T(n/b) + f(n),這裡我們省略不整除的處理,因為只計算 big-O,通常不會造成影響。這個遞迴式子幾乎是絕大部分分治時間複雜度的共同形式,所以也有人稱為 divide-and-conquer recurrences,對於絕大部分會碰到的狀況 (a,b, f(n)),這個遞迴式都有公式解,有興趣的請上網查詢 [Master theorem](https://en.wikipedia.org/wiki/Master_theorem_(analysis_of_algorithms))。 分治的複雜度算法,主要有兩種方式,對數學有興趣的可以自行查詢。 1. Recursion-Tree Method搭配Substitution Method(「數學歸納法」) 2. Master Method ## 經典例子 反序問題 ### 題目 考慮一個數列 A[1:n]。如果 A 中兩個數字 A[i]和 A[j]滿足 i<j 且 A[i]>A[j],也就是在前面的比較大,則我們說(a[i],a[j])是一個反序對(inversion)。定義W(A)為數列 A 中反序對數量。 例如,在數列 A=(3,1,9,8,9,2)中,一共有(3,1)、(3,2)、(9,8)、(9,2)、(8,2)、(9,2)一共 6 個反序對,所以 W(A)=6。請注意到序列中有兩個 9 都在 2 之前,因此有兩個(9,2)反序對,也就是說,不同位置的反序對都要計算,不管兩對的內容是否一樣。請撰寫一個程式,計算一個數列 A 的反序數量W(A) 。 輸入格式:第一行是一個正整數 n,代表數列長度,第二行有 n 個非負整數,是依序數列內容,數字間以空白隔開。 n 不超過 1e5 數列內容不超過 1e6。 範例輸入: 6 3 1 9 8 9 2 範例結果: 6 ### 思路 可能會想,就每個都判斷他後面的所有數字有沒有吻合就好了,但這樣的時間複雜度為$O(n^2)$,很明顯不夠好。用分治法的思路來看,。要計算一個區間的反序數量,將區間平分為兩段,一個反序對的兩個數可能都在左邊或都在右邊,否則就是跨在左右兩邊。都在同一邊的可以遞迴解,我們只要會計算跨左右兩邊的就行了,也就是對每一個左邊的元素 x,要算出右邊比 x 小的有幾個。假設對左邊的每一個,去檢查右邊的每一個,那會花 $O(n^2)$,不行,我們聚焦的目標是在不到 $O(n^2)$的時間內完成跨左右的反序數量計算。 記得之前學過的排序應用,假設我們將右邊排序,花 $O(nlog(n))$,然後,每個左邊的 x 可以用二分搜就可算出右邊有多少個小於x,這樣只要花$O(nlog(n))$的時間就可以完成合併,那麼,根據分治複雜度 $T(n) = 2\space T(n/2)+O(nlog(n))$ , $T(n)$的結果是 $O(nlog^2(n))$。 ### 實作 這邊在二分搜的部分使用stl的lower_bound。 ```cpp= #include <bits/stdc++.h> using namespace std; int divide (vector<int> ls, int left, int right) { int mid = (left+right)/2; if (left+1 >= right) return 0 ; long long ans = divide(ls,left,mid) + divide(ls,mid,right); sort(ls.begin() + mid , ls.begin() + right); for (int i = left ; i < mid ; i++) { ans += lower_bound(ls.begin()+mid,ls.begin()+right,ls[i]) - (ls.begin()+mid); } return ans ; } int main() { int n; cin >> n; vector<int> ls(n); for (int i = 0 ; i < n ; i++) { cin >> ls[i]; } cout << divide(ls,0,n) << '\n'; return 0; } ``` 也可以自己寫一個bsearch(),這邊一樣用一般的二分搜 ```cpp= #include <bits/stdc++.h> using namespace std; int bsearch (vector<int> ls, int left, int right, int val) { while (left < right){ int mid = (left + right)/2; if (ls[mid] < val) left = mid + 1; else right = mid; } return (left+right)/2; } int divide (vector<int> ls, int left, int right) { int mid = (left+right)/2; if (left+1 >= right) return 0 ; long long ans = divide(ls,left,mid) + divide(ls,mid,right); sort(ls.begin() + mid , ls.begin() + right); for (int i = left ; i < mid ; i++) { ans += bsearch(ls,mid,right,ls[i]) - mid; } return ans ; } int main() { int n; cin >> n; vector<int> ls(n); for (int i = 0 ; i < n ; i++) { cin >> ls[i]; } cout << divide(ls,0,n) << '\n'; return 0; } ``` ### 還能更快? 上面的寫法是$O(nlog^2(n))$,雖然已經比$O(n^2)$好的多了。不過我們可以注意到,我們只需要再回傳的時候sort就好了,不需要每一次都進行單邊排序。 也就是說在分治的部分可以改寫成 ```cpp= int divide (vector<int> &ls, int left, int right) { int mid = (left+right)/2; if (left+1 >= right) return 0 ; long long ans = divide(ls,left,mid) + divide(ls,mid,right); for (int i = left; i < mid ; i++) { ans += lower_bound (ls.begin()+mid,ls.begin()+right,ls[i]) - (ls.begin()+mid); } sort(ls.begin()+left,ls.begin()+right); return ans ; } ``` 在搜尋的時候其實跟一路慢慢逛過去的時間複雜度差不多,所以其實不需要多用二分。 ```cpp= int divide (vector<int> &ls, int left, int right) { int mid = (left+right)/2; if (left+1 >= right) return 0 ; long long ans = divide(ls,left,mid) + divide(ls,mid,right); for (int i = left , j = mid; i < mid ; i++) { while (j<right && ls[j]<ls[i]) j++; ans += j - mid; } sort(ls.begin()+left,ls.begin()+right); return ans ; } ``` 不過這樣會因為排序所以複雜度還是$O(nlog^2(n))$啊,不過上面的這種寫法,是不是很像用merge sort呢,所以我們在一個很像merge sort裡面還用sort(),很多此一舉,只要配合merge sort的寫法,我們就可以將複雜度降到$O(nlog(n))$了。也就是在兩個排序好的資料中,本身就已經是反序數了。 所以就會改寫成 ```cpp= int divide (vector<int> &ls, int left, int right) { int mid = (left+right)/2; if (left+1 >= right) return 0 ; long long ans = divide(ls,left,mid) + divide(ls,mid,right); int temp[right-left], k=0, j = mid; for (int i = left ; i < mid ; i++) { while (j<right && ls[j]<ls[i]) temp[k++] = ls[j++]; temp[k++] = ls[i]; ans += j - mid; } for (k=left; k < j; k++) ls[k] = temp[k-left]; return ans ; } ``` 這樣時間複雜度就從$O(nlog^2(n))$降為$O(nlog(n))$了 ###### tags: `演算法` {%hackmd aPqG0f7uS3CSdeXvHSYQKQ %}