# Range Queries Sol [toc] ## 前言 本題單有許多題目可以被或必須用線段樹的技巧來解。 請確保熟悉線段樹的基本用法,如`build`, `up`, `lazy tag` 可以看 [線段樹 總論](/iQrvDNeyStSAFigl8VJEIg) 好了,現在 ![s](https://hackmd.io/_uploads/HJirBfDKle.jpg) ## [Static Range Minimum Queries](https://cses.fi/problemset/task/1647) 想法:直接套sparse table 一個跟binary lifting很像的東西 TC: $O(q + n\log n)$ SC: $O(n\log n)$ :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; class ST { public: vector<vector<int>> t; vector<int> logs; int n; ST(vector<int>& nums): logs(nums.size()+1), n(nums.size()) { logs[1] = 0; for (int i=2;i<=n;i++) logs[i] = 1+logs[i/2]; t.assign(logs[n]+1,vector<int>(n)); t[0] = nums; for (int i=1;(1<<i)<=n;i++) { for (int j=0;j+(1<<i)<=n;j++) { t[i][j] = min(t[i-1][j],t[i-1][j+(1<<(i-1))]); } } } int query(int l,int r) { int i = logs[r-l+1]; return min(t[i][l],t[i][r-(1<<i)+1]); } }; int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; vector<int> nums(n); for (int i=0;i<n;i++) { cin >> nums[i]; } ST st(nums); for (int i=0;i<q;i++) { int a,b; cin >> a >> b; a--;b--; cout << st.query(a,b) << '\n'; } ``` ::: 另解:或是用線段樹。最大的差異是最後時間複雜度裡$\log n$乘的地方不同,可以自行評估。 TC: $O(n + q\log n)$ SC: $O(n)$ :::spoiler ```cpp= #include <bits/stdc++.h> using namespace std; #define maxn 200001 int arr[maxn]; int t[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); t[i] = min(t[i<<1],t[i<<1|1]); } int query(int jl,int jr, int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; int ans = INT_MAX; int mid = midpoint(l,r); if (jl <= mid) ans = min(ans,query(jl,jr,l,mid,i<<1)); if (jr > mid) ans = min(ans,query(jl,jr,mid+1,r,i<<1|1)); return ans; } }; int main() { cin.tie(0), ios::sync_with_stdio(0); int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int a,b; cin >> a >> b, a--, b--; cout << seg.query(a,b,0,n-1,1) << '\n'; } } ``` ::: ## [Dynamic Range Minimum Queries](https://cses.fi/problemset/task/1649) 一定要用線段樹了,因為要做修改比較方便。 :::spoiler ```cpp= #include <bits/stdc++.h> #include <cstring> using namespace std; int arr[200001], t[200001<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void up(int i) { t[i] = min(t[i<<1],t[i<<1|1]); } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } int query(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; int ans = INT_MAX, mid = (l+r)>>1; if (jl <= mid) ans = min(ans,query(jl,jr,l,mid,i<<1)); if (jr > mid) ans = min(ans,query(jl,jr,mid+1,r,i<<1|1)); return ans; } int query(int jl,int jr) { return query(jl,jr,0,n-1,1); } void point_set(int idx,int val,int l,int r,int i) { t[i] = val; if (l==r) return; int mid = (l+r)>>1; if (idx <= mid) point_set(idx,val,l,mid,i<<1); else point_set(idx,val,mid+1,r,i<<1|1); up(i); } void point_set(int idx,int val) { point_set(idx,val,0,n-1,1); } }; int main() { int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int a,b,c; cin >> a >> b >> c; if (a==1) { b--; seg.point_set(b,c); } else { b--,c--; cout << seg.query(b,c) << '\n'; } } } ``` ::: ## [Range Update Queries](https://cses.fi/problemset/task/1651) 想法:懶標線段樹,實現範圍修改。 TC: $O(n + q\log n)$ SC: $O(n)$ :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; ll arr[200001], t[200001<<2], lazy[200001<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void apply(ll val,int l,int r,int i) { t[i] += (r-l+1)*val; lazy[i] += val; } void down(int l,int r,int i) { if (!lazy[i]) return; int mid = (l+r)>>1; apply(lazy[i],l,mid,i<<1); apply(lazy[i],mid+1,r,i<<1|1); lazy[i] = 0; } void up(int i) { t[i] = t[i<<1]+t[i<<1|1]; } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } ll query(int jl,int jr,int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; ll ans = 0; int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) ans+=query(jl,jr,l,mid,i<<1); if (jr > mid) ans+=query(jl,jr,mid+1,r,i<<1|1); up(i); return ans; } ll query(int jl,int jr) { return query(jl,jr,0,n-1,1); } void add(int jl,int jr,ll val,int l,int r,int i) { if (jl <= l && r <= jr) { apply(val,l,r,i); return; } int mid = (l+r)>>1; down(l,r,i); if (jl <= mid) add(jl,jr,val,l,mid,i<<1); if (jr > mid) add(jl,jr,val,mid+1,r,i<<1|1); up(i); } void add(int jl,int jr,ll val) { add(jl,jr,val,0,n-1,1); } }; int main() { int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int t; cin >> t; if (t==1) { int a,b,u; cin >> a >> b >> u, a--,b--; seg.add(a,b,u); } else { int k; cin >> k, k--; cout << seg.query(k,k) << '\n'; } } } ``` ::: ## [Forest Queries](https://cses.fi/problemset/task/1652) 想法:排容原理 TC: $O(n^2 + q)$ SC: $O(n^2)$ :::spoiler code ```cpp #include <bits/stdc++.h> using namespace std; using ll = long long; int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; vector<vector<bool>> nums(n,vector<bool>(n)); for (int i=0;i<n;i++) { for (int j=0;j<n;j++) { char c; cin >> c; nums[i][j] = (c=='*'); } } vector<vector<int>> pre(n+1,vector<int>(n+1)); for (int i=1;i<=n;i++) { for (int j=1;j<=n;j++) { pre[i][j] = pre[i-1][j] + pre[i][j-1] - pre[i-1][j-1] + nums[i-1][j-1]; } } for (int i=0;i<q;i++) { int y1,x1,y2,x2; cin >> y1 >> x1 >> y2 >> x2; int lx = min(x1,x2); int ly = min(y1,y2); int ux = max(x1,x2); int uy = max(y1,y2); cout << pre[uy][ux] + pre[ly-1][lx-1] - pre[uy][lx-1] - pre[ly-1][ux] << '\n'; } } ``` ::: ## [Hotel Queries](https://cses.fi/problemset/task/1143) 想法:[線段樹上二分搜](https://wiki.sam571128.codes/data-structure/segment-tree/segment-tree-2) TC: $O(n + q\log n)$ SC: $O(n)$ :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; int arr[200001], t[200001<<2], lazy[200001<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void up(int i) { t[i] = max(t[i<<1],t[i<<1|1]); } void build(int l,int r,int i) { if (l==r) { t[i] = arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } void padd(int idx,int val,int l,int r,int i) { t[i] += val; if (l==r) return; int mid = (l+r)>>1; if (idx <= mid) padd(idx,val,l,mid,i<<1); else padd(idx,val,mid+1,r,i<<1|1); up(i); } void padd(int idx,int val) { padd(idx,val,0,n-1,1); } int query(int x,int l,int r,int i) { if (t[i] < x) return -1; if (l==r) return l; int mid = (l+r)>>1; if (t[i<<1] >= x) return query(x,l,mid,i<<1); return query(x,mid+1,r,i<<1|1); } int query(int x) { return query(x,0,n-1,1); } }; int main() { int n,m; cin >> n >> m; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<m;i++) { int mv; cin >> mv; int idx = seg.query(mv); if (idx != -1) seg.padd(idx,-mv); cout << idx+1 << '\n'; } } ``` ::: 同場加映:[P1503 鬼子进村](https://www.luogu.com.cn/problem/P1503) 可以用很多方法做。 也可以用線段樹上二分搜。同時需要搜第一個跟最後一個,而且是在區間上而非全局上搜。 :::spoiler code ```cpp= #include <bits/stdc++.h> using ll = long long; using namespace std; #define maxn 50001 bool dested[maxn]; int sta[maxn]; // t bool t[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void build(int l,int r,int i) { if (l==r) { t[i] = false; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } void up(int i) { t[i] = t[i<<1] || t[i<<1|1]; } int findFirst(int jl,int jr,int l,int r,int i) { if (r < jl || jr < l || !t[i]) return n; if (l==r) return l; int mid = midpoint(l,r); int v = findFirst(jl,jr,l,mid,i<<1); if (v != n) return v; return findFirst(jl,jr,mid+1,r,i<<1|1); } int findLast(int jl,int jr,int l,int r,int i) { if (r < jl || jr < l || !t[i]) return -1; if (l==r) return l; int mid = midpoint(l,r); int v = findLast(jl,jr,mid+1,r,i<<1|1); if (v != -1) return v; return findLast(jl,jr,l,mid,i<<1); } void modify(int idx,bool v,int l,int r,int i) { if (l==r) { t[i] = v; return; } int mid = midpoint(l,r); if (idx <= mid) modify(idx,v,l,mid,i<<1); else modify(idx,v,mid+1,r,i<<1|1); up(i); } }; int main() { cin.tie(0)->sync_with_stdio(0); memset(dested,0,sizeof(dested)); int n,m; cin >> n >> m; segtree seg(n); int stan = 0; for (int i=0;i<m;i++) { char c; cin >> c; if (c=='D') { int x; cin >> x, x--; sta[stan++] = x; dested[x] = true; seg.modify(x,true,0,n-1,1); } else if (c=='R') { seg.modify(sta[stan-1],false,0,n-1,1); dested[sta[stan-1]] = false; stan--; } else { int x; cin >> x, x--; if (dested[x]) { cout << "0\n"; continue; } int l = seg.findLast(0,x,0,n-1,1), r = seg.findFirst(x,n-1,0,n-1,1); cout << r-l-1 << '\n'; } } } ``` ::: ## [List Removals](https://cses.fi/problemset/model/1749/) 想法:模擬一個 [0,1,2,...,n-1] 的下標族,要移掉第k個時,先找第k個在哪,然後對於k這個點, 把它移掉(設成很小的值),然後k後所有的點的下標都減一。 需要: * 線段樹上二分搜(下標族為遞增) * max range query + range addition 可以用懶標線段樹。 一題考你兩個點,是一個考驗實作能力的好題。 :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; #define maxn 200001 int arr[maxn]; int t[maxn<<2]; int lz[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { memset(lz,0,sizeof(lz)); build(0,n-1,1); } void up(int i) { t[i] = max(t[i<<1],t[i<<1|1]); } void down(int l,int r,int i) { int mid = midpoint(l,r); if (lz[i]) { apply(lz[i],l,mid,i<<1); apply(lz[i],mid+1,r,i<<1|1); lz[i] = 0; } } void build(int l,int r,int i) { if (l==r) { t[i] = l; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(i); } int query(int jl,int jr, int l,int r,int i) { if (jl <= l && r <= jr) return t[i]; int ans = INT_MIN; int mid = midpoint(l,r); down(l,r,i); if (jl <= mid) ans = max(ans,query(jl,jr,l,mid,i<<1)); if (jr > mid) ans = max(ans,query(jl,jr,mid+1,r,i<<1|1)); return ans; } void apply(int v,int l,int r,int i) { t[i] += v; // imp lz[i] += v; } void add(int jl,int jr,int v,int l,int r,int i) { if (jl <= l && r <= jr) { apply(v,l,r,i); return; } int mid = midpoint(l,r); down(l,r,i); if (jl <= mid) add(jl,jr,v,l,mid,i<<1); if (jr > mid) add(jl,jr,v,mid+1,r,i<<1|1); up(i); } int find(int x,int l,int r,int i) { if (t[i] < x) return -1; if (l==r) return l; int mid = midpoint(l,r); down(l,r,i); if (t[i<<1] >= x) return find(x,l,mid,i<<1); return find(x,mid+1,r,i<<1|1); } }; int main() { cin.tie(0), ios::sync_with_stdio(0); int n; cin >> n; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<n;i++) { int x; cin >> x, x--; int idx = seg.find(x,0,n-1,1); cout << arr[idx] << '\n'; seg.add(idx,idx,-114514,0,n-1,1); seg.add(idx+1,n-1,-1,0,n-1,1); } } ``` ::: ## [Range Interval Queries](https://cses.fi/problemset/task/3163) 想法:merge sort tree,一種神秘的segment tree。 值域有限時 也可以用更快的wavelet tree或是樹套樹,但我不會。 TC: $O(n\log n + q\log^2 n)$ (build tree + query) SC: $O(n\log n)$ :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; struct msTree { vector<int> arr; int* (t[200001 << 2]); int n; msTree(vector<int>& arr): arr(arr), n(arr.size()) { build(0,n-1,1); } void build(int l,int r,int i=1) { t[i] = new int[r-l+1]; if (l==r) { t[i][0]=arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); // up merge(t[i<<1],t[i<<1]+(mid-l+1),t[i<<1|1],t[i<<1|1]+(r-mid),t[i]); } int query(int jl,int jr,int lo,int hi,int l,int r,int i=1) { if (jl>r || jr<l) return 0; if (jl<=l && r<=jr) return upper_bound(t[i],t[i]+(r-l+1),hi)-lower_bound(t[i],t[i]+(r-l+1),lo); int mid = (l+r)>>1; return query(jl,jr,lo,hi,l,mid,i<<1) + query(jl,jr,lo,hi,mid+1,r,i<<1|1); } int query(int jl,int jr,int lo,int hi) { return query(jl,jr,lo,hi,0,n-1,1); } }; int main() { cin.tie(0); ios::sync_with_stdio(0); int n,q; cin >> n >> q; vector<int> arr(n); for (int i=0;i<n;i++) cin >> arr[i]; msTree t(arr); for (int i=0;i<q;i++) { int a,b,c,d; cin >> a >> b >> c >> d;a--,b--; cout << t.query(a,b,c,d) << '\n'; } } ``` ::: 為什麼要手撕`new int`?因為用`vector<int>`會吃TLE,這一題很搞剛,時間判得很嚴,還得用io加速。我以0.02秒之差通過全部測資。 下面這個是使用`vector<int>`而超時的例子 :::spoiler 用`vector<int>`的 merge sort tree ```cpp= struct msTree { vector<int> arr; int* (t[200001 << 2]); int n; msTree(vector<int>& arr): arr(arr), n(arr.size()) { build(0,n-1,1); } void build(int l,int r,int i=1) { t[i] = new int[r-l+1]; if (l==r) { t[i][0]=arr[l]; return; } int mid = (l+r)>>1; build(l,mid,i<<1); build(mid+1,r,i<<1|1); // up merge(t[i<<1],t[i<<1]+(mid-l+1),t[i<<1|1],t[i<<1|1]+(r-mid),t[i]); } int query(int jl,int jr,int lo,int hi,int l,int r,int i=1) { if (jl>r || jr<l) return 0; if (jl<=l && r<=jr) return upper_bound(t[i],t[i]+(r-l+1),hi)-lower_bound(t[i],t[i]+(r-l+1),lo); int mid = (l+r)>>1; return query(jl,jr,lo,hi,l,mid,i<<1) + query(jl,jr,lo,hi,mid+1,r,i<<1|1); } int query(int jl,int jr,int lo,int hi) { return query(jl,jr,lo,hi,0,n-1,1); } ``` ::: 所以裸的`new int`比`vector`更快,約50毫秒 順帶一提,在 [這篇](https://codeforces.com/blog/entry/68949) 你會看到另一種寫法。但是會吃TLE。 就算改成`new int`也是,如下。 :::spoiler TLE ver ```cpp struct msTree { int l,r,mid; msTree* lt,* rt; int* part; msTree(vector<int>& nums,int l,int r): l(l), r(r), mid((l+r)>>1) { part = new int[r-l+1]; if (l == r) { part[0] = nums[l]; return; } lt = new msTree(nums,l,mid); rt = new msTree(nums,mid+1,r); merge(lt->part,lt->part+(lt->r-lt->l+1),rt->part,rt->part+(rt->r-rt->l+1),part); } int query(int jl,int jr,int lo,int hi) { if (jl>r || jr<l) return 0; if (jl<=l && r<=jr) return upper_bound(part,part+(r-l+1),hi)-lower_bound(part,part+(r-l+1),lo); return lt->query(jl,jr,lo,hi)+rt->query(jl,jr,lo,hi); } }; ``` ::: ## [Subarray Sum Queries](https://cses.fi/problemset/task/1190) 想法:線段樹的區間合併。 一魚三吃,一題也可以三做,以 [53. Maximum Subarray](https://leetcode.com/problems/maximum-subarray/description/)為例, 我們可以用下列方法解: * Kadane's algorithm (也就是dp) * 用prefix sum轉化成 [121. Best Time to Buy and Sell Stock](https://leetcode.com/problems/best-time-to-buy-and-sell-stock/description/) * Divide and conquer 方法,需O(n),也就是該題的follow up 但是上面只做一次查詢。為了應負多次查詢需求,自然而然就考慮起了線段樹。 並且,線段樹本質上是對divide and conquer法的終極運用,自然而然就考慮起了第三個方法。 下面的解法用的也是第三個方法。 :::spoiler code ```cpp= #include <bits/stdc++.h> using namespace std; using ll = long long; #define maxn 200001 int arr[maxn]; // t ll sum[maxn<<2], pre[maxn<<2], suf[maxn<<2], c[maxn<<2]; struct segtree { int n; segtree(int n): n(n) { build(0,n-1,1); } void up(int l,int r,int i) { int il = i<<1, ir = i<<1|1; sum[i] = sum[il] + sum[ir]; pre[i] = max(pre[il],sum[il]+pre[ir]); suf[i] = max(suf[ir],suf[il]+sum[ir]); c[i] = max({c[il],c[ir],suf[il]+pre[ir]}); } void build(int l,int r,int i) { if (l==r) { pre[i] = suf[i] = c[i] = arr[l]; sum[i] = arr[l]; return; } int mid = midpoint(l,r); build(l,mid,i<<1); build(mid+1,r,i<<1|1); up(l,r,i); } void set(int idx,int v,int l,int r,int i) { if (l==r) { sum[i] = v; pre[i] = suf[i] = c[i] = v; return; } int mid = midpoint(l,r); if (idx <= mid) set(idx,v,l,mid,i<<1); else set(idx,v,mid+1,r,i<<1|1); up(l,r,i); } }; int main() { int n,q; cin >> n >> q; for (int i=0;i<n;i++) cin >> arr[i]; segtree seg(n); for (int i=0;i<q;i++) { int k, x; cin >> k >> x, k--; seg.set(k,x,0,n-1,1); cout << max(0ll,c[1]) << '\n'; } } ``` :::