# Range Queries Sol [toc] ## 前言 本題單有許多題目可以被或必須用線段樹的技巧來解。 請確保熟悉線段樹的基本用法,如`build`, `up`, `lazy tag` 可以看 [線段樹 總論](/Z3vYeKufRzaTZHdekANwRA) 好了,現在 ![s](https://hackmd.io/_uploads/HJirBfDKle.jpg) ## [Static Range Sum Queries](https://cses.fi/problemset/task/1646) 想法:前綴和。 加法是可以“取消”的操作,所以才可以用前綴和。 像是 "區間查 [l,r] 的 product 模 17" 也可以前綴和,因為 17 是質數,所以可以乘個反元素做取消。 但是 "區間查 [l,r] 的 product 模 16" 不可以前綴和。注意 $ax = 1 \pmod m$ 可解若且唯若 $\gcd(a,m) = 1$ :::spoiler ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; ll nums[200001]; ll pre[200001]; int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; for (int i=0;i<n;i++) { cin >> nums[i]; } pre[0] = 0; for (int i=1;i<=n;i++) pre[i] = pre[i-1] + nums[i-1]; for (int i=0;i<q;i++) { int l,r; cin >> l >> r; cout << pre[r]-pre[l-1] << '\n'; } } ``` ::: ## [Static Range Minimum Queries](https://cses.fi/problemset/task/1647) 想法:直接套 sparse table 一個跟 binary lifting 很像的東西 TC: $O(q + n\log n)$ SC: $O(n\log n)$ sparse table 在各種 bitwise and / or 都很好用。 因為 `&` 很多數字是遞減, `|` 很多數字是遞增,所以在一些區間 bitwise 題目上可以跟 binary search 搭配使用。 :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; class ST { public: vector<vector<int>> t; vector<int> logs; int n; ST(vector<int>& nums): logs(nums.size()+1), n(nums.size()) { logs[1] = 0; for (int i=2;i<=n;i++) logs[i] = 1+logs[i/2]; t.assign(logs[n]+1,vector<int>(n)); t[0] = nums; for (int i=1;(1<<i)<=n;i++) { for (int j=0;j+(1<<i)<=n;j++) { t[i][j] = min(t[i-1][j],t[i-1][j+(1<<(i-1))]); } } } int query(int l,int r) { int i = logs[r-l+1]; return min(t[i][l],t[i][r-(1<<i)+1]); } }; int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; vector<int> nums(n); for (int i=0;i<n;i++) { cin >> nums[i]; } ST st(nums); for (int i=0;i<q;i++) { int a,b; cin >> a >> b; a--;b--; cout << st.query(a,b) << '\n'; } ``` ::: 另解:或是用線段樹。最大的差異是最後時間複雜度裡$\log n$乘的地方不同,可以自行評估使用時機。 TC: $O(n + q\log n)$ SC: $O(n)$ :::spoiler ```cpp= #include <bits/stdc++.h> using namespace std; #define maxn 200001 int arr[maxn]; int t[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); t[i] = min(t[i<<1],t[i<<1|1]); } int query(int jl,int jr, int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; int ans = INT_MAX; int mid = midpoint(l,r); if (jl <= mid) ans = min(ans,query(jl,jr,l,mid,i<<1)); if (jr > mid) ans = min(ans,query(jl,jr,mid+1,r,i<<1|1)); return ans; } }; int main() { cin.tie(0), ios::sync_with_stdio(0); int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int a,b; cin >> a >> b, a--, b--; cout << seg.query(a,b,0,n-1,1) << '\n'; } } ``` ::: ## [Dynamic Range Sum Queries](https://cses.fi/problemset/task/1648) 想法:BIT 模板題。 :::spoiler ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; class BIT { public: vector<ll> t,nums; int n; BIT(vector<ll>& nums): t(nums.size()+1), n(nums.size()), nums(nums) { for (int i=1;i<=n;i++) { t[i] += nums[i-1]; int j = i + (i&-i); if (j<=n) t[j]+=t[i]; } } ll query(int i) { ll ans = 0; while (i) ans+=t[i], i-=i&-i; return ans; } void modify(int i,int val) { if (val == nums[i-1]) return; int diff = val-nums[i-1]; nums[i-1] = val; add(i,diff); } void add(int i,int val) { while (i<=n) t[i]+=val, i+=i&-i; } }; int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; vector<ll> nums(n); for (int i=0;i<n;i++) { cin >> nums[i]; } BIT bit(nums); for (int i=0;i<q;i++) { int t; cin >> t; if (t==1) { int k,u; cin >> k >> u; bit.modify(k,u); } else { int a,b; cin >> a >> b; cout << bit.query(b)-bit.query(a-1) << '\n'; } } } ``` ::: TC: $O(q\log n)$ SC: $O(n)$ ## [Dynamic Range Minimum Queries](https://cses.fi/problemset/task/1649) 一定要用線段樹了,因為要做修改比較方便。 另有一種使用兩個 $BIT$ 的做法。此處略。 :::spoiler ```cpp= #include <bits/stdc++.h> #include <cstring> using namespace std; int arr[200001], t[200001<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void up(int i) { t[i] = min(t[i<<1],t[i<<1|1]); } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } int query(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; int ans = INT_MAX, mid = (l+r)>>1; if (jl <= mid) ans = min(ans,query(jl,jr,l,mid,i<<1)); if (jr > mid) ans = min(ans,query(jl,jr,mid+1,r,i<<1|1)); return ans; } int query(int jl,int jr) { return query(jl,jr,0,n-1,1); } void point_set(int idx,int val,int l,int r,int i) { t[i] = val; if (l==r) return; int mid = (l+r)>>1; if (idx <= mid) point_set(idx,val,l,mid,i<<1); else point_set(idx,val,mid+1,r,i<<1|1); up(i); } void point_set(int idx,int val) { point_set(idx,val,0,n-1,1); } }; int main() { int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int a,b,c; cin >> a >> b >> c; if (a==1) { b--; seg.point_set(b,c); } else { b--,c--; cout << seg.query(b,c) << '\n'; } } } ``` ::: ## [Range Xor Queries](https://cses.fi/problemset/task/1650) 想法:因為 $x \text{ xor } x = 0$ 而且 $0 \text{ xor } x = x$,所以 xor 可取消,所以使用前綴和。 :::spoiler ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; vector<int> arr(n); for (int i=0;i<n;i++) cin >> arr[i]; vector<int> pre(n+1); pre[0] = 0; for (int i=1;i<=n;i++) pre[i] = pre[i-1]^arr[i-1]; for (int i=0;i<q;i++) { int a,b; cin >> a >> b; cout << (pre[b]^pre[a-1]) << '\n'; } } ``` ::: ## [Range Update Queries](https://cses.fi/problemset/task/1651) 想法:懶標線段樹,實現範圍修改。 TC: $O(n + q\log n)$ SC: $O(n)$ :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; ll arr[200001], t[200001<<2], lazy[200001<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void apply(ll val,int l,int r,int i) { t[i] += (r-l+1)*val; lazy[i] += val; } void down(int l,int r,int i) { if (!lazy[i]) return; int mid = (l+r)>>1; apply(lazy[i],l,mid,i<<1); apply(lazy[i],mid+1,r,i<<1|1); lazy[i] = 0; } void up(int i) { t[i] = t[i<<1]+t[i<<1|1]; } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } ll query(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; ll ans = 0; int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) ans+=query(jl,jr,l,mid,i<<1); if (jr > mid) ans+=query(jl,jr,mid+1,r,i<<1|1); up(i); return ans; } ll query(int jl,int jr) { return query(jl,jr,0,n-1,1); } void add(int jl,int jr,ll val,int l,int r,int i) { if (jl <= l && r <= jr) { apply(val,l,r,i); return; } int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) add(jl,jr,val,l,mid,i<<1); if (jr > mid) add(jl,jr,val,mid+1,r,i<<1|1); up(i); } void add(int jl,int jr,ll val) { add(jl,jr,val,0,n-1,1); } }; int main() { int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int t; cin >> t; if (t==1) { int a,b,u; cin >> a >> b >> u, a--,b--; seg.add(a,b,u); } else { int k; cin >> k, k--; cout << seg.query(k,k) << '\n'; } } } ``` ::: ## [Forest Queries](https://cses.fi/problemset/task/1652) 想法:用排容原理預處理出二維前綴和,再用排容原理回答每個 subrectangle 的答案。 TC: $O(n^2 + q)$ SC: $O(n^2)$ :::spoiler code ```cpp #include <bits/stdc++.h> using namespace std; using ll = long long; int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; vector<vector<bool>> nums(n,vector<bool>(n)); for (int i=0;i<n;i++) { for (int j=0;j<n;j++) { char c; cin >> c; nums[i][j] = (c=='*'); } } vector<vector<int>> pre(n+1,vector<int>(n+1)); for (int i=1;i<=n;i++) { for (int j=1;j<=n;j++) { pre[i][j] = pre[i-1][j] + pre[i][j-1] - pre[i-1][j-1] + nums[i-1][j-1]; } } for (int i=0;i<q;i++) { int y1,x1,y2,x2; cin >> y1 >> x1 >> y2 >> x2; int lx = min(x1,x2); int ly = min(y1,y2); int ux = max(x1,x2); int uy = max(y1,y2); cout << pre[uy][ux] + pre[ly-1][lx-1] - pre[uy][lx-1] - pre[ly-1][ux] << '\n'; } } ``` ::: ## [Hotel Queries](https://cses.fi/problemset/task/1143) 想法:[線段樹上二分搜](https://wiki.sam571128.codes/data-structure/segment-tree/segment-tree-2) TC: $O(n + q\log n)$ SC: $O(n)$ :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; int arr[200001], t[200001<<2], lazy[200001<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void up(int i) { t[i] = max(t[i<<1],t[i<<1|1]); } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } void padd(int idx,int val,int l,int r,int i) { t[i] += val; if (l==r) return; int mid = (l+r)>>1; if (idx <= mid) padd(idx,val,l,mid,i<<1); else padd(idx,val,mid+1,r,i<<1|1); up(i); } void padd(int idx,int val) { padd(idx,val,0,n-1,1); } int query(int x,int l,int r,int i) { if (t[i] < x) return -1; if (l==r) return l; int mid = (l+r)>>1; if (t[i<<1] >= x) return query(x,l,mid,i<<1); return query(x,mid+1,r,i<<1|1); } int query(int x) { return query(x,0,n-1,1); } }; int main() { int n,m; cin >> n >> m; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<m;i++) { int mv; cin >> mv; int idx = seg.query(mv); if (idx != -1) seg.padd(idx,-mv); cout << idx+1 << '\n'; } } ``` ::: 同場加映:[Luogu P1503](https://www.luogu.com.cn/problem/P1503) 可以用很多方法做。 也可以用線段樹上二分搜。同時需要搜第一個跟最後一個,而且是在區間上而非全局上搜。 :::spoiler code ```cpp= #include <bits/stdc++.h> using ll = long long; using namespace std; #define maxn 50001 bool dested[maxn]; int sta[maxn]; // t bool t[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void build(int l,int r,int i) { if (l==r) { t[i] = false; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } void up(int i) { t[i] = t[i<<1] || t[i<<1|1]; } int findFirst(int jl,int jr,int l,int r,int i) { if (r < jl || jr < l || !t[i]) return n; if (l==r) return l; int mid = midpoint(l,r); int v = findFirst(jl,jr,l,mid,i<<1); if (v != n) return v; return findFirst(jl,jr,mid+1,r,i<<1|1); } int findLast(int jl,int jr,int l,int r,int i) { if (r < jl || jr < l || !t[i]) return -1; if (l==r) return l; int mid = midpoint(l,r); int v = findLast(jl,jr,mid+1,r,i<<1|1); if (v != -1) return v; return findLast(jl,jr,l,mid,i<<1); } void modify(int idx,bool v,int l,int r,int i) { if (l==r) { t[i] = v; return; } int mid = midpoint(l,r); if (idx <= mid) modify(idx,v,l,mid,i<<1); else modify(idx,v,mid+1,r,i<<1|1); up(i); } }; int main() { cin.tie(0)->sync_with_stdio(0); memset(dested,0,sizeof(dested)); int n,m; cin >> n >> m; segtree seg(n); int stan = 0; for (int i=0;i<m;i++) { char c; cin >> c; if (c=='D') { int x; cin >> x, x--; sta[stan++] = x; dested[x] = true; seg.modify(x,true,0,n-1,1); } else if (c=='R') { seg.modify(sta[stan-1],false,0,n-1,1); dested[sta[stan-1]] = false; stan--; } else { int x; cin >> x, x--; if (dested[x]) { cout << "0\n"; continue; } int l = seg.findLast(0,x,0,n-1,1), r = seg.findFirst(x,n-1,0,n-1,1); cout << r-l-1 << '\n'; } } } ``` ::: ## [List Removals](https://cses.fi/problemset/model/1749/) 想法:模擬一個 [0,1,2,...,n-1] 的下標族,要移掉第k個時,先找第k個在哪,然後對於k這個點, 把它移掉(設成很小的值),然後k後所有的點的下標都減一。 需要: * 線段樹上二分搜(下標族為遞增) * max range query + range addition 可以用懶標線段樹。 一題考你兩個點,是一個考驗實作能力的好題。 :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; #define maxn 200001 int arr[maxn]; int t[maxn<<2]; int lz[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { memset(lz,0,sizeof(lz)); build(0,n-1,1); } void up(int i) { t[i] = max(t[i<<1],t[i<<1|1]); } void down(int l,int r,int i) { int mid = midpoint(l,r); if (lz[i]) { apply(lz[i],l,mid,i<<1); apply(lz[i],mid+1,r,i<<1|1); lz[i] = 0; } } void build(int l,int r,int i) { if (l==r) { t[i] = l; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } int query(int jl,int jr, int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; int ans = INT_MIN; int mid = midpoint(l,r); down(l,r,i); if (jl <= mid) ans = max(ans,query(jl,jr,l,mid,i<<1)); if (jr > mid) ans = max(ans,query(jl,jr,mid+1,r,i<<1|1)); return ans; } void apply(int v,int l,int r,int i) { t[i] += v; // imp lz[i] += v; } void add(int jl,int jr,int v,int l,int r,int i) { if (jl <= l && r <= jr) { apply(v,l,r,i); return; } int mid = midpoint(l,r); down(l,r,i); if (jl <= mid) add(jl,jr,v,l,mid,i<<1); if (jr > mid) add(jl,jr,v,mid+1,r,i<<1|1); up(i); } int find(int x,int l,int r,int i) { if (t[i] < x) return -1; if (l==r) return l; int mid = midpoint(l,r); down(l,r,i); if (t[i<<1] >= x) return find(x,l,mid,i<<1); return find(x,mid+1,r,i<<1|1); } }; int main() { cin.tie(0), ios::sync_with_stdio(0); int n; cin >> n; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<n;i++) { int x; cin >> x, x--; int idx = seg.find(x,0,n-1,1); cout << arr[idx] << '\n'; seg.add(idx,idx,-114514,0,n-1,1); seg.add(idx+1,n-1,-1,0,n-1,1); } } ``` ::: ## [Salary Queries](https://cses.fi/problemset/task/1144) 想法:動態開點線段樹。如果有個薪水為 $k$ 的人就把數組的第 $k$ 個欄位 +1,問題就變成區間和查詢了。 另外一個做法是做 coordinate compression 如果要使用開點線段樹,注意不要用 `new` 的空間,會超時的! 請使用 memory pool 吧 :::spoiler ```cpp= #include <bits/stdc++.h> #define umap unordered_map #define uset unordered_set using ll = long long; using namespace std; int mod = 1e9+7; ll modadd(ll a,ll b) {return (a+b)%mod;} ll modmul(ll a,ll b) {return (a*b)%mod;} ll fastpow(ll x,ll p) {ll ans = 1;for (;p;p>>=1, x=modmul(x,x)) if(p&1) ans=modmul(ans,x);return ans;} bool fastio = cin.tie(0)->sync_with_stdio(0); template <typename T> ostream& operator<<(ostream& l,const vector<T>& r) {l << '['; for (int i=0;i<r.size();i++) l << r[i] << " ]"[i==r.size()-1]; return l;} // TEMPLATE END const int N = 200001*60; ll t[N<<2]; int cnt = 1; int lt[N<<2], rt[N<<2]; struct segtree { ll query(int jl,int jr,int l,int r,int i) { if (jl<=l && r<=jr) return t[i]; ll ans = 0; int mid = l+(r-l)/2; if (jl<=mid && lt[i]) ans += query(jl,jr,l,mid,lt[i]); if (mid<jr && rt[i]) ans += query(jl,jr,mid+1,r,rt[i]); return ans; } void up(int i) { t[i] = t[lt[i]] + t[rt[i]]; } void add(int ji,int v,int l,int r,int i) { if (l==r) { t[i] += v; return; } int mid = l+(r-l)/2; if (ji<=mid) { if (!lt[i]) lt[i] = ++cnt; add(ji,v,l,mid,lt[i]); } else { if (!rt[i]) rt[i] = ++cnt; add(ji,v,mid+1,r,rt[i]); } up(i); } }; int arr[200001]; int main() { segtree seg; int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; int M = *max_element(arr,arr+n); for (int i=0;i<n;i++) seg.add(arr[i],1,0,M,1); for (int i=0;i<q;i++) { char c; cin >> c; if (c=='!') { int k,x; cin >> k >> x, k--; seg.add(arr[k],-1,0,M,1); seg.add(x,1,0,M,1); arr[k] = x; } else { int a,b; cin >> a >> b; cout << seg.query(a,b,0,M,1) << '\n'; } } } ``` ::: 令 $U = 10^9$ 為值域,則 TC: $O((n+q)\log U)$ SC: $O((n+q)\log U)$ ## [Prefix Sum Queries](https://cses.fi/problemset/task/2166) 想法:請見 [線段樹 壹之型](/Z3vYeKufRzaTZHdekANwRA) 的區間合併部分 :::spoiler ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; #define maxn 200001 int arr[maxn]; // t ll sum[maxn<<2], pre[maxn<<2], suf[maxn<<2], c[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void up(int l,int r,int i) { int il = i<<1, ir = i<<1|1; sum[i] = sum[il] + sum[ir]; pre[i] = max(pre[il],sum[il]+pre[ir]); suf[i] = max(suf[ir],suf[il]+sum[ir]); c[i] = max({c[il],c[ir],suf[il]+pre[ir]}); } void build(int l,int r,int i) { if (l==r) { pre[i] = suf[i] = c[i] = arr[l]; sum[i] = arr[l]; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(l,r,i); } void set(int idx,int v,int l,int r,int i) { if (l==r) { sum[i] = v; pre[i] = suf[i] = c[i] = v; return; } int mid = midpoint(l,r); if (idx <= mid) set(idx,v,l,mid,i<<1); else set(idx,v,mid+1,r,i<<1|1); up(l,r,i); } // pre, suf, c, sum tuple<ll,ll,ll,ll> q(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return {pre[i],suf[i],c[i],sum[i]}; int mid = midpoint(l,r); if (jr<=mid) return q(jl,jr,l,mid,i<<1); if (jl>mid) return q(jl,jr,mid+1,r,i<<1|1); auto [prel,sufl,cl,sl] = q(jl,jr,l,mid,i<<1); auto [prer,sufr,cr,sr] = q(jl,jr,mid+1,r,i<<1|1); return { max(sl+prer,prel), max(sufr, sufl+sr), max({cl,cr,sufl+prer}), sl+sr }; } }; int main() { int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int t; cin >> t; if (t==1) { int k,u; cin >> k >> u, k--; seg.set(k,u,0,n-1,1); } else if (t==2) { int a,b; cin >> a >> b, a--, b--; cout << max(0ll,get<0>(seg.q(a,b,0,n-1,1))) << '\n'; } } } ``` ::: TC: $O(q\log n)$ SC: $O(n)$ ## [Range Interval Queries](https://cses.fi/problemset/task/3163) 想法:merge sort tree,一種 segment tree。 值域有限時 也可以用更快的wavelet tree或是樹套樹,但我不會。 TC: $O(n\log n + q\log^2 n)$ (build tree + query) SC: $O(n\log n)$ :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; struct msTree { vector<int> arr; int* (t[200001 << 2]); int n; msTree(vector<int>& arr): arr(arr), n(arr.size()) { build(0,n-1,1); } void build(int l,int r,int i=1) { t[i] = new int[r-l+1]; if (l==r) { t[i][0]=arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); // up merge(t[i<<1],t[i<<1]+(mid-l+1),t[i<<1|1],t[i<<1|1]+(r-mid),t[i]); } int query(int jl,int jr,int lo,int hi,int l,int r,int i=1) { if (jl>r || jr<l) return 0; if (jl<=l && r<=jr) return upper_bound(t[i],t[i]+(r-l+1),hi)-lower_bound(t[i],t[i]+(r-l+1),lo); int mid = (l+r)>>1; return query(jl,jr,lo,hi,l,mid,i<<1) + query(jl,jr,lo,hi,mid+1,r,i<<1|1); } int query(int jl,int jr,int lo,int hi) { return query(jl,jr,lo,hi,0,n-1,1); } }; int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; vector<int> arr(n); for (int i=0;i<n;i++) cin >> arr[i]; msTree t(arr); for (int i=0;i<q;i++) { int a,b,c,d; cin >> a >> b >> c >> d;a--,b--; cout << t.query(a,b,c,d) << '\n'; } } ``` ::: 為什麼要手撕`new int`?因為用`vector<int>`會吃TLE,這一題很搞剛,時間判得很嚴,還得用io加速。我以0.02秒之差通過全部測資。 下面這個是使用`vector<int>`而超時的例子 :::spoiler 用`vector<int>`的 merge sort tree ```cpp= struct msTree { vector<int> arr; int* (t[200001 << 2]); int n; msTree(vector<int>& arr): arr(arr), n(arr.size()) { build(0,n-1,1); } void build(int l,int r,int i=1) { t[i] = new int[r-l+1]; if (l==r) { t[i][0]=arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); // up merge(t[i<<1],t[i<<1]+(mid-l+1),t[i<<1|1],t[i<<1|1]+(r-mid),t[i]); } int query(int jl,int jr,int lo,int hi,int l,int r,int i=1) { if (jl>r || jr<l) return 0; if (jl<=l && r<=jr) return upper_bound(t[i],t[i]+(r-l+1),hi)-lower_bound(t[i],t[i]+(r-l+1),lo); int mid = (l+r)>>1; return query(jl,jr,lo,hi,l,mid,i<<1) + query(jl,jr,lo,hi,mid+1,r,i<<1|1); } int query(int jl,int jr,int lo,int hi) { return query(jl,jr,lo,hi,0,n-1,1); } ``` ::: 所以裸的`new int`比`vector`更快,約50毫秒 順帶一提,在 [這篇](https://codeforces.com/blog/entry/68949) 你會看到另一種寫法。但是會吃TLE。 就算改成`new int`也是,如下。 :::spoiler TLE ver ```cpp struct msTree { int l,r,mid; msTree* lt,* rt; int* part; msTree(vector<int>& nums,int l,int r): l(l), r(r), mid((l+r)>>1) { part = new int[r-l+1]; if (l == r) { part[0] = nums[l]; return; } lt = new msTree(nums,l,mid); rt = new msTree(nums,mid+1,r); merge(lt->part,lt->part+(lt->r-lt->l+1),rt->part,rt->part+(rt->r-rt->l+1),part); } int query(int jl,int jr,int lo,int hi) { if (jl>r || jr<l) return 0; if (jl<=l && r<=jr) return upper_bound(part,part+(r-l+1),hi)-lower_bound(part,part+(r-l+1),lo); return lt->query(jl,jr,lo,hi)+rt->query(jl,jr,lo,hi); } }; ``` ::: ## [Subarray Sum Queries](https://cses.fi/problemset/task/1190) 想法:線段樹的區間合併。 一魚三吃,一題也可以三做,以 [53. Maximum Subarray](https://leetcode.com/problems/maximum-subarray/description/)為例, 我們可以用下列方法解: * Kadane's algorithm (也就是dp) * 用prefix sum轉化成 [121. Best Time to Buy and Sell Stock](https://leetcode.com/problems/best-time-to-buy-and-sell-stock/description/) * Divide and conquer 方法,需O(n),也就是該題的follow up 但是上面只做一次查詢。為了應付多次查詢需求,自然而然就考慮起了線段樹。 並且,線段樹本質上是對 divide and conquer 法的終極運用,自然而然就考慮起了第三個方法。 下面的解法用的也是第三個方法,不過用線段樹包裝了遞歸過程。 :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; #define maxn 200001 int arr[maxn]; // t ll sum[maxn<<2], pre[maxn<<2], suf[maxn<<2], c[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void up(int l,int r,int i) { int il = i<<1, ir = i<<1|1; sum[i] = sum[il] + sum[ir]; pre[i] = max(pre[il],sum[il]+pre[ir]); suf[i] = max(suf[ir],suf[il]+sum[ir]); c[i] = max({c[il],c[ir],suf[il]+pre[ir]}); } void build(int l,int r,int i) { if (l==r) { pre[i] = suf[i] = c[i] = arr[l]; sum[i] = arr[l]; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(l,r,i); } void set(int idx,int v,int l,int r,int i) { if (l==r) { sum[i] = v; pre[i] = suf[i] = c[i] = v; return; } int mid = midpoint(l,r); if (idx <= mid) set(idx,v,l,mid,i<<1); else set(idx,v,mid+1,r,i<<1|1); up(l,r,i); } }; int main() { int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int k, x; cin >> k >> x, k--; seg.set(k,x,0,n-1,1); cout << max(0ll,c[1]) << '\n'; } } ``` ::: ## [Subarray Sum Queries II](https://cses.fi/problemset/task/3226) 想法:與上一題並沒什麼差別。 :::spoiler ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; #define maxn 200001 int arr[maxn]; // t ll sum[maxn<<2], pre[maxn<<2], suf[maxn<<2], c[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void up(int l,int r,int i) { int il = i<<1, ir = i<<1|1; sum[i] = sum[il] + sum[ir]; pre[i] = max(pre[il],sum[il]+pre[ir]); suf[i] = max(suf[ir],suf[il]+sum[ir]); c[i] = max({c[il],c[ir],suf[il]+pre[ir]}); } void build(int l,int r,int i) { if (l==r) { pre[i] = suf[i] = c[i] = arr[l]; sum[i] = arr[l]; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(l,r,i); } void set(int idx,int v,int l,int r,int i) { if (l==r) { sum[i] = v; pre[i] = suf[i] = c[i] = v; return; } int mid = midpoint(l,r); if (idx <= mid) set(idx,v,l,mid,i<<1); else set(idx,v,mid+1,r,i<<1|1); up(l,r,i); } // pre, suf, c, sum tuple<ll,ll,ll,ll> q(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return {pre[i],suf[i],c[i],sum[i]}; int mid = midpoint(l,r); if (jr<=mid) return q(jl,jr,l,mid,i<<1); if (jl>mid) return q(jl,jr,mid+1,r,i<<1|1); auto [prel,sufl,cl,sl] = q(jl,jr,l,mid,i<<1); auto [prer,sufr,cr,sr] = q(jl,jr,mid+1,r,i<<1|1); return { max(sl+prer,prel), max(sufr, sufl+sr), max({cl,cr,sufl+prer}), sl+sr }; } }; int main() { int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int a,b; cin >> a >> b, a--, b--; cout << max(0ll,get<2>(seg.q(a,b,0,n-1,1))) << '\n'; } } ``` ::: ## [Forest Queries II](https://cses.fi/problemset/task/1739) 想法:二維前綴和,Exclusion-Exclusion principle :::spoiler ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; ll mod = 1e9+7; int t[1001][1001]; struct BIT { int n; BIT(int n): n(n) {} void add(int i_,int j_,int v) { for (int i = i_+1;i<=n;i+=i&-i) { for (int j = j_+1;j<=n;j+=j&-j) t[i][j] += v; } } int query(int i_,int j_) { int ans = 0; for (int i = i_+1;i>0;i-=i&-i) { for (int j = j_+1;j>0;j-=j&-j) ans += t[i][j]; } return ans; } int query(int mi,int mj,int Mi,int Mj) { return query(Mi,Mj) + query(mi-1,mj-1) - query(Mi,mj-1) - query(mi-1,Mj); } }; bool ist[1001][1001]; int main() { cin.tie(0)->sync_with_stdio(0); int n,q; cin >> n >> q; BIT bit(n); for (int i=0;i<n;i++) { for (int j=0;j<n;j++) { char c; cin >> c; if (c=='*') { bit.add(i,j,1); ist[i][j] = true; } } } for (int i=0;i<q;i++) { int t; cin >> t; if (t==1) { int i,j; cin >> i >> j, i--, j--; if (ist[i][j]) { bit.add(i,j,-1); } else { bit.add(i,j,1); } ist[i][j] = !ist[i][j]; } else { int i1,j1,i2,j2; cin >> i1 >> j1 >> i2 >> j2, i1--, j1--, i2--, j2--; int mi = min(i1,i2), Mi = max(i1,i2), mj = min(j1,j2), Mj = max(j1,j2); cout << bit.query(mi,mj,Mi,Mj) << '\n'; } } } ``` ::: TC: $O(很快)$ SC: $O(很少)$ ## [Range Updates and Sums](https://cses.fi/problemset/task/1735) 想法:見 [線段樹 壹之型](/Z3vYeKufRzaTZHdekANwRA) 的 $f-$修改 部分 :::spoiler ```cpp= #include <bits/stdc++.h> #include <cstring> using namespace std; using ll = long long; struct poly { ll a, b; }; bool operator==(const poly& l,const poly& r) { return l.a == r.a && l.b == r.b; } poly id{1,0}; #define maxn 200001 int arr[maxn]; ll t[maxn<<2]; poly lz[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { fill(lz,lz+(n<<2),id); build(0,n-1,1); } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } void down(int l,int r,int i) { int mid = (l+r)>>1; if (lz[i] != id) { apply(lz[i],l,mid,i<<1); apply(lz[i],mid+1,r,i<<1|1); lz[i] = id; } } void up(int i) { t[i] = t[i<<1] + t[i<<1|1]; } void apply(poly v,int l,int r,int i) { t[i] = v.a*t[i] + v.b*(r-l+1); lz[i] = poly{v.a*lz[i].a,v.a*lz[i].b + v.b}; } ll query(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; ll ans = 0; int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) ans += query(jl,jr,l,mid,i<<1); if (jr > mid) ans += query(jl,jr,mid+1,r,i<<1|1); return ans; } ll query(int jl,int jr) { return query(jl,jr,0,n-1,1); } void modify(ll a,ll b,int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) { apply(poly{a,b},l,r,i); return; } int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) modify(a,b,jl,jr,l,mid,i<<1); if (jr > mid) modify(a,b,jl,jr,mid+1,r,i<<1|1); up(i); } void modify(ll a,ll b,int jl,int jr) { modify(a,b,jl,jr,0,n-1,1); } }; void pall(segtree& seg) { for (int i=0;i<seg.n;i++) cout << seg.query(i,i) << ' '; cout << endl; } int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int t; cin >> t; if (t==1 || t==2) { int a,b,x; cin >> a >> b >> x, a--, b--; seg.modify(2-t,x,a,b); } else { int a,b; cin >> a >> b, a--, b--; cout << seg.query(a,b) << '\n'; } } } ``` ::: TC: $O(q\log n)$ SC: $O(n)$