Segment Tree

「線段樹」酷酷的資料結構

basic

現在的高中生人手一棵線段樹 BY NPSC裁判

一棵最簡單的線段樹支援以下幾種操作(對陣列)

  • 單點改值
  • 區間求值(區間和、區間最大最小)

示意圖

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

觀察一下這張圖,一個節點需要記錄什麼資訊?

  • 左界、右界
  • 左節點、右節點 (存指標)
  • 值 (圖中沒畫,儲存當前區間的值)

Why 左閉右開?

有人會覺得左閉右閉比較好,但根本邪教

  • 一個區間的右界就會是下一個接續區間的左界(不用+1-1)
  • size=rl
    (一樣不用+1-1)

所以開一個struct用來記錄每個節點:

struct seg { int l, r, mid, val; seg* ch[2] = {}; //two children };

How to construct a segment tree?

  • 如果
    l+1r
    , 繼續往下開點
  • 如果
    l+1=r
    , 結束

code:

struct seg{ int l, r, mid, val; seg *ch[2] = {}; seg(int _l, int _r):l(_l),r(_r){ mid = (l+r)/2; if(l == r-1){ return; } ch[0] = new seg(l, mid); //left child ch[1] = new seg(mid, r); //right child } };

蓋好線段樹了,如何單點改值?

單點改值
操作 : 將

aidx 改為
aidx+k
(也可以直接重設)

  • 如果
    idx<mid
    ,往左邊更新節點,否則
    idxmid
    ,往右邊更新節點
  • 如果
    l=r1
    ,代表到目標了,更新節點
  • 回來時要更新上面的節點(pull)

pull
當一個節點的的下方有更改過,就要將資訊(更新)上來,就是pull

code:

void pull(){ val = ch[0]->val + ch[1]->val; } void add(int idx, int k){ if(l == r-1){ val += k; return; } if(idx < mid){ ch[0]->add(idx, k); } else{ ch[1]->add(idx, k); } pull(); }

特別注意指標呼叫成員用->

區間求值

操作:給一區間

[a,b) ,求區間和

  • 如果
    a<mid
    ,代表所求區間有跨到左邊,向左邊找答案
  • 如果
    b>mid
    ,代表所求區間有跨到右邊,向右邊找答案
  • 如果
    [a,b)
    完全包含
    [l,r)
    ,回傳值
  • 統整當前所有答案並往上回傳(求區間和,所以要相加)

code:

int query(int a, int b){ if(a <= l && r <= b){ return val; } int ans = 0; if(a < mid){ ans += ch[0]->query(a, b); } if(b > mid){ ans += ch[1]->query(a, b); } return ans; }

大功告成!!

例題

Static Range Minimum Queries
Dynamic Range Sum Queries
Dynamic Range Minimum Queries
Range Xor Queries

Tip

善用pull與參照,直接在建構線段樹時設好初始值,不用n次單點改值

code:

struct seg{ int l, r, mid, val; seg *ch[2] = {}; seg(int _l, int _r, vector<int> &arr):l(_l),r(_r){ mid = (l+r)/2; if(l == r-1){ val = arr[l]; return; } ch[0] = new seg(l, mid, arr); ch[1] = new seg(mid, r, arr); pull(); } void pull(){ val = ch[0]->val + ch[1]->val; } };

Lazy Tag

懶人標記,俗稱懶標,用來記錄當前區間已經做過的操作(子節點尚未更新),等到需要往下修改或查詢的時候再將資訊推下去。

區間加值

操作:將區間

[a,b)內的每一項加
k

跟query(區間查詢)作法很像,當

[a,b)完全包含
[l,r)
:

  • lz
    改為
    lz+k
    (標記該區間每一項都
    +k
    )
  • val
    改為
    val+k×(r1)
  • 做完要
    pull

這時候query(區間和)往下的時候就要記得

push

Push ??
push 怎麼做?

