# 111 選手班 - 線段樹
###### tags: `宜中資訊` `CP`
2022.08.19
[110 基礎資節 slide](https://hackmd.io/@Ccucumber12/SkuTo_S1Y#/)
[110 進階資節 slide](https://hackmd.io/@Ccucumber12/B1PqClBlY#/)
## RMQ 問題
給定 $N$ 個數字的序列 $a_1, a_2, \cdots, a_N$,請支援以下操作 $Q$ 次:
- 單點修改:$a_k := a_k + v$
- 區間修改:$a_i := a_i + v\quad (i \in [l, r])$
- 單點查詢:詢問 $a_k$ 的值
- 區間查詢
- 詢問總和:$\displaystyle\sum_{i=l}^r a_i$
- 詢問極值:$\max(a_i | i \in [l, r])$
- 詢問 XOR:$a_l \oplus a_{l+1} \oplus \cdots \oplus a_r$
- 詢問最大公因數:$\gcd(a_l, a_{l+1}, \cdots, a_r)$
- $\cdots$
應用實例
- DP優化:$\text{dp}[i] = \max\left(\text{dp}[j]\ |\ j \in [1, i-1]\right)$
## 基礎
### 問題
給定一個長度 $N$ 的序列 $A$ 個跟 $Q$ 比操作,操作包含兩種:
- 修改:把某位置 $A_i$ 的值改成 $k$
- 查詢:查詢某個區間 $A_l\cdots A_r$ 的最大值
### 想法
#### trivial
直接修改,暴力查詢
- 修改:$O(1)$
- 查詢:$O(n)$
#### 分塊
每 $k$ 個分成一組,記錄每組的最大值
- 修改:$O(k)$
- 查詢:$O(k+\frac{N}{k})$
取 $k = \sqrt{n}$ 有最小值
- 修改:$O(\sqrt{N})$
- 查詢:$O(\sqrt{N})$
#### 分更小塊...?
把 $k$ 定得更小,但在上面再加上一層...
### 結構
- 完滿二元樹
- 區間為左右子樹的區間聯集
- 從最基本長度為 1 的區間開始倆倆合併,直到包含所有區間
- 高度為 $O(\log N)$
![](https://i.imgur.com/OS9uzUQ.png)
### 原理
#### 修改
- 把$A_i$ 改成 $k$
- 調整所有包含 $A_i$ 的區間
- $O(\log N)$
#### 查詢
- 查詢區間 $[l, r]$
- 從最上層開始,如果完整包含就回傳
- 否則遞迴左右子樹
- $O(\log N)$
### 實作
#### 結構
- 用陣列實作完滿二元樹的結構
- 把序列擴充為二的冪次 $MXN$
- 根是 $seg[1]$,序列 $A_i$ 對應到 $seg[MXN+i]$ (0-based)
- 左子樹為 $seg[i*2]$ 右子樹為 $seg[i*2+1]$
#### modify
- 把葉節點 $seg[MXN+i]$ 改成 $k$
- 調整該點的所有祖先 $x$
- 透過 $\div 2$ 往父節點移動
- $O(\log N)$
#### query
- query $[l, r]$
- 從最上層 $f(1,N)$ 開始
- 如果完整包含則 return
- 否則遞迴 $f(Lb, Rb) = \max (f(Lb, mid),f(mid+1,Rb))$
- $O(\log N)$
#### build
- 找出最適合的二的冪次長度 $MXN$
- 把 $seg[MXN]\cdots seg[MXN+N-1]$ 填滿原序列
- 由下往上遞迴填滿其他區間
- $O(N)$
### 注意事項
- $seg[]$ 長度大小 $4 * N$
- 往左右遞迴的條件
- 最小值的初始值 (特別是擴充的部分)
- 區間設定 (左閉右閉 / 左閉右開 / ...)
- 線段樹沒有絕對的寫法,只有喜歡的寫法。請多練習以找到自己最習慣的實作方式
:::spoiler Code
```cpp=
#include <bits/stdc++.h>
using namespace std ;
int N, MXN ;
int seg[4000010] ; // 4 * N
void build(int lb, int rb, int idx) {
if(lb == rb) return ;
int mid = (lb + rb) / 2 ;
build(lb, mid, idx*2) ;
build(mid + 1, rb, idx*2+1) ;
seg[idx] = seg[idx*2] + seg[idx*2+1] ;
}
void modify(int x, int k) {
x = x + MXN - 1 ;
seg[x] = k ;
while(x > 1) {
x /= 2 ;
seg[x] = seg[x*2] + seg[x*2+1] ;
}
}
int query(int l, int r, int lb, int rb, int idx) {
if(l <= lb && rb <= r) {
return seg[idx] ;
}
int mid = (lb + rb) / 2 ;
int ret = 0 ;
if(l <= mid) ret += query(l, r, lb, mid, idx*2) ;
if(r >= mid+1) ret += query(l, r, mid+1, rb, idx*2+1) ;
return ret ;
}
int main() {
int a[100010] ;
cin >> N ;
MXN = 1 ;
while(MXN < N)
MXN <<= 1 ;
for(int i=1; i<=N; ++i)
seg[MXN + i - 1] = a[i] ;
build(1, MXN, 1) ;
}
```
:::
## 懶人標記
### 題敘
(區間修改,區間查詢)
給定數列 $a_1, a_2, a_3, \dots a_N$,請支援以下操作
1. $\text{sum L R}$:計算 $a_L+a_{L+1}+\dots+a_R$
2. $\text{add L R V}$:將$a_L+a_{L+1}+\dots+a_R$ 加上 $V$
$N, Q \leq 10^5$
### 線段樹
考慮原本的線段樹
- $\text{sum}$:$O(\log N)$
- $\text{add}$:$O((R - L)\log N)$
- Total: $O(QN\log N)$ \:cry:
顯然太爛了,還不如直接一個一個改。
### 分塊
考慮分塊的作法,對於每一塊除了 $sum$,還額外紀錄 $tag$,代表整個區間被加了多少。
假設每塊的長度是 $L$
- modify
- 如果包含整個區間,直接加在 $tag$ 上面
- 如果不包含,則一個一個加
- query
- 如果詢問包含整個區間,return $sum + tag*L$
- 如果不包含整個區間,return $\sum a_i + tag*k$,其中 $k$ 是個數
如此一來便可以回到每次操作都 $\mathcal O(\sqrt{N})$
### 懶人標記
別名:懶惰標記、懶標、Lazy Tag、Lazy Propagation
如果分塊可以,那線段樹應該也行。
對於每個區間額外紀錄 $tag$,然後都把要加上的 $k$ 都打在盡可能上面的 $tag$。
Query的時候,則是每次往下遞迴時,額外加上這層 $tag$ 造成的貢獻。
#### KEY
任意區間皆可在線段樹上表示為 $O(\log N)$ 個區間
Proof:
- 每個區間長度皆為 $2^i$
- 每個長度的區間最多兩個
每次操作只更動 $O(\log N)$ 個區間:
- 若不須往下遞迴,則將修改的值暫時放在該區間
- 下次需往下遞迴時,再把他加進答案
Modify
```cpp=
void modify(int l, int r, int lb, int rb, int idx, int val) {
if(l <= lb && rb <= r) {
seg[idx] += val * (rb - lb + 1) ;
tag[idx] += val ;
return ;
}
int mid = (lb + rb) / 2 ;
if(l <= mid) modify(l, r, lb, mid, idx*2, val) ;
if(mid < r) modify(l, r, mid+1, rb, idx*2+1, val) ;
seg[idx] = seg[idx*2] + seg[idx*2+1] ;
}
```
Query
```cpp=
int query(int l, int r, int lb, int rb, int idx) {
if(l <= lb && rb <= r)
return seg[idx] ;
int mid = (lb + rb) / 2 ;
int ret = tag[idx] * (min(rb,r) - max(l,lb) + 1) ;
if(l <= mid) ret += query(l, r, lb, mid, idx*2) ;
if(mid < r) ret += query(l, r, mid+1, rb, idx*2+1) ;
return ret ;
}
```
## push & pull
### 題敘
(區間修改,區間查詢)
給定一個長度為 $N$ 的序列與 $M$ 比操作
- 把某個區間 $[l,r]$ 都加上 $k$
- 把某個區間 $[l,r]$ 都乘上 $k$
- 詢問某個區間 $[l,r]$ 的和
### 懶人標記
現在不能把 $tag$ 留在原地了,因為我們不知道他要往下乘多少。
但我們一樣可以先把他暫時留在那裡,等有需要的時候在往下推。
- 區間不完整包含:
- 將該格 $\text{tag}$ 往下推 ($\text{push}$)
- 往下遞迴
- 區間完整包含:
- 修改該格的 $\text{val}$
- 暫時將修改的值記錄在 $\text{tag}$
如果能保證每次往下推都是 $\mathcal O(1)$,那就不會增加我們最後的時間複雜度。
我們把**往下推**,跟**重新計算自己這格**寫成兩個函式
**push**
```c=
void push(int lb, int rb, int idx) {
int len = (rb - lb + 1) / 2 ;
seg[idx * 2].value += seg[idx].tag * len ;
seg[idx*2+1].value += seg[idx].tag * len ;
seg[idx * 2].tag += seg[idx].tag ;
seg[idx*2+1].tag += seg[idx].tag ;
seg[idx].tag = 0 ;
}
```
**pull**
```c=
void pull(int idx) {
seg[idx].value = seg[idx*2].value + seg[idx*2+1].value ;
}
```
:::spoiler Code
```c=
const int N = 100000 ;
int sum[4*N], add[4*N], mul[4*N] ;
int MXN = 1 ;
void push(int lb, int rb, int idx) {
int l = idx*2, r = idx*2+1 ;
int len = (rb - lb + 1)/2 ;
sum[l] *= mul[idx] ; sum[r] *= mul[idx] ;
sum[l] += add[idx] * len ; sum[r] += add[idx] * len;
add[l] *= mul[idx] ; add[r] *= mul[idx] ;
add[l] += add[idx] ; add[r] += add[idx] ;
mul[l] *= mul[idx] ; mul[r] *= mul[idx] ;
add[idx] = 0 ;
mul[idx] = 1 ;
}
void pull(int idx) {
sum[idx] = sum[idx*2] + sum[idx*2+1] ;
}
void rangeAdd(int l, int r, int k, int lb, int rb, int idx) {
if(l <= lb && rb <= r) {
sum[idx] += k * (rb - lb + 1);
add[idx] += k ;
return ;
}
push(lb, rb, idx) ;
int mid = lb + rb >> 1 ;
if(l <= mid) rangeAdd(l, r, k, lb, mid, idx*2) ;
if(mid < r) rangeAdd(l, r, k, mid+1, rb, idx*2+1) ;
pull(idx) ;
}
void rangeMul(int l, int r, int k, int lb, int rb, int idx) {
if(l <= lb && rb <= r) {
sum[idx] *= k ;
add[idx] *= k ; // (x+2) * 4 = x*4 + 2*4
mul[idx] *= k ;
return ;
}
push(lb, rb, idx) ;
int mid = lb + rb >> 1 ;
if(l <= mid) rangeMul(l, r, k, lb, mid, idx*2) ;
if(mid < r) rangeMul(l, r, k, mid+1, rb, idx*2+1) ;
pull(idx) ;
}
int query(int l, int r, int lb, int rb, int idx) {
if(l <= lb && rb <= r)
return sum[idx] ;
push(lb, rb, idx) ;
int mid = lb + rb >> 1, ret = 0 ;
if(l <= mid) ret += query(l, r, lb, mid, idx*2) ;
if(mid < r) ret += query(l, r, mid+1, rb, idx*2+1) ;
// pull(idx) ;
return ret;
}
```
:::
## 題單
- [Contest](https://vjudge.net/contest/511280)
- Password: `111apcs`