---
tags: data-structure
GA: UA-145883477-1
---
# Binary Indexed Tree
寫LeetCode過程中,學了不少的資料結構,binary indexed tree (Fenwick tree)的有趣度絕對排得上前三名。在看過它之前,我從來沒有想過可以用這樣的方式存取資料。
和binary indexed tree相關的問題是求prefix sum:給定一個array `nums`,進行k次query,每次query給定一個index `i`,回傳`nums[0..i]`的和。
看到這個問題,通常第一個想到的做法是每次query就直接計算總和,時間複雜度O(kN),空間O(1)。想省時間的話,拿到`nums`就先把所有index的prefix sum算好存下,時間O(N + k),空間O(N)。如果問題就只是這樣(i.e. `nums`是immutable),那麼binary indexed tree還派不上用場。
如果`nums`是mutable,每次update一個元素,進行l次update,以上述省時間的做法,每次update時間複雜度為O(N),總共時間為O(N + lN + k)。
Binary indexed tree提供了另一個選擇,能將每次update時間複雜度降為O(logN),代價是每次query時間上升為O(logN)。總時間複雜度為O(NlogN + llogN + klogN),空間同樣是O(N)。適合在大量update的情境使用。
## 概念
(以下為方便說明,令`nums`的index從1開始到N)
Binary indexed tree準備了一個array `t`,用來存放預先計算好的sum。其中`t[i]`存放的sum不像直觀的做法是存`nums[1..i]`的sum,而是`nums[i - lowbit(i) + 1, i]`的sum,這樣的設計讓query和update可以在O(logN)完成。
以下方Fenwick原始論文中的圖來說明,每個長方形都有一個灰色部分,對應到一個index `i`,`t[i]`存放的值就是該長方形所涵蓋的所有elements的sum。例如,`t[12]`存`nums[9..12]`的sum,`t[11]`存`nums[11..11]`的sum。

(source: [A new data structure for cumulative frequency tables](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.8917&rep=rep1&type=pdf))
### lowbit function
上述lowbit function是用來計算某個數的最小的bit 1的值。
例如lowbit(12) = 4 (12 = 0xb1100,最小的bit 1為0xb0100 = 4)。
### Get prefix sum
計算prefix sum只要從圖上index對應的灰色部分一路traverse到root。例如prefix sum of index 11
= t[11] + t[10] + t[8]
= nums[11] + nums[9..10] + nums[1..8]
仔細觀察會發現traversal path下個index是當前index拿掉最小的bit 1 (減去lowbit):11 (0b1011) -> 10 (0b1010) -> 8 (0b1000)
因為每次都會少掉一個bit 1,時間複雜度為O(logN)。
### Update an element
`nums[i]`如果有update,有影響的即為圖上所有有涵蓋到`nums[i]`的`t`的元素。例如
nums[3]被update,t[3]、t[4]、t[8]會受影響。
仔細觀察會發現被影響的下個index是當前index加上lowbit:
3 (0b0011) -> 4 (0b0100) -> 8 (0b1000)
因為每次至少往MSB推進一個bit,時間複雜度為O(logN)。
## 實作
https://leetcode.com/problems/range-sum-query-mutable/
```cpp=
class NumArray {
public:
NumArray(vector<int>& nums) {
n = nums.size();
t.assign(n+1, 0);
for (int i = 1; i <= n; i++)
for (int j = i; j <= n; j += lowbit(j))
t[j] += nums[i-1];
}
void update(int i, int val) {
int num = sumRange(i, i);
int d = val - num;
for (i = i + 1; i <= n; i += lowbit(i))
t[i] += d;
}
int sumRange(int i, int j) {
return sumPrefix(j) - sumPrefix(i-1);
}
int sumPrefix(int i)
{
int res = 0;
for (i = i + 1; i > 0; i -= lowbit(i))
res += t[i];
return res;
}
int lowbit(int i)
{
return i & ~(i - 1);
}
private:
vector<int> t;
int n;
};
```
Tips
* 迴圈只有兩種邏輯:update an element和get the prefix sum。初始化可以想成是逐一update element,從0 update成`nums[i]`。
* 一致的index規則:function input index皆為存取 `nums`的index。存取 `t`前,記得將index + 1。