owned this note
owned this note
Published
Linked with GitHub
# Treap
Treap = Tree + Heap
1. Tree : Tree其實是Binary Search Tree,具備二元搜尋的功能
2. Heap : 所有父節點都 < 子節點,用意是平衡樹的高度
支援的操作 :
```
插入一個數
刪除一個數
查詢一個數的排名
查詢排名為多少的數
查詢小於一個數最大的數
查詢大於一個數最小的數
```
這些操作BST都可以做到,只是如果只用BST最壞的複雜度可能到O(n),所以我們要加上Heap的操作來讓整棵樹維持複雜度期望值為O(log_n),以下就讓我們來講解Treap的操作。
## 構造

Treap的每一個節點會記錄兩個值,一個是BST用的val(<font color="blue">藍</font>),一個是Heap用的key(<font color="red">紅</font>)。
**val : 左節點 > 右節點**
**key : 子節點 > 父節點**
在初始化時,可以先建立兩個節點,-inf跟inf,這樣在一開始新增節點的時候會比較方便,但在查詢排名的時候要注意-inf應該要是第0名。
```cpp=
typedef struct Node{
int val, key, size, cnt; //size為以當前節點為跟的樹的節點總數,cnt為val的數量
struct Node *l, *r;
Node(int x){ //初始化
val=x, size=cnt=1, key=rand();
l=r=NULL;
}
void update_son(){ //更新子樹節點函數
size=cnt;
if(l != NULL) size+=l->size;
if(r != NULL) size+=r->size;
}
}Node;
Node *built(){
Node* rt=new Node(inf);
rt->l=new Node(-inf); //-inf < inf,所以在左邊
rt->key=-inf, rt->l->key=-inf; //讓key相同就不會和heap性質衝突
rt->update_son(); //更新節點數
return rt;
}
Node *root=built();
```
## 插入
首先跟BST的插入一樣,如果要插入的值小於中間的值就往左遞迴,如果大於中間的值就往右遞迴,否則就插入節點。
雖然當前滿足了BST的條件,但還不一定滿足Heap的性質,如果有任意子節點 < 父節點,那就要利用左旋和右璇來維護key。
把插入的節點往插入的路徑上推 :
如果左子樹的key比中間小就右璇,如果右子樹的key比中間小就左旋,讓較小的key往上提。
### 左旋
把中間的key更新成右子節點的key
:::spoiler 示意圖



:::
3可以更新到2是因為2<3<4,3可以是4的左子節點也可以是2的右子節點,因為4更新後多了一個節點2,2更新後少了一個節點4,因此多的3要改接在2的右子節點。
```cpp=
void left_rotate(Node*& a){ //要取本尊才更新的到
Node* b=a->r; //a=2, b=4
a->r=b->l; //先把3接到2的右
b->l=a; //才能更新4的左
a=b; //更新root節點為4
a->update_son(), a->l->update_son(); //更新總節點數
}
```
### 右旋
把中間的key更新成左子節點的key
:::spoiler 示意圖