  • 將左右子樹的
    val
    lz
    按照當前區間的
    lz
    修改,並重設
    lz

code:

struct seg{ int l, r, mid, val, lz = 0; seg *ch[2] = {}; seg(int _l, int _r, vector<int> &arr):l(_l),r(_r){ mid = (l+r)/2; if(l == r-1){ val = arr[l]; return; } ch[0] = new seg(l, mid, arr); ch[1] = new seg(mid, r, arr); pull(); } void pull(){ val = ch[0]->val + ch[1]->val; } void push(){ if(!lz) return; ch[0]->lz += lz; ch[1]->lz += lz; ch[0]->val += lz*(mid - l); ch[1]->val += lz*(r - mid); lz = 0; } void add(int a, int b, int k){ if(a <= l && r <= b){ val += k*(r-l); lz += k; return; } if(a < mid){ ch[0]->add(a, b, k); } if(b > mid){ ch[1]->add(a, b, k); } pull(); } int query(int a, int b){ if(a <= l && r <= b){ return val; } push(); int ans = 0; if(a < mid){ ans += ch[0]->query(a, b); } if(b > mid){ ans += ch[1]->query(a, b); } return ans; } };

以上是懶標的基礎運用,懶標的概念就是紀錄做了什麼事,等到要查詢再

push就好

例題

Range Update Queries
Subarray Sum Queries
寫完下面這題代表你會用懶標了
Range Updates and Sums

Persistent Segment Tree

持久化線段樹,神奇的資料結構,可以在修改後還能查詢先前的版本。

示意圖

藍色為第一版本,其他為後續。

如何開一顆新節點? 需有人對照(舊版本)

  • l,r,mid,val,lz
    都照抄
  • 往下改點(
    replace,add
    )或是更新懶標(
    push
    )的時候,將左節點設為參考原左節點而開成的新點,右邊亦然
  • 查詢時直接查,
    push
    記得開點

code:

seg(seg *root){ l = root->l; r = root->r; mid = (l+r)/2; if(l == r-1) return; ch[0] = root->ch[0]; ch[1] = root->ch[1]; pull(); } void pull(){ val = ch[0]->val + ch[1]->val; } void replace(int idx, int k){ if(l == r-1){ val = k; return; } if(idx < mid){ ch[0] = new seg(ch[0]); ch[0]->replace(idx, k); } else{ ch[1] = new seg(ch[1]); ch[1]->replace(idx, k); } pull(); }

constructor,query 一樣,故省略(此為無懶標版本)

補充:如果不想打1~9行的constructor可以這樣寫:

ch[0] = new seg(*ch[0]);

new 代表會新開一個空間,而新的seg的成員變數都會跟原本一樣

例題(裸題,直接套模板就行)

Range Queries and Copies

區間K-th問題

給一陣列

A 每次詢問
[l,r]
間第k大的是誰

想法:
如果每次詢問都滿足

l=0,那我們可以對值域開一棵線段樹,存每個區間的數字出現的次數,再來只要在線段樹上二分搜就行(找第一個前綴和會大於等於k的地方)。

l0 ?

接下來只要開一棵持久化線段樹,第

i個版本代表紀錄前
i
項的線段樹,在查詢
[l,r]
的時候用
r
(l1)
兩個版本相減就好

例題:
Range Kth Smallest

Code
#include<bits/stdc++.h> using namespace std; #define int long long #define all(x) x.begin(),x.end() struct seg{ int l, r, mid, val; seg *lc, *rc; seg(){}; seg(int _l, int _r):l(_l),r(_r),mid((l+r)>>1){ if(l == r-1){ val = 0; return; } lc = new seg(l, mid); rc = new seg(mid, r); pull(); } void pull(){ val = lc->val + rc->val; } void add(int idx, int k){ if(l == r-1){ val += k; return; } if(idx < mid){ lc = new seg(*lc); lc->add(idx, k); } else{ rc = new seg(*rc); rc->add(idx, k); } pull(); } }; int query(seg *a, seg *b, int k){ //find (k+1)-th if(a->l == a->r-1) return a->l; if(b->lc->val-a->lc->val > k) return query(a->lc, b->lc, k); else return query(a->rc, b->rc, k-(b->lc->val-a->lc->val)); } signed main(){ int n, q; cin >> n >> q; vector<int> arr(n), lisan(n); for(int i = 0; i < n; i++) cin >> arr[i]; lisan = arr; sort(all(lisan)); lisan.resize(unique(all(lisan))-lisan.begin()); vector<seg*> sg(n+1); sg[0] = new seg(0, lisan.size()); for(int i = 0; i < n; i++){ sg[i+1] = new seg(*sg[i]); sg[i+1]->add(lower_bound(all(lisan), arr[i])-lisan.begin(), 1); } for(int i = 0,a,b,k; i < q; i++){ cin >> a >> b >> k; cout << lisan[query(sg[a], sg[b], k)] << endl; } }

Li-chao tree

李超線段樹,一棵神奇的線段樹,維護一個集合(存線),給一個

x,求最小的
f(x)
(或最大)

目標:

