Treap是平衡搜索树的一种。所谓 “Treap”,即 “Tree” + “Heap”, 顾名思义,是使用堆方法对搜索树进行平衡的一种数据结构。
约定该篇文章讨论的二叉搜索树都遵守 “比当前节点值大的节点在右子树,小于等于当前节点值的节点在其左子树” 这一规定。
Treap维护平衡的方式
Treap的目的主要是利用堆的性质来平衡原搜索树。因为堆是一棵完全二叉树,深度最优嘛。
对于一棵普通的二叉搜索树,我们对其每个节点再随机赋上一个优先级权值 $p$. 即对任意节点,多了一个变量 $p=random()$.
例如,如果我们以小根堆的规则来约束这个平衡树,那么插入和删除两个主要操作就会变成这样:
插入:先按照普通平衡树的方式将元素插入到合适位置。然后我们需要向上回溯,在回溯的过程中检查这些节点关于 $p$ 是否符合小根堆的性质。如果当前节点的 $p$ 大于左右子节点 $p$ 的最小值,那么就不符合小根堆的性质。这个时候,如果左子节点的 $p$ 更小,我们就需要左旋,反之需要右旋,以此保证平衡。
至于旋转的具体操作,即:
以右旋为例:
- 将当前节点下放至其左子节点的位置
- 令右子节点的左子节点 称为当前节点的新的右子节点
- 右子节点的左子节点变为当前节点,取代当前节点的位置
删除操作当然是同理了,回溯的时候观察 $p$ 来维护平衡。由于 $p$ 值是随机的, Treap也不是严格的堆结构,所以它是一种期望平衡的弱平衡的平衡树,搜索、插入和删除的期望时间复杂度为 $O(\log n)$.
Treap的优势在于其实现起来非常简单,包含的也就上面这些东西了。
模板部分
定义
1 2 3 4 5 6
| int cnt; const int maxn = XXXXX; struct node { int l, r, size, val, p; }a[maxn];
|
回溯统计
1 2 3
| void pushup(int k) { a[k].size = a[a[k].l].size + a[a[k].r].size + 1; }
|
旋转
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| void rrotate(int &k) { int tmp = a[k].l; a[k].l = a[tmp].r; a[tmp].r = k; a[tmp].size = a[k].size; pushup(k); k = tmp; }
void lrotate(int &k) { int tmp = a[k].r; a[k].r = a[tmp].l; a[tmp].l = k; a[tmp].size = a[k].size; pushup(k); k = tmp; }
|
插入
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| void ins(int &k, int val) { if (k == 0) { k = ++cnt; a[k].val = val; a[k].p = rand(); a[k].size = 1; return; } a[k].size++; if (val >= a[k].val) { ins(a[k].r, val); } else ins(a[k].l, val); if (a[k].l && a[k].p > a[a[k].l].p) { rrotate(k); } if (a[k].r && a[k].p > a[a[k].r].p) { lrotate(k); } pushup(k); }
|
删除
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| void del(int &k, int val) { a[k].size--; if (a[k].val == val) { if (!a[k].l && !a[k].r) { k = 0; return; } if (!a[k].l || !a[k].r) { k = a[k].l + a[k].r; return; } if (a[a[k].l].p < a[a[k].r].p) { rrotate(k); del(a[k].r, val); return; } else { lrotate(k); del(a[k].l, val); return; } } if (a[k].val >= val) { del(a[k].l, val); } else { del(a[k].r, val); } pushup(k); }
|
例题分割线~
例题:
其中还要维护查询排名,查询前驱后继等搜索操作。
当普通平衡树维护就可以啦,比较简单。
完整代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
| #include <bits/stdc++.h> using namespace std; const int maxn = 1e5 + 10; int cnt; int root;
struct node { int l, r, size, val, p; }a[maxn]; void pushup(int k) { a[k].size = 1 + a[a[k].l].size + a[a[k].r].size; } void rrotate(int &k) { int tmp = a[k].l; a[k].l = a[tmp].r; a[tmp].r = k; a[tmp].size = a[k].size; pushup(k); k = tmp; } void lrotate(int &k) { int tmp = a[k].r; a[k].r = a[tmp].l; a[tmp].l = k; a[tmp].size = a[k].size; pushup(k); k = tmp; } void ins(int &k, int val) { if (k == 0) { k = ++cnt; a[k].val = val; a[k].p = rand(); a[k].size = 1; return; } a[k].size++; if (val >= a[k].val) { ins(a[k].r, val); } else ins(a[k].l, val); if (a[k].l && a[k].p > a[a[k].l].p) { rrotate(k); } if (a[k].r && a[k].p > a[a[k].r].p) { lrotate(k); } pushup(k); } void del(int &k, int val) { a[k].size--; if (a[k].val == val) { if (!a[k].l && !a[k].r) { k = 0; return; } if (!a[k].l || !a[k].r) { k = a[k].l + a[k].r; return; } if (a[a[k].l].p < a[a[k].r].p) { rrotate(k); del(a[k].r, val); return; } else { lrotate(k); del(a[k].l, val); return; } } if (a[k].val >= val) { del(a[k].l, val); } else { del(a[k].r, val); } pushup(k); } int rk(int k, int val) { if (!k) return 0; if (val > a[k].val) { return a[a[k].l].size + rk(a[k].r, val) + 1; } return rk(a[k].l, val); } int find(int k, int rnk) { if (rnk == a[a[k].l].size + 1) return a[k].val; if (rnk > a[a[k].l].size + 1) return find(a[k].r, rnk - a[a[k].l].size - 1); return find(a[k].l, rnk); } int query_pre(int k, int val) { if (!k) return 0; if (a[k].val >= val) { return query_pre(a[k].l, val); } int tmp = query_pre(a[k].r, val); if (!tmp) return a[k].val; return tmp; } int query_suf(int k, int val) { if (!k) return 0; if (a[k].val <= val) { return query_suf(a[k].r, val); } int tmp = query_suf(a[k].l, val); if (!tmp) return a[k].val; return tmp; } int main() { ios::sync_with_stdio(0); cin.tie(0); int n; cin >> n; int opt, x; for (int i = 1; i <= n; ++i) { cin >> opt >> x; if (opt == 1) { ins(root, x); } else if (opt == 2) { del(root, x); } else if (opt == 3) { cout << rk(root, x) + 1 << endl; } else if (opt == 4) { cout << find(root, x) << endl; } else if (opt == 5) { cout << query_pre(root, x) << endl; } else cout << query_suf(root, x) << endl; } return 0; }
|