Skip to content

Instantly share code, notes, and snippets.

@cezarguimaraes
Last active February 12, 2017 12:39
Show Gist options
  • Save cezarguimaraes/5992d25027d25bc12890a119f7bff6d3 to your computer and use it in GitHub Desktop.
Save cezarguimaraes/5992d25027d25bc12890a119f7bff6d3 to your computer and use it in GitHub Desktop.
Splay Tree
#include <bits/stdc++.h>
using namespace std;
struct node {
node *l, *r, *p;
int key, s;
node() : l(nullptr), r(nullptr), p(nullptr), s(0) { }
};
node *left(node *t) { return t ? t->l : nullptr; }
node *right(node *t) { return t ? t->r : nullptr; }
node *p(node *t) { return t ? t->p : nullptr; }
node *g(node *t) { return p(p(t)); }
int size(node *t) { return t ? t->s : 0; }
void upd(node *t) {
if (t) t->s = 1 + size(left(t)) + size(right(t));
}
void link_left(node *t, node *c) {
t->l = c;
if (c) c->p = t;
upd(t);
}
void link_right(node *t, node *c) {
t->r = c;
if (c) c->p = t;
upd(t);
}
void rotate_left(node *t) {
node *x = right(t), *y = p(t);
link_right(t, left(x));
link_left(x, t);
if (t == left(y)) link_left(y, x);
else if (t == right(y)) link_right(y, x);
else x->p = nullptr;
}
void rotate_right(node *t) {
node *x = left(t), *y = p(t);
link_left(t, right(x));
link_right(x, t);
if (t == left(y)) link_left(y, x);
else if (t == right(y)) link_right(y, x);
else x->p = nullptr;
}
node *splay(node *t) {
do {
if (t == left(p(t))) {
if (!g(t)) rotate_right(p(t));
else if (p(t) == left(g(t))) {
rotate_right(g(t));
rotate_right(p(t));
} else {
rotate_right(p(t));
rotate_left(p(t));
}
}
if (t == right(p(t))) {
if (!g(t)) rotate_left(p(t));
else if (p(t) == right(g(t))) {
rotate_left(g(t));
rotate_left(p(t));
} else {
rotate_left(p(t));
rotate_right(p(t));
}
}
} while (p(t));
return t;
}
node *find(node *t, int key) {
if (!t) return nullptr;
if (key == t->key) return splay(t);
if (key < t->key) {
node *x = find(left(t), key);
return x ? x : splay(t);
} else {
node *x = find(right(t), key);
return x ? x : splay(t);
}
}
node* join(node *t1, node *t2) {
if (!t1 && !t2) return nullptr;
if (!t1 || !t2) return t1 ? t1 : t2;
node *t = t1;
while (right(t)) t = right(t);
splay(t);
link_right(t, t2);
return t;
}
pair<node *, node *> split(node *t, int key) {
if (!t) return make_pair(nullptr, nullptr);
t = find(t, key);
if (t->key <= key) {
node *l = t, *r = t->r;
link_right(l, nullptr);
if (r) r->p = nullptr;
return make_pair(l, r);
} else {
node *l = t->l, *r = t;
link_left(r, nullptr);
if (l) l->p = nullptr;
return make_pair(l, r);
}
}
node *insert(node *t, node *x) {
node *l, *r;
tie(l, r) = split(t, x->key);
link_left(x, l);
link_right(x, r);
return x;
}
node *erase(node *t, int key) {
t = find(t, key);
if (t && t->key == key) {
return join(left(t), right(t));
}
return t;
}
void print(node *t) {
if (t) {
print(left(t));
printf("%d ", t->key);
print(right(t));
}
}
int main() {
node *root = nullptr;
int n, m;
scanf("%d %d", &n, &m);
while (m--) {
char op;
int x;
scanf(" %c %d", &op, &x);
node *t = root, *last = nullptr;
int past = 0, rx = x;
while (t) {
last = t;
int hl = t->key - size(left(t)) - past - 1;
if (x <= hl) {
t = left(t);
} else {
past += size(left(t)) + 1;
rx = t->key + x - hl;
t = right(t);
}
}
if (last) root = splay(last);
if (op == 'L') printf("%d\n", rx);
else {
node *t = new node();
t->key = rx;
root = insert(root, t);
}
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment