0%

Treap树

Treap树的定义

Treap树的名字即为Tree和Heap的结合,它既是一颗二叉查找树(BST),又满足堆的性质(通过左旋右旋不改变Treap的二叉查找树性质而又尽可能的平衡)

Treap树的结构

1
2
3
4
5
6
7
8
9
10
11
struct node{
int l,r,v,siz,rnd,cnt;
}p[manx];
int cnt;
//l:左节点id
//r:右节点id
//v:节点权重
//siz:子树大小
//p[x].cnt:v的出现次数,这样权重相同的节点作为一个点来处理
//cnt:总共的节点个数
//node节点中可以加入其他需要维护的信息

节点更新

每当节点关系更改时,需要更新节点信息,包括

  • 节点的子树大小
  • 维护的其他内容
    1
    2
    3
    inline void update(int id){
    p[id].siz=p[p[id].l].siz+p[p[id].r].siz+p[id].cnt;
    }

左旋和右旋

左旋和右旋是Treap树的核心操作,每个节点在插入时都被赋予了随机的优先级$rnd$,通过保证优先级的堆性质,进而保证了Treap树的期望高度是$\log n$,而不至于退化为链,是Treap树尽可能保持平衡的关键

  • 左旋和右旋不改变节点的二叉树查找树性质
  • 该操作保证$rnd$的小顶堆性质
  • 右旋: 左儿子的$rnd$小于父节点时,左儿子的右儿子$\Rightarrow$父节点的左儿子,父节点$\Rightarrow$左儿子的新右儿子
  • 左旋: 右儿子的$rnd$小于父节点时,右儿子的左儿子$\Rightarrow$父节点的右儿子,父节点$\Rightarrow$右儿子的新左儿子
    Treap树旋转图解
  • 右旋代码如下
    1
    2
    3
    4
    5
    6
    7
    8
    void rturn(int &k){
    int tmp=p[k].l;//记录左儿子
    p[k].l=p[tmp].r;//左儿子的右儿子变成父节点的左儿子
    p[tmp].r=k;//父节点变成左儿子的新右儿子
    p[tmp].siz=p[k].siz;//整个子树的大小不变
    update(k);//新右节点更新
    k=tmp;//左儿子变成了新父节点,完成了右旋上位
    }
  • 左旋同理

节点插入和删除

插入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
void insert(int &id,int val){
if(!id){
//如果到叶子节点了
id=++cnt;
p[id].siz=1;
p[id].cnt=1;
p[id].val=val;
p[id].rnd=rand();//随机优先级
return;
}
p[id].siz++;
if(p[id].val==val)
p[id].cnt++;
else if(p[id].val<val){
insert(p[id].r,val);
if(p[p[id].r].rnd<p[id].rnd)
lturn(id);
}else{
insert(p[id].l,val);
if(p[p[id].l].rnd<p[id].rnd)
rturn(id);
}
}

删除

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
void del(int &id,int val){
if(!id)
return;
if(p[id].val==val){
if(p[id].cnt>1){
p[id].cnt--;
p[id].siz--;
}else{
if(p[id].l==0||p[id].r==0)
id=p[id].l+p[id].r;//如果该节点最多只有一个子节点,那么子节点接上或者直接删除该节点
else if(p[p[id].l].rnd<p[p[id].r].rnd){
rturn(id);
del(id,val);
}else{
lturn(id);
del(id,val);
}
}
}else if(val>p[id].val){
p[id].siz--;
del(p[id].r,val);
}else{
p[id].siz--;
del(p[id].l,val);
}
}

查询val的排名

1
2
3
4
5
6
7
8
9
10
int find_rank(int id,int val){
if(id==0)
return 0;
if(p[id].val==val)
return p[p[id].l].siz+1;
else if(p[id].val<val)
return p[p[id].l].siz+p[id].cnt+find_rank(p[id].r,val);
else
return find_rank(p[id].l,val);
}