:::
```cpp=
void right_rotate(Node*& a){
Node* b=a->l;
a->l=b->r, b->r=a, a=b;
a->update_son(), a->r->update_son();
}
```
經過以上操作後,可以在維持BST的前提下,更改key的順序讓整棵Treap符合Heap的性質。
```cpp=
void insert(Node*& node, int val){
if(!node){ //如果到了葉節點就插入val
node=new Node(val);
return;
}
if(val == node->val){ //如果已有val,就把數量加一
node->cnt++;
}else if(val < node->val){ //往左遞迴
insert(node->l, val);
if(node->l->key < node->key) right_rotate(node); //左子節點比較小就右璇
}else{ //往右遞迴
insert(node->r, val);
if(node->r->key < node->key) left_rotate(node); //右子節點比較小就左旋
}
node->update_son(); //更新節點數
}
```
## 刪除
在刪除的時候要先把要刪除的節點移到葉節點才能刪除,移動的方法就是上述的左璇和右璇。
```cpp=
void del(Node*& node, int val){
if(val < node->val) del(node->l, val);
else if( node->val < val) del(node->r, val);
else{ //當前節點是val
if(node->cnt > 1){ //如果數量>1還不用移除節點,把數量-1即可
node->cnt--;
node->update_son();
return;
}
if(node->l == NULL && node->r == NULL){ //移到葉節點就可以刪除了
delete node;
node=NULL;
return;
}
//否則要往下移到葉節點
if(node->l == NULL){ //只有右邊可以拿來移
left_rotate(node); //把右子節點提上來代替要刪除的節點
del(node->l,val); //左旋完中間節點會變成在左子節點,繼續遞迴
}else if(node->r == NULL){ //只有左邊可以拿來移
right_rotate(node);
del(node->r,val);
}else{ //兩邊都有節點可以移
if(node->l->key < node->r->key){ //選擇較小的來當root,效益最大
right_rotate(node);
del(node->r,val);
}else{
left_rotate(node);
del(node->l,val);
}
}
}
node->update_son(); //移完要記得更新節點總數
}
```
## 查詢一個數的排名(第幾小)
如果中間比我小就往左遞迴,如果中間比我大就往右遞迴,且要加上左邊及中間的個數,如果找到了就回傳左邊的個數。
```cpp=
int val_to_rank(Node *node, int val){
if(val == node->val) return size(node->l) + 1; //加1是加上我自己
else if(val < node->val) return val_to_rank(node->l,val);
else return size(node->l) + node->cnt + val_to_rank(node->r,val);
}
```
## 查詢排名為多少的數
如果要查詢的排名在左邊就往左遞迴,如果在中間就回傳該數,如果在右邊就往右遞迴,並且減掉左邊及中間的個數。
```cpp=
int rank_to_val(Node *node, int rank){
if(rank <= size(node->l)) return rank_to_val(node->l,rank);
else if(rank <= size(node->l) + node->cnt) return node->val;
else return rank_to_val(node->r,rank - size(node->l) - node->cnt);
}
```
## 查詢比一個數小最大的數
在不斷的往左往或右中,與答案的距離會越來越小(BST性質),答案就是在最後一個往右時的節點。
```cpp=
int pref(Node* p, int x){
int pre=0;
while(p){
if(x <= p->val) p=p->l;
else pre=p->val, p=p->r;
}
return pre;
}
```
## 查詢比一個數大最小的數
同理,答案會在最後一個往左時的節點。
```cpp=
int suff(Node* p, int x){
int suf=0;
while(p){
if(x >= p->val) p=p->r;
else suf=p->val,p=p->l;
}
return suf;
}
```
::: spoiler 完整Code
```cpp=
#include <bits/stdc++.h>
using namespace std;
const int inf=1e9;
typedef struct Node{
int val, size, cnt, w;
struct Node *l, *r;
Node(int x){
val=x,size=cnt=1,w=rand();
l=r=NULL;
}
void update_son(){
size=cnt;
if(l != NULL) size+=l->size;
if(r != NULL) size+=r->size;
}
}Node;
Node *built(){
Node* rt=new Node(inf);
rt->l=new Node(-inf);
rt->w=-inf, rt->l->w=-inf;
rt->update_son();
return rt;
}
Node *root=built();
void left_rotate(Node*& a){
Node* b=a->r;
a->r=b->l, b->l=a, a=b;
a->update_son(), a->l->update_son();
}
void right_rotate(Node*& a){
Node* b=a->l;
a->l=b->r, b->r=a, a=b;
a->update_son(), a->r->update_son();
}
int size(Node *node){
return node ? node->size : 0;
}
void insert(Node*& node, int val){
if(!node){
node=new Node(val);
return;
}
if(val == node->val){
node->cnt++;
}else if(val < node->val){
insert(node->l, val);
if(node->l->w < node->w) right_rotate(node);
}else{
insert(node->r, val);
if(node->r->w < node->w) left_rotate(node);
}
node->update_son();
}
void del(Node*& node, int val){
if(val < node->val) del(node->l, val);
else if( node->val < val) del(node->r, val);
else{
if(node->cnt > 1){
node->cnt--;
node->update_son();
return;
}
if(node->l == NULL && node->r == NULL){
delete node;
node=NULL;
return;
}
if(node->l == NULL){
left_rotate(node);
del(node->l,val);
}else if(node->r == NULL){
right_rotate(node);
del(node->r,val);
}else{
if(node->l->w < node->r->w){
right_rotate(node);
del(node->r,val);
}else{
left_rotate(node);
del(node->l,val);
}
}
}
node->update_son();
}
int val_to_rank(Node *node, int val){
if(val == node->val) return size(node->l) + 1;
else if(val < node->val) return val_to_rank(node->l,val);
else return size(node->l) + node->cnt + val_to_rank(node->r,val);
}
int rank_to_val(Node *node, int rank){
if(rank <= size(node->l)) return rank_to_val(node->l,rank);
else if(rank <= size(node->l) + node->cnt) return node->val;
else return rank_to_val(node->r,rank - size(node->l) - node->cnt);
}
int pref(Node* p, int val){
int pre=0;
while(p){
if(val <= p->val) p=p->l;
else pre=p->val,p=p->r;
}
return pre;
}
int suff(Node* p, int val){
int suf=0;
while(p){
if(val >= p->val) p=p->r;
else suf=p->val,p=p->l;
}
return suf;
}
int main(){
int n,x,y;
cin >> n;
while(n--){
cin >> x >> y;
if(x==1){
insert(root,y);
}else if(x==2){
del(root,y);
}else if(x==3){
cout << val_to_rank(root,y)-1 << '\n';
}else if(x==4){
cout << rank_to_val(root,y+1) << '\n';
}else if(x==5){
cout << pref(root,y) << '\n';
}else{
cout << suff(root,y) << '\n';
}
}
}
```
:::