  • 開一個線段樹在
    x
    軸,每個節點存一條線
    (val)
  • 保證
    query
    x1
    所經過的節點上的線
    f1(x),f2(x)fn(x)
    ,均代入
    x1
    min
    即為答案

假設一節點

[l,r)存著一線
a
(紅色),插入一線
b
(藍色),則有四種情況:

  1. fa(mid)<fb(mid)ma>mb

原節點維護之線不變,將

b插入右子節點更新(可能更優)

  1. fa(mid)<fb(mid)ma<mb

原節點維護之線不變,將

b插入左子節點更新(可能更優)

  1. fa(mid)>fb(mid)ma>mb

原節點維護之線改為

b,將
a
插入左子節點更新(可能更優)

  1. fa(mid)>fb(mid)ma<mb

原節點維護之線改為

b,將
a
插入右子節點更新(可能更優)

總結

  • f(mid)
    較小的線設為當前節點的
    val
  • 另一條線若斜率較大,往左邊更新,反之則往右邊
  • f(mid)
    較大的線完全在另一線上方,直接
    return
  • query
    時,經過的所有線
    (val)
    mid
    即為答案

code:

struct line{ int a, b; line(int _a, int _b):a(_a),b(_b){} int f(int x){ return a*x + b; } }; struct seg{ int l, r, mid; line val = line(0, 1e18); seg *ch[2] = {}; seg(int _l, int _r):l(_l),r(_r){ mid = (l+r)/2; if(l == r-1) return; ch[0] = new seg(l, mid); ch[1] = new seg(mid, r); } void insert(line k){ if(k.f(mid) < val.f(mid)) swap(k, val); if(l == r-1) return; if(k.a > val.a){ ch[0]->insert(k); } else{ ch[1]->insert(k); } } int query(int x){ if(l == r-1){ return val.f(x); } if(x < mid){ return min(val.f(x), ch[0]->query(x)); } else{ return min(val.f(x), ch[1]->query(x)); } } };

例題

Shopping in AtCoder store

解法

看完題目後,可以發現當我們考慮調整第

j物品的價格時,
先將
B
由大到小排序
如果想要讓
i
個人購買物品,那最大的
Pj=Bi+Cj

總獲利為
i×(Bi+Cj)
整理後得
iCj+iBi

但枚舉
i=[1,n]
會TLE,所以我們考慮可以插入
n
條直線
fi(x)=ix+iBi

接下來只要由給定的
Cj
,找到
maxi[1,n]iCj+iBi
就行了

code:
#include<bits/stdc++.h> using namespace std; #define int long long #define all(x) x.begin(),x.end() struct line{ int a, b; line(int _a, int _b):a(_a),b(_b){} int f(int x) { return a*x+b; } }; struct seg{ int l, r, mid; line val = line(0, 0); seg *lc = NULL, *rc = NULL; seg(int _l, int _r):l(_l),r(_r),mid((l+r)>>1){} void newNode(){ if(lc != nullptr || l == r-1)return; lc = new seg(l, mid); rc = new seg(mid, r); } void insert(line k){ if(k.f(mid) > val.f(mid)) swap(k, val); if(l == r-1)return; newNode(); if(k.a > val.a) rc->insert(k); else lc->insert(k); } int query(int idx){ if(l == r-1 || lc == nullptr) return val.f(idx); int ans = val.f(idx); if(idx < mid) ans = max(ans, lc->query(idx)); else ans = max(ans, rc->query(idx)); return ans; } }; signed main(){ int n, m; cin >> n >> m; vector<int> b(n), c(m); seg sg(0, 1e9); for(int i = 0; i < n; i++) cin >> b[i]; for(int i = 0; i < m; i++) cin >> c[i]; sort(all(b)); reverse(all(b)); for(int i = 0; i < n; i++) sg.insert(line(i+1,(i+1)*b[i])); for(int i = 0; i < m; i++) cout << sg.query(c[i]) << " "; cout << endl; }



Monster Game II

用李超線段樹把DP過程優化吧

解法

考慮

dp[i]為打倒第
i
隻怪獸所花費的最小總時間,且打完第
j
之後打
i
所花費的時間為
s[i]×f[j]
,所以
dp[i]=minj[0,i1]s[i]×f[j]+dp[j]

砸李超線段樹就過了

其他應用

Segment Add Get Min
插入直線變成插入線段怎麼辦? 其實原本插入的就是在該區間範圍的直線,所以只需要將要插的直線範圍拆成

logn個區間插入就好:

void insert(int a, int b,line k){ if(a <= l && r <= b) return insert_line(k); if(a < mid) lc->insert(a, b, k); if(b > mid) rc->insert(a, b, k); } void insert_line(line k){ if(k.f(mid) < ln.f(mid)) swap(k, ln); if(l == r-1){ val = min(ln.f(l),ln.f(r-1)); return; } if(k.a > ln.a) lc->insert_line(k); else rc->insert_line(k); }

當然也是可以把兩個函式合併。

誰說一定要直線??

李超線段樹適用於具有優超性的函數 只要函數

f1在某個地方被
f2
超越後,
f1
就不可能再反超。

Li-chao tree extended

在 2021 年 1 月,有人在 Codeforces 上發表了擴充版的李超線段樹。

可以處理以下兩種問題

問題一

有一個大小為

N的陣列
A
,還有
Q
筆操作:

