# Wavelet Tree ## Intro **一棵 Wavelet Tree** 支援以下三種詢問: 1. 詢問區間 $[l,\ r]$中第 $k$ 大/小的數 2. 詢問區間 $[l,\ r]$中 $k$ 出現的次數 3. 詢問區間 $[l,\ r]$中 $\leq k$ 的數字的數量 不支援修改 :crying_cat_face: ## 結構 ![alt image](<https://i.imgur.com/IKCjdls.png> =500x) 將一串字串(也可以是數字)用**類似 $\text{quick-sort}$** 的方式拆成兩邊,直到該區間所有數字皆相等 同時在每個節點維護一個 $0-1$ 陣列,代表這個數字被分到左邊還是右邊 ($1$ 代表左邊) 在整棵建構完後會發現**原陣列被排序好了** ## 建構 建構的過程會用到 $\text{std::stable_partition}$,沒用過的可以到[**這裡**](https://en.cppreference.com/w/cpp/algorithm/stable_partition)或[**這裡**](http://c.biancheng.net/view/7513.html)先看看 ### 節點 ```cpp= int lo, hi, m; vector<int> lf; wavelet_tree *ln, *rn; ``` * $lo,\ hi$ : 該區間的最小/最大值 * $m$ : $(lo+hi)/2$ * $lf$ : 結構那邊講到的 $0-1$ 陣列的**前綴和** ### 建樹 把陣列從 $m$ 拆成兩邊 ```cpp= vector<LL> val; //開在全域,原陣列 wavelet_tree(LL l, LL r, LL x, LL y) : lo(x), hi(y) { ln = rn = NULL; if (l >= r) return; if (lo == hi) return; lf.pb(0); m = (lo + hi) >> 1; auto cmp = [&](LL x) { return x <= m; }; for (int i = l; i < r; i++) lf.pb(lf.back() + (LL)cmp(val[i])); LL pivot = stable_partition(val.begin() + l, val.begin() + r, cmp) - val.begin(); ln = new wavelet_tree(l, pivot, lo, *max_element(val.begin() + l, val.begin() + pivot)); rn = new wavelet_tree(pivot, r, *min_element(val.begin() + pivot, val.begin() + r), hi); } // wavelet_tree tr(0, n, *min_element(all(val)), *max_element(all(val))); ``` 直接照定義做,除了對 $lf$ 做前綴和以外應該沒什麼重點 ## 詢問 1. K-th 以第 $k$ 小為例 ```cpp= LL kth(LL l, LL r, LL k) //1-based [l,r] { if (lo == hi) return lo; LL t = lf[r] - lf[l - 1]; if (t >= k) return ln->kth(lf[l - 1] + 1, lf[r], k); return rn->kth(l - lf[l - 1], r - lf[r], k - t); } ``` 變數 $t$ 代表 $[l,\ r]$ 有幾個數字被分到左邊 ![alt image](<https://i.imgur.com/JYKJrHH.png> =300x) 如果 $t<k$ 代表答案會在右子樹,所以往右邊遞迴 否則往左邊遞迴 但如果直接把 $[l,\ r]$ 傳下去顯然不對 因為這個區間有一些數字是分到另一邊的,$[l,\ r]$ 太大了 所以 * 如果往左子樹 $\to$ 新的 $l$ 設為 $lf[l - 1] + 1$,$r$ 設為 $lf[r]$ * 如果往右子樹 $\to$ 新的 $l$ 設為 $l-lf[l - 1]$,$r$ 設為 $r-lf[r]$ $lf[l-1]$ 代表在 $l$ 之前已經有幾個數字在左子樹了,所以新的 $l$ 要從 $lf[l - 1] + 1$ 開始。 而 $lf[r]$ 也差不多 右子樹也一樣 Wavelet Tree 支援的三種詢問縮小區間的方法都一樣,所以下面兩種詢問的說明會很短,真的很短的那種 ## 詢問 2. count 詢問區間 $[l,\ r]$中 $k$ 出現的次數 ```cpp= LL count(LL l, LL r, LL k) { if (l > r || k > hi || k < lo) return 0; if (lo == hi) return r - l + 1; if (k <= m) return ln->count(lf[l - 1] + 1, lf[r], k); return rn->count(l - lf[l - 1], r - lf[r], k); } ``` ## 詢問 3. LTE 詢問區間 $[l,\ r]$中 $\leq k$ 的數字的數量 ```cpp= LL LTE(LL l, LL r, LL k) { if (l > r || k < lo) return 0; if (hi <= k) return r - l + 1; return ln->LTE(lf[l - 1] + 1, lf[r], k) + rn->LTE(l - lf[l - 1], r - lf[r], k); } ``` ## 複雜度 ### 時間複雜度 建樹 $\mathcal{O}(n \log n)$ 三種詢問單次詢問的時間複雜度均為 $\mathcal{O}(\log n)$ ### 空間複雜度 一層 $\mathcal{O}(n)\times \log n$ 層 $=\mathcal{O}(n\log n)$ ## 例題 [**Range Kth Smallest**](https://judge.yosupo.jp/problem/range_kth_smallest) [**Static Range Frequency**](https://judge.yosupo.jp/problem/static_range_frequency) ## Template ```cpp= vector<int> val; struct wavelet_tree { int lo, hi, m; vector<int> lf; wavelet_tree *ln, *rn; wavelet_tree(LL l, LL r, LL x, LL y) : lo(x), hi(y) // 0-based [l, r) { ln = rn = NULL; if (l >= r) return; if (lo == hi) return; lf.push_back(0); m = (lo + hi) >> 1; auto cmp = [&](LL x) { return x <= m; }; for (int i = l; i < r; i++) lf.push_back(lf.back() + (LL)cmp(val[i])); int pivot = stable_partition(val.begin() + l, val.begin() + r, cmp) - val.begin(); ln = new wavelet_tree(l, pivot, lo, *max_element(val.begin() + l, val.begin() + pivot)); rn = new wavelet_tree(pivot, r, *min_element(val.begin() + pivot, val.begin() + r), hi); } int kth(int l, int r, int k) // 1-based [l, r] { if (lo == hi) return lo; LL t = lf[r] - lf[l - 1]; if (k <= t) return ln->kth(lf[l - 1] + 1, lf[r], k); return rn->kth(l - lf[l - 1], r - lf[r], k - t); } int count(int l, int r, int k) // 1-based [l, r] { if (l > r || k > hi || k < lo) return 0; if (lo == hi) return r - l + 1; if (k <= m) return ln->count(lf[l - 1] + 1, lf[r], k); return rn->count(l - lf[l - 1], r - lf[r], k); } int LTE(int l, int r, int k) // 1-based [l, r] { if (l > r || k < lo) return 0; if (hi <= k) return r - l + 1; return ln->LTE(lf[l - 1] + 1, lf[r], k) + rn->LTE(l - lf[l - 1], r - lf[r], k); } }; wavelet_tree tr(0, n, *min_element(all(val)), *max_element(all(val))); ``` ## End 網路上的圖幾乎都是把 $1$ 當成往右邊,但我懶得改 $\text{Code}$ 了所以乾脆自己畫圖