# 線段樹 總論 [toc] ## 線段樹原理與Code 本篇是根據[YT REF](https://www.youtube.com/watch?v=Teu-rb4uVVM)做成的筆記 Segment tree supports monoid(半群) operation. Especially fast for those operation that combines efficiently. E.g, if we know $sum(a,b]$ and $sum(b,c]$, then $sum(a,c]$ can be calculated $O(1)$. However, if we know the mode(眾數) for the left and the mode for the right, it is difficult to know the mode for the whole. Such a query is then an example a segment tree can't handle. :::info 見 http://hzwer.com/8053.html 之 分块入门 9 by hzwer > 给出一个长为n的数列,以及n个操作,操作涉及询问区间的最小众数。 这是一道经典难题,其实可以支持修改操作... ::: We use `int t[]`, 1-indexed to store values. For node `i`, the left child is `2*i`, and right child is `2*i+1`. :::info This way of storage is also used in implementation of a heap. ::: 線段樹已經萬用到是一種思維概念或是程式範式,可以根據這種概念展開很多很瘋狂的技巧,延伸之概念可說是子孫繁多。 以下列舉之: * 樹套樹 (for higher dimension) * 可持久化 * 動態開點樹 * 區間查詢 * 極值 * 總和 * gcd/lcm * 各類位運算 * 第一個大於x的數的位置 * 也就是樹上二分 * 區間修改 * 小於<=x的元素個數,有這些做法: * wavelet tree (不是線段樹) * merge-sort-tree (特化的線段樹) * BIT之樹套樹 * Mo's algorithm + sqrt decomposition * 多項式修改 * 包含了設定,加總 總之,可以把線段數想像成或是核心,我們可以隨需求安裝上不同的插件(函數)。 比方說下面是一個最基礎的線段樹,僅支援範圍查詢: ```cpp= struct segtree { ll arr[maxn], ll t[maxn<<2]; build(int l,int r,int i) { if (l==r) { t[l] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); t[i] = t[i<<1] + t[i<<1|1]; } ll query(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; ll ans = 0; int mid = (l+r)>>1; if (jl <= mid) ans += query(jl,jr,l,mid,i<<1); if (jr > mid) ans += query(jl,jr,mid+1,r,i<<1|1); return ans; } ``` 可以想成 1. 安裝了`build`插件,使支援$O(n)$建樹 2. 安裝了`query`插件,使支援$O(\log n)$區間總和查詢 可以繼續安裝插件,使其具有更多功能,比方說,來一個單點累加的方法: ```cpp=21 void point_add(int idx,int val,int l,int r,int i) { t[i] += val*(r-l+1); if (l==r) return; int mid = (l+r)>>1; if (jl <= mid) point_add(idx,val,l,mid,i<<1); if (jr > mid) point_add(idx,val,mid+1,r,i<<1|1); t[i] = t[i<<1] + t[i<<1|1]; } }; ``` 目前,我們得到了與BIT完全一樣的功能。 此時碼量來到29行。但對於有志之士來說,既然訴諸了線段樹,豈會滿足於點累加?肯定還要來點區間累加才對。 如果套上lazy propagation等模板,代碼便來到58行。 :::spoiler ```cpp= ll arr[200001], t[200001<<2], lazy[200001<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void apply(ll val,int l,int r,int i) { t[i] += (r-l+1)*val; lazy[i] += val; } void down(int l,int r,int i) { if (!lazy[i]) return; int mid = (l+r)>>1; apply(lazy[i],l,mid,i<<1); apply(lazy[i],mid+1,r,i<<1|1); lazy[i] = 0; } void up(int i) { t[i] = t[i<<1]+t[i<<1|1]; } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } ll query(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; ll ans = 0; int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) ans+=query(jl,jr,l,mid,i<<1); if (jr > mid) ans+=query(jl,jr,mid+1,r,i<<1|1); up(i); return ans; } ll query(int jl,int jr) { return query(jl,jr,0,n-1,1); } void add(int jl,int jr,ll val,int l,int r,int i) { if (jl <= l && r <= jr) { apply(val,l,r,i); return; } int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) add(jl,jr,val,l,mid,i<<1); if (jr > mid) add(jl,jr,val,mid+1,r,i<<1|1); up(i); } void add(int jl,int jr,ll val) { add(jl,jr,val,0,n-1,1); } }; ``` ::: 我們應該釐清這些新變數的意義。 `lazy`: `lazy[i]`如果有值,表示 `t[i]` 的節點已經按照要求做完累加了,但是他的兒子還沒有。 `down`: 因為兒子的`t[i]`還沒有做累加,所以要下發給左右。這在查詢及修改時都要做。 `up`:左右修改完後,要合併給父節點。其實只在修改時要做(也就是38行可以省去的意思),因為查詢除了由過往`lazy`帶來的左右兒子的修改外,沒有別的修改,而關於這個修改所必須做的修正早就完成了。 此時碼量來到58行。但對於有志之士來說,既然訴諸了線段樹,豈會滿足於區間累加?肯定還要來點區間設定才對。 為了顧及兩種操作的順序及正確性,代碼便來到87行。 :::spoiler ```cpp= #define maxn 200001 int arr[maxn]; ll unset = 1145141919810; ll t[maxn<<2], slazy[maxn<<2], alazy[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { fill(alazy,alazy+(n<<2),0); fill(slazy,slazy+(n<<2),unset); build(0,n-1,1); } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } void down(int l,int r,int i) { int mid = (l+r)>>1; if (slazy[i] != unset) { sapply(slazy[i],l,mid,i<<1); sapply(slazy[i],mid+1,r,i<<1|1); slazy[i] = unset; } if (alazy[i]) { aapply(alazy[i],l,mid,i<<1); aapply(alazy[i],mid+1,r,i<<1|1); alazy[i] = 0; } } void up(int i) { t[i] = t[i<<1] + t[i<<1|1]; } void sapply(ll v,int l,int r,int i) { alazy[i] = 0; t[i] = (r-l+1)*v; slazy[i] = v; } void aapply(ll v,int l,int r,int i) { t[i] += (r-l+1)*v; alazy[i] += v; } ll query(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; ll ans = 0; int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) ans += query(jl,jr,l,mid,i<<1); if (jr > mid) ans += query(jl,jr,mid+1,r,i<<1|1); return ans; } ll query(int jl,int jr) { return query(jl,jr,0,n-1,1); } void add(int jl,int jr,int v,int l,int r,int i) { if (jl <= l && r <= jr) { aapply(v,l,r,i); return; } int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) add(jl,jr,v,l,mid,i<<1); if (jr > mid) add(jl,jr,v,mid+1,r,i<<1|1); up(i); } void add(int jl,int jr,int v) { add(jl,jr,v,0,n-1,1); } void set(int jl,int jr,int v,int l,int r,int i) { if (jl <= l && r <= jr) { sapply(v,l,r,i); return; } int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) set(jl,jr,v,l,mid,i<<1); if (jr > mid) set(jl,jr,v,mid+1,r,i<<1|1); up(i); } void set(int jl,int jr,int v) { set(jl,jr,v,0,n-1,1); } }; ``` ::: :::info 事實上,很多都是重複的代碼跟邏輯。花頂多一個禮拜的時間,不用死背就可以學得得心應手。 ::: ### 一種優雅的寫法: $f-$修改 > 以下的內容是筆者做夢想到的,但應該有人早就有跟我一樣的想法並發到網上了但隨便啦 上面的方法,汲汲營營的把`lazy`, `apply`分成兩種,一種照顧設定,一種照顧加值, 我們就要費很多腦細胞處理好他們的先後順序關西(因為設定操作會覆蓋掉加值操作),寫成了也不好解釋。 我們不妨處理好 > 將區間[l,r]的所有數字根據函數 `f` 做變換。 亦即,本來是 $$a_0, \dots, a_l, \dots, a_r, \dots, a_{n-1}$$ 現在是 $$a_0, \dots, f(a_l), \dots, f(a_r), \dots, a_{n-1}$$ 如果可以處理好這種query,並且當$f(x) = ax+b$ 也就是是一個 affine map 的時候,不就同時處理好了 1. 區間設定 2. 區間加值,甚至是 3. 區間乘 等操作了嗎? :::spoiler ```cpp= using ll = long long; struct poly { ll a, b; }; bool operator==(const poly& l,const poly& r) { return l.a == r.a && l.b == r.b; } poly id{1,0}; #define maxn 200001 int arr[maxn]; ll t[maxn<<2]; poly lz[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { fill(lz,lz+(n<<2),id); build(0,n-1,1); } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } void down(int l,int r,int i) { int mid = (l+r)>>1; if (lz[i] != id) { apply(lz[i],l,mid,i<<1); apply(lz[i],mid+1,r,i<<1|1); lz[i] = id; } } void up(int i) { t[i] = t[i<<1] + t[i<<1|1]; } void apply(poly v,int l,int r,int i) { t[i] = v.a*t[i] + v.b*(r-l+1); lz[i] = poly{v.a*lz[i].a,v.a*lz[i].b + v.b}; } ll query(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; ll ans = 0; int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) ans += query(jl,jr,l,mid,i<<1); if (jr > mid) ans += query(jl,jr,mid+1,r,i<<1|1); return ans; } ll query(int jl,int jr) { return query(jl,jr,0,n-1,1); } void modify(ll a,ll b,int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) { apply(poly{a,b},l,r,i); return; } int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) modify(a,b,jl,jr,l,mid,i<<1); if (jr > mid) modify(a,b,jl,jr,mid+1,r,i<<1|1); up(i); } void modify(ll a,ll b,int jl,int jr) { modify(a,b,jl,jr,0,n-1,1); } }; ``` ::: 原來的寫法,我們要謹慎處理他們的先後順序。現在,`先乘除後加減`自動幫我們做好了這部分。 注意程式碼更少但支援的操作反而更多! $f$還可以是別的多項式或是有理式,他們的懶標記累加起來都比較直觀(雖然麻煩)。 更可以是max,min,clamp。但是需要特別處理技巧,見 [A simple introduction to "Segment tree beats"](https://codeforces.com/blog/entry/57319) ## 離散化、二分搜及特別修改 ## 線段樹的擴展 ## 區間合併 為了回答特別的query,我們有必要特別練習設計狀態的方式, 使得區間容易合併出大答案。 現在,以leetcode的 [53. Maximum Subarray](https://leetcode.com/problems/maximum-subarray/description/) 為例。 ### 使用的線段樹介紹 Kadane's alg. 也許優雅,卻不夠潮。現在,我們要用同樣線性時間的線段樹手法解開這題。 為此,我們要構建一棵支援「單點修改」和「區間查詢」的線段樹。在一次查詢中,我們得同時回答: * 從 `jl` 開始的最大前綴和 * 以 `jr` 結尾的最大後綴和 * 區間 `[jl, jr]` 的最大子陣列和 * 區間 `[jl, jr]` 的總和 一旦構建好這棵線段樹,這題就變得相當簡單。 甚至(雖然題目沒問),我們還能$O(\log n)$改值。配上懶標更能範圍改。 `build` 方法和 `set` 方法都是老生常談了 困難的部分是如何設計一個`query`,能從左右子區間的資訊構造出父節點的資訊,也就是 `up` 與 `q` 的功能。 --- ### `up` 與 `q` 方法 圈內人士都知道,`up`(有些人也叫 `pull`)是用來把左右子節點的資訊合併成父節點的資訊。 我們在樹上儲存: ```cpp // sum : 區間總和 // pre : 最大前綴和 // suf : 最大後綴和 // c : 最大子陣列和 ll sum[maxn<<2], pre[maxn<<2], suf[maxn<<2], c[maxn<<2]; ``` 這些資訊都是為了計算 `c`(最大子陣列和)所必需的。 為什麼需要 `pre` 和 `suf`? 你不妨可以做做看 ![image](https://hackmd.io/_uploads/HJ7c5pOFeg.png) 或是看[這個解答](https://leetcode.com/problems/maximum-subarray/solutions/1595195/c-python-7-simple-solutions-w-explanation-brute-force-dp-kadane-divide-conquer/) 就可以略知一二,尤其是第七個解, ```cpp void up(int l,int r,int i) { int il = i<<1, ir = i<<1|1; sum[i] = sum[il] + sum[ir]; pre[i] = max(pre[il], sum[il] + pre[ir]); suf[i] = max(suf[ir], suf[il] + sum[ir]); c[i] = max({c[il], c[ir], suf[il] + pre[ir]}); // 特別是這一行 } ``` 這就是原因。 :::info 可以想一下,在做數學歸納法的時候,有時會遇到想證明的東西無法從過去的東西推出來的情況。 這時候我們就要自己設計每個狀態有哪些命題為真。 像是 ![image](https://hackmd.io/_uploads/B14PnpuKex.png) 就用了五個命題。 好處就在於,條件多了,有興趣的命題就變得好證,但不好就在於 你要把多的這些不感興趣的條件給證了,才能給下個狀態用。 在這邊,也是一樣的意思。 意思就是,注意到數學歸納法跟遞回的密不可分的關西。 ::: --- #### 為什麼還需要 `sum`? 不能直接寫成這樣嗎? ```cpp void up(int l,int r,int i) { int il = i<<1, ir = i<<1|1; pre[i] = max(pre[il], pre[il] + pre[ir]); suf[i] = max(suf[ir], suf[il] + suf[ir]); c[i] = max({c[il], c[ir], suf[il] + pre[ir]}); } ``` 答案:不行。考慮以下例子: `1, -114514, -1919, 1, 0, 0` 在 `[0,2]` 的最大前綴和是 1,在 `[3,5]` 的最大前綴和也是 1,但整體 `[0,5]` 的最大前綴和不是 1+1,如果這樣計算就錯了。 簡單說,我們不能只拿左邊的部分前綴和,然後跳過 `-114514`。 我們需要 `sum` 來判斷是否應該把整個左區間都包含進來。 這就是為什麼它和純粹的 Divide & Conquer 解法不同,在這裡 `sum` 是必須的。 類似地,`q` 方法也是基於同樣的想法:它回傳子區間的所有必要資訊,方便合併出父節點的結果。 --- ### TC 建樹需要 $O(n)$,因為 `up` 是 $O(1)$。 查詢最大子陣列和則需要 $O(\log n)$。 事實上,樹建好之後,答案就儲存在 `c[1]`,因此只要直接回傳 `c[1]`,查詢就變成 $O(1)$。 所以整體:TC = SC = $O(n)$。 --- :::spoiler ac code ```cpp [] using ll = long long; #define maxn 100001 int arr[maxn]; // sum : 區間總和 // pre : 最大前綴和 // suf : 最大後綴和 // c : 最大子陣列和 ll sum[maxn<<2], pre[maxn<<2], suf[maxn<<2], c[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { build(0, n-1, 1); } void up(int l,int r,int i) { int il = i<<1, ir = i<<1|1; sum[i] = sum[il] + sum[ir]; pre[i] = max(pre[il], sum[il] + pre[ir]); suf[i] = max(suf[ir], suf[il] + sum[ir]); c[i] = max({c[il], c[ir], suf[il] + pre[ir]}); } void build(int l,int r,int i) { if (l==r) { pre[i] = suf[i] = c[i] = arr[l]; sum[i] = arr[l]; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(l,r,i); } void set(int idx,int v,int l,int r,int i) { if (l==r) { sum[i] = v; pre[i] = suf[i] = c[i] = v; return; } int mid = midpoint(l,r); if (idx <= mid) set(idx,v,l,mid,i<<1); else set(idx,v,mid+1,r,i<<1|1); up(l,r,i); } // 回傳 pre, suf, c, sum tuple<ll,ll,ll,ll> q(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return {pre[i],suf[i],c[i],sum[i]}; int mid = midpoint(l,r); if (jr<=mid) return q(jl,jr,l,mid,i<<1); if (jl>mid) return q(jl,jr,mid+1,r,i<<1|1); auto [prel,sufl,cl,sl] = q(jl,jr,l,mid,i<<1); auto [prer,sufr,cr,sr] = q(jl,jr,mid+1,r,i<<1|1); return { max(sl+prer,prel), max(sufr, sufl+sr), max({cl,cr,sufl+prer}), sl+sr }; } }; class Solution { public: int maxSubArray(vector<int>& nums) { int n = nums.size(); for (int i=0;i<n;i++) arr[i] = nums[i]; segtree seg(n); return get<2>(seg.q(0,n-1,0,n-1,1)); // 第 3 個元素 (c) 是最大子陣列和 } }; ``` ::: --- ### 備註 潮就潮在,使用這個完全相同的線段樹,然後改一改 `main()`,可以直接 AC CSES 上的這四題: * [Prefix Sum Queries](https://cses.fi/problemset/task/2166) * [Range Interval Queries](https://cses.fi/problemset/task/3163) * [Subarray Sum Queries II](https://cses.fi/problemset/task/3226) * [Dynamic Range Sum Queries](https://cses.fi/problemset/task/1648) ### 再做一題 以[P2572 [SCOI2010] 序列操作](https://www.luogu.com.cn/problem/P2572)這提難題為例(題解來自影片), 現在我們要要求區間不能有任何的0,所以合併時要再考慮一下,並不是總是可以貪心。 不如直接看代碼 ::: spoiler code ```cpp= #include <bits/stdc++.h> #define maxn 100002 using namespace std; // orig int arr[maxn]; // data int sum[maxn<<2]; int pre0[maxn<<2], suf0[maxn<<2], c0[maxn<<2]; int pre1[maxn<<2], suf1[maxn<<2], c1[maxn<<2]; // lazy int set_to[maxn<<2], unset = -1; bool flip[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { // 注意1 : 直接sizeof(set_to) 全部清零。 // 這點小時間不足掛齒 memset(set_to,unset,sizeof(set_to)); memset(flip,0,sizeof(flip)); build(0,n-1,1); } void build(int l,int r,int i) { // if (l==r) { sum[i] = arr[l]; pre0[i] = suf0[i] = c0[i] = (arr[l]==0); pre1[i] = suf1[i] = c1[i] = (arr[l]==1); return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(l,r,i); } void down(int l,int r,int i) { int mid = midpoint(l,r); // 注意2 : 先處理set再處理flip if (set_to[i] != unset) { apply_set(set_to[i],l,mid,i<<1); apply_set(set_to[i],mid+1,r,i<<1|1); set_to[i] = unset; } if (flip[i]) { apply_flip(l,mid,i<<1); apply_flip(mid+1,r,i<<1|1); flip[i] = false; } } // 注意4 : first encounterence of 三個變數的up方法 void up(int l,int r,int i) { sum[i] = sum[i<<1] + sum[i<<1|1]; int n = (r-l+1), mid = midpoint(l,r); // 注意5 : 用變數所以你不會眼花撩亂 int nl = mid-l+1, nr = r-mid; int lv = i<<1, rv = i<<1|1; // 注意6 : 可以換成 pre0[i] = (c0[lv] < nl)?pre0[lv]:nl+pre0[rv]; // 但這種寫法較雅觀 pre0[i] = (c0[lv] < nl)?pre0[lv]:pre0[lv]+pre0[rv]; suf0[i] = (c0[rv] < nr)?suf0[rv]:suf0[lv]+suf0[rv]; c0[i] = max({c0[lv],c0[rv],suf0[lv]+pre0[rv]}); pre1[i] = (c1[lv] < nl)?pre1[lv]:pre1[lv]+pre1[rv]; suf1[i] = (c1[rv] < nr)?suf1[rv]:suf1[lv]+suf1[rv]; c1[i] = max({c1[lv],c1[rv],suf1[lv]+pre1[rv]}); } void apply_set(int v,int l,int r,int i) { sum[i] = (v==1?r-l+1:0); pre0[i] = suf0[i] = c0[i] = (v==0?r-l+1:0); pre1[i] = suf1[i] = c1[i] = (v==1?r-l+1:0); set_to[i] = v; flip[i] = false; } void apply_flip(int l,int r,int i) { sum[i] = (r-l+1)-sum[i]; swap(pre0[i],pre1[i]), swap(suf0[i],suf1[i]), swap(c0[i],c1[i]); flip[i] = !flip[i]; // 注意7 : 不用寫 // if (set_to[i] != unset) set_to[i] ^= 1; } void set(int v,int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) { apply_set(v,l,r,i); return; } int mid = midpoint(l,r); down(l,r,i); if (jl <= mid) set(v,jl,jr,l,mid,i<<1); if (jr > mid) set(v,jl,jr,mid+1,r,i<<1|1); up(l,r,i); } void fl(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) { apply_flip(l,r,i); return; } int mid = midpoint(l,r); down(l,r,i); if (jl <= mid) fl(jl,jr,l,mid,i<<1); if (jr > mid) fl(jl,jr,mid+1,r,i<<1|1); up(l,r,i); } int qsum(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return sum[i]; int ans = 0, mid = midpoint(l,r); down(l,r,i); if (jl <= mid) ans += qsum(jl,jr,l,mid,i<<1); if (jr > mid) ans += qsum(jl,jr,mid+1,r,i<<1|1); // 注意8 : 不用 up 想想為啥 return ans; } tuple<int,int,int> qc(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return {pre1[i],suf1[i],c1[i]}; int mid = midpoint(l,r); down(l,r,i); if (jr <= mid) return qc(jl,jr,l,mid,i<<1); if (jl > mid) return qc(jl,jr,mid+1,r,i<<1|1); auto [lpre1,lsuf1,lc1] = qc(jl,jr,l,mid,i<<1); auto [rpre1,rsuf1,rc1] = qc(jl,jr,mid+1,r,i<<1|1); // 注意9 : 為啥是 max, min? int nl = mid-max(l,jl)+1, nr = min(r,jr)-mid; return { (lc1 < nl)?lpre1:lpre1+rpre1, (rc1 < nr)?rsuf1:lsuf1+rsuf1, max({lc1,rc1,lsuf1+rpre1}) }; } }; int main() { cin.tie(0), ios::sync_with_stdio(0); int n,m; cin >> n >> m; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<m;i++) { int t,l,r; cin >> t >> l >> r; if (t==0 || t==1) seg.set(t,l,r,0,n-1,1); else if (t==2) seg.fl(l,r,0,n-1,1); else if (t==3) cout << seg.qsum(l,r,0,n-1,1) << '\n'; else cout << get<2>(seg.qc(l,r,0,n-1,1)) << '\n'; } } ``` ::: ## 動態開點及主席樹(Persistent Segment Tree) ## w/ Line sweep ## 各種變形