Segment Tree (線段樹)

Usage

  • 主要用在區間序列的問題,區間答案要有 結合性 ,並依 結合性 建樹並將區間答案合併
  • 可以用陣列實現 (記憶體較小、較快)
  • 可以用指標實現 (方便實作動態開點、持久化)
  • 主要用於 RMQ (Range Minimum/Maximim Query) 區間最值查詢、序列 (帶修改) 區間操作

Introduction

- - - - - - - -
index 1 2 3 4 5 6 7
value 5 2 3 4 7 6 1

以一個序列 (1-based) 為例,
以其節點為對應的區間,
可以建構出以下的樹。

graph main {
    "[1, 7]" -- {"[1, 4]" "[5, 7]"}
    "[1, 4]" -- {"[1, 2]" "[3, 4]"}
    "[5, 7]" -- {"[5, 6]" "[7, 7]"}
    "[1, 2]" -- {"[1, 1]" "[2, 2]"}
    "[3, 4]" -- {"[3, 3]" "[4, 4]"}
    "[5, 6]" -- {"[5, 5]" "[6, 6]"}
    "[7, 7]" -- {"n1"[label=""] "n2"[label=""]}
}

而區間對應的編號如下:

graph main {
    "1" -- {"2" "3"}
    "2" -- {"4" "5"}
    "3" -- {"6" "7"}
    "4" -- {"8" "9"}
    "5" -- {"10" "11"}
    "6" -- {"12" "13"}
    "7" -- {"n1"[label=""] "n2"[label=""]}
}

前面的為節點編號,
後面的為對應區間。
可以藉由觀察,
可發現對於一個編號為 id 的節點有以下特質:

  • 左子樹節點編號為 \(id\times 2\)
  • 右子樹節點編號為 \(id\times 2+1\)
  • 左子樹與右子樹的區間答案可結合成該節點對應區間的答案

根據以上特質:
可以用陣列實做,
並完成查詢、更新的操作。
\(n\)\(2\) 的冪次時,
最多可能會用到 \(4\times n\) 個空間,
\(\implies\) 空間複雜度 \(O(n)\)

Build

通常以遞迴來實做建樹,
\(id=1,l=1,r=N\) 開始,
\(m=\frac{l+r}{2}\) 分割兩個區間 \([l,m]\)\([m+1,r]\)
先遞迴左子樹再遞迴右子樹,
最後更新答案。
時間複雜度 \(O(n)\)

#define lc(x) (x << 1) #define rc(x) (x << 1 | 1) constexpr int N = 2e5 + 1; int seg[N << 2]; void build(int id, int l, int r) { if (l == r) { // 葉節點 cin >> seg[id]; // seg[id] = arr[l]; return; } int m = (l + r) >> 1; build(lc(id), l, m); // 遞迴建構左子樹 build(rc(id), m + 1, r); // 遞迴建構右子樹 seg[id] = max(seg[lc(id)], seg[rc(id)]); // 更新答案 }

Query

不必把每個區間存在每一個節點上,
只需要在需要的時候模擬區間即可。
\(id=1,l=1,r=N,ql=查詢左界,qr=查詢右界\) 開始遞迴,
時間複雜度 \(O(\log n)\)
有兩種方式,
喜歡哪種因人而異。

Method 1.

欲查詢 \([l,r]\) 的資料,
從根節點開始遞迴,
\(m=\frac{l+r}{2}\)
可分為以下三種狀況:

  1. \(qr\le m\) (只在左半邊),只要往左半邊遞迴找答案即可。
  2. \(ql>m\) (只在右半邊),只要往右半邊遞迴找答案即可。
  3. \(ql\le m\)\(qr>m\) (跨越左右兩邊),分別遞迴左右兩邊,更新答案並回傳。
#define lc(x) (x << 1) #define rc(x) (x << 1 | 1) constexpr int N = 2e5 + 1; int seg[N << 2]; int qry(int id, int l, int r, int ql, int qr) { if (ql == l && qr == r) // 找到對應區間 return seg[id]; int m = (l + r) >> 1; if (qr <= m) // 只在左半邊 return qry(lc(id), l, m, ql, qr); if (ql > m) // 只在右半邊 return qry(rc(id), m + 1, r, ql, qr); return max(qry(lc(id), l, m, ql, m), qry(rc(id), m + 1, r, m + 1, qr)); }

Method 2.

一樣從根節點開始遞迴,
\(m=\frac{l+r}{2}\)
可分為兩種狀況:

  1. \(ql <= m\) (跨越左半邊,也可能只在左半邊),查詢左邊答案。
  2. \(qr > m\) (跨越右半邊,也可能只右半邊),查詢右邊答案。

