# Segment Tree
## [Codeforces](https://codeforces.com/edu/course/2/lesson/4/1/practice/contest/273169/problem/A)
## [A. Segment Tree for the Sum](https://codeforces.com/edu/course/2/lesson/4/1/practice/contest/273169/problem/A)
### [vector.assign](http://www.cplusplus.com/reference/vector/vector/assign/)
> 需要特別注意下面的code實作方式經過了codeforces 測資的測試,可靠度比較高,盡量以下面的模板進行修改就好。
> 另外,在這個範例中的range query方式為 **[l, r)** ,實際query的range為 **l ~ r - 1**,如果題目要求的範圍為 **[l, r]** 則將query範圍改成 **[l, r + 1)** 就行。
> 最後,這是第一次練習時所使用的版本,之後在建segment tree的時候可以使用[build](#Modification-of-the-initializaton-of-the-segment-tree)來優化
```cpp=
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct segtree
{
int size;
vector<ll> sums;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
sums.assign(2 * size, 0LL);
}
void set(int i, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
sums[x] = v;
return;
}
int m = (lx + rx) / 2;
if (i < m)
{
set(i, v, 2 * x + 1, lx, m);
}
else
{
set(i, v, 2 * x + 2, m, rx);
}
sums[x] = sums[2 * x + 1] + sums[2 * x + 2];
}
void set(int i, int v)
{
set(i, v, 0, 0, size);
}
ll sum(int l, int r, int x, int lx, int rx)
{
if (lx >= r || l >= rx)
return 0;
if (lx >= l && rx <= r)
return sums[x];
int m = (lx + rx) / 2;
ll s1 = sum(l, r, 2 * x + 1, lx, m);
ll s2 = sum(l, r, 2 * x + 2, m, rx);
return s1 + s2;
}
ll sum(int l, int r)
{
return sum(l, r, 0, 0, size);
}
};
int main()
{
ios::sync_with_stdio(false);
int n, m;
cin >> n >> m;
segtree st;
st.init(n);
for (int i = 0; i < n; i++)
{
int v;
cin >> v;
st.set(i, v);
}
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int i, v;
cin >> i >> v;
st.set(i, v);
}
else
{
int l, r;
cin >> l >> r;
cout << st.sum(l, r) << '\n';
}
}
return 0;
}
```
## Detail Explaination
## void init()
為了建構出完美的tree(leaf都在同一個level)必須讓leaf的個數是2的冪次,所以我們讓size從1開始每次乘2,當size第一次大於n的時候,size就是大於n的最小2的冪次數字(這邊計算出的size只有leaf),在sums.assign的部分則以size * 2作為整個tree的大小(leaf加上internal nodes)。
> 當需要補數字讓size為2的冪次時,需要注意補進去的元素是否**neutral**(不影響結果),如果是要求最大值的query就用最小值補元素,如果是要求和的query就用0補元素。
## void set(int i, v)
將位置在i(最原始array的位置)的值改為v
這邊遞迴呼叫了void set(int i, int v, int x, int lx, int rx)
## void set(int i, int v, int x, int lx, int rx)
在這個引數中
* x是目前正在拜訪的node
* lx是這個node所涵蓋區間的左邊邊界
* rx是這個node所涵蓋區間的右邊邊界的右邊一格
* 特別注意在第一次呼叫這個函式時(`set(i, v, 0, 0, size)`),是以size(leaf數)來當作參數,而不是size * 2,也不是用size - 1實際帶數字進去以後如果用size - 1(奇數)會出現truncate的問題。
當`rx - lx == 1`時,代表這個區間只有一個元素,也就是已經拜訪到要修改的leaf了,於是直接將`sum[x] = v`(note that x 是目前正在拜訪的元素)。
如果區間長度不為1時則可以利用類似binary search的想法,看要搜尋x的左子樹還是右子樹。
這邊用m來代表中點
* 當`i < m`時,代表i的位置在這個區間(lx ~ rx)的左半邊且不包含m,遞迴呼叫`set(i, v, 2 * x + 1, lx, m)`(因為區段的右半是閉區間所以不包含m,固範圍為i < m),搜尋左半子樹,注意*2 * x + 1*代表x左子樹的編號
* 當`i >= m`時,代表i的位置在這個區間的右半邊(因為區段的左半是開區間,所以包含m時要這樣呼叫),遞迴呼叫`set(i, v, 2 * x + 2, m, rx)`,搜尋右半子樹
當這個遞迴一路找到leaf的位置並return後,使用`sums[x] = sums[2 * x + 1] + sums[2 * x + 2]`,來更新internal node的sums值。
## sum函式的概念也用類似的概念向下遞迴。
## [B. Segment Tree for the Minimum](https://codeforces.com/edu/course/2/lesson/4/1/practice/contest/273169/problem/B)
基本上思路與A一樣
```cpp=
#include <bits/stdc++.h>
using namespace std;
struct segTree
{
int size;
vector<int> Min;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
Min.assign(2 * size, INT_MAX);
}
void set(int i, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
Min[x] = v;
return;
}
int m = (rx + lx) / 2;
if (i < m)
{
set(i, v, 2 * x + 1, lx, m);
}
else
{
set(i, v, 2 * x + 2, m, rx);
}
Min[x] = min(Min[2 * x + 1], Min[2 * x + 2]);
}
void set(int i, int v)
{
set(i, v, 0, 0, size);
}
int findMin(int l, int r, int x, int lx, int rx)
{
if (lx >= r || rx <= l)
return INT_MAX;
if (lx >= l && rx <= r)
return Min[x];
int m = (lx + rx) / 2;
int min1 = findMin(l, r, 2 * x + 1, lx, m);
int min2 = findMin(l, r, 2 * x + 2, m, rx);
return min(min1, min2);
}
int findMin(int l, int r)
{
return findMin(l, r, 0, 0, size);
}
};
int main()
{
int n, m;
cin >> n >> m;
segTree st;
st.init(n);
for (int i = 0; i < n; i++)
{
int v;
cin >> v;
st.set(i, v);
}
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int i, v;
cin >> i >> v;
st.set(i, v);
}
else
{
int l, r;
cin >> l >> r;
cout << st.findMin(l, r) << '\n';
}
}
}
```
## [C. Number of Minimums on a Segment](https://codeforces.com/edu/course/2/lesson/4/1/practice/contest/273169/problem/C)
這題要用稍微複雜一點的資料結構來儲存。
每一個node分別儲存一個Item instance,Item儲存的資料有
* m: 目前子樹的最小值
* c: 目前子樹中值等於最小值的node個數
在建立樹比較的過程中:
* 如果左子點的最小值比右子點的最小值小,則當前的node的資料等於左子點的資料
* 如果右子點的最小值比左子點的最小值小,則當前的node的資料等於右子點的資料
* 如果左右子點的最小值相等,則當前node的最小值等於其值,個數等於左右子點相加
```cpp=
#include <bits/stdc++.h>
using namespace std;
struct Item
{
int m;
int c;
void init(int m, int c)
{
this->m = m;
this->c = c;
}
};
struct segTree
{
int size;
vector<Item> items;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
Item item;
item.init(INT_MAX, 0);
items.assign(2 * size, item);
}
void set(int i, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
Item item;
item.init(v, 1);
items[x] = item;
return;
}
int m = (lx + rx) / 2;
if (i < m)
{
set(i, v, 2 * x + 1, lx, m);
}
else
{
set(i, v, 2 * x + 2, m, rx);
}
Item leftItem = items[2 * x + 1];
Item rightItem = items[2 * x + 2];
if (leftItem.m == rightItem.m)
{
Item result;
result.init(leftItem.m, leftItem.c + rightItem.c);
items[x] = result;
}
else if (leftItem.m < rightItem.m)
{
items[x] = leftItem;
}
else
{
items[x] = rightItem;
}
}
void set(int i, int v)
{
set(i, v, 0, 0, size);
}
Item calc(int l, int r, int x, int lx, int rx)
{
if (lx >= r || rx <= l)
{
Item item;
item.init(INT_MAX, 0);
return item;
}
if (lx >= l && rx <= r)
return items[x];
int m = (lx + rx) / 2;
auto leftItem = calc(l, r, 2 * x + 1, lx, m);
auto rightItem = calc(l, r, 2 * x + 2, m, rx);
if (leftItem.m == rightItem.m)
{
Item item;
item.init(leftItem.m, leftItem.c + rightItem.c);
return item;
}
if (leftItem.m < rightItem.m)
{
return leftItem;
}
else
{
return rightItem;
}
}
Item calc(int l, int r)
{
return calc(l, r, 0, 0, size);
}
};
int main()
{
int n, m;
cin >> n >> m;
segTree st;
st.init(n);
for (int i = 0; i < n; i++)
{
int v;
cin >> v;
st.set(i, v);
}
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int i, v;
cin >> i >> v;
st.set(i, v);
}
else
{
int l, r;
cin >> l >> r;
auto item = st.calc(l, r);
cout << item.m << ' ' << item.c << '\n';
}
}
}
```
## [12532: Interval Product (CPE)](https://onlinejudge.org/external/125/p12532.pdf)
```cpp=
#include<bits/stdc++.h>
using namespace std;
#define ll long long
int n;
int m;
struct segTree{
int size;
vector<int> signs;
void init(int n)
{
size = 1;
while(size < n)
size *= 2;
signs.assign(2 * size, 1);
}
void set(int i, int v, int x, int lx, int rx)
{
if(rx - lx == 1)
{
signs[x] = v;
return;
}
int m = (lx + rx) / 2;
if(i < m)
{
set(i, v, 2 * x + 1, lx, m);
}
else
{
set(i, v, 2 * x + 2, m, rx);
}
signs[x] = signs[2 * x + 1] * signs[2 * x + 2];
}
void set(int i, int v)
{
if(v > 0)
v = 1;
else if(v < 0)
v = -1;
else
v = 0;
set(i, v, 0, 0, size);
}
int calc(int l, int r, int x, int lx, int rx)
{
if(lx >= r || l >= rx)
return 1;
if(lx >= l && rx <= r)
return signs[x];
int m = (lx + rx) / 2;
int left = calc(l, r, 2 * x + 1, lx, m);
int right = calc(l, r, 2 * x + 2, m, rx);
return left * right;
}
char calc(int l, int r)
{
int result = calc(l, r, 0, 0, size);
if(result == 1)
return '+';
if(result == -1)
return '-';
return '0';
}
};
int main()
{
while(cin >> n >> m)
{
segTree st;
st.init(n);
for(int i = 0; i < n; i++)
{
int v;
cin >> v;
st.set(i, v);
}
while(m--)
{
char c;
cin >> c;
if(c == 'C')
{
int i, v;
cin >> i >> v;
st.set(i - 1, v);
}
else
{
int l, r;
cin >> l >> r;
cout << st.calc(l - 1, r);
}
}
cout << '\n';
}
}
```
## [A. Segment with the Maximum Sum](https://codeforces.com/edu/course/2/lesson/4/2/practice/contest/273278/problem/A)
這題要求最大連續區間的和,每次的query給定一個interval,要求出這個interval的最大連續和,在每個Item(node)中保存四個數據
1. prefix: 前綴最大和
2. suffix: 後綴最大和
3. sum: 總和
4. seg: 這個Item區間的最大練續區間和
leaf為區間長度為1的node,如果leaf值為v:
* 如果 *v > 0*,則這個item保存的資料為{v, v, v, v}
* 如果 *v < 0*,則這個item保存的資料為{0, 0, v, 0}(除了sum以外,其他值相當於不取這個leaf的值)
接著是兩個child的結果merge成其parent的方式,這邊有點類似用divide and conquer對一個陣列求最大subarray的方式,parent的連續最大值為:
#### **max{左child的最大subarray值, 右child的最大subarray值, 橫跨左右child中點index的subarray}**
而橫跨左右child中點的subarray的最大值由:
* 右子點的最大prefix
* 左子點的最大suffix
所組成。
以上是求出parent的seg值的方式。
另外求出sum的方式trivial跳過。
而求出prefix的方式為比較:
* 左子點的prefix
* 左子點的sum + 右子點的prefix(考慮這個的原因在於,兩個子點合併以後可能會有新的更大prefix橫跨過中點index)。
同理可以求出suffix。
```cpp=
#include <bits/stdc++.h>
using namespace std;
int n, m;
struct Item
{
long long prefix, suffix, sum, seg;
};
struct segTree
{
int size;
vector<Item> Max;
Item neutral = {0, 0, 0, 0};
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
Max.assign(2 * size, neutral);
}
Item merge(Item first, Item second)
{
return {
max(first.prefix, first.sum + second.prefix),
max(second.suffix, second.sum + first.suffix),
first.sum + second.sum,
max({first.seg, second.seg, first.suffix + second.prefix})};
}
void set(int i, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
if (v > 0)
{
Max[x] = {v, v, v, v};
}
else
{
Max[x] = {0, 0, v, 0};
}
return;
}
int m = (lx + rx) / 2;
if (i < m)
set(i, v, 2 * x + 1, lx, m);
else
{
set(i, v, 2 * x + 2, m, rx);
}
Max[x] = merge(Max[2 * x + 1], Max[2 * x + 2]);
}
void set(int i, int v)
{
set(i, v, 0, 0, size);
}
};
void solve();
int main()
{
cin >> n >> m;
solve();
}
void solve()
{
segTree st;
st.init(n);
for (int i = 0; i < n; i++)
{
int v;
cin >> v;
st.set(i, v);
}
cout << st.Max[0].seg << '\n';
while (m--)
{
int i, v;
cin >> i >> v;
st.set(i, v);
cout << st.Max[0].seg << '\n';
}
}
```
## [B. K-th one](https://codeforces.com/edu/course/2/lesson/4/2/practice/contest/273278/problem/B)
這題中tree的每一個node保存的數值為該區間的1的個數。
當find進行遞迴尋找第k個1的index時:
* 如果左子樹的1個數比k還大,那代表該1在左子樹
* 如果左子樹的1個數比k還小,那代表該1在右子樹,但是往右子樹遞迴進行搜尋時須將k減去左子樹1的個數(相當於對右子樹尋找第 **k-values[2 * x + 1]** 個1)
```cpp=
#include <bits/stdc++.h>
using namespace std;
int n, m;
struct segTree
{
int size;
vector<int> values;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
values.assign(2 * size, 0);
}
int merge(int left, int right)
{
return left + right;
}
void set(int i, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
if (v == 1)
values[x] = 1;
else
{
values[x] = 0;
}
return;
}
int m = (rx + lx) / 2;
if (i < m)
set(i, v, 2 * x + 1, lx, m);
else
set(i, v, 2 * x + 2, m, rx);
values[x] = merge(values[2 * x + 1], values[2 * x + 2]);
}
void set(int i, int v)
{
set(i, v, 0, 0, size);
}
int find(int k, int x, int lx, int rx)
{
if (rx - lx == 1)
{
return lx;
}
int m = (lx + rx) / 2;
if (values[2 * x + 1] > k)
return find(k, 2 * x + 1, lx, m);
else
return find(k - values[2 * x + 1], 2 * x + 2, m, rx);
}
int find(int k)
{
return find(k, 0, 0, size);
}
};
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m;
segTree st;
st.init(n);
vector<int> nums(n);
for (int i = 0; i < n; i++)
{
cin >> nums[i];
st.set(i, nums[i]);
}
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int i;
cin >> i;
nums[i] = 1 - nums[i];
st.set(i, nums[i]);
}
else
{
int k;
cin >> k;
cout << st.find(k) << '\n';
}
}
}
```
## [C. First element at least X](https://codeforces.com/edu/course/2/lesson/4/2/practice/contest/273278/problem/Cc)
```cpp=
#include <bits/stdc++.h>
using namespace std;
struct segTree
{
int size;
vector<int> Max;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
Max.assign(size * 2, 0);
}
int merge(int first, int second)
{
return max(first, second);
}
void set(int i, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
Max[x] = v;
return;
}
int m = (lx + rx) / 2;
if (i < m)
{
set(i, v, 2 * x + 1, lx, m);
}
else
{
set(i, v, 2 * x + 2, m, rx);
}
Max[x] = merge(Max[2 * x + 1], Max[2 * x + 2]);
}
void set(int i, int v)
{
set(i, v, 0, 0, size);
}
int find(int k, int x, int lx, int rx)
{
if (rx - lx == 1)
{
if (Max[x] >= k)
return lx;
else
return -1;
}
int m = (lx + rx) / 2;
int left = Max[2 * x + 1];
int right = Max[2 * x + 2];
if (left >= k)
return find(k, 2 * x + 1, lx, m);
if (right >= k)
return find(k, 2 * x + 2, m, rx);
else
{
return -1;
}
}
int find(int k)
{
return find(k, 0, 0, size);
}
};
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
cin >> n >> m;
segTree st;
st.init(n);
for (int i = 0; i < n; i++)
{
int v;
cin >> v;
st.set(i, v);
}
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int i, v;
cin >> i >> v;
st.set(i, v);
}
else
{
int x;
cin >> x;
cout << st.find(x) << '\n';
}
}
}
```
## [D. First element at least X - 2](https://codeforces.com/edu/course/2/lesson/4/2/practice/contest/273278/problem/D)
```cpp=
#include <bits/stdc++.h>
using namespace std;
struct segTree
{
int size;
vector<int> nums;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
nums.resize(size * 2);
}
int merge(int a, int b)
{
return max(a, b);
}
void set(int i, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
nums[x] = v;
return;
}
int m = (rx + lx) / 2;
if (i < m)
{
set(i, v, 2 * x + 1, lx, m);
}
else
{
set(i, v, 2 * x + 2, m, rx);
}
nums[x] = merge(nums[2 * x + 1], nums[2 * x + 2]);
}
void set(int i, int v)
{
set(i, v, 0, 0, size);
}
int find(int v, int l, int x, int lx, int rx)
{
if (nums[x] < v)
return -1;
if (rx <= l)
return -1;
if (rx - lx == 1)
return lx;
int m = (rx + lx) / 2;
int res = find(v, l, 2 * x + 1, lx, m);
if (res == -1)
{
res = find(v, l, 2 * x + 2, m, rx);
}
return res;
}
int find(int k, int i)
{
return find(k, i, 0, 0, size);
}
};
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
cin >> n >> m;
segTree st;
st.init(n);
for (int i = 0; i < n; i++)
{
int v;
cin >> v;
st.set(i, v);
}
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int i, v;
cin >> i >> v;
st.set(i, v);
}
else
{
int index, k;
cin >> k >> index;
cout << st.find(k, index) << '\n';
}
}
}
```
## [A. Inversions](https://codeforces.com/edu/course/2/lesson/4/3/practice/contest/274545/problem/A)
### 一解
這題的query以及建tree的方式都蠻特別的,首先出現的數字為1~n的permutation(題目給予的),如果照given的array順序建tree,比較難看出所謂inversion的關係,所以以題目的testcase為例
```
5
4 1 3 5 2
```
建立的array為照index順序的1 ~ 5的array(index a代表這格元素代表數字a是否有出現過),初始化將array全部的值設為0(全部的數字都還沒出現過),接下來開始依序讀input。
當讀到數字k時,我們從建好的array的index k + 1開始向後query(如果array size為n,query的範圍就為k + 1 ~ n),如果在數字k之前出現了比數字k更大的數,那麼該格元素就會是1,如果還沒出現過就會是0,我們計算1出現次數的總和就可以知道有多少個大於k的數字在k之前就出現過,query結束過後我們在index為k的位置填入1,代表數字k也出現過了。
### 二解(第二次解題時的想法)
首先當我們從頭開始讀數字的時候,比如目前讀到k,則我們在意的是過去有出現過幾筆比k還大的數字(k + 1 ~ n),所以我們必須要記錄這件事,我們可以用一個陣列來記錄這個數字是否出現過,如果數字k出現過,則在這格上填上1代表出現過,反之,目前這格是0的話代表還沒有讀到這一筆數字。
每讀進一個數字時我們要做兩件事:
1. 將這格記為1
2. query k + 1 ~ n,看看出現了幾個1
```cpp=
#include <bits/stdc++.h>
using namespace std;
struct segTree
{
int size;
vector<int> nums;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
nums.assign(size * 2, 0);
}
void set(int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
nums[x] = 1;
return;
}
int m = (lx + rx) / 2;
if (v < m)
{
set(v, 2 * x + 1, lx, m);
}
else
{
set(v, 2 * x + 2, m, rx);
}
nums[x] = nums[2 * x + 1] + nums[2 * x + 2];
}
void set(int v)
{
set(v, 0, 0, size);
}
int find(int l, int r, int x, int lx, int rx)
{
if (l <= lx && r >= rx)
return nums[x];
if (l >= rx)
return 0;
int m = (lx + rx) / 2;
int left = find(l, r, 2 * x + 1, lx, m);
int right = find(l, r, 2 * x + 2, m, rx);
return left + right;
}
int find(int i)
{
return find(i, size, 0, 0, size);
}
};
int main()
{
int n;
cin >> n;
segTree st;
st.init(n);
for (int i = 0; i < n; i++)
{
int v;
cin >> v;
v--;
st.set(v);
cout << st.find(v + 1) << ' ';
}
}
```
## Modification of the initializaton of the segment tree
The way we initialize segment tree is to use, *set(i, v)* n times, and each *set(i, v)* takes *logn* time, so we use *nlogn* time to build the segment tree.
Why don't we store the input in an array first and use this array to build the segment tree?
Use range sum as an example.
```cpp=
int main(){
...
...
segTree st;
st.init(n);
vector<int> a(n);
for(int i = 0; i < n; i++)
cin >> a[i];
st.build(a);
...
...
}
```
```cpp=
struct segTree{
...
...
void build(vector<int>& a, int x, int lx, int rx)
{
if(rx - lx == 1)
{
if(lx < (int)a.size())
{
sums[x] = a[lx];
}
return;
}
int m = (lx + rx) / 2;
build(a, 2 * x + 1, lx, m);
build(a, 2 * x + 2, m, rx);
sums[x] = sums[2 * x + 1] + sums[2 * x + 2]
}
void build(vector<int>& a)
{
void(a, 0, 0, size);
}
...
...
}
```
In this build function, when we are at the leaf level, since we might append some null elements at the end of the vector, we first need to check is the index in vector a, if the index is out of bound, then we simply return.
And we build the tree recursively.
### 為什麼在build方法中需要額外check index而在用set建tree的方法則不用?
> Ans: 用set的方法建tree的過程中有傳入i, 並且用這個i決定set的方向為左半還是右半,因此在leaf時自然落在原本的index內。
This function is only called once, and visit every node once. Hence, the time complexity of build is *O(n)*.
In total, the original method with n nodes and m queries need *O(nlogn + mlogn)*, in this optimize version we still need *mlogn* time in queries, but we only need linear time to build the tree, so the time complexity is *O(n + mlogn)*.
## [B. Inversions 2](https://codeforces.com/edu/course/2/lesson/4/3/practice/contest/274545/problem/B)
這題的解題思路與find kth 1差不多,但是有點難看出。
首先必須要從給予的陣列尾巴往回掃,而不是從頭開始掃。**因為從尾巴開始掃,可以唯一決定當前的數字**。
當目前讀到的數字為k時代表這個格數是是剩下m個數字的第m - k + 1大,這樣可以唯一確定這格數字為何。相反的如果從頭開始掃,假設第一個數字就掃到0,那麼這格數字從1~n都有可能。
所以我們的目標就在與找出當前剩下數字的第m - k + 1 大的數,首先我們初始化一個陣列,其內容全部為1,代表這個數列裡面的數字(數字a會被放在index為a的位置)我們都還沒有入最後的結果當中,接者我們從尾巴開始掃input陣列,當掃到k時,我們就去找第m - k + 1大的數,這個部分就相當於找出剩下的1當中的第m - k + 1個1,當我們找到這個數字時,把這格數字改為0,代表這個數字已經被使用過了
```cpp=
#include <bits/stdc++.h>
using namespace std;
struct segTree
{
int size;
vector<int> count;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
count.assign(size * 2, 0);
}
void build(vector<int> &a, int x, int lx, int rx)
{
if (rx - lx == 1)
{
if (lx < a.size())
count[x] = a[lx];
return;
}
int m = (lx + rx) / 2;
build(a, 2 * x + 1, lx, m);
build(a, 2 * x + 2, m, rx);
count[x] = count[2 * x + 1] + count[2 * x + 2];
}
void build(vector<int> &a)
{
build(a, 0, 0, size);
}
void set(int i, int x, int lx, int rx)
{
if (rx - lx == 1)
{
count[x] = 0;
return;
}
int m = (lx + rx) / 2;
if (i < m)
{
set(i, 2 * x + 1, lx, m);
}
else
{
set(i, 2 * x + 2, m, rx);
}
count[x] = count[2 * x + 1] + count[2 * x + 2];
}
void set(int i)
{
set(i, 0, 0, size);
}
int find(int k, int x, int lx, int rx)
{
if (rx - lx == 1)
{
return lx;
}
int m = (lx + rx) / 2;
int left = count[2 * x + 1];
int right = count[2 * x + 2];
if (left >= k)
{
return find(k, 2 * x + 1, lx, m);
}
else
{
return find(k - left, 2 * x + 2, m, rx);
}
}
int find(int k)
{
return find(k, 0, 0, size);
}
};
int main()
{
int n;
cin >> n;
vector<int> inversions(n);
vector<int> a(n + 1, 1);
a[0] = 0;
segTree st;
st.init(n + 1);
st.build(a);
for (int i = 0; i < n; i++)
cin >> inversions[i];
vector<int> ans;
for (int i = n - 1; i >= 0; i--)
{
int result = st.find((i + 1) - inversions[i]);
ans.push_back(result);
st.set(result);
}
for (int i = n - 1; i >= 0; i--)
cout << ans[i] << ' ';
}
```
## [C. Nested Segments](https://codeforces.com/edu/course/2/lesson/4/3/practice/contest/274545/problem/C)
還蠻有趣的一題,首先需要初始化一個全部為0長度為2 * n的陣列A,
這題主要的解法需要在一個segment右邊的border被讀到的時候(假設這個數字為k),將左邊的border改為1(改1的這個動作發生在A相對應的左邊border index當中),接著query k的左右border中間1的數量,如此一來就可以得到答案。
```cpp=
#include <bits/stdc++.h>
using namespace std;
struct segTree
{
int size;
vector<int> count;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
count.assign(size * 2, 0);
}
int find(int l, int r, int x, int lx, int rx)
{
if (lx >= l && r >= rx)
return count[x];
if (l >= rx || lx >= r)
return 0;
int m = (lx + rx) / 2;
int left = find(l, r, 2 * x + 1, lx, m);
int right = find(l, r, 2 * x + 2, m, rx);
return left + right;
}
int find(int l, int r)
{
return find(l, r, 0, 0, size);
}
void set(int k, int x, int lx, int rx)
{
if (rx - lx == 1)
{
count[x] = 1;
return;
}
int m = (lx + rx) / 2;
if (k < m)
{
set(k, 2 * x + 1, lx, m);
}
else
{
set(k, 2 * x + 2, m, rx);
}
count[x] = count[2 * x + 1] + count[2 * x + 2];
}
void set(int k)
{
set(k, 0, 0, size);
}
};
int main()
{
int n;
cin >> n;
vector<int> a(2 * n);
vector<int> leftBorder(n + 1, -1);
vector<int> count(n + 1);
segTree st;
st.init(2 * n);
for (int i = 0; i < a.size(); i++)
{
cin >> a[i];
if (leftBorder[a[i]] == -1)
{
leftBorder[a[i]] = i;
}
else
{
count[a[i]] = st.find(leftBorder[a[i]], i + 1);
st.set(leftBorder[a[i]]);
}
}
for (int i = 1; i <= n; i++)
cout << count[i] << ' ';
}
```
Why does this approach work ?
假設目前讀到的右border為rx而我們想要判斷segment y在不在segment x裡面,下面是所有可能發生的情況。

## [D - Intersecting Segments](https://codeforces.com/edu/course/2/lesson/4/3/practice/contest/274545/problem/D)
作法與上一題類似,當讀到left border的時候將left border在segment tree相對應的位置設為1,讀到right border的時候將segment tree對應的left border設為0,如此一來當由左而右進行query的時候,如果query的區間有1代表有某個segment的left border落在這個segment當中,但是right border沒有,這就是我們要找的intersection。
做完一次之後要在reverse整個陣列再做一次。
```cpp=
#include <bits/stdc++.h>
using namespace std;
struct segTree
{
int size;
vector<int> count;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
count.assign(size * 2, 0);
}
void set(int i, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
count[x] = v;
return;
}
int m = (rx + lx) / 2;
if (i < m)
set(i, v, 2 * x + 1, lx, m);
else
{
set(i, v, 2 * x + 2, m, rx);
}
count[x] = count[2 * x + 1] + count[2 * x + 2];
}
void set(int i, int v)
{
set(i, v, 0, 0, size);
}
int find(int l, int r, int x, int lx, int rx)
{
if (lx >= l && rx <= r)
return count[x];
if (lx >= r || l >= rx)
return 0;
int m = (lx + rx) / 2;
int left = find(l, r, 2 * x + 1, lx, m);
int right = find(l, r, 2 * x + 2, m, rx);
return left + right;
}
int find(int l, int r)
{
return find(l, r, 0, 0, size);
}
};
int main()
{
int n;
cin >> n;
vector<int> a(2 * n);
segTree st;
st.init(2 * n);
vector<int> leftBorder(n + 1, -1);
vector<int> count(n + 1);
for (int i = 0; i < 2 * n; i++)
{
cin >> a[i];
if (leftBorder[a[i]] == -1)
{
leftBorder[a[i]] = i;
st.set(i, 1);
}
else
{
count[a[i]] = st.find(leftBorder[a[i]] + 1, i + 1);
st.set(leftBorder[a[i]], 0);
}
}
reverse(a.begin(), a.end());
leftBorder.assign(n + 1, -1);
segTree st2;
st2.init(2 * n);
for (int i = 0; i < 2 * n; i++)
{
if (leftBorder[a[i]] == -1)
{
leftBorder[a[i]] = i;
st2.set(i, 1);
}
else
{
count[a[i]] += st2.find(leftBorder[a[i]] + 1, i + 1);
st2.set(leftBorder[a[i]], 0);
}
}
for (int i = 1; i <= n; i++)
cout << count[i] << '\n';
}
```
## [E. Addition to Segment](https://codeforces.com/edu/course/2/lesson/4/3/practice/contest/274545/problem/E)
這題的重點是add operation,我們以segment為單位,遞迴的紀錄每個segment的累加情況。
get則先一路進到leaf以後,再bottom-up一路trace區間被累加過的值。
```cpp=
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct segTree
{
int size;
vector<ll> sums;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
sums.assign(2 * size, 0LL);
}
void add(int l, int r, int v, int x, int lx, int rx)
{
if (lx >= l && rx <= r)
{
sums[x] += v;
return;
}
if (lx >= r || l >= rx)
return;
int m = (lx + rx) / 2;
add(l, r, v, 2 * x + 1, lx, m);
add(l, r, v, 2 * x + 2, m, rx);
}
void add(int l, int r, int v)
{
add(l, r, v, 0, 0, size);
}
ll get(int i, int x, int lx, int rx)
{
if (rx - lx == 1)
{
return sums[x];
}
int m = (lx + rx) / 2;
int result = sums[x];
if (i < m)
{
result += get(i, 2 * x + 1, lx, m);
}
else
{
result += get(i, 2 * x + 2, m, rx);
}
return result;
}
ll get(int i)
{
return get(i, 0, 0, size);
}
};
int main()
{
int n, m;
cin >> n >> m;
segTree st;
st.init(n);
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int l, r, v;
cin >> l >> r >> v;
st.add(l, r, v);
}
else
{
int i;
cin >> i;
cout << st.get(i) << '\n';
}
}
}
```
# Practice Problems
## [A. Sign alternation](https://codeforces.com/edu/course/2/lesson/4/4/practice/contest/274684/problem/A)
解法: 儲存兩個segment tree分別處理odd index與even index的range sum
```cpp=
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct SegmentTree
{
int size;
vector<ll> sums;
void init(int n)
{
size = 1;
while (size < n)
{
size *= 2;
}
sums.resize(2 * size);
}
void merge(int x)
{
sums[x] = sums[2 * x + 1] + sums[2 * x + 2];
}
void build(vector<int> &a, int x, int lx, int rx)
{
// cout << "building range " << lx << ' ' << rx << '\n';
if (rx - lx == 1)
{
if (lx < a.size())
{
sums[x] = a[lx];
}
return;
}
int m = (rx + lx) / 2;
build(a, 2 * x + 1, lx, m);
build(a, 2 * x + 2, m, rx);
merge(x);
// cout << "merging " << lx << ", " << m - 1 << " and " << m << ' ' << rx - 1 << ", with result " << sums[x] << '\n';
}
void build(vector<int> &a)
{
build(a, 0, 0, size);
}
void set(int k, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
sums[x] = v;
return;
}
int m = (lx + rx) / 2;
if (k < m)
{
set(k, v, 2 * x + 1, lx, m);
}
else
{
set(k, v, 2 * x + 2, m, rx);
}
merge(x);
}
void set(int k, int v)
{
set(k, v, 0, 0, size);
}
ll query(int l, int r, int x, int lx, int rx)
{
if (lx >= r || l >= rx)
{
return 0;
}
if (lx >= l && rx <= r)
{
return sums[x];
}
int m = (rx + lx) / 2;
ll left = query(l, r, 2 * x + 1, lx, m);
ll right = query(l, r, 2 * x + 2, m, rx);
return left + right;
}
ll query(int l, int r)
{
return query(l, r, 0, 0, size);
}
};
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n;
cin >> n;
vector<int> odd;
vector<int> even;
for (int i = 0; i < n; i++)
{
int num;
cin >> num;
if (i % 2 == 0)
{
odd.push_back(num);
}
else
{
even.push_back(num);
}
}
SegmentTree oddTree;
SegmentTree evenTree;
oddTree.init((int)odd.size());
evenTree.init((int)even.size());
oddTree.build(odd);
evenTree.build(even);
int q;
cin >> q;
while (q--)
{
int op;
cin >> op;
if (op == 0)
{
int k, v;
cin >> k >> v;
if (k % 2 == 1)
{
oddTree.set(k / 2, v);
}
else
{
evenTree.set(k / 2 - 1, v);
}
}
else
{
int l, r;
cin >> l >> r;
ll result = 0;
if (l % 2 == 1)
{
result += oddTree.query(l / 2, (r + 1) / 2); // ok
result -= evenTree.query(l / 2, r / 2); // ok
}
else
{
result += evenTree.query((l - 1) / 2, (r / 2));
result -= oddTree.query(l / 2, (r + 1) / 2);
}
cout << result << '\n';
}
}
}
```
## [B. Cryptography](https://codeforces.com/edu/course/2/lesson/4/4/practice/contest/274684/problem/B)
重點比較在struct 的operator overloading本身
```cpp=
#include <bits/stdc++.h>
using namespace std;
#define ll long long
int mod;
struct Matrix
{
/*
[ x1 y1 ]
[ x2 y2 ]
*/
ll x1, x2, y1, y2;
Matrix operator*(Matrix &other)
{
ll x1 = (this->x1 * other.x1) % mod + (this->y1 * other.x2) % mod;
ll x2 = (this->x2 * other.x1) % mod + (this->y2 * other.x2) % mod;
ll y1 = (this->x1 * other.y1) % mod + (this->y1 * other.y2) % mod;
ll y2 = (this->x2 * other.y1) % mod + (this->y2 * other.y2) % mod;
return {x1 % mod, x2 % mod, y1 % mod, y2 % mod};
}
friend ostream &operator<<(ostream &os, Matrix m)
{
os << m.x1 << ' ' << m.y1 << '\n'
<< m.x2 << ' ' << m.y2 << '\n';
return os;
}
};
struct SegmentTree
{
int size;
vector<Matrix> val;
void init(int n)
{
Matrix Identity = {1, 0, 0, 1};
size = 1;
while (size < n)
{
size *= 2;
}
val.assign(2 * size, Identity);
}
void merge(int x)
{
val[x] = val[2 * x + 1] * val[2 * x + 2];
}
void build(vector<Matrix> &a, int x, int lx, int rx)
{
if (rx - lx == 1)
{
if (lx < a.size())
{
val[x] = a[lx];
}
return;
}
int m = (rx + lx) / 2;
build(a, 2 * x + 1, lx, m);
build(a, 2 * x + 2, m, rx);
merge(x);
void build(vector<Matrix> &a)
{
build(a, 0, 0, size);
}
Matrix query(int l, int r, int x, int lx, int rx)
{
if (lx >= r || l >= rx)
{
return {1, 0, 0, 1}; // Identity
}
if (lx >= l && rx <= r)
{
return val[x];
}
int m = (rx + lx) / 2;
Matrix left = query(l, r, 2 * x + 1, lx, m);
Matrix right = query(l, r, 2 * x + 2, m, rx);
return left * right;
}
Matrix query(int l, int r)
{
return query(l, r, 0, 0, size);
}
};
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
cin >> mod >> n >> m;
vector<Matrix> matrices(n);
for (int i = 0; i < n; i++)
{
int x1, x2, y1, y2;
cin >> x1 >> y1 >> x2 >> y2;
matrices[i] = {x1, x2, y1, y2};
}
SegmentTree st;
st.init(n);
st.build(matrices);
while (m--)
{
int l, r;
cin >> l >> r;
l--;
cout << st.query(l, r) << "\n";
}
}
```
## [C. Number of Inversions on Segment](https://codeforces.com/edu/course/2/lesson/4/4/practice/contest/274684/problem/C)
也算是一個蠻暴力的解法,每個node存兩件事情:
* 該點當前的inversions數
* 該點當前的frequency map
當merge兩個區間時,新的inversions數為下列三者的組合:
1. 左區間本身的inversions
2. 又區間本身的inversions
3. merge所造成的inversions
而因為我們知道在這個陣列中數字大小$\leq 40$,所以這點可以利用聰明一點的暴力法在線性時間內計算inversions的個數。
```cpp=
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct Data
{
vector<int> freq;
ll count;
void init()
{
freq.resize(41);
count = 0;
}
};
struct SegmentTree
{
int size;
vector<Data> val;
void init(int n)
{
Data empty;
empty.init();
size = 1;
while (size < n)
size *= 2;
val.assign(2 * size, empty);
}
Data merge(int x)
{
auto left = val[2 * x + 1];
auto right = val[2 * x + 2];
Data newData;
newData.init();
ll total = 0;
for (auto &i : left.freq)
total += i;
for (int i = 1; i <= 40; i++)
{
total -= left.freq[i];
if (right.freq[i])
{
newData.count += right.freq[i] * total;
}
newData.freq[i] += (left.freq[i] + right.freq[i]);
}
newData.count += (left.count + right.count);
return newData;
}
void build(vector<int> &a, int x, int lx, int rx)
{
if (rx - lx == 1)
{
if (lx < (int)a.size())
{
Data data;
data.init();
data.freq[a[lx]] = 1;
val[x] = data;
}
return;
}
int m = (lx + rx) / 2;
build(a, 2 * x + 1, lx, m);
build(a, 2 * x + 2, m, rx);
val[x] = merge(x);
}
void build(vector<int> &a)
{
build(a, 0, 0, size);
}
Data query(int l, int r, int x, int lx, int rx)
{
if (lx >= r || l >= rx)
{
Data null;
null.init();
return null;
}
if (lx >= l && rx <= r)
{
return val[x];
}
int m = (lx + rx) / 2;
Data right = query(l, r, 2 * x + 2, m, rx);
Data left = query(l, r, 2 * x + 1, lx, m);
Data newData;
newData.init();
ll total = 0;
for (auto &i : left.freq)
total += i;
for (int i = 1; i <= 40; i++)
{
total -= left.freq[i];
if (right.freq[i])
{
newData.count += right.freq[i] * total;
}
newData.freq[i] += (left.freq[i] + right.freq[i]);
}
newData.count += right.count + left.count;
return newData;
}
void set(int k, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
Data data;
data.init();
data.freq[v] = 1;
val[x] = data;
return;
}
int m = (rx + lx) / 2;
if (k < m)
{
set(k, v, 2 * x + 1, lx, m);
}
else
{
set(k, v, 2 * x + 2, m, rx);
}
val[x] = merge(x);
}
void set(int k, int v)
{
set(k, v, 0, 0, size);
}
ll query(int l, int r)
{
return query(l, r, 0, 0, size).count;
}
};
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n, q;
cin >> n >> q;
vector<int> a(n);
for (int i = 0; i < n; i++)
{
cin >> a[i];
}
SegmentTree st;
st.init(n);
st.build(a);
while (q--)
{
int op;
cin >> op;
if (op == 1)
{
int l, r;
cin >> l >> r;
l--;
cout << st.query(l, r) << '\n';
}
else
{
int k, v;
cin >> k >> v;
st.set(k - 1, v);
}
}
}
```
## [D. Number of Different on Segment](https://codeforces.com/edu/course/2/lesson/4/4/practice/contest/274684/problem/D)
每一個node存的資料代表該區間的distinct number數,因為這題的數字不大($\leq 40$),小於long long,所以可以用一個long long來代表,merge兩個node時,可以使用or( | )。
最後,C++的`__builtin_popcount`的long long版本為`__builtin_popcountll`
```cpp=
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct SegmentTree
{
int size;
vector<ll> val;
void init(int n)
{
size = 1;
while (size < n)
{
size *= 2;
}
val.resize(2 * size);
}
void merge(int x)
{
val[x] = val[2 * x + 1] | val[2 * x + 2];
}
void build(vector<int> &a, int x, int lx, int rx)
{
// cout << "building range " << lx << ' ' << rx << '\n';
if (rx - lx == 1)
{
if (lx < a.size())
{
val[x] = 1LL << (a[lx] - 1);
}
return;
}
int m = (rx + lx) / 2;
build(a, 2 * x + 1, lx, m);
build(a, 2 * x + 2, m, rx);
merge(x);
// cout << "merging " << lx << ", " << m - 1 << " and " << m << ' ' << rx - 1 << ", with result " << val[x] << '\n';
}
void build(vector<int> &a)
{
build(a, 0, 0, size);
}
void set(int k, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
val[x] = (1LL << (v - 1));
return;
}
int m = (lx + rx) / 2;
if (k < m)
{
set(k, v, 2 * x + 1, lx, m);
}
else
{
set(k, v, 2 * x + 2, m, rx);
}
merge(x);
}
void set(int k, int v)
{
set(k, v, 0, 0, size);
}
ll query(int l, int r, int x, int lx, int rx)
{
if (lx >= r || l >= rx)
{
return 0;
}
if (lx >= l && rx <= r)
{
return val[x];
}
int m = (rx + lx) / 2;
ll left = query(l, r, 2 * x + 1, lx, m);
ll right = query(l, r, 2 * x + 2, m, rx);
return left | right;
}
ll query(int l, int r)
{
return query(l, r, 0, 0, size);
}
};
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n, q;
cin >> n >> q;
vector<int> a(n);
for (int i = 0; i < n; i++)
cin >> a[i];
SegmentTree st;
st.init(n);
st.build(a);
while (q--)
{
int op;
cin >> op;
if (op == 1)
{
int l, r;
cin >> l >> r;
l--;
cout << __builtin_popcountll(st.query(l, r)) << '\n';
}
else
{
int k, v;
cin >> k >> v;
st.set(--k, v);
}
}
}
```
## [E. Earthquakes](https://codeforces.com/edu/course/2/lesson/4/4/practice/contest/274684/problem/E)
這題其實有點難分析時間複雜度,所以嚴格來說算是硬著頭皮看別人的解法硬做下去,大致上的演算法是建一棵min segment tree,每次要摧毀建築時,如果照暴力法從$l$掃到$r$一一檢查的話時間複雜度會到$O(mn)$,很明顯不行,所以使用segment tree可以幫忙我們找出需要更新的區間,然後一路往leaf找需要更新的點。
在時間複雜度方面,雖然從$l \sim r$可能摧毀的建築數量為$O(n)$但是因為題目的情境下,一棟建築要被摧毀之前要先被建立,所以事實上摧毀建築數量量級為$O(n)$的情況並不是很常見。
```cpp=
#include <bits/stdc++.h>
using namespace std;
struct SegmentTree
{
int size;
vector<int> val;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
val.assign(2 * size, INT_MAX);
}
void set(int k, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
val[x] = v;
return;
}
int m = (lx + rx) / 2;
if (k < m)
{
set(k, v, 2 * x + 1, lx, m);
}
else
{
set(k, v, 2 * x + 2, m, rx);
}
val[x] = min(val[2 * x + 1], val[2 * x + 2]);
}
void set(int k, int v)
{
set(k, v, 0, 0, size);
}
int find(int l, int r, int p, int x, int lx, int rx)
{
if (lx >= r || l >= rx || val[x] > p)
{
return 0;
}
if (lx >= l && rx <= r)
{
if (rx - lx == 1)
{
val[x] = INT_MAX;
return 1;
}
}
// intersect or complete in range but not in leaf node
int m = (lx + rx) / 2;
int left = find(l, r, p, 2 * x + 1, lx, m);
int right = find(l, r, p, 2 * x + 2, m, rx);
val[x] = min(val[2 * x + 1], val[2 * x + 2]);
return left + right;
}
int find(int l, int r, int p)
{
return find(l, r, p, 0, 0, size);
}
};
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int n, q;
cin >> n >> q;
SegmentTree st;
st.init(n);
while (q--)
{
int op;
cin >> op;
if (op == 1)
{
int k, v;
cin >> k >> v;
st.set(k, v);
}
else
{
int l, r, p;
cin >> l >> r >> p;
cout << st.find(l, r, p) << '\n';
}
}
}
```
# Segment Tree Part 2
## A跟Part 1的E一樣
## [B. Applying MAX to Segment](https://codeforces.com/edu/course/2/lesson/5/1/practice/contest/279634/problem/B)
在這堂課裡面提到了range update的兩個重要條件
* Commtative
* Associative
### Other commutative operations
Recall that the operation $⊗$ ($⊗$ is an arbitrary operation, for example +, ∗ or gcd) is called:
* associative, if $(a⊗b)⊗c=a⊗(b⊗c)$, that is, for any arrangement of parentheses in the expression, the result does not change.
* commutative if $a⊗b=b⊗a$, that is, for any order of arguments, the result does not change.
In the current version of the segment tree, we can handle only associative and commutative operations. Why is this so? Let's introduce a new request $modify(l, r, x)$: apply to all $a_i$ ($l≤i<r$) the operation $⊗$ with the second argument $x$, that is, $a_i=a_i⊗x$.
#### Why is it necessary for the operation to be associative?
When we want to change a value in a node that has already been changed before, we apply the $⊗$ operation to the two requests $x$ (old value) and $y$ (new request). That is, it is necessary that the results of $(a_i⊗x)⊗y$ and $a_i⊗(x⊗y)$ are equal.
#### Why is it necessary for the operation to be commutative?
Because we have to calculate the operation from the arguments that come in order from request to request, but we calculate the operation from the arguments that come in the order from the leaf to the root. That is, it is necessary that when the order of the arguments is changed, the result of the operation does not change.
```cpp=
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct segTree
{
int size;
vector<ll> vals;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
vals.resize(2 * size);
}
void changeMax(int l, int r, ll v, int x, int lx, int rx)
{
if (l >= rx || lx >= r)
return;
if (lx >= l && rx <= r)
{
vals[x] = max(vals[x], v);
return;
}
int m = (lx + rx) / 2;
changeMax(l, r, v, 2 * x + 1, lx, m);
changeMax(l, r, v, 2 * x + 2, m, rx);
}
void changeMax(int l, int r, int v)
{
changeMax(l, r, v, 0, 0, size);
}
ll get(int i, int x, int lx, int rx)
{
if (rx - lx == 1)
return vals[x];
int m = (lx + rx) / 2;
if (i < m)
return max(vals[x], get(i, 2 * x + 1, lx, m));
else
{
return max(vals[x], get(i, 2 * x + 2, m, rx));
}
}
ll get(int i)
{
return get(i, 0, 0, size);
}
};
int main()
{
int n, m;
cin >> n >> m;
segTree st;
st.init(n);
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int l, r, v;
cin >> l >> r >> v;
st.changeMax(l, r, v);
}
else
{
int i;
cin >> i;
cout << st.get(i) << '\n';
}
}
}
```
## [C. Assignment to Segment](https://codeforces.com/edu/course/2/lesson/5/1/practice/contest/279634/problem/C)
### 如果出現的運算不像上面一樣有commutative性質該怎麼辦(仍然保有associative)?
這時候運算的次序就顯得很重要,上面的例子中,都是把結果全部運算後在存在node裡,這是因為上面的例子中運算的次序並不重要。
### 而我們該如何保存運算的次序呢?
> Key Point: 在處理get的時候,我們是從leaf一路回傳值到root。
所以我們保存次序的方法就是使用lazy propagation,假設目前有一個對於線段$[l, r-1]$的運算$k_i$,如果在更前面的時間點同樣有一個運算$k_m$$(m < i)$,也對了$[l, r - 1]$運算,那我們就把這個$k_m$運算往leaf的方向傳遞(把$k_m$折成兩半傳遞給$[l, m]$以及$[m, r - 1]$),如此一來在處理get的時候(由leaf傳回root的時候)會先遇到$k_m$再遇到$k_i$,因此可以保存運算的先後性質。
### 甚麼時候使用propagte?
當使用$modify$更改區間的值的時候,我們就使用$propagate$將當前$x$ node的運算$propagate$給他的兩個子點$2x + 1$以及$2x + 2$,最後因為將當前$x$的運算已經被傳導下去了,所以要把當前的$x$ node內容修改為$NOP$(no operation,可以利用一個不可能出現的值來代替,處理一個node時,可以使用`node(x)== NOP`來測試這個node是不是不含任何運算)。另外,如果當前的node已經是leaf的話就沒有需要propagate了,直接`return`即可。
```cpp=
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct segTree
{
int size;
vector<ll> operations;
ll NOP = LLONG_MAX;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
operations.assign(2 * size, 0LL);
}
ll operation(ll a, ll b)
{
if (b == NOP)
return a;
return b;
}
void apply_operation(ll &a, ll &b)
{
a = operation(a, b);
}
void propogate(int x, int lx, int rx)
{
if (rx - lx == 1)
return;
apply_operation(operations[2 * x + 1], operations[x]);
apply_operation(operations[2 * x + 2], operations[x]);
operations[x] = NOP;
}
void modify(int l, int r, ll v, int x, int lx, int rx)
{
propogate(x, lx, rx);
if (lx >= r || l >= rx)
return;
if (lx >= l && rx <= r)
{
apply_operation(operations[x], v);
return;
}
int m = (lx + rx) / 2;
modify(l, r, v, 2 * x + 1, lx, m);
modify(l, r, v, 2 * x + 2, m, rx);
}
void modify(int l, int r, int v)
{
modify(l, r, v, 0, 0, size);
}
ll get(int i, int x, int lx, int rx)
{
if (rx - lx == 1)
return operations[x];
int m = (lx + rx) / 2;
ll res;
if (i < m)
{
res = get(i, 2 * x + 1, lx, m);
}
else
{
res = get(i, 2 * x + 2, m, rx);
}
return operation(res, operations[x]);
}
ll get(int i)
{
return get(i, 0, 0, size);
}
};
int main()
{
int n, m;
cin >> n >> m;
segTree st;
st.init(n);
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int l, r, v;
cin >> l >> r >> v;
st.modify(l, r, v);
}
else
{
int i;
cin >> i;
cout << st.get(i) << '\n';
}
}
}
```
# Extra Problems
## [Google KickStart Round C 2020 Problem D. Candies](https://codingcompetitions.withgoogle.com/kickstart/round/000000000019ff43/0000000000337b4d)
題目大意: 回答$1a_i - 2a_{i + 1} + 3a_{i + 2} + ...$的range query,並需要支持單點修改(update)
算是蠻奇形怪狀的range query問題,在這題裡面我們要分別對三種不同的數列建立計算range sum 的segment tree
* $a_1, -2a_2, 3a_3, ...$
* $a_1, -a_2, a_3, ...$
* $-a_1, a_2, -a_3, ...$
之後便可以用上面的第一種搭配二與三選其中一個的組合來湊出query的要求。
```cpp=
#include <bits/stdc++.h>
using namespace std;
#define ll long long
void solve();
struct SegmentTree
{
int size;
vector<ll> sums;
void init(int n)
{
size = 1;
while (size < n)
size *= 2;
sums.assign(2 * size, 0);
}
void build(vector<ll> &a, ll x, int lx, int rx)
{
if (rx - lx == 1)
{
if (lx < (int)a.size())
{
sums[x] = a[lx];
}
return;
}
int m = (rx + lx) / 2;
build(a, 2 * x + 1, lx, m);
build(a, 2 * x + 2, m, rx);
sums[x] = sums[2 * x + 1] + sums[2 * x + 2];
}
void build(vector<ll> &a)
{
build(a, 0, 0, size);
}
ll query(int l, int r, int x, int lx, int rx)
{
if (l >= rx || lx >= r)
{
return 0;
}
if (lx >= l && rx <= r)
{
return sums[x];
}
int m = (lx + rx) / 2;
ll left = query(l, r, 2 * x + 1, lx, m);
ll right = query(l, r, 2 * x + 2, m, rx);
return left + right;
}
ll query(int l, int r)
{
return query(l, r, 0, 0, size);
}
void set(int k, int v, int x, int lx, int rx)
{
if (rx - lx == 1)
{
sums[x] = v;
return;
}
int m = (rx + lx) / 2;
if (k < m)
{
set(k, v, 2 * x + 1, lx, m);
}
else
{
set(k, v, 2 * x + 2, m, rx);
}
sums[x] = sums[2 * x + 1] + sums[2 * x + 2];
}
void set(int k, int v)
{
set(k, v, 0, 0, size);
}
};
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int t;
cin >> t;
int cases = 1;
while (t--)
{
cout << "Case #" << cases++ << ": ";
solve();
}
}
void solve()
{
int n, q;
cin >> n >> q;
vector<ll> a(n);
vector<ll> odd(n);
vector<ll> even(n);
int sign = 1;
for (int i = 0; i < n; i++)
{
cin >> a[i];
if (i % 2 == 0)
{
odd[i] = a[i];
even[i] = -a[i];
}
else
{
odd[i] = -a[i];
even[i] = a[i];
}
a[i] *= (i + 1);
a[i] *= sign;
sign *= -1;
}
SegmentTree st;
SegmentTree oddTree;
SegmentTree evenTree;
st.init(n);
oddTree.init(n);
evenTree.init(n);
st.build(a);
oddTree.build(odd);
evenTree.build(even);
ll ans = 0;
while (q--)
{
char op;
cin >> op;
if (op == 'Q')
{
int l, r;
cin >> l >> r;
l--;
ll result = st.query(l, r);
if (l % 2 == 0)
{
ll cancel = l * evenTree.query(l, r);
ans += (result + cancel);
}
else
{
result *= -1;
ll cancel = l * oddTree.query(l, r);
// cout << result << ' ' << cancel << '\n';
ans += (result + cancel);
}
// cout << ans << '\n';
}
else
{
int k, v;
cin >> k >> v;
k--;
if (k % 2 == 0)
{
st.set(k, (k + 1) * v);
}
else
{
st.set(k, (k + 1) * (-v));
}
if (k % 2 == 0)
{
oddTree.set(k, v);
evenTree.set(k, -v);
}
else
{
oddTree.set(k, -v);
evenTree.set(k, v);
}
}
}
cout << ans << '\n';
}
```