# 二分搜 簡而言之就是用 $O(log(n))$ 的時間複雜度去在一個單調陣列中找到索引值最小(或最大)且滿足題目所求之值 ## 二分搜的基本介紹 先看這道題,想想如果是你會怎麼做 ### 題意 有一個長度為 $n$ 且只由 $0$ 和 $1$ 組成的陣列,且滿足所有 $0$ 都在 $1$ 的左方,問你第一個出現的 $1$ 的 index :::spoiler 如果看不懂題目的 ||對不起我不會敘述反正|| 陣列可以長 ```0 0 0 1 1 1 1``` 或 ```0 0 1 1 1``` 但不能長這樣 ```1 0 0 1 1 1``` 或 ```0 0 0 1 1 0``` 反正就是不能有任何 $1$ 在 $0$ 的左邊,也不能有任何 $0$ 在 $1$ 的右邊,然後問你從左數到右第一個 $1$ 出現的位置 ::: ### why 二分搜 ? 若是使用傳統的方法,直接用`for`迴圈由左找到右,時間複雜度會是 $O(n)$,若 $n$ 太大將會超時 **所以在此我們要介紹二分搜** ### 概念介紹 二分搜的基本精神,就是維護一個閉區間 $[left,\ right]$ 並在每一步驟都慢慢縮小,且確保答案的值在此區間中,直到此區間中只有一個數值,也就是 $left == right$,此值就將會是我們要的答案 以此題為例 我們先定義在索引值 $n$ 的地方有一個必能滿足題目所求之值,所以如果最後搜尋到的索引值為 $n$ 會代表陣列中沒有任何值會符合 (也就是整個陣列都是 $0$ ) 在進行二分搜前,初始化兩個邊界,$left = 0,\ right = n$ 再來,將執行以下步驟直到 $left == right$ 1. 用 $mid$ 將此陣列分為兩半 $[left, mid]$ 和 $[mid+1, right]$ $int \ mid = (right + left)\ /\ 2;$ 2. 分為以下兩種情況討論 * 若是 $arr[\ mid\ ] == 0$ : 代表答案在 $mid$ 的右邊,也就是 $[mid+1, right]$ 區間內,所以使左界 $left = mid+1$ * 若是 $arr[\ mid\ ] == 1$ : 代表答案在 $mid$ 的左邊或是答案 $=mid$,也就是 $[left, mid]$ 區間內,所以使右界 $right = mid$ :::spoiler 圖示 ||對不起我畫好醜|| **$n=8$** **$arr=\{0,0,1,1,1,1,1,1\}$** ![S__2949134_0](https://hackmd.io/_uploads/SyTx0iYebe.jpg) ![S__2949135_0](https://hackmd.io/_uploads/BkbppoKlbl.jpg) ![S__2949136_0](https://hackmd.io/_uploads/SJYA6itebl.jpg) ![S__2949137_0](https://hackmd.io/_uploads/rJv1AiYlZe.jpg) 到這裡我們就可以發現左右邊界相等,所以答案為此值 (2) ::: 提醒一下,不用擔心 $mid$ 在 $arr$ 中會溢出,因為若是跑到迴圈裡,必會滿足 $left<right$,又 $mid = (right + left)\ /\ 2$ 且 $right\ \le\ n$ , 所以 $mid$ 不可能會等於 $n$ ,所以不會有溢位的問題 ### 實作 ```cpp= #include <iostream> using namespace std; int arr[100001]; int main(){ int n; cin >> n; for(int i = 0; i < n; i++){ cin >> arr[i]; } //初始化左界,右界 int left = 0, right = n; while(left != right){ int mid = (left + right) / 2; //藉由 mid 的狀態來判斷答案在其右邊(包含mid)抑或者是左邊 if(arr[mid] == 1){ right = mid; }else{ left = mid + 1; } } cout << right << '\n'; } ``` ## 二分搜可以幹嘛 ? 若你認為二分搜只能處理上面的那種題目,那就大錯特錯了,以下為二分搜最常見之應用 ### 題義 在一個嚴格非遞減的陣列 `arr` 中,定義一數 $x$,我們可以透過二分搜實現以 $O(log n)$ 的時間複雜度找到一滿足下列條件的最小 $y$ 值 : ```arr[y] >= x``` ### 邏輯 根據前面的介紹,你可以將此陣列的值分為兩種 : 1. $arr[\ i\ ] < x$ 2. $arr[\ i\ ] \ge x$ 且把第一種情況當作 $0$,第二種情況當作 $1$ 這樣你就可以按照前一題的思考邏輯,去找到第一個 $1$ 出現的索引值,也就是第一個滿足 $arr[\ i\ ] \ge x$ 的 $i$ 值 ||也就是 $y$|| ### 程式 ```cpp= #include <iostream> #include <vector> #include <algorithm> using namespace std; vector<int> vec; int main(){ //n 為陣列長度, k為想找的值 int n, k; cin >> n >> k; vec.resize(n); for(int i = 0; i < n; i++){ cin >> vec[i]; } //這步很重要!!! 因二分搜只能用在單調性序列上,所以需要排序 sort(vec.begin(), vec.end()); int left = 0, right = n; while(left != right){ int mid = (left + right) / 2; //藉由 mid 的狀態來判斷答案在其右邊(包含mid)抑或者是左邊 if(vec[mid] >= k) right = mid; else left = mid + 1; } if(right == n) cout << "無任一數符合" << '\n'; else{ cout << "排序後的索引值為 " << right << '\n'; cout << "值為 " << vec[right] << '\n'; } } ``` 因此處的right是排序後的索引值,若是要求排序前的,需使用pair儲存,可參考以下程式 :::spoiler Code ```cpp= #include <iostream> #include <vector> #include <algorithm> using namespace std; vector<pair<int, int>> vec; int main(){ //n 為陣列長度, k為想找的值 int n, k; cin >> n >> k; vec.resize(n); for(int i = 0; i < n; i++){ int t; cin >> t; //vec[i].first 存值, second 存排序前索引值 vec[i] = {t, i}; } //這步很重要!!! 因二分搜只能用在單調性序列上,所以需要排序 sort(vec.begin(), vec.end()); int left = 0, right = n; while(left != right){ int mid = (left + right) / 2; //藉由 mid 的狀態來判斷答案在其右邊(包含mid)抑或者是左邊 if(vec[mid].first >= k) right = mid; else left = mid + 1; } if(right == n) cout << "無任一數符合" << '\n'; else{ cout << "索引值為 " << vec[right].second << '\n'; cout << "值為 " << vec[right].first << '\n'; } } ``` ::: ### 練習題 #### [ABC231 C – Counting 2](https://atcoder.jp/contests/abc231/tasks/abc231_c) :::spoiler 題意 給定一個長度為 $N$ 的整數序列 $A$ 接著有 $Q$ 個查詢,每個查詢給一個整數 $x$ 對於每個查詢,你需要回答: 在 $A$ 中,有多少元素的值 $\ge x$? ||$by\ chatGPT$|| ::: :::spoiler 思路 先將陣列排序,透過二分搜尋找第一個 $\ge x$的索引值(排序後的) $y$,$n-y$ 即為所求 ::: :::spoiler AC Code ```cpp= #include <iostream> #include <vector> #include <algorithm> using namespace std; vector<int> vec; int main(){ int n, q; cin >> n >> q; vec.resize(n); for(int i = 0; i < n; i++){ cin >> vec[i]; } sort(vec.begin(), vec.end()); for(int i = 0; i < q; i++){ int x; cin >> x; int left = 0, right = n; while(left != right){ int mid = (left + right) / 2; if(vec[mid] >= x) right = mid; else left = mid + 1; } cout << n - right << '\n'; } } ``` ::: #### [Zerojudge d732. 二分搜尋法](https://zerojudge.tw/ShowProblem?problemid=d732) :::spoiler 思路 利用 ```pair<int, int>``` 儲存 {值, 初始索引值},一樣拿去 ```sort``` 之後用二分搜尋找第一個 ```>=x``` 的 ```vec[i].first``` 再分為三種情況,若 ```i==n``` 或是 ```vec[i].first != x``` 就輸出 $0$,否則輸出 ```vec[i].second``` ::: :::spoiler AC code ```cpp= #include <iostream> #include <vector> #include <algorithm> using namespace std; vector<pair<int, int>> vec; int main() { int n, k; cin >> n >> k; for(int i = 0; i < n; i++){ int t; cin >> t; vec.emplace_back(t, i+1); } sort(vec.begin(), vec.end()); for(int i = 0; i < k; i++){ int x; cin >> x; //初始化左右界 int left = 0, right = n; while(left != right){ int mid = (left + right) / 2; if(vec[mid].first >= x) right = mid; else left = mid + 1; } if(right == n || vec[right].first != x) cout << 0 << '\n'; else cout << vec[right].second << '\n'; } } ``` ::: ## STL 中的二分搜 如果使用二分搜都用手刻的話可能有點麻煩,所以我們可以使用以下兩種函式快速尋找指向陣列中第一個大於或大於等於某個數的迭代器 但首先,我們需要先引用下述之標頭檔 ```cpp #include <algorithm> ``` ### lower_bound ```lower_bound``` 的功能就是回傳指向陣列中第一個 $\ge x$ 的值的迭代器,使用方法如下 : ```cpp auto it = lower_bound(vec.begin(), vec.end(), x); ``` 詳細使用方法請參考程式 ```cpp= #include <iostream> #include <vector> #include <algorithm> using namespace std; vector<int> vec; int main(){ //n 為陣列長度,x為欲找之值 int n, x; cin >> n >> x; vec.resize(n); for(int i = 0; i < n; i++){ cin >> vec[i]; } //還是要記得二分搜只有在單調序列中可以用,所以要排序 sort(vec.begin(), vec.end()); auto it = lower_bound(vec.begin(), vec.end(), x); //若是無任一數符合情況,迭代器將會指向 vec.end() if(it == vec.end()) cout << "無任一數符合\n"; else{ //若需取值,需在迭代器前加上 * 號 cout << "第一個 >= n 的值為 " << *it << '\n'; //若要取得索引值,只需用 it - vec.begin() 即可,因此值之位置和索引值為 0 之位置相減為索引值之位置 cout << "此數排序後之索引值為 " << it - vec.begin() << '\n'; } } ``` :::info 總而言之有以下三點需要注意的 * 若是無任一數符合,回傳的迭代器會指向 $vec.end()$ * 若要取此迭代器指向的值,需輸出 $*it$ * 若要取索引值,需輸出 $it - vec.begin()$ ::: ### upper_bound ```upper_bound``` 的功能就是回傳指向陣列中第一個 $>x$ 的值的迭代器,使用方法如下 : ```cpp auto it = upper_bound(vec.begin(), vec.end(), x); ``` 整體而言和 ```lower_bound``` 大同小異,只差在於一個是取 $\ge x$ 一個是取 $>x$ ```cpp= #include <iostream> #include <vector> #include <algorithm> using namespace std; vector<int> vec; int main(){ //n 為陣列長度,x為欲找之值 int n, x; cin >> n >> x; vec.resize(n); for(int i = 0; i < n; i++){ cin >> vec[i]; } //還是要記得二分搜只有在單調序列中可以用,所以要排序 sort(vec.begin(), vec.end()); auto it = upper_bound(vec.begin(), vec.end(), x); //若是無任一數符合情況,迭代器將會指向 vec.end() if(it == vec.end()) cout << "無任一數符合\n"; else{ //若需取值,需在迭代器前加上 * 號 cout << "第一個 > n 的值為 " << *it << '\n'; //若要取得索引值,只需用 it - vec.begin() 即可,因此值之位置和索引值為 0 之位置相減為索引值之位置 cout << "此數排序後之索引值為 " << it - vec.begin() << '\n'; } } ``` ### 練習題 #### [ABC231 C – Counting 2](https://atcoder.jp/contests/abc231/tasks/abc231_c) 若是用 ```lower_bound``` 寫此題的話要怎麼做呢 ? :::spoiler AC Code ```cpp= #include <iostream> #include <vector> #include <algorithm> using namespace std; vector<int> vec; int main(){ int n, q; cin >> n >> q; vec.resize(n); for(int i = 0; i < n; i++){ cin >> vec[i]; } sort(vec.begin(), vec.end()); for(int i = 0; i < q; i++){ int x; cin >> x; auto it = lower_bound(vec.begin(), vec.end(), x); //index = it - vec.begin() cout << n - (it - vec.begin()) << '\n'; } } ``` ::: #### [ABC077 C – Snuke Festival](https://atcoder.jp/contests/abc077/tasks/arc084_a) ::: spoiler 題意 有三個長度皆為 $N$ 的序列 A, B, C,你需要計算有多少三元組 (a, b, c) 滿足: $a\ \in\ A,\ b\ \in\ B,\ c\ \in\ C$ 且 $a<b<c$ ::: :::spoiler 思路 此題 $N$ 最大可到 $10^5$,若是直接用```for```迴圈硬炸的話,肯定會超時,所以我們可以透過以下方法來達到 $O(nlogn)$ 首先先枚舉 $b$ ,對於每個 $b$,可透過二分搜的用 $A$ 中 $<b$ 的值的數量,在用一次得到 $C$ 中 $>b$ 的值的數量,將此兩值相乘後加到 $ans$ 裡,這樣當我們把所有 $b$ 跑過一次後,$ans$ 即為所求 ::: :::spoiler AC Code ```cpp= #include <iostream> #include <vector> #include <algorithm> using namespace std; #define ll long long vector<int> A, B, C; int main() { int n; cin >> n; A.resize(n); B.resize(n); C.resize(n); for(int i = 0; i < n; i++) cin >> A[i]; for(int i = 0; i < n; i++) cin >> B[i]; for(int i = 0; i < n; i++) cin >> C[i]; sort(A.begin(), A.end()); sort(B.begin(), B.end()); sort(C.begin(), C.end()); ll ans = 0; //枚舉 b for(int i = 0; i < n; i++){ int b = B[i]; // 若 a[0] 不符合, a的個數為 lower_bound 搜到的地方(A.begin()) - A.begin(),符合 na 之定義 // 若 a[0] 符合, a 的範圍可從 0 到 lower_bound 搜到的地方 - 1,個數為 lower_bound 搜到的地方 - 1 - A.begin() + 1 auto na = lower_bound(A.begin(), A.end(), b) - A.begin(); // 若 c[n-1] 不符合, c 的個數為 C.end() - upper_bound 搜到的地方(C.end()),符合 nc 之定義 // 若 c[n-1] 符合, c 的範圍可從 C.end() - 1 到 upper_bound 搜到的地方,個數為 C.end() - 1 - upper_bound + 1 auto nc = C.end() - upper_bound(C.begin(), C.end(), b); // 當中間的數 = b 時,{a, b, c}可能的組數為 na * nc ans += na * nc; } cout << ans << '\n'; } ``` :::