Treap是平衡搜索树的一种。所谓 “Treap”,即 “Tree” + “Heap”, 顾名思义,是使用堆方法对搜索树进行平衡的一种数据结构。

约定该篇文章讨论的二叉搜索树都遵守 “比当前节点值大的节点在右子树,小于等于当前节点值的节点在其左子树” 这一规定。

Treap维护平衡的方式

Treap的目的主要是利用堆的性质来平衡原搜索树。因为堆是一棵完全二叉树,深度最优嘛。

对于一棵普通的二叉搜索树,我们对其每个节点再随机赋上一个优先级权值 $p$. 即对任意节点,多了一个变量 $p=random()$.

例如,如果我们以小根堆的规则来约束这个平衡树,那么插入和删除两个主要操作就会变成这样:

插入:先按照普通平衡树的方式将元素插入到合适位置。然后我们需要向上回溯,在回溯的过程中检查这些节点关于 $p$ 是否符合小根堆的性质。如果当前节点的 $p$ 大于左右子节点 $p$ 的最小值,那么就不符合小根堆的性质。这个时候,如果左子节点的 $p$ 更小,我们就需要左旋,反之需要右旋,以此保证平衡。

至于旋转的具体操作,即:

以右旋为例:

  1. 将当前节点下放至其左子节点的位置
  2. 右子节点的左子节点 称为当前节点的新的右子节点
  3. 右子节点的左子节点变为当前节点,取代当前节点的位置

删除操作当然是同理了,回溯的时候观察 $p$ 来维护平衡。由于 $p$ 值是随机的, Treap也不是严格的堆结构,所以它是一种期望平衡的弱平衡的平衡树,搜索、插入和删除的期望时间复杂度为 $O(\log n)$.

Treap的优势在于其实现起来非常简单,包含的也就上面这些东西了。

模板部分

定义

1
2
3
4
5
6
int cnt; // cnt 表示当前treap内节点总数,用于动态开点
const int maxn = XXXXX;
struct node {
// 分别表示左右子节点、所在子树的大小、当前节点值以及p值
int l, r, size, val, p;
}a[maxn];

回溯统计

1
2
3
void pushup(int k) {
a[k].size = a[a[k].l].size + a[a[k].r].size + 1;
}

旋转

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// 右旋
void rrotate(int &k) {
int tmp = a[k].l;
a[k].l = a[tmp].r;
a[tmp].r = k;
a[tmp].size = a[k].size;
pushup(k);
k = tmp;
}
// 左旋
void lrotate(int &k) {
int tmp = a[k].r;
a[k].r = a[tmp].l;
a[tmp].l = k;
a[tmp].size = a[k].size;
pushup(k);
k = tmp;
}

插入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void ins(int &k, int val) {
if (k == 0) {
k = ++cnt;
a[k].val = val;
a[k].p = rand();
a[k].size = 1;
return;
}
a[k].size++;
if (val >= a[k].val) {
ins(a[k].r, val);
}
else ins(a[k].l, val);
if (a[k].l && a[k].p > a[a[k].l].p) {
rrotate(k);
}
if (a[k].r && a[k].p > a[a[k].r].p) {
lrotate(k);
}
pushup(k);
}

删除

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
void del(int &k, int val) {
a[k].size--;
if (a[k].val == val) {
if (!a[k].l && !a[k].r) {
k = 0;
return;
}
if (!a[k].l || !a[k].r) {
k = a[k].l + a[k].r;
return;
}
if (a[a[k].l].p < a[a[k].r].p) {
rrotate(k);
del(a[k].r, val); return;
}
else {
lrotate(k);
del(a[k].l, val); return;
}
}
if (a[k].val >= val) {
del(a[k].l, val);
}
else {
del(a[k].r, val);
}
pushup(k);
}

例题分割线~

例题:

P3369 【模板】普通平衡树

其中还要维护查询排名,查询前驱后继等搜索操作。

当普通平衡树维护就可以啦,比较简单。

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 10;
int cnt;
int root;
// 注意按p是小根堆
struct node {
int l, r, size, val, p;
}a[maxn];
void pushup(int k) {
a[k].size = 1 + a[a[k].l].size + a[a[k].r].size;
}
void rrotate(int &k) {
int tmp = a[k].l;
a[k].l = a[tmp].r;
a[tmp].r = k;
a[tmp].size = a[k].size;
pushup(k);
k = tmp;
}
void lrotate(int &k) {
int tmp = a[k].r;
a[k].r = a[tmp].l;
a[tmp].l = k;
a[tmp].size = a[k].size;
pushup(k);
k = tmp;
}
void ins(int &k, int val) {
if (k == 0) {
k = ++cnt;
a[k].val = val;
a[k].p = rand();
a[k].size = 1;
return;
}
a[k].size++;
if (val >= a[k].val) {
ins(a[k].r, val);
}
else ins(a[k].l, val);
if (a[k].l && a[k].p > a[a[k].l].p) {
rrotate(k);
}
if (a[k].r && a[k].p > a[a[k].r].p) {
lrotate(k);
}
pushup(k);
}
void del(int &k, int val) {
a[k].size--;
if (a[k].val == val) {
if (!a[k].l && !a[k].r) {
k = 0;
return;
}
if (!a[k].l || !a[k].r) {
k = a[k].l + a[k].r;
return;
}
if (a[a[k].l].p < a[a[k].r].p) {
rrotate(k);
del(a[k].r, val); return;
}
else {
lrotate(k);
del(a[k].l, val); return;
}
}
if (a[k].val >= val) {
del(a[k].l, val);
}
else {
del(a[k].r, val);
}
pushup(k);
}
int rk(int k, int val) {
if (!k) return 0;
if (val > a[k].val) {
return a[a[k].l].size + rk(a[k].r, val) + 1;
}
return rk(a[k].l, val);
}
int find(int k, int rnk) {
if (rnk == a[a[k].l].size + 1) return a[k].val;
if (rnk > a[a[k].l].size + 1) return find(a[k].r, rnk - a[a[k].l].size - 1);
return find(a[k].l, rnk);
}
int query_pre(int k, int val) {
if (!k) return 0;
if (a[k].val >= val) {
return query_pre(a[k].l, val);
}
int tmp = query_pre(a[k].r, val);
if (!tmp) return a[k].val;
return tmp;
}
int query_suf(int k, int val) {
if (!k) return 0;
if (a[k].val <= val) {
return query_suf(a[k].r, val);
}
int tmp = query_suf(a[k].l, val);
if (!tmp) return a[k].val;
return tmp;
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
int n;
cin >> n;
int opt, x;
for (int i = 1; i <= n; ++i) {
cin >> opt >> x;
if (opt == 1) {
ins(root, x);
}
else if (opt == 2) {
del(root, x);
}
else if (opt == 3) {
cout << rk(root, x) + 1 << endl;
}
else if (opt == 4) {
cout << find(root, x) << endl;
}
else if (opt == 5) {
cout << query_pre(root, x) << endl;
}
else cout << query_suf(root, x) << endl;
}
return 0;
}