  • 區間插入線,給定
    l,r,a,b
    ,做
    Ai=max(Ai,ai+b)
    ,
    i[l,r]
    in
    O(log2N)
  • 區間加上線,給定
    l,r,a,b
    ,做
    Ai+=ai+b
    ,
    i[l,r]
    in
    O(log2N)
  • 單點查詢值,給定
    i
    ,回傳
    Ai
    in
    O(logN)

問題二

有一個大小為

N的陣列
A
,還有
Q
筆操作:

  • 區間插入線,給定
    l,r,a,b
    ,做
    Ai=max(Ai,ai+b)
    ,
    i[l,r]
    in
    O(log2N)
  • 區間加單值,給定
    l,r,b
    ,做
    Ai+=b
    ,
    i[l,r]
    in
    O(log2N)
  • 區間取極值,給定
    l,r
    ,回傳
    maxi[l,r]Ai
    in
    O(logN)



一般李超線段樹是從可能成為答案的直線取

max,插入順序是具有交換律的。
而有了區間加線後,插入的線是在加線前加入還是在加線後插入就差很多了。

以原版的李超來舉例:

  • [1,N]
    一開始為空
  • [1,N]
    插入
    y=1
    ,然後在
    [2,N]
    加上
    y=
  • 很顯然
    [1,N]
    不包含在
    [2,N]
    所以
    [1,N]
    這個節點不受到加值影響
  • 查詢
    x=2
    時,會從
    [1,N]
    這個節點查到
    y=1
    這條線,而得到錯誤的結果
    1
    而不是


考慮這種情況,答案其實很明顯:
加線段的時候在經過

[l,r]這個節點時,就先把它上面的直線
f
清掉,並對
[l,mid]
[mid,r]
插入
f

這個情況下,加入的線段就不能直接推到葉子,要用懶標存在節點上,才可以保證複雜度依然是
O(Nlog2N)



問題二其實也差不多

但區間加直線會造成區間最大值的位置改變,若最大值發生在底下的直線,會造成懶人標記的不可預測性,因此作者才在第二種維護中將該操作降級成了單純的區間加值。
code 1

struct line{ int a = 0, b = 0; line(){} line(int _a, int _b):a(_a),b(_b){} int f(int x){ return b+a*x; } bool operator ==(const line &next)const{ return (a == next.a && b == next.b); } void operator +=(const line &next){ a += next.a; b += next.b; } }; struct seg{ int l, r, mid; line ln = line(0, 0), lazy = line(0, 0); seg *lc = NULL, *rc = NULL; seg(int _l, int _r):l(_l),r(_r),mid((l+r)>>1){} void newNode(){ if(l == r-1)return; lc = new seg(l, mid); rc = new seg(mid, r); } void push_lazy(){ if(lazy == line(0, 0))return; lc->lazy += lazy; rc->lazy += lazy; lc->ln += lazy; rc->ln += lazy; lazy = line(0, 0); } void push_line(){ lc->insert_line(ln); rc->insert_line(ln); ln = line(0, 0); } void insert_line(line k){ if(k.f(mid) < ln.f(mid)) swap(k, ln); if(l == r-1)return; if(lc == nullptr) newNode(); push_lazy(); if(k.a > ln.a) lc->insert_line(k); else rc->insert_line(k); } void insert(int a, int b,line k){ if(a <= l && r <= b) return insert_line(k); if(lc == nullptr) newNode(); push_lazy(); if(a < mid) lc->insert(a, b, k); if(b > mid) rc->insert(a, b, k); } void add(int a, int b, line k){ if(a <= l && r <= b){ lazy += k; ln += k; return; } if(lc == nullptr) newNode(); push_lazy(); push_line(); if(a < mid) lc->add(a, b, k); if(b > mid) rc->add(a, b, k); } int query(int idx){ if(l == r-1 || lc == nullptr) return ln.f(idx); push_lazy(); int ans = ln.f(idx); if(idx < mid) ans = max(ans, lc->query(idx)); else ans = max(ans, rc->query(idx)); return ans; } };

code 2

struct seg{ int l, r, mid; int val = 0,lazy = 0; seg *lc = NULL, *rc = NULL; line ln = line(0, 0); seg(int _l, int _r):l(_l),r(_r),mid((l+r)>>1){} void newNode(){ if(l == r-1)return; lc = new seg(l, mid); rc = new seg(mid, r); } void push_lazy(){ if(lazy == 0)return; lc->ln.b += lazy; rc->ln.b += lazy; lc->lazy += lazy; rc->lazy += lazy; lc->val += lazy; rc->val += lazy; lazy = 0; } void push_line(){ lc->insert_line(ln); rc->insert_line(ln); ln = line(0, INF); } void pull(){ val = max(max(lc->val,rc->val),max(ln.f(l),ln.f(r-1))); } void insert_line(line k){ if(k.f(mid) < ln.f(mid)) swap(k, ln); if(l == r-1){ val = max(ln.f(l),ln.f(r-1)); return; } if(lc == nullptr) newNode(); push_lazy(); if(k.a > ln.a) lc->insert_line(k); else rc->insert_line(k); pull(); } void insert(int a, int b,line k){ if(a <= l && r <= b) return insert_line(k); if(lc == nullptr) newNode(); push_lazy(); if(a < mid) lc->insert(a, b, k); if(b > mid) rc->insert(a, b, k); pull(); } void add(int a, int b, int k){ if(a <= l && r <= b){ lazy += k; ln.b += k; val += k; return; } if(lc == nullptr) newNode(); push_lazy(); push_line(); if(a < mid) lc->add(a, b, k); if(b > mid) rc->add(a, b, k); pull(); } int query(int a, int b){ if(a <= l && r <= b) return val; if(lc == nullptr) return max(ln.f(a), ln.f(b-1)); push_lazy(); int ans = max(ln.f(a), ln.f(b-1)); if(a < mid) ans = max(ans, lc->query(a, b)); if(b > mid) ans = max(ans, rc->query(a, b)); return ans; } };

因為RMQ要預處理區間

max所以有些細節要注意,有點麻煩。

例題

I hate Shortest Path Problem




參考資料:
[Tutorial] Li Chao Tree Extended
ioic2023 講義