查询排名为k的数

  • 如果k小于父节点左子树的大小,那么这个数一定在左子树里面,就到左子树里查找排名为$k$的数
  • 如果k大于左子树与父节点的大小之和,那么到右子树去找排名为$k-(p[l].siz+cnt)$的数
  • 如果既不在左子树也不右子树,父节点就是我们要找的数
    1
    2
    3
    4
    5
    6
    7
    8
    9
    int kth(int id,int k)
    {
    if(k<=p[p[id].l].siz)
    return kth(p[id].l,k);
    else if(k>p[p[id].l].siz+p[id].cnt)
    return kth(p[id].r,k-(p[p[id].l].siz+p[id].cnt));
    else
    return p[id].val;
    }

查询val的前驱

1
2
3
4
5
6
7
8
9
const int INF=0x3f3f3f3f;
int pre(int id,int val){
if(!id)
return -INF;
if(p[id].val>=val)
return pre(p[id].l,val);
else
return max(p[id].val,pre(p[id].r,val));
}

查询val的后继

1
2
3
4
5
6
7
8
9
const int INF=0x3f3f3f3f;
int nxt(int id,int val){
if(!id)
return INF;
if(p[id].val<=val)
return nxt(p[id].r,val);
else
return min(p[id].val,nxt(p[id].l,val));
}

完整模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
const int INF = 0x3f3f3f3f;
struct node {
int l, r, val, siz, rnd, cnt;
} p[maxn];
int cnt;

inline void update(int id) {
p[id].siz = p[p[id].l].siz + p[p[id].r].siz + p[id].cnt;
}

void rturn(int &k) {
int tmp = p[k].l;
p[k].l = p[tmp].r;
p[tmp].r = k;
p[tmp].siz = p[k].siz;
update(k);
k = tmp;
}

void lturn(int &k) {
int tmp = p[k].r;
p[k].r = p[tmp].l;
p[tmp].l = k;
p[tmp].siz = p[k].siz;
update(k);
k = tmp;
}

void insert(int &id, int val) {
if (!id) {
id = ++cnt;
p[id].siz = 1;
p[id].cnt = 1;
p[id].val = val;
p[id].rnd = rand();
return;
}
p[id].siz++;
if (p[id].val == val)
p[id].cnt++;
else if (p[id].val < val) {
insert(p[id].r, val);
if (p[p[id].r].rnd < p[id].rnd)
lturn(id);
} else {
insert(p[id].l, val);
if (p[p[id].l].rnd < p[id].rnd)
rturn(id);
}
}

void del(int &id, int val) {
if (!id)
return;
if (p[id].val == val) {
if (p[id].cnt > 1) {
p[id].cnt--;
p[id].siz--;
} else {
if (p[id].l == 0 || p[id].r == 0)
id = p[id].l + p[id].r;
else if (p[p[id].l].rnd < p[p[id].r].rnd) {
rturn(id);
del(id, val);
} else {
lturn(id);
del(id, val);
}
}
} else if (val > p[id].val) {
p[id].siz--;
del(p[id].r, val);
} else {
p[id].siz--;
del(p[id].l, val);
}
}

int find_rank(int id, int val) {
if (id == 0)
return 0;
if (p[id].val == val)
return p[p[id].l].siz + 1;
else if (p[id].val < val)
return p[p[id].l].siz + p[id].cnt + find_rank(p[id].r, val);
else
return find_rank(p[id].l, val);
}

int kth(int id, int k) {
if (k <= p[p[id].l].siz)
return kth(p[id].l, k);
else if (k > p[p[id].l].siz + p[id].cnt)
return kth(p[id].r, k - (p[p[id].l].siz + p[id].cnt));
else
return p[id].val;
}

int pre(int id, int val) {
if (!id)
return -INF;
if (p[id].val >= val)
return pre(p[id].l, val);
else
return max(p[id].val, pre(p[id].r, val));
}

int nxt(int id, int val) {
if (!id)
return INF;
if (p[id].val <= val)
return nxt(p[id].r, val);
else
return min(p[id].val, nxt(p[id].l, val));
}