# Two Stacks Sorting 筆記 ## 前言 我從[mocha的這篇文章](https://hackmd.io/@mochaowo/SJ_8Z-IS6)中獲得了不少關於CSES Two Stacks Sorting這題的想法,我希望這篇文能把他沒有說明清楚的部分,用我自己的方法來補充。 ## 題敘 >有一個長度為 $n$ 的permutation,你可以執行以下兩種操作的其中一種: >1. 把輸入序列中尚未處理的第一個數字放入兩個stack的其中一個 >2. 把數字從某個stack中取出來並加到陣列 $b$ 的後方 > >請問有沒有可能使陣列 $b$ 為遞增序列? >你需要輸出 $n$ 個數字,第 $i$ 個數字表示數字 $i$ 要被放進哪一個stack裡面。 >如果無解,只輸出"IMPOSSIBLE"即可。 >($1\leq n\leq 2\cdot {10}^5,$ 陣列 $b$ 和兩個stack一開始為空) 題目連結:https://cses.fi/problemset/task/2402 ## 解題 ### 操作序列 這題的要求是輸出每個元素要放到哪一個stack裡面(為了方便,我們把這個序列稱為答案序列),但我們先不考慮這個,我們先考慮「操作的順序」。 定義操作為以下兩種: 1. push x:把數字 $x$ 推到其中一個stack裡面 2. pop x:把數字 $x$ 從它所在的stack中取出 而一個合法的操作序列長度為 $2n$($n$ 個push和 $n$ 個pop)。 現在的問題是:如果你已經知道其中一種答案序列了,那你要照什麼順序執行操作? :::spoiler Solution 這個題目有一個性質是:如果你已經把數字 $1, 2, 3, \dots, n-1$ 都pop到最後的序列了,而且數字 $n$ 已經在某個stack裡面了,那立刻把數字 $n$ pop出來總是最好的選擇。 因此我們可以用greedy的方法來得到最好的操作序列,請參考以下的code。 ```cpp= #include <bits/stdc++.h> using namespace std; const int mxn = (int)(2e5) + 10; int a[mxn], in[mxn], out[mxn]; bool in_stack[mxn]; int main(){ int n; cin >> n; for(int i = 0; i < n; i++){ cin >> a[i]; } int to_add = 0, to_pop = 1; for(int i = 1; i <= n * 2; i++){ if(in_stack[to_pop]){ cout << " pop " << to_pop << '\n'; out[to_pop] = i; to_pop += 1; } else{ cout << "push " << a[to_add] << '\n'; in_stack[a[to_add]] = true; in[a[to_add]] = i; to_add += 1; } } } ``` ::: <br> 你可能會想,求出最好的操作序列對於解這題有什麼幫助? 事實上,我們可以透過執行操作的順序和一些其他的性質,來決定「哪些數字只能放在同一個stack」和「哪些數字只能放在不同的stack」,而越爛的操作序列對於答案序列的限制越多(例如:你如果在該pop的時候不pop,那這個stack就不能再放任何數字進去了)。 ### 存在區間 我們可以把剛剛的程式碼輸出的操作序列畫出來,例如: ``` 6 5 2 4 1 6 3 ``` 這是一筆$n=6$的測資,把它輸入到上面的程式碼之後,可以畫出下圖: ![test1](https://hackmd.io/_uploads/S1_V5g1k0.png) 註:線段左邊的數字代表這個線段對應到測資中的哪個數字。 如果把這些線段對應到座標軸上,那麼數字 $x$ 對應到的區間為 $[in(x), out(x)]$ (為了方便我們把他稱作線段 $x$),而 $in(x),~out(x)$ 分別代表數字 $x$ 被push跟pop的時間點。 $(1\leq in(x) < out(x) \leq 2n)$ 因為題目要求pop出來數字必須由小到大,所以 $\forall 1\leq x<y\leq n, out(x)< out(y)$ ,知道這個性質之後我們就可以知道任意兩線段的關係實際上只有三種: 1. 包含($x<y, in(y)<in(x)<out(x)<out(y)$) ![two stacks sorting (contained)](https://hackmd.io/_uploads/BJVytYRJ0.png) 2. 部分交集($x<y, in(x)<in(y)<out(x)<out(y)$) ![two stacks sorting (partial intersect)](https://hackmd.io/_uploads/Hkb-uY0yC.png) 3. 無交集($x<y, in(x)<out(x)<in(y)<out(y)$) ![two stacks sorting (no intersection)](https://hackmd.io/_uploads/SJju_tAJA.png) ### 暴力作法 透過觀察可以發現:這三種關係之中,只有部分交集會影響到stack的選擇,如果要把部分交集的兩線段 $x,y$ 都放進同一個stack,就會違反stack的FIFO性質。 也就是說:部分交集的兩線段不能放在同一個stack裡面。 這樣我們就可以得到一個非常直覺的 $O(n^2)$ 作法: > 枚舉所有的數對 $(i, j)~(1\leq i<j\leq n)$ ,如果 $in[i]<in[j]<out[i]<out[j]$ ,那就用2-SAT加一條 $color(i)\neq color(j)$ 的邊($color$ 只能是0或1),最後再用2-SAT構造所有點的 $color$ 。($color$ 即為最終答案) :::spoiler 程式碼 註:這種特殊的2-SAT關係可以用並查集來處理,如果我們要表達 $color(i)\neq color(j)$ 的關係,我們可以用 $color(i)=color(j'), color(i')=color(j)$ 來表達,當任意的 $1\leq i\leq n$ 滿足 $color(i)=color(i')$ 時,目前的2-SAT模型就不可能有合法的解。 ```cpp= #include <bits/stdc++.h> using namespace std; const int mxn = (int)(2e5) + 10; int a[mxn]; bool in_stack[mxn]; int in[mxn], out[mxn], dsu[mxn << 1], dsu_sz[mxn << 1], color[mxn << 1]; int dsu_get(int x){ if(dsu[dsu[x]] == dsu[x]) return dsu[x]; return dsu[x] = dsu_get(dsu[x]); } void merg(int x, int y){ x = dsu_get(x); y = dsu_get(y); if(x == y) return; if(dsu_sz[x] > dsu_sz[y]) swap(x, y); dsu[x] = y; dsu_sz[y] += dsu_sz[x]; } int main(){ ios::sync_with_stdio(false); cin.tie(0); int n; cin >> n; for(int i = 1; i <= n; i++){ cin >> a[i]; } int to_add = 1, to_pop = 1; for(int i = 1; i <= n * 2; i++){ if(in_stack[to_pop]){ out[to_pop] = i; to_pop += 1; } else{ in[a[to_add]] = i; in_stack[a[to_add]] = true; to_add += 1; } } for(int i = 1; i <= n; i++){ dsu[i] = i; dsu_sz[i] = 1; dsu[i + n] = i + n; dsu_sz[i + n] = 1; } for(int i = 1; i <= n; i++){ for(int j = i + 1; j <= n; j++){ if(in[i] < in[j] && in[j] < out[i] && out[i] < out[j]){ merg(i, j + n); merg(j, i + n); } } } for(int i = 1; i <= n; i++){ if(dsu_get(i) == dsu_get(i + n)){ cout << "IMPOSSIBLE\n"; return 0; } } for(int i = 1; i <= n; i++){ if(color[dsu_get(i)] == 0){ color[dsu_get(i)] = 1; color[dsu_get(i + n)] = 2; } } for(int i = 1; i <= n; i++){ cout << color[dsu_get(a[i])] << ' '; } cout << '\n'; } ``` ::: ### 更多的觀察 在剛剛的做法中,我們最多需要建 $O(n^2)$ 條邊(2-SAT的邊),但我們可以用一些觀察把邊數壓到 $O(n)$ 。 如果我們嘗試枚舉三線段 $x,y,z$ (不失一般性假設 $1\leq x<y<z\leq n$ )的所有 $3^3$ 種關係(當然其中有些是不合法的,所以實際上不會剛好是 $3^3$ 種),我們會發現其中有一些具有很強的性質: 1. $x,y,z$ 三者互相部分交集 ![upload_947b9bd184beb9740b82a010d9b1baaf](https://hackmd.io/_uploads/SyY8pjyg0.png) 三個式子 $color(x)\neq color(y), color(y)\neq color(z), color(x)\neq color(z)$ 皆須成立,但 $color$ 的選擇只有兩種,所以不管怎麼選都無法使所有條件成立,此時可以直接輸出 IMPOSSIBLE 並立刻終止程式。 2. $x,y$ 包含、 $y,z$ 部分交集、 $x,z$ 部分交集 ![upload_947b9bd184beb9740b82a010d9b1baaf](https://hackmd.io/_uploads/rJitRo1lR.png) 兩個式子 $color(x)\neq color(z), color(y)\neq color(z)$ 皆須成立,所以 $color(x)=color(y)\neq color(z)$ 。 因為所有和 $x$ 為部分交集的 $k$ ($k>x$)都會同樣和 $y$ 是部分交集(包含關係也是同理),所以我們不需要再讓 $x$ 和其他線段連邊(2-SAT的邊)了。 至此我們可以得出以下做法: ::: spoiler pseudo code ```cpp= for(y 從 1 到 n){ for(x ∈ x_list中所有和區間y有交集的線段){ if(x 和 y 為包含關係){ if(存在z使得 y<z 且 x和z為部分交集 且 y和z為部分交集){ 2_SAT.add_edge(x, y, same_color = true); x_list.remove(x); } } else if(x 和 y 為部分交集){ if(存在z使得 y<z 且 x和z為部分交集 且 y和z為部分交集){ print("IMPOSSIBLE"); // 不可能存在合法的答案序列 } else{ 2_SAT.add_edge(x, y, same_color = false); // 即不同顏色 } } else{ print("程式爛了"); } } x_list.add(y); } ``` ::: ### 實作 剛剛的pseudo code中的`x_list`可以用`set<pair<int,int>>`來實作,pair存的是 $[in(i), i]$ ,而第4行跟第10行的`if`判斷可以用一棵區間查詢的BIT來實作。 如果你真的把這個pseudo code實作出來就差不多會是這個樣子: :::spoiler Code ```cpp= #include <bits/stdc++.h> using namespace std; const int mxn = (int)(2e5) + 10; int a[mxn]; bool in_stack[mxn]; int in[mxn], out[mxn], dsu[mxn << 1], dsu_sz[mxn << 1], bit[mxn << 1], color[mxn << 1]; int lowbit(int x){return (x) & (-x);} void upd(int ind, int x){ for(; ind < (mxn << 1); ind += lowbit(ind)){ bit[ind] += x; } } int qry(int ind){ int ret = 0; for(; ind > 0; ind -= lowbit(ind)){ ret += bit[ind]; } return ret; } int dsu_get(int x){ if(dsu[dsu[x]] == dsu[x]) return dsu[x]; return dsu[x] = dsu_get(dsu[x]); } void merg(int x, int y){ x = dsu_get(x); y = dsu_get(y); if(x == y) return; if(dsu_sz[x] > dsu_sz[y]) swap(x, y); dsu[x] = y; dsu_sz[y] += dsu_sz[x]; } int main(){ ios::sync_with_stdio(false); cin.tie(0); int n; cin >> n; for(int i = 1; i <= n; i++){ cin >> a[i]; } int to_add = 1, to_pop = 1; for(int i = 1; i <= n * 2; i++){ if(in_stack[to_pop]){ out[to_pop] = i; to_pop += 1; } else{ in[a[to_add]] = i; in_stack[a[to_add]] = true; to_add += 1; } } for(int i = 1; i <= n; i++){ upd(in[i], 1); dsu[i] = i; dsu_sz[i] = 1; dsu[i + n] = i + n; dsu_sz[i + n] = 1; } set<pair<int,int>> x_list; for(int y = 1; y <= n; y++){ upd(in[y], -1); auto it = x_list.upper_bound({in[y], -1}); if(it != x_list.begin()){ it = prev(it); } while(it != x_list.end()){ int x = it->second; int intersect_L = max(in[y], in[x]); int intersect_R = min(out[y], out[x]); auto it_to_remove = x_list.end(); if(intersect_R - intersect_L > 0){ if(in[y] > in[x]){ if(qry(intersect_R) - qry(intersect_L-1) > 0){ cout << "IMPOSSIBLE\n"; return 0; } else{ merg(y, x + n); merg(x, y + n); } } else{ if(qry(intersect_R) - qry(intersect_L-1) > 0){ merg(y, x); merg(y + n, x + n); } it_to_remove = it; } } it++; if(it_to_remove != x_list.end()){ x_list.erase(it_to_remove); } } x_list.insert({in[y], y}); } for(int i = 1; i <= n; i++){ if(dsu_get(i) == dsu_get(i + n)){ cout << "IMPOSSIBLE\n"; return 0; } } for(int i = 1; i <= n; i++){ if(color[dsu_get(i)] == 0){ color[dsu_get(i)] = 1; color[dsu_get(i + n)] = 2; } } for(int i = 1; i <= n; i++){ cout << color[dsu_get(a[i])] << ' '; } cout << '\n'; } ``` ::: ### 正確性分析 如果你有嘗試去理解剛剛的code,你可能會發現有一些地方可能不太合理。 具體來說,在剛剛的code的第64行,我用的是`upper_bound`,然後我只讓指針向後退一次,這樣會不會造成沒有枚舉到所有和 $y$ 有交集的線段 $x$ ? 我們可以用反證法來證明這種情況不存在。 假設存在兩個 $x$ (假設為 $x_1, x_2$ 且 $x_1<x_2$ ),滿足 $x_1, x_2$ 都在`x_list`中且都和線段 $y$ 有交集且 $in(x_1), in(x_2) < in(y)$ 。 由 $x, y$ 的大小關係可知 $x_1 < x_2 < y$ ,所以 $out(x_1) < out(x_2) < out(y)$,因此 $x_1, x_2, y$ 三者互相為部分交集關係,則程式應該在更之前($x' = x_1, y' = x_2, z' = y$ 的時候)就終止了,故假設矛盾,因此得證:只需要讓`upper_bound`的指針後退一次就可以枚舉到所有`x_list`中和 $y$ 有交集的 $x$ 。 另一個小問題就留給讀者:為什麼指針可以直接跑到`x_list.end()`而不會TLE? ### 複雜度分析 set 插入與刪除最多 $n$ 個元素的複雜度:$O(\log n)\times 2n = O(n\log n)$ 呼叫 $n$ 次`upper_bound`的複雜度:$O(\log n)\times n = O(n\log n)$ 讓指針後退最多 $n$ 次的複雜度:$O(1)\times n = O(n)$ 讓指針前進最多 $n$ 次的複雜度:$O(1)\times n = O(n)$ 執行 $2n$ 次BIT的update和 $4n$ 次 BIT的query的複雜度:$O(\log n)\times (2n+4n) = O(n\log n)$ 呼叫 $O(n)$ 次並查集的複雜度:$O(\alpha(n))\times O(n) = O(n\alpha(n))$ 總複雜度: $O(n\log n)$ ## 致謝 我要特別感謝mocha 要想出這個做法比寫這篇文章難太多了