最後再把兩邊答案結合在一起並回傳。

#define lc(x) (x << 1) #define rc(x) (x << 1 | 1) constexpr int N = 2e5 + 1, inf = 1e9; int seg[N << 2]; int qry(int id, int l, int r, int ql, int qr) { if (ql <= l && r <= qr) // 區間完全包含在查詢區間 return seg[id]; int m = (l + r) >> 1, ret = -inf; if (ql <= m) // 跨越左半邊 ret = max(ret, qry(lc(id)), l, m, ql, qr); if (qr > m) // 跨越右半邊 ret = max(ret, qry(rc(id)), m + 1, r, ql, qr); return ret;

Update

單點更新,
\(id=1,l=1,r=N,x=更新目標,v=更新值\) 分左右區間遞迴處理。
更新完之後要記得對經過的每個節點,
檢查其值是否需要更新。
時間複雜度 \(O(\log n)\)

#define lc(x) (x << 1) #define rc(x) (x << 1 | 1) constexpr int N = 2e5 + 1; int seg[N << 2]; void upd(int id, int l, int r, int x, int v) { if (l == r) { // 找到對應節點 seg[id] = v; return; } int m = (l + r) >> 1; if (x <= m) upd(lc(id), l, m, x, v); // 遞迴左子樹 else upd(rc(id), m + 1, r, x, v); // 遞迴右子樹 seg[id] = max(seg[lc(id)], seg[rc(id)]); // 更新 }

Pointer Version

另一種實現方式,
僅供參考。

Code

constexpr int N = 2e5 + 1, inf = 1e9; struct node { node *l, *r; int v; void pull() {v = max(l ? l->v : -inf, r ? r->v : -inf);} node(int vv = 0): v(vv), l(nullptr), r(nullptr) {} } *root = nullptr; void build(node* o, int l, int r) { if (l == r) {cin >> o->v; return;} int m = (l + r) >> 1; build(o->l = new node, l, m); build(o->r = new node, m + 1, r); o->pull(); } void upd(node* o, int l, int r, int x, int v) { if (l == r) {o->v = v; return;} int m = (l + r) >> 1; if (x <= m) upd(o->l, l, m, x, v); else upd(o->r, m + 1, r, x, v); o->pull(); } int qry(node* o, int l, int r, int ql, int qr) { if (l == ql && r == qr) return o->v; int m = (l + r) >> 1; if (qr <= m) return qry(o->l, l, m, ql, qr); if (ql > m) return qry(o->r, m + 1, r, ql, qr); return max(qry(o->l, l, m, ql, m), qry(o->r, m + 1, r, m + 1, qr)); } // build(root = new node, 1, N); // upd(root, 1, N, x, v); // qry(root, 1, N, ql, qr);

Lazy Propagation

  • 延遲標記,也有人說 Lazy Tag (懶惰標記)
  • 主要用於 區間更新 或著 降低時間或空間複雜度,更新時不把答案推到葉節點,只更新到對應區間,而是於再次經過該節點時才將標記下放
  • 下放標記有優先順序,要考慮好先做哪個

Code

// 區間更新 using ll = long long; constexpr int N = 2e5 + 1; ll seg[N << 2], tag[N << 2]; #define lc(x) (x << 1) #define rc(x) (x << 1 | 1) void pull(int id) { // pull up seg[id] = seg[lc(id)] + seg[rc(id)]; } void push(int id, int l, int r) { // push down if (tag[id]) { int m = (l + r) >> 1; tag[lc(id)] += tag[id], tag[rc(id)] += tag[id]; // 下放標記 seg[lc(id)] += tag[id] * (m - l + 1); // 更新左子樹區間和 seg[rc(id)] += tag[id] * (r - m); // 更新右子樹區間和 tag[id] = 0; // 歸 0 } } void upd(int id, int l, int r, int ql, int qr, ll v) { if (ql == l && qr == r) { // 找到對應區間 tag[id] += v; // 打上標記 seg[id] += v * (r - l + 1); // 更新該節點區間和 return; } push(id, l, r); // 下放標記 int m = (l + r) >> 1; if (qr <= m) upd(lc(id), l, m, ql, qr, v); else if (ql > m) upd(rc(id), m + 1, r, ql, qr, v); else upd(lc(id), l, m, ql, m, v), upd(rc(id), m + 1, r, m + 1, qr, v); pull(id); } ll qry(int id, int l, int r, int ql, int qr) { if (ql == l && qr == r) return seg[id]; push(id, l, r); // 經過時下放標記 int m = (l + r) >> 1; if (qr <= m) return qry(lc(id), l, m, ql, qr); if (ql > m) return qry(rc(id), m + 1, r, ql, qr); return qry(lc(id), l, m, ql, m) + qry(rc(id), m + 1, r, m + 1, qr); }

