# 111 選手班 - 線段樹 ###### tags: `宜中資訊` `CP` 2022.08.19 [110 基礎資節 slide](https://hackmd.io/@Ccucumber12/SkuTo_S1Y#/) [110 進階資節 slide](https://hackmd.io/@Ccucumber12/B1PqClBlY#/) ## RMQ 問題 給定 $N$ 個數字的序列 $a_1, a_2, \cdots, a_N$,請支援以下操作 $Q$ 次: - 單點修改:$a_k := a_k + v$ - 區間修改:$a_i := a_i + v\quad (i \in [l, r])$ - 單點查詢:詢問 $a_k$ 的值 - 區間查詢 - 詢問總和:$\displaystyle\sum_{i=l}^r a_i$ - 詢問極值:$\max(a_i | i \in [l, r])$ - 詢問 XOR:$a_l \oplus a_{l+1} \oplus \cdots \oplus a_r$ - 詢問最大公因數:$\gcd(a_l, a_{l+1}, \cdots, a_r)$ - $\cdots$ 應用實例 - DP優化:$\text{dp}[i] = \max\left(\text{dp}[j]\ |\ j \in [1, i-1]\right)$ ## 基礎 ### 問題 給定一個長度 $N$ 的序列 $A$ 個跟 $Q$ 比操作,操作包含兩種: - 修改:把某位置 $A_i$ 的值改成 $k$ - 查詢:查詢某個區間 $A_l\cdots A_r$ 的最大值 ### 想法 #### trivial 直接修改,暴力查詢 - 修改:$O(1)$ - 查詢:$O(n)$ #### 分塊 每 $k$ 個分成一組,記錄每組的最大值 - 修改:$O(k)$ - 查詢:$O(k+\frac{N}{k})$ 取 $k = \sqrt{n}$ 有最小值 - 修改:$O(\sqrt{N})$ - 查詢:$O(\sqrt{N})$ #### 分更小塊...? 把 $k$ 定得更小,但在上面再加上一層... ### 結構 - 完滿二元樹 - 區間為左右子樹的區間聯集 - 從最基本長度為 1 的區間開始倆倆合併,直到包含所有區間 - 高度為 $O(\log N)$ ![](https://i.imgur.com/OS9uzUQ.png) ### 原理 #### 修改 - 把$A_i$ 改成 $k$ - 調整所有包含 $A_i$ 的區間 - $O(\log N)$ #### 查詢 - 查詢區間 $[l, r]$ - 從最上層開始,如果完整包含就回傳 - 否則遞迴左右子樹 - $O(\log N)$ ### 實作 #### 結構 - 用陣列實作完滿二元樹的結構 - 把序列擴充為二的冪次 $MXN$ - 根是 $seg[1]$,序列 $A_i$ 對應到 $seg[MXN+i]$ (0-based) - 左子樹為 $seg[i*2]$ 右子樹為 $seg[i*2+1]$ #### modify - 把葉節點 $seg[MXN+i]$ 改成 $k$ - 調整該點的所有祖先 $x$ - 透過 $\div 2$ 往父節點移動 - $O(\log N)$ #### query - query $[l, r]$ - 從最上層 $f(1,N)$ 開始 - 如果完整包含則 return - 否則遞迴 $f(Lb, Rb) = \max (f(Lb, mid),f(mid+1,Rb))$ - $O(\log N)$ #### build - 找出最適合的二的冪次長度 $MXN$ - 把 $seg[MXN]\cdots seg[MXN+N-1]$ 填滿原序列 - 由下往上遞迴填滿其他區間 - $O(N)$ ### 注意事項 - $seg[]$ 長度大小 $4 * N$ - 往左右遞迴的條件 - 最小值的初始值 (特別是擴充的部分) - 區間設定 (左閉右閉 / 左閉右開 / ...) - 線段樹沒有絕對的寫法,只有喜歡的寫法。請多練習以找到自己最習慣的實作方式 :::spoiler Code ```cpp= #include <bits/stdc++.h> using namespace std ; int N, MXN ; int seg[4000010] ; // 4 * N void build(int lb, int rb, int idx) { if(lb == rb) return ; int mid = (lb + rb) / 2 ; build(lb, mid, idx*2) ; build(mid + 1, rb, idx*2+1) ; seg[idx] = seg[idx*2] + seg[idx*2+1] ; } void modify(int x, int k) { x = x + MXN - 1 ; seg[x] = k ; while(x > 1) { x /= 2 ; seg[x] = seg[x*2] + seg[x*2+1] ; } } int query(int l, int r, int lb, int rb, int idx) { if(l <= lb && rb <= r) { return seg[idx] ; } int mid = (lb + rb) / 2 ; int ret = 0 ; if(l <= mid) ret += query(l, r, lb, mid, idx*2) ; if(r >= mid+1) ret += query(l, r, mid+1, rb, idx*2+1) ; return ret ; } int main() { int a[100010] ; cin >> N ; MXN = 1 ; while(MXN < N) MXN <<= 1 ; for(int i=1; i<=N; ++i) seg[MXN + i - 1] = a[i] ; build(1, MXN, 1) ; } ``` ::: ## 懶人標記 ### 題敘 (區間修改,區間查詢) 給定數列 $a_1, a_2, a_3, \dots a_N$,請支援以下操作 1. $\text{sum L R}$:計算 $a_L+a_{L+1}+\dots+a_R$ 2. $\text{add L R V}$:將$a_L+a_{L+1}+\dots+a_R$ 加上 $V$ $N, Q \leq 10^5$ ### 線段樹 考慮原本的線段樹 - $\text{sum}$:$O(\log N)$ - $\text{add}$:$O((R - L)\log N)$ - Total: $O(QN\log N)$ \:cry: 顯然太爛了,還不如直接一個一個改。 ### 分塊 考慮分塊的作法,對於每一塊除了 $sum$,還額外紀錄 $tag$,代表整個區間被加了多少。 假設每塊的長度是 $L$ - modify - 如果包含整個區間,直接加在 $tag$ 上面 - 如果不包含,則一個一個加 - query - 如果詢問包含整個區間,return $sum + tag*L$ - 如果不包含整個區間,return $\sum a_i + tag*k$,其中 $k$ 是個數 如此一來便可以回到每次操作都 $\mathcal O(\sqrt{N})$ ### 懶人標記 別名:懶惰標記、懶標、Lazy Tag、Lazy Propagation 如果分塊可以,那線段樹應該也行。 對於每個區間額外紀錄 $tag$,然後都把要加上的 $k$ 都打在盡可能上面的 $tag$。 Query的時候,則是每次往下遞迴時,額外加上這層 $tag$ 造成的貢獻。 #### KEY 任意區間皆可在線段樹上表示為 $O(\log N)$ 個區間 Proof: - 每個區間長度皆為 $2^i$ - 每個長度的區間最多兩個 每次操作只更動 $O(\log N)$ 個區間: - 若不須往下遞迴,則將修改的值暫時放在該區間 - 下次需往下遞迴時,再把他加進答案 Modify ```cpp= void modify(int l, int r, int lb, int rb, int idx, int val) { if(l <= lb && rb <= r) { seg[idx] += val * (rb - lb + 1) ; tag[idx] += val ; return ; } int mid = (lb + rb) / 2 ; if(l <= mid) modify(l, r, lb, mid, idx*2, val) ; if(mid < r) modify(l, r, mid+1, rb, idx*2+1, val) ; seg[idx] = seg[idx*2] + seg[idx*2+1] ; } ``` Query ```cpp= int query(int l, int r, int lb, int rb, int idx) { if(l <= lb && rb <= r) return seg[idx] ; int mid = (lb + rb) / 2 ; int ret = tag[idx] * (min(rb,r) - max(l,lb) + 1) ; if(l <= mid) ret += query(l, r, lb, mid, idx*2) ; if(mid < r) ret += query(l, r, mid+1, rb, idx*2+1) ; return ret ; } ``` ## push & pull ### 題敘 (區間修改,區間查詢) 給定一個長度為 $N$ 的序列與 $M$ 比操作 - 把某個區間 $[l,r]$ 都加上 $k$ - 把某個區間 $[l,r]$ 都乘上 $k$ - 詢問某個區間 $[l,r]$ 的和 ### 懶人標記 現在不能把 $tag$ 留在原地了,因為我們不知道他要往下乘多少。 但我們一樣可以先把他暫時留在那裡,等有需要的時候在往下推。 - 區間不完整包含: - 將該格 $\text{tag}$ 往下推 ($\text{push}$) - 往下遞迴 - 區間完整包含: - 修改該格的 $\text{val}$ - 暫時將修改的值記錄在 $\text{tag}$ 如果能保證每次往下推都是 $\mathcal O(1)$,那就不會增加我們最後的時間複雜度。 我們把**往下推**,跟**重新計算自己這格**寫成兩個函式 **push** ```c= void push(int lb, int rb, int idx) { int len = (rb - lb + 1) / 2 ; seg[idx * 2].value += seg[idx].tag * len ; seg[idx*2+1].value += seg[idx].tag * len ; seg[idx * 2].tag += seg[idx].tag ; seg[idx*2+1].tag += seg[idx].tag ; seg[idx].tag = 0 ; } ``` **pull** ```c= void pull(int idx) { seg[idx].value = seg[idx*2].value + seg[idx*2+1].value ; } ``` :::spoiler Code ```c= const int N = 100000 ; int sum[4*N], add[4*N], mul[4*N] ; int MXN = 1 ; void push(int lb, int rb, int idx) { int l = idx*2, r = idx*2+1 ; int len = (rb - lb + 1)/2 ; sum[l] *= mul[idx] ; sum[r] *= mul[idx] ; sum[l] += add[idx] * len ; sum[r] += add[idx] * len; add[l] *= mul[idx] ; add[r] *= mul[idx] ; add[l] += add[idx] ; add[r] += add[idx] ; mul[l] *= mul[idx] ; mul[r] *= mul[idx] ; add[idx] = 0 ; mul[idx] = 1 ; } void pull(int idx) { sum[idx] = sum[idx*2] + sum[idx*2+1] ; } void rangeAdd(int l, int r, int k, int lb, int rb, int idx) { if(l <= lb && rb <= r) { sum[idx] += k * (rb - lb + 1); add[idx] += k ; return ; } push(lb, rb, idx) ; int mid = lb + rb >> 1 ; if(l <= mid) rangeAdd(l, r, k, lb, mid, idx*2) ; if(mid < r) rangeAdd(l, r, k, mid+1, rb, idx*2+1) ; pull(idx) ; } void rangeMul(int l, int r, int k, int lb, int rb, int idx) { if(l <= lb && rb <= r) { sum[idx] *= k ; add[idx] *= k ; // (x+2) * 4 = x*4 + 2*4 mul[idx] *= k ; return ; } push(lb, rb, idx) ; int mid = lb + rb >> 1 ; if(l <= mid) rangeMul(l, r, k, lb, mid, idx*2) ; if(mid < r) rangeMul(l, r, k, mid+1, rb, idx*2+1) ; pull(idx) ; } int query(int l, int r, int lb, int rb, int idx) { if(l <= lb && rb <= r) return sum[idx] ; push(lb, rb, idx) ; int mid = lb + rb >> 1, ret = 0 ; if(l <= mid) ret += query(l, r, lb, mid, idx*2) ; if(mid < r) ret += query(l, r, mid+1, rb, idx*2+1) ; // pull(idx) ; return ret; } ``` ::: ## 題單 - [Contest](https://vjudge.net/contest/511280) - Password: `111apcs`