# Treap Treap = Tree + Heap 1. Tree : Tree其實是Binary Search Tree,具備二元搜尋的功能 2. Heap : 所有父節點都 < 子節點,用意是平衡樹的高度 支援的操作 : ``` 插入一個數 刪除一個數 查詢一個數的排名 查詢排名為多少的數 查詢小於一個數最大的數 查詢大於一個數最小的數 ``` 這些操作BST都可以做到,只是如果只用BST最壞的複雜度可能到O(n),所以我們要加上Heap的操作來讓整棵樹維持複雜度期望值為O(log_n),以下就讓我們來講解Treap的操作。 ## 構造 ![](https://i.imgur.com/AFleVAj.jpg) Treap的每一個節點會記錄兩個值,一個是BST用的val(<font color="blue">藍</font>),一個是Heap用的key(<font color="red">紅</font>)。 **val&nbsp; : 左節點 > 右節點** **key : 子節點 > 父節點** 在初始化時,可以先建立兩個節點,-inf跟inf,這樣在一開始新增節點的時候會比較方便,但在查詢排名的時候要注意-inf應該要是第0名。 ```cpp= typedef struct Node{ int val, key, size, cnt; //size為以當前節點為跟的樹的節點總數,cnt為val的數量 struct Node *l, *r; Node(int x){ //初始化 val=x, size=cnt=1, key=rand(); l=r=NULL; } void update_son(){ //更新子樹節點函數 size=cnt; if(l != NULL) size+=l->size; if(r != NULL) size+=r->size; } }Node; Node *built(){ Node* rt=new Node(inf); rt->l=new Node(-inf); //-inf < inf,所以在左邊 rt->key=-inf, rt->l->key=-inf; //讓key相同就不會和heap性質衝突 rt->update_son(); //更新節點數 return rt; } Node *root=built(); ``` ## 插入 首先跟BST的插入一樣,如果要插入的值小於中間的值就往左遞迴,如果大於中間的值就往右遞迴,否則就插入節點。 雖然當前滿足了BST的條件,但還不一定滿足Heap的性質,如果有任意子節點 < 父節點,那就要利用左旋和右璇來維護key。 把插入的節點往插入的路徑上推 : 如果左子樹的key比中間小就右璇,如果右子樹的key比中間小就左旋,讓較小的key往上提。 ### 左旋 把中間的key更新成右子節點的key :::spoiler 示意圖 ![](https://i.imgur.com/vm2SlR1.jpg) ![](https://i.imgur.com/8YqlaS7.jpg) ![](https://i.imgur.com/gfQ66az.jpg) ::: 3可以更新到2是因為2<3<4,3可以是4的左子節點也可以是2的右子節點,因為4更新後多了一個節點2,2更新後少了一個節點4,因此多的3要改接在2的右子節點。 ```cpp= void left_rotate(Node*& a){ //要取本尊才更新的到 Node* b=a->r; //a=2, b=4 a->r=b->l; //先把3接到2的右 b->l=a; //才能更新4的左 a=b; //更新root節點為4 a->update_son(), a->l->update_son(); //更新總節點數 } ``` ### 右旋 把中間的key更新成左子節點的key :::spoiler 示意圖 ![](https://i.imgur.com/FjyNKks.jpg) ![](https://i.imgur.com/9eNKfAG.jpg) ![](https://i.imgur.com/AHjSJXG.jpg) ::: ```cpp= void right_rotate(Node*& a){ Node* b=a->l; a->l=b->r, b->r=a, a=b; a->update_son(), a->r->update_son(); } ``` 經過以上操作後,可以在維持BST的前提下,更改key的順序讓整棵Treap符合Heap的性質。 ```cpp= void insert(Node*& node, int val){ if(!node){ //如果到了葉節點就插入val node=new Node(val); return; } if(val == node->val){ //如果已有val,就把數量加一 node->cnt++; }else if(val < node->val){ //往左遞迴 insert(node->l, val); if(node->l->key < node->key) right_rotate(node); //左子節點比較小就右璇 }else{ //往右遞迴 insert(node->r, val); if(node->r->key < node->key) left_rotate(node); //右子節點比較小就左旋 } node->update_son(); //更新節點數 } ``` ## 刪除 在刪除的時候要先把要刪除的節點移到葉節點才能刪除,移動的方法就是上述的左璇和右璇。 ```cpp= void del(Node*& node, int val){ if(val < node->val) del(node->l, val); else if( node->val < val) del(node->r, val); else{ //當前節點是val if(node->cnt > 1){ //如果數量>1還不用移除節點,把數量-1即可 node->cnt--; node->update_son(); return; } if(node->l == NULL && node->r == NULL){ //移到葉節點就可以刪除了 delete node; node=NULL; return; } //否則要往下移到葉節點 if(node->l == NULL){ //只有右邊可以拿來移 left_rotate(node); //把右子節點提上來代替要刪除的節點 del(node->l,val); //左旋完中間節點會變成在左子節點,繼續遞迴 }else if(node->r == NULL){ //只有左邊可以拿來移 right_rotate(node); del(node->r,val); }else{ //兩邊都有節點可以移 if(node->l->key < node->r->key){ //選擇較小的來當root,效益最大 right_rotate(node); del(node->r,val); }else{ left_rotate(node); del(node->l,val); } } } node->update_son(); //移完要記得更新節點總數 } ``` ## 查詢一個數的排名(第幾小) 如果中間比我小就往左遞迴,如果中間比我大就往右遞迴,且要加上左邊及中間的個數,如果找到了就回傳左邊的個數。 ```cpp= int val_to_rank(Node *node, int val){ if(val == node->val) return size(node->l) + 1; //加1是加上我自己 else if(val < node->val) return val_to_rank(node->l,val); else return size(node->l) + node->cnt + val_to_rank(node->r,val); } ``` ## 查詢排名為多少的數 如果要查詢的排名在左邊就往左遞迴,如果在中間就回傳該數,如果在右邊就往右遞迴,並且減掉左邊及中間的個數。 ```cpp= int rank_to_val(Node *node, int rank){ if(rank <= size(node->l)) return rank_to_val(node->l,rank); else if(rank <= size(node->l) + node->cnt) return node->val; else return rank_to_val(node->r,rank - size(node->l) - node->cnt); } ``` ## 查詢比一個數小最大的數 在不斷的往左往或右中,與答案的距離會越來越小(BST性質),答案就是在最後一個往右時的節點。 ```cpp= int pref(Node* p, int x){ int pre=0; while(p){ if(x <= p->val) p=p->l; else pre=p->val, p=p->r; } return pre; } ``` ## 查詢比一個數大最小的數 同理,答案會在最後一個往左時的節點。 ```cpp= int suff(Node* p, int x){ int suf=0; while(p){ if(x >= p->val) p=p->r; else suf=p->val,p=p->l; } return suf; } ``` ::: spoiler 完整Code ```cpp= #include <bits/stdc++.h> using namespace std; const int inf=1e9; typedef struct Node{ int val, size, cnt, w; struct Node *l, *r; Node(int x){ val=x,size=cnt=1,w=rand(); l=r=NULL; } void update_son(){ size=cnt; if(l != NULL) size+=l->size; if(r != NULL) size+=r->size; } }Node; Node *built(){ Node* rt=new Node(inf); rt->l=new Node(-inf); rt->w=-inf, rt->l->w=-inf; rt->update_son(); return rt; } Node *root=built(); void left_rotate(Node*& a){ Node* b=a->r; a->r=b->l, b->l=a, a=b; a->update_son(), a->l->update_son(); } void right_rotate(Node*& a){ Node* b=a->l; a->l=b->r, b->r=a, a=b; a->update_son(), a->r->update_son(); } int size(Node *node){ return node ? node->size : 0; } void insert(Node*& node, int val){ if(!node){ node=new Node(val); return; } if(val == node->val){ node->cnt++; }else if(val < node->val){ insert(node->l, val); if(node->l->w < node->w) right_rotate(node); }else{ insert(node->r, val); if(node->r->w < node->w) left_rotate(node); } node->update_son(); } void del(Node*& node, int val){ if(val < node->val) del(node->l, val); else if( node->val < val) del(node->r, val); else{ if(node->cnt > 1){ node->cnt--; node->update_son(); return; } if(node->l == NULL && node->r == NULL){ delete node; node=NULL; return; } if(node->l == NULL){ left_rotate(node); del(node->l,val); }else if(node->r == NULL){ right_rotate(node); del(node->r,val); }else{ if(node->l->w < node->r->w){ right_rotate(node); del(node->r,val); }else{ left_rotate(node); del(node->l,val); } } } node->update_son(); } int val_to_rank(Node *node, int val){ if(val == node->val) return size(node->l) + 1; else if(val < node->val) return val_to_rank(node->l,val); else return size(node->l) + node->cnt + val_to_rank(node->r,val); } int rank_to_val(Node *node, int rank){ if(rank <= size(node->l)) return rank_to_val(node->l,rank); else if(rank <= size(node->l) + node->cnt) return node->val; else return rank_to_val(node->r,rank - size(node->l) - node->cnt); } int pref(Node* p, int val){ int pre=0; while(p){ if(val <= p->val) p=p->l; else pre=p->val,p=p->r; } return pre; } int suff(Node* p, int val){ int suf=0; while(p){ if(val >= p->val) p=p->r; else suf=p->val,p=p->l; } return suf; } int main(){ int n,x,y; cin >> n; while(n--){ cin >> x >> y; if(x==1){ insert(root,y); }else if(x==2){ del(root,y); }else if(x==3){ cout << val_to_rank(root,y)-1 << '\n'; }else if(x==4){ cout << rank_to_val(root,y+1) << '\n'; }else if(x==5){ cout << pref(root,y) << '\n'; }else{ cout << suff(root,y) << '\n'; } } } ``` :::