Dynamic Allocation

  • 「動態開點」用於值域範圍大,直接建構會 REMLE 的情況,實際上根本不需要這麼多節點,用於降低空間複雜度
  • 每次更新,動態開點 \(\log n\) 個點,可以搭配延遲標記降低空間複雜度
  • 查詢時,如果該點沒有設值,則返回 不影響答案的值 (e.g. 區間和則回傳 \(0\))
  • 可以用陣列實現,但指標比較好實作

Code

// 區間和 & 動態開點 struct node { node *l, *r; int v; void pull() { // 更新 v = (l ? l->v : 0) + (r ? r->v : 0); } node(int vv = 0): v(vv), l(nullptr), r(nullptr) {} } *root = nullptr; void upd(node*& o, int l, int r, int x, int v) { // 因為是動態開點,傳入的指標要加參照 if (!o) o = new node; // 動態開點 if (l == r) {o->v += v; return;} int m = (l + r) >> 1; if (x <= m) upd(o->l, l, m, x, v); else upd(o->r, m + 1, r, x, v); o->pull(); } int qry(node* o, int l, int r, int ql, int qr) { if (!o) return 0; // 沒有這個點,返回不影響答案的值 if (ql <= l && r <= qr) return o->v; // 區間完全包含在查詢區間 int m = (l + r) >> 1, ret = 0; if (ql <= m) ret += qry(o->l, l, m, ql, qr); if (qr > m) ret += qry(o->r, m + 1, r, ql, qr); return ret; }

2-Dimensional Segment Tree

  • 線段樹套線段樹,二維的每個節點存 一維線段樹

Code

constexpr int inf = 1e9; #define lc(x) (x << 1) #define rc(x) (x << 1 | 1) int N, M; // N : 直的 max, M : 橫的 max struct seg { vector<int> st; void pull(int); void merge(const seg&, const seg&, int, int, int); void build(int, int, int); void upd(int, int, int, int, int); int qry(int, int, int, int, int); seg(int size): st(size << 2 | 1) {} }; void seg::pull(int id) { st[id] = max(st[lc(id)], st[rc(id)]); } void seg::merge(const seg& a, const seg& b, int id = 1, int l = 1, int r = M) { st[id] = max(a.st[id], b.st[id]); if (l == r) return; int m = (l + r) >> 1; merge(a, b, lc(id), l, m), merge(a, b, rc(id), m + 1, r); } void seg::build(int id = 1, int l = 1, int r = M) { if (l == r) {cin >> st[id]; return;} int m = (l + r) >> 1; build(lc(id), l, m), build(rc(id), m + 1, r); pull(id); } void seg::upd(int x, int v, int id = 1, int l = 1, int r = M) { if (l == r) {st[id] = v; return;} int m = (l + r) >> 1; if (x <= m) upd(x, v, lc(id), l, m); else upd(x, v, rc(id), m + 1, r); pull(id); } int seg::qry(int ql, int qr, int id = 1, int l = 1, int r = M) { if (ql <= l && r <= qr) return st[id]; int m = (l + r) >> 1, ret = -inf; if (ql <= m) ret = max(ret, qry(ql, qr, lc(id), l, m)); if (qr > m) ret = max(ret, qry(ql, qr, rc(id), m + 1, r)); return ret; } struct segseg { vector<seg> st; void pull(int, int); void build(int, int, int); void upd(int, int, int, int, int, int); int qry(int, int, int, int, int, int, int); segseg(int n, int m): st(n << 2 | 1, seg(m)) {} }; void segseg::pull(int id, int x) { st[id].upd(x, max(st[lc(id)].qry(x, x), st[rc(id)].qry(x, x))); } void segseg::build(int id = 1, int l = 1, int r = N) { if (l == r) {st[id].build(); return;} int m = (l + r) >> 1; build(lc(id), l, m), build(rc(id), m + 1, r); st[id].merge(st[lc(id)], st[rc(id)]); } void segseg::upd(int y, int x, int v, int id = 1, int l = 1, int r = N) { if (l == r) {st[id].upd(x, v); return;} int m = (l + r) >> 1; if (y <= m) upd(y, x, v, lc(id), l, m); else upd(y, x, v, rc(id), m + 1, r); pull(id, x); } int segseg::qry(int y1, int y2, int x1, int x2, int id = 1, int l = 1, int r = N) { if (y1 <= l && r <= y2) return st[id].qry(x1, x2); int m = (l + r) >> 1, ret = -inf; if (y1 <= m) ret = max(ret, qry(y1, y2, x1, x2, lc(id), l, m)); if (y2 > m) ret = max(ret, qry(y1, y2, x1, x2, rc(id), m + 1, r)); return ret; }

通常這記憶體使用很大,
有以下幾個解決方法 :

  1. 對其中一維打懶標,如果區間只有一個數就不必遞迴下去,懶標存位置
  2. 四分樹,複雜度相對較好,空間、時間都少一個 \(\log\)

Persistent

  • 持久化線段樹,又稱主席樹
  • 以左閉右開的圖舉例 :

可以發現每次只有少部分的點更新,因此對於沒有更新的點就共用前一個版本的點,有更新的則新開點。

  • 更新多開 \(\log n\) 的節點更新資訊
  • 沒有更新到的節點共用前一個版本的節點
  • 根節點一定會更新到,只要保存好根節點,就可以做版本之間的查詢

Code

Update

void upd(node* prv, node* cur, int l, int r, int x, int v) { if (l == r) {cur->v = v; return;} int m = (l + r) >> 1; if (x <= m) cur->r = prv->r, upd(prv->l, cur->l = new node, l, m, x, v); else cur->l = prv->l, upd(prv->r, cur->r = new node, m + 1, r, x, v); cur->pull(); // 更新節點資訊 }
  • prv 是前一個版本,cur 是現在的版本
  • 有更新的就 new node 開點,沒更新的就共用前一個版本的節點
  • 不論是 prv 又或者是 cur 都是完整的線段樹 (根節點)

Query

// 以版本間區間合的變化舉例 int qry(node* a, node* b, int l, int r, int ql, int qr) { if (ql <= l && r <= qr) return a->v - b->v; int m = (l + r) >> 1, ret = 0; if (ql <= m) ret += qry(a->l, b->l, l, m, ql, qr); if (qr > m) ret += qry(a->r, b->r, m + 1, r, ql, qr); return ret; }
  • 查詢的話跟一般線段樹一樣,不過可以兩個版本一起查詢,少一個常數時間

Problems

ZJ f315: 4. 低地距離
#include <bits/stdc++.h> #define lc(x) (x << 1) #define rc(x) (x << 1 | 1) #define ll long long #define _ ios::sync_with_stdio(false), cin.tie(nullptr); using namespace std; const int MAXN = 100000; int l[MAXN + 1], r[MAXN + 1]; int seg[MAXN << 3 | 1]; void upd(int id, int l, int r, int x, int v) { if (l == r) { seg[id] += v; return; } int m = (l + r) >> 1; if (x <= m) upd(lc(id), l, m, x, v); else upd(rc(id), m + 1, r, x, v); seg[id] = seg[lc(id)] + seg[rc(id)]; } int qry(int id, int l, int r, int ql, int qr) { if (l == ql && r == qr) return seg[id]; int m = (l + r) >> 1; if (qr <= m) return qry(lc(id), l, m, ql, qr); if (ql > m) return qry(rc(id), m + 1, r, ql, qr); return qry(lc(id), l, m, ql, m) + qry(rc(id), m + 1, r, m + 1, qr); } int main() { _ int n; cin >> n; for (int i = 1, x, lim = n << 1; i <= lim; i++) { cin >> x; if (!l[x]) l[x] = i; else r[x] = i; } ll ans = 0; for (int i = 1, n2 = n << 1; i <= n; i++) { ans += qry(1, 1, n2, l[i], r[i]); upd(1, 1, n2, l[i], 1), upd(1, 1, n2, r[i], 1); } cout << ans << '\n'; }
TIOJ 1836. [IOI 2013] 遊戲 Game

線段樹套樹 + 動態開點 + 懶人標記

#pragma GCC optimize("O2") #include <bits/stdc++.h> #define gcd __gcd // if before C++ 17 using namespace std; #define debug(x) cerr << (#x) << " : " << x << '\n' using ll = long long; int N, M; struct node { node *l, *r; ll v; int tag; void pull() {v = gcd(l ? l->v : 0LL, r ? r->v : 0LL);} void push(int l, int r) { int m = (l + r) >> 1; if (tag <= m) this->l = new node(v, tag); else this->r = new node(v, tag); tag = -2; } node(ll _v = 0, int _tag = -1): v(_v), tag(_tag) {l = r = nullptr;} }; void upd(node*& o, int x, ll v, int l = 0, int r = M - 1) { if (!o) o = new node; if (o->tag == x || l == r) { o->v = v; return; } if (o->tag == -1) { o->tag = x, o->v = v; return; } if (o->tag >= 0) o->push(l, r); int m = (l + r) >> 1; if (x <= m) upd(o->l, x, v, l, m); else upd(o->r, x, v, m + 1, r); o->pull(); } ll qry(node* o, int ql, int qr, int l = 0, int r = M - 1) { if (!o) return 0; if (ql <= l && r <= qr) return o->v; if (o->tag >= 0) return ql <= o->tag && o->tag <= qr ? o->v : 0; int m = (l + r) >> 1; ll ret = 0; if (ql <= m) ret = gcd(ret, qry(o->l, ql, qr, l, m)); if (qr > m) ret = gcd(ret, qry(o->r, ql, qr, m + 1, r)); return ret; } struct seg { seg *l, *r; node* v; void pull(int x) { upd(v, x, gcd(l ? qry(l->v, x, x) : 0LL, r ? qry(r->v, x, x) : 0LL)); } seg(): v(nullptr) {l = r = nullptr;} }; void upd(seg*& o, int y, int x, ll v, int l = 0, int r = N - 1) { if (!o) o = new seg; if (!o->v) o->v = new node; if (l == r) { upd(o->v, x, v); return; } int m = (l + r) >> 1; if (y <= m) upd(o->l, y, x, v, l, m); else upd(o->r, y, x, v, m + 1, r); // update node for 1D segment tree o->pull(x); } ll qry(seg* o, int y1, int y2, int x1, int x2, int l = 0, int r = N - 1) { if (!o) return 0; if (y1 <= l && r <= y2) return qry(o->v, x1, x2); int m = (l + r) >> 1; ll ret = 0; if (y1 <= m) ret = gcd(ret, qry(o->l, y1, y2, x1, x2, l, m)); if (y2 > m) ret = gcd(ret, qry(o->r, y1, y2, x1, x2, m + 1, r)); return ret; } #define _ ios::sync_with_stdio(false), cin.tie(nullptr); int main() { _ int Q; seg *root = new seg; cin >> M >> N >> Q; // N 直的, M 橫的 while (Q--) { int op; cin >> op; if (op == 1) { int x, y; ll v; cin >> x >> y >> v; upd(root, y, x, v); } else { int x1, y1, x2, y2; cin >> x1 >> y1 >> x2 >> y2; if (x1 > x2) swap(x1, x2); if (y1 > y2) swap(y1, y2); cout << qry(root, y1, y2, x1, x2) << '\n'; } } }
ZJ a331: K-th Number
#pragma GCC optimize("O2") #include <bits/stdc++.h> #define ALL(x) (x).begin(), (x).end() using namespace std; struct node { node *l, *r; int v; void pull() {v = (l ? l->v : 0) + (r ? r->v : 0);} node(int _v = 0): v(_v) {l = r = nullptr;} }; void build(node* o, int l, int r) { if (l == r) return; int m = (l + r) >> 1; build(o->l = new node, l, m); build(o->r = new node, m + 1, r); } void upd(node* prv, node* cur, int l, int r, int x) { if (l == r) {cur->v += 1; return;} int m = (l + r) >> 1; if (x <= m) cur->r = prv->r, upd(prv->l, cur->l = new node, l, m, x); else cur->l = prv->l, upd(prv->r, cur->r = new node, m + 1, r, x); cur->pull(); } int find(node* a, node* b, int l, int r, int k) { if (l == r) return l; int m = (l + r) >> 1, add = b->l->v - a->l->v; if (add >= k) return find(a->l, b->l, l, m, k); return find(a->r, b->r, m + 1, r, k - add); } vector<int> a, tmp; vector<node*> ver; #define _ ios::sync_with_stdio(false), cin.tie(nullptr); int main() { _ for (int n, q; cin >> n >> q;) { a.resize(n + 1), tmp.resize(n + 1), ver.resize(n + 1); for (int i = 1, x; i <= n; i++) cin >> x, a[i] = tmp[i] = x; /* descretize */ sort(ALL(tmp)), tmp.resize(unique(ALL(tmp)) - tmp.begin()); int size = tmp.size(); build(ver[0] = new node, 1, size); for (int i = 1; i <= n; i++) { int x = lower_bound(ALL(tmp), a[i]) - tmp.begin() + 1; upd(ver[i - 1], ver[i] = new node, 1, size, x); } while (q--) { int l, r, k; cin >> l >> r >> k; int ans = tmp[find(ver[l - 1], ver[r], 1, size, k) - 1]; cout << ans << '\n'; } } }
Select a repo