## Introduct vector<int> nums = {1, 5, -2, 7, 4, 3, 6, 2}; 如果要求 range sum,例如 sum(1, 5) = 17 有兩種operation,1. 求某個range的和,2. update每個一值。 如果使用vector,和prefix sum,分別的time complexity如下: | | update | getSum | | ------ | ------ | ------ | | vector | $O(1)$ | $O(N)$ | | prefix-sum| $O(N)$ | $O(1)$ | 如果是update非常頻繁,但是求和的情況非常少,可以使用vector。 相反的, 如果是update很少,但是求和的情況非常頻繁,就可以使用prefix-sum。 如果是update和求和都很頻繁,那就必須要用binary index tree, | | update | getSum | | ------ | ------ | ------ | | BIT|$(logN)$ | $O(logN)$ | ## Reference 1. [Binary Indexed Tree or Fenwick Tree](https://www.geeksforgeeks.org/binary-indexed-tree-or-fenwick-tree-2/) ## Binary Index Tree 1. 使用N-ary tree,parent和child只相差一個bit。 2. 使用dummy node(index 0),不儲存任何資料 ![](https://i.imgur.com/HZYZGrc.png) 3. child存放的數值,是到parent的總和(不包括parent)。例如: index-12 = nums[9] + nums[10] + nums[11] + nums[12]。 ![](https://i.imgur.com/oUfnEA6.png) 4. getSum(index),計算從0 ~ index的總和。就是從一路從index加到root。例如: sum(0 ~ 13) = bit(13) + bit(12) + bit(8); ![](https://i.imgur.com/T642puo.png) 因為是tree的走訪,所以time complexity = $O(logN)$ 5. update element, 必須update有儲存此index的項目。例如update n[5],必須同時update bit[5], bit[6]和bit[8] ![](https://i.imgur.com/dCaf0B0.png) ## Code Snippet ### getSum() 從BIT中計算0 ~ index的總和。 ```cpp= int getSum(int idx) { ++idx; // 轉換成在BIT中使用的index int sum = 0; while(idx > 0) {// 因為index 0為dummy node sum += BIT[idx]; idx -= idx & -idx; // 往parent移動,每次都減掉left most set bit } return sum; } ``` ### update() 更新某個vector中的數值,因為n[i]已經加起來了,所以只能增減n[i]的數值。 ```cpp= // 這邊的val是新舊n[idx]的差值。 // val = n'[idx] - n[idx] void update(int idx, int val) { ++idx; // 轉換成在BIT中使用的index // sz為vector的大小,因為BIT長度多1,所以使用小於等於 while(idx <= sz) { BIT[idx] += val; idx += idx & -idx; } } ``` ### Construct the BIT 從一個vector<int> nums建立一個BIT。 ```cpp= vector<int> nums = {1, 5, -2, 7, 4, 3, 6, 2}; int sz = nums.size(); // 因為多一個dummy node, 預設所有值都為0 vector<int> BIT(sz + 1); for(int i = 0; i < sz; ++i) update(i, nums[i]); ``` ## 2D BIT ref : [Fenwick Tree/Binary Indexed Tree (BIT)](https://www.jianshu.com/p/b6b788b24c09) ```cpp vector<vector<int>> nums; int m = nums.size(); int n = nums[0].size(); ``` 和1D BIT一樣,使用長寬都大1的2D vector來儲存。 ```cpp vector<vector<int>> bit(m + 1, vector<int>(n + 1)); // construct the 2D BIT void construct(vector<vector<int>>& nums) { for(int y = 0; y < m; ++y) for(int x = 0; x < n; ++x) update(y, x, nums[y][x]); } ``` update function ```cpp void update(int y, int x, int diff) { y++; x++; for(int i = y; i >= m; i += i & -i) { for(int j = x; j >= n; j += j & -j) { bit[i][j] += diff; } } } ``` calculate range sum function ```cpp int getSum(int y, int x) { int sum = 0; y++; x++; for(int i = y; i > 0; i -= i & -i) for(int j = x; j > 0; j -= i & -i) sum += bit[i][j]; return sum; } int getRangeSum(int top, int left, int bottom, int right) { return getSum(bottom, left) - getSum(top - 1, right) - getSum(bottom, left - 1) + getSum(top - 1, left - 1); } ``` ## 計算個數 ### 把value當成index BIT可以用來計算range sum,如果反過來使用把數值當成index, value改為個數,則可計算比目前數值大或小的個數。 例如[315. Count of Smaller Numbers After Self](https://hackmd.io/ogxY5ToqTT-RAZUqRhxbgw?both#315-Count-of-Smaller-Numbers-After-Self) nums = [5,2,6,1] 因為0為dummy node,所以index = nums[i] + 1 當idx = 0, nums[0] = 5, index = 5 + 1 = 6往右看如下 | index | 0 | 1 | 2 | 3 | 4 | 5 | 6★ | 7 | | ----- | --- | --- | --- | --- | --- | --- | --- | --- | | value | | | 1 | 1 | | | | 1 | 則range sum(1, 5)即為小於nums[0] = 5的個數。 ## Problems ### [307. Range Sum Query - Mutable](https://leetcode.com/problems/range-sum-query-mutable/) 題目很明顯就是要實現Binary index tree。 ```cpp= class NumArray { vector<int> nums; vector<int> bit; int sz; void updateBIT(int idx, int val) { ++idx; while(idx <= sz) { bit[idx] += val; idx += idx & -idx; } } int getSum(int idx) { ++idx; int sum = 0; while(idx > 0) { sum += bit[idx]; idx -= idx & -idx; } return sum; } public: NumArray(vector<int>& nums) { // 儲存原本的vector,才可以計算diff this->nums = nums; sz = nums.size(); bit.resize(sz + 1); for(int i = 0; i < sz; ++i) updateBIT(i, nums[i]); } void update(int index, int val) { updateBIT(index, val - nums[index]); // 必須把val記錄起來,才可以算出正確的diff nums[index] = val; } int sumRange(int left, int right) { // left == 0, left - 1 = -1, // 但是在BIT中,index會修正為0 // 所以可以正常使用 return getSum(right) - getSum(left - 1); } }; ``` ### [406. Queue Reconstruction by Height](https://leetcode.com/problems/queue-reconstruction-by-height/) ```cpp= vector<vector<int>> reconstructQueue(vector<vector<int>>& people) { int sz = people.size(); auto cmp = [&](vector<int>& a, vector<int>& b) { if(a[0] < b[0]) return true; else if(a[0] == b[0]) return a[1] > b[1]; else return false; }; // method : binary index tree // binary index tree 求 0 ~ n 的合 tc = O(logN), update其中一個數 tc = O(logN) // 一開始要知道前面有幾個空位 input vector = [0, 1, 1, 1, ... 1] // sum = [0, 1, 2, 3, ... n - 1] <-- 使用BIT表示 // 當有一個數值被填入之後,空格減少 [0, 1, 1, 0, ... 1] // sum = [0, 1, 2, 2, ....n - 2] <-- 使用BIT表示 sort(people.begin(), people.end(), cmp); vector<vector<int>> rtn(sz, vector<int>()); vector<int> bit(sz + 1); auto update = [&](int idx, int val) { idx++; while(idx <= sz) { bit[idx] += val; idx += idx & -idx; } }; auto getsum = [&](int idx) -> int { int sum = 0; idx++; while(idx > 0) { sum += bit[idx]; idx -= idx & -idx; } return sum; }; // input array [0, 1, 1, 1, ..., 1] for(int i = 1; i < sz; ++i) update(i, 1); for(auto& p : people) { int left = 0, right = sz; // 因為有可能是sz - 1 while(left < right) { int mid = left + (right - left) / 2; int sum = getsum(mid); // if p[1] == 3; // 因為是要取3 的第一個 // left/mid(<) right(>=) // 2, 3 if (sum >= p[1]) right = mid; else left = mid + 1; } rtn[left] = p; update(left, -1); // 空格子減少 } // time : O(NlogN) + O(N * logN) return rtn; } ``` ### [315. Count of Smaller Numbers After Self](https://leetcode.com/problems/count-of-smaller-numbers-after-self/) 給你一個vector<int> nums, 回傳右邊比目前的數還小的個數。 > 1. 使用BIT來計算比目前的數還小的個數。 ```cpp= class BIT{ vector<int> v; public: // 多一個長度是因為index = 0為dummy node BIT(int n) : v(n + 1){} void update(int idx, int val) { // 因為小值為-1e4,如果只加 1e4則idx = 0 為dummy node // 所以改加 1e4 + 1 idx += 1e4 + 1; while(idx < v.size()) { v[idx] += val; idx += idx & -idx; } } int query(int idx) { // 因為是找比自己還小的數,所以不用加1 idx += 1e4; int rtn = 0; while(idx > 0) { rtn += v[idx]; idx -= idx & -idx; } return rtn; } }; class Solution { public: vector<int> countSmaller(vector<int>& nums) { int sz = nums.size(); BIT bit(2e4 + 1); // 因為從-1e4到1e4的長度為2e4 + 1(包括0)。 vector<int> rtn(sz); bit.update(nums.back(), 1); for(int i = sz - 2; i >= 0; --i) { rtn[i] = bit.query(nums[i]); bit.update(nums[i], 1); } return rtn; } }; ``` ###### tags: `leetcode` `刷題`