owned this note
owned this note
Published
Linked with GitHub
# 線段樹 Segment Tree
## 簡介
線段樹是一種有效的資料結構,通常用來解決區間查詢和區間更新的問題。線段樹具有良好的時間複雜度,通常為 O(log n),這使它能夠快速處理區間的求和、最大值、最小值等操作。
## 時間複雜度
* 建樹: O(n),這是因為每個節點都需要被處理一次。
* 查詢: O(log n),因為每次查詢都會分成兩個區間處理,並且每次都會將區間範圍縮小一半。
* 更新: O(log n),與查詢類似,更新過程需要對節點進行修改。
## Code
### array version(無懶標)
```cpp=
const int maxn = 2e5 + 1;
int arr[maxn], seg[maxn << 2];
#define m ((l + r) >> 1)
void pull(int &i) {seg[i] = seg[i << 1] + seg[i << 1 | 1];}
void build(int l = 1, int r = maxn, int i = 1) {
if(l == r) return seg[i] = arr[l], void();
build(l, m, i << 1), build(m + 1, r, i << 1 | 1);
pull(i);
}
int qry(int ql, int qr, int l = 1, int r = maxn, int i = 1) {
if(ql > r || qr < l) return 0;
if(ql <= l && qr >= r) return seg[i];
return qry(ql, qr, l, m, i << 1) + qry(ql, qr, m + 1, r, i << 1 | 1);
}
void upd(int x, int k, int l = 1, int r = maxn, int i = 1) {
if(x < l || x > r) return;
if(l == r) return seg[i] += k, void();
upd(x, k, l, m, i << 1), upd(x, k, m + 1, r, i << 1 | 1);
pull(i);
}
#undef m
```
### array version(帶懶標)
```cpp=
const int maxn = 2e5 + 1;
int arr[maxn], seg[maxn << 2], tag[maxn << 2];
#define m ((l + r) >> 1)
void pull(int &i) {seg[i] = seg[i << 1] + seg[i << 1 | 1];}
void push(int &l, int &r, int &i) {
if(!tag[i]) return;
if(l == r) return seg[i] += tag[i], tag[i] = 0, void();
seg[i] += tag[i] * (r - l + 1);
tag[i << 1] += tag[i], tag[i << 1 | 1] += tag[i];
tag[i] = 0;
}
void build(int l = 1, int r = maxn, int i = 1) {
if(l == r) return seg[i] = arr[l], void();
build(l, m, i << 1); build(m + 1, r, i << 1 | 1);
pull(i);
}
int qry(int ql, int qr, int l = 1, int r = maxn, int i = 1) {
push(l, r, i);
if(ql > r || qr < l) return 0;
if(ql <= l && qr >= r) return seg[i];
return qry(ql, qr, l, m, i << 1) + qry(ql, qr, m + 1, r, i << 1 | 1);
}
void upd(int ql, int qr, int k, int l = 1, int r = maxn, int i = 1) {
push(l, r, i);
if(ql > r || qr < l) return;
if(ql <= l && qr >= r) return tag[i] += k, push(l, r, i), void();
upd(ql, qr, k, l, m, i << 1), upd(ql, qr, k, m + 1, r, i << 1 | 1);
pull(i);
}
#undef m
```
### pointer version(無懶標)
```cpp=
const int maxn = 1e8 + 1;
#define m ((l + r) >> 1)
struct node {
int v;
node *l, *r;
node(int k = 0) {v = k, l = r = 0;}
void pull() {v = (l? l -> v : 0) + (r? r -> v : 0);}
};
int qry(node* &nd, int ql, int qr, int l = 1, int r = maxn) {
if(!nd) return 0;
if(ql > r || qr < l) return 0;
if(ql <= l && qr >= r) return nd -> v;
return qry(nd -> l, ql, qr, l, m) + qry(nd -> r, ql, qr, m + 1, r);
}
void upd(node* &nd, int x, int k, int l = 1, int r = maxn) {
if(!nd) nd = new node();
if(x > r || x < l) return;
if(l == r) {nd -> v += k; return;}
upd(nd -> l, x, k, l, m), upd(nd -> r, x, k, m + 1, r);
nd -> pull();
}
```
:::spoiler 微優化但較麻煩
```cpp=
const int maxn = 1e8 + 1;
#define m ((l + r) >> 1)
struct node {
int x;
node *l, *r;
node(int k = 0) {x = k, l = r = 0;}
void pull() {x = (l? l -> x : 0) + (r? r -> x : 0);}
};
void upd(node* &nd, int p, int x, int l = 1, int r = maxn) {
if(!nd) nd = new node();
if(p < l || p > r) return;
if(l == r) return nd -> x += x, void();
if(p < m) upd(nd -> l, p, x, l, m);
else if(p > m) upd(nd -> r, p, x, m + 1, r);
else upd(nd -> l, p, x, l, m), upd(nd -> r, p, x, m + 1, r);
nd -> pull();
}
int qry(node* &nd, int ql, int qr, int l = 1, int r = maxn) {
if(!nd) return 0;
if(qr < l || ql > r) return 0;
if(ql <= l && qr >= r) return nd -> x;
if(ql <= l && qr < m) return qry(nd -> l, ql, qr, l, m);
else if(ql > m && qr >= r) return qry(nd -> r, ql, qr, m + 1, r);
else return qry(nd -> l, ql, qr, l, m) + qry(nd -> r, ql, qr, m + 1, r);
}
#undef m
```
:::
### pointer version(帶懶標)
```cpp=
const int maxn = 1e8 + 1;
#define m ((l + r) >> 1)
struct node {
int v, tag;
node *l, *r;
node(int k = 0) {
v = 0;
tag = k;
l = r = 0;
}
void pull() {v = (l? l -> v : 0) + (r? r -> v : 0);}
void push(int _l, int _r) {
if(!tag) return;
if(_l == _r) {v += tag, tag = 0; return;}
v += tag * (_r - _l + 1);
if(!l) l = new node();
if(!r) r = new node();
l -> tag += tag, r -> tag += tag;
tag = 0;
}
};
int qry(node* &nd, int ql, int qr, int l = 1, int r = maxn) {
if(!nd) return 0;
nd -> push(l, r);
if(ql > r || qr < l) return 0;
if(ql <= l && qr >= r) return nd -> v;
return qry(nd -> l, ql, qr, l, m) + qry(nd -> r, ql, qr, m + 1, r);
}
void upd(node* &nd, int ql, int qr, int k, int l = 1, int r = maxn) {
if(!nd) nd = new node();
nd -> push(l, r);
if(ql > r || qr < l) return;
if(ql <= l && qr >= r) {nd -> tag += k, nd -> push(l, r); return;}
upd(nd -> l, ql, qr, k, l, m), upd(nd -> r, ql, qr, k, m + 1, r);
nd -> pull();
}
```
:::spoiler 微優化但較麻煩
```cpp=
const int maxn = 1e8 + 1;
#define m ((l + r) >> 1)
struct node {
int v, tag;
node *l, *r;
node(int k = 0) {
v = 0;
tag = k;
l = r = 0;
}
void pull() {v = (l? l -> v : 0) + (r? r -> v : 0);}
void push(int _l, int _r) {
if(!tag) return;
if(_l == _r) {v += tag, tag = 0; return;}
v += tag * (_r - _l + 1);
if(!l) l = new node();
if(!r) r = new node();
l -> tag += tag, r -> tag += tag;
tag = 0;
}
};
int qry(node* &nd, int ql, int qr, int l = 1, int r = maxn) {
if(!nd) return 0;
nd -> push(l, r);
if(ql > r || qr < l) return 0;
if(ql <= l && qr >= r) return nd -> v;
if(ql <= l && qr < m) return qry(nd -> l, ql, qr, l, m);
else if(ql > m && qr >= r) return qry(nd -> r, ql, qr, m + 1, r);
else return qry(nd -> l, ql, qr, l, m) + qry(nd -> r, ql, qr, m + 1, r);
}
void upd(node* &nd, int ql, int qr, int k, int l = 1, int r = maxn) {
if(!nd) nd = new node();
nd -> push(l, r);
if(ql > r || qr < l) return;
if(ql <= l && qr >= r) {nd -> tag += k, nd -> push(l, r); return;}
if(ql <= l && qr < m) upd(nd -> l, ql, qr, k, l, m);
else if(ql > m && qr >= r) upd(nd -> r, ql, qr, k, m + 1, r);
else upd(nd -> l, ql, qr, k, l, m), upd(nd -> r, ql, qr, k, m + 1, r);
nd -> pull();
}
#undef m
```
:::
### 區間加值+設值
```cpp=
struct node {int sum, tag_set, tag_add;};
const int maxn = 2e5 + 1;
node seg[maxn << 2];
int arr[maxn];
#define m ((l + r) >> 1)
#define ll (i << 1)
#define rr (i << 1 | 1)
void pull(int &i) {seg[i].sum = seg[ll].sum + seg[rr].sum;}
void push(int &l, int &r, int &i) {
if(!seg[i].tag_add && !seg[i].tag_set) return;
if(l == r) {
if(seg[i].tag_add) seg[i].sum += seg[i].tag_add;
else seg[i].sum = seg[i].tag_set;
seg[i].tag_add = seg[i].tag_set = 0;
return;
}
if(seg[i].tag_set) {
seg[i].sum = seg[i].tag_set * (r - l + 1);
seg[ll].tag_set = seg[rr].tag_set = seg[i].tag_set;
seg[ll].tag_add = seg[rr].tag_add = 0;
}
else if(seg[i].tag_add) {
seg[i].sum += seg[i].tag_add * (r - l + 1);
for(auto &x : {ll, rr}) {
if(seg[x].tag_set) seg[x].tag_set += seg[i].tag_add;
else seg[x].tag_add += seg[i].tag_add;
}
}
seg[i].tag_add = seg[i].tag_set = 0;
}
void build(int l = 1, int r = maxn, int i = 1) {
if(l == r) return seg[i].sum = arr[l], void();
build(l, m, ll), build(m + 1, r, rr);
pull(i);
}
int qry(int ql, int qr, int l = 1, int r = maxn, int i = 1) {
push(l, r, i);
if(ql > r || qr < l) return 0;
if(ql <= l && qr >= r) return seg[i].sum;
return qry(ql, qr, l, m, ll) + qry(ql, qr, m + 1, r, rr);
}
void upd_set(int ql, int qr, int k, int l = 1, int r = maxn, int i = 1) {
push(l, r, i);
if(ql > r || qr < l) return;
if(ql <= l && qr >= r) return seg[i].tag_set = k, push(l, r, i), void();
upd_set(ql, qr, k, l, m, ll);
upd_set(ql, qr, k, m + 1, r, rr);
pull(i);
}
void upd_add(int ql, int qr, int k, int l = 1, int r = maxn, int i = 1) {
push(l, r, i);
if(ql > r || qr < l) return;
if(ql <= l && qr >= r) return seg[i].tag_add += k, push(l, r, i), void();
upd_add(ql, qr, k, l, m, ll);
upd_add(ql, qr, k, m + 1, r, rr);
pull(i);
}
#undef m
```
### 持久化
持久化線段樹允許保留每個操作前後的版本,這樣用戶可以查詢歷史狀態。這在需要回溯或者需要處理動態版本控制時特別有用。
```cpp=
const int maxn = 2e5 + 1;
#define m ((l + r) >> 1)
struct node {
int v;
node *l, *r;
node(int x = 0) : v(x) {l = r = 0;}
void pull() {v = (l? l -> v : 0) + (r? r -> v : 0);}
};
node* build(int l = 1, int r = maxn) {
node *nd = new node();
if(l == r) return nd;
nd -> l = build(l, m);
nd -> r = build(m + 1, r);
return nd;
}
node* upd(node *nd, int x, int k, int l = 1, int r = maxn) {
node *tmp = new node(*nd);
if(x < l || x > r) return tmp;
if(l == r) return tmp -> v = k, tmp;
if(x < m) tmp -> l = upd(nd -> l, x, k, l, m);
else if(x > m) tmp -> r = upd(nd -> r, x, k, m + 1, r);
else {
tmp -> l = upd(nd -> l, x, k, l, m);
tmp -> r = upd(nd -> r, x, k, m + 1, r);
}
tmp -> pull();
return tmp;
}
int qry(node *nd, int ql, int qr, int l = 1, int r = maxn) {
if(ql > r || qr < l) return 0;
if(ql <= l && qr >= r) return nd -> v;
return qry(nd -> l, ql, qr, l, m) + qry(nd -> r, ql, qr, m + 1, r);
}
vector<node*> v;
```