# Binary Indexed Tree 樹狀樹組 ## 介紹 - 每個 `bit[x]` 紀錄 `[x − lowbit(x) + 1, x]` 這段區間的總和 - The index of x 必須是從 1 開始 - ![](https://i.imgur.com/GDU40YF.png) - presum(7) - ans += bit[7] - ans += bit[6] - ans += bit[4] - update(5, 1) - bit[5] += 1 - bit[6] += 1 - bit[8] += 1 - 可以在 `O(logn)` 時間單點修改 `x` 位置的值 - 可以在 `O(logn)` 時間查詢 `[1, x]` 的區間和(前綴和) ```clike= int n; int bit[1000]; void update(int x, int val){ // 把 x 位置加 val for(int i = x; i <= n; i += (i&(-i))) bit[i] += val; } int presum(int x){ // 找出 [1, x] 總和 int sum = 0; for(int i = x; i > 0; i -= (i&(-i))) sum += bit[i]; return sum; } ``` ## 1.單點加值,查詢區間總和 ```clike= // 初始化 n = 5; vector<int> data = {1,4,9,16,25}; for(int i=0;i<n;i++) update(i+1,data[i]); // 1 4 9 16 25 // 查詢區間 [2, 4] 總和 cout<<presum(4) - presum(1)<<endl; // 更新區間值 update(2, -2); // 1 2 9 16 25 update(3, -5); // 1 2 4 16 25 ``` ## 2.區間加值,單點查詢數值 ```clike= // 初始化 n = 5; vector<int> data = {1,4,9,16,25}; // 維持差分數列 dif = 1 3 5 7 9 update(1, data[0]); for(int i=1;i<n;i++) update(i+1,data[i]-data[i-1]); // 查詢 data[4] 數值 cout<<presum(4)<<endl; // 更新區間 [2, 4] 同加 3 update(2, 3); update(5, -3); // dif = 1 6 5 7 6 ``` ## 3.區間加值,區間查詢總和 ```clike= #include <bits/stdc++.h> using namespace std; int n; int bit1[1000]; // 維持 d[i] int bit2[1000]; // 維持 d[i] * i void update(int *bit, int x, int val){ for(int i = x; i <= n; i += (i&(-i))) bit[i] += val; } int presum(int *bit, int x){ int sum = 0; for(int i = x; i > 0; i -= (i&(-i))) sum += bit[i]; return sum; } // 因為是維持陣列 d[i] 及陣列 d[i]*i , 區間修改變成單點修改 void change(int l, int r, int val){ update(bit1, l, val ); update(bit1, r+1, -1*val ); update(bit2, l, val*l ); update(bit2, r+1, -1*val*(r+1) ); } /* 單點查詢 a[i] = d[1] + d[2] + ... + d[i] 區間查詢 a[1] + ... + a[10] = d[1] + (d[1] + d[2]) + ... + (d[1] + ... + d[10]) = d[1] * 10 + d[2] * 9 + ... + d[10] * 1 = (d[1] + ... + d[10])*11 - (d[1] * 1 + d[2] * 2 + ... + d[10] * 10) = presum(bit1, 10)*11 - presum(bit2, 10) */ int query(int l, int r){ l--; int rr = presum(bit1, r)*(r+1) - presum(bit2, r); int ll = presum(bit1, l)*(l+1) - presum(bit2, l); return rr - ll; } int main(){ vector<int> data = {1,2,3,1,2,3,4}; n = data.size(); // 初始化 update(bit1, 1, data[0]); update(bit2, 1, data[0]); for(int i=1;i<n;i++){ update(bit1, i+1, (data[i]-data[i-1]) ); update(bit2, i+1, (data[i]-data[i-1])*(i+1) ); } // 1 2 3 1 2 3 4 cout<<query(1,3)<<endl; cout<<query(5,7)<<endl; change(3, 6, -1); // 1 2 2 0 1 2 4 cout<<query(1,3)<<endl; cout<<query(5,7)<<endl; return 0; } ```