# Segment Tree 線段樹 ## 1.單點加值,區間查詢最大值 ```clike= #include <bits/stdc++.h> using namespace std; struct node{ int l, r, data; // 該區間的最大值 }T[40000]; void build(int id, int l, int r, int *data){ T[id].l = l; T[id].r = r; if(l == r){ T[id].data = data[l]; }else{ int mid = (l+r) >> 1; build(id*2+1, l, mid, data); build(id*2+2, mid+1, r, data); T[id].data = max(T[id*2+1].data, T[id*2+2].data); } } void add(int id, int x, int val){ // 單點加值 data[x] += val if(T[id].l == T[id].r){ T[id].data += val; }else{ int mid = (T[id].l+T[id].r) >> 1; if(x<=mid) add(id*2+1, x, val); else add(id*2+2, x, val); T[id].data = max(T[id*2+1].data, T[id*2+2].data); } } int query(int id, int l, int r){ // 區間查詢 maximum of data[l..r] if(l == T[id].l && r == T[id].r){ return T[id].data; }else{ int mid = (T[id].l+T[id].r) >> 1; if(r<=mid) return query(id*2+1, l, r); else if(l>mid) return query(id*2+2, l, r); else return max(query(id*2+1, l, mid), query(id*2+2, mid+1, r)); } } int main(void){ int data[10] = {2,1,3,1,2,4,1,3,5,2}; build(0, 0, 9, data); for(int i=0;i<10;i++) for(int j=i;j<10;j++) cout<<query(0,i,j)<<(j==9?'\n':' '); // query data[i..j] add(0, 3, 2); // data[3] += 2 } ``` ## 2.區間加值,單點查詢數值 ```clike= #include <bits/stdc++.h> using namespace std; struct node{ int l, r, data; // 完整覆蓋當前區間的線段數量 }T[40000]; void build(int id, int l, int r){ T[id].l = l; T[id].r = r; T[id].data = 0; if(l == r){ }else{ int mid = (l+r) >> 1; build(id*2+1, l, mid); build(id*2+2, mid+1, r); } } void add(int id, int l, int r, int val){ // 區間加值 data[l..r] += val if(l == T[id].l && r == T[id].r){ T[id].data += val; }else{ int mid = (T[id].l+T[id].r) >> 1; if(r<=mid) add(id*2+1, l, r, val); else if(l>mid) add(id*2+2, l, r, val); else add(id*2+1, l, mid, val), add(id*2+2, mid+1, r, val); } } int query(int id, int x){ // 單點查詢 data[x] if(T[id].l == T[id].r){ return T[id].data; }else{ int mid = (T[id].l+T[id].r) >> 1; if(x<=mid) return T[id].data + query(id*2+1, x); else return T[id].data + query(id*2+2, x); } } int main(void){ build(0, 0, 9); add(0, 1, 4, 2); // data[1..4] += 2 for(int i=0;i<10;i++) cout<<query(0,i)<<" "; cout<<endl; add(0, 3, 7, 1); // data[3..7] += 1 for(int i=0;i<10;i++) cout<<query(0,i)<<" "; cout<<endl; add(0, 1, 9, 3); // data[1..9] += 3 for(int i=0;i<10;i++) cout<<query(0,i)<<" "; cout<<endl; add(0, 2, 5, -2); // data[2..5] -= 2 for(int i=0;i<10;i++) cout<<query(0,i)<<" "; cout<<endl; } ``` ## 3.區間修改 (加值/改值),區間查詢最大值 - 每個線段樹上的區間,可能被加值,也可能被改值 - 每次遍歷線段樹時,保證走到的區間是正確的 - 加值或是改值就暫時放在那個節點上 - 用到的時候,處理改值和加值的 tag 並正確維護 data 值,並把 tag 往下推一層 - 往下推一層 `push 的設計` (精隨就是懶?) - (1)如果同時存在改值和加值,則先處理改值再處理加值 - (2)若把改值推給後代,強迫後代的加值 tag 變成 0 - (3)若把加值推給後代,只更新後代的加值 tag ```clike= #include <bits/stdc++.h> using namespace std; // 區間修改 (加值/改值),區間查詢 (max) struct node{ int chg1, chg2, data; // chg1 改值 , chg2 加值 int l, r; }T[40000]; void build(int id, int l, int r, int *data){ T[id].l = l; T[id].r = r; T[id].chg1 = 0; T[id].chg2 = 0; if(l == r){ T[id].data = data[l]; }else{ int mid = (l+r) >> 1; build(id*2+1, l, mid, data); build(id*2+2, mid+1, r, data); T[id].data = max(T[id*2+1].data, T[id*2+2].data); } } void push(int id){ // 神奇 lazy tag 函數 if(T[id].l == T[id].r){ T[id].data = T[id].chg2 + ( T[id].chg1 ? T[id].chg1 : T[id].data ); T[id].chg1 = 0; T[id].chg2 = 0; return; } int ll = id*2+1, rr = id*2+2; if(T[id].chg1 != 0){ T[id].data = T[id].chg1; T[ll].chg1 = T[id].chg1, T[ll].chg2 = 0; T[rr].chg1 = T[id].chg1, T[rr].chg2 = 0; T[id].chg1 = 0; } if(T[id].chg2 != 0){ T[id].data += T[id].chg2; T[ll].chg2 += T[id].chg2; T[rr].chg2 += T[id].chg2; T[id].chg2 = 0; } } void add(int id, int l, int r, int val){ // 區間加值 push(id); if(l == T[id].l && r == T[id].r){ T[id].chg2 += val; }else{ int mid = (T[id].l+T[id].r) >> 1; if(r<=mid) add(id*2+1, l, r, val); else if(l>mid) add(id*2+2, l, r, val); else add(id*2+1, l, mid, val), add(id*2+2, mid+1, r, val); push(id*2+1); // revised push(id*2+2); // revised T[id].data = max(T[id*2+1].data, T[id*2+2].data); } } void chg(int id, int l, int r, int val){ // 區間改值 push(id); if(l == T[id].l && r == T[id].r){ T[id].chg1 = val; }else{ int mid = (T[id].l+T[id].r) >> 1; if(r<=mid) chg(id*2+1, l, r, val); else if(l>mid) chg(id*2+2, l, r, val); else chg(id*2+1, l, mid, val), chg(id*2+2, mid+1, r, val); push(id*2+1); // revised push(id*2+2); // revised T[id].data = max(T[id*2+1].data, T[id*2+2].data); } } int query(int id, int l, int r){ // 區間查詢 push(id); if(l == T[id].l && r == T[id].r){ return T[id].data; }else{ int mid = (T[id].l+T[id].r) >> 1; if(r<=mid) return query(id*2+1, l, r); else if(l>mid) return query(id*2+2, l, r); else return max(query(id*2+1,l,mid), query(id*2+2,mid+1,r)); } } void show(int *data){ for(int i=0;i<16;i++) cout<<query(0, i, i)<<" "; cout<<endl; } int main(){ int data[16] = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4}; build(0, 0, 15, data); show(data); add(0, 3, 8, 4); show(data); chg(0, 1, 10, 9); show(data); add(0, 4, 14, -2); show(data); chg(0, 7, 12, 5); show(data); add(0, 2, 13, -1); show(data); } ```