Oct 26, 2022

Splay Tree: One Tree to Rule Them All

Are you scared when you hear about all these pesky data structures like Fenwick Tree, RMQ, Segment Tree in competitive programming?

Are you afraid of writing code to solve problems like finding the minimum, maximum, or sum of some range query, especially when there are updates to the data?

Well, fear no more! In this tutorial I will introduce the Swiss knife of all sequence manipulation data structure, one code that can (theoretically) solve every problem of this kind, one tree to rule them all - the Splay Tree!

Disclaimer: One drawback of the Splay Tree is having a somewhat large constant despite being asymptotically amortized \( O(\log n) \). While generally good problems should not punish solutions for having a large constant, using it on very strict problems has a risk of TLE (you can always substitute it with a less-demanding data structure if constant is the issue, though).

From an Array to a Tree

The easiest way to store a sequence of data: \( (1, 2, 4, 5, 6) \) is to put them in an array. An array is a simple and very good data structure that allows us to look up each element in \( O(1) \) time from the index.

However, the problem comes up when we want to insert some data into the sequence: let us say we want to put \( 3 \) between \( 2 \) and \( 4 \). Then there suddenly is no good way for an array to perform this operation except to move every element after \( 3 \), i.e. \( 4, 5, 6\), which results in \( O(n) \) complexity for \( n \) elements.

For those of you who have taken a data structure course in the college, you know that a linked list can do the reverse of an array: it can insert an element in \( O(1) \) but needs to look up an element in \( O(n) \) complexity. This is still not ideal, so let us try to solve the problem with a binary tree.

First, we need some way to establish the relation between a binary and a sequence. We will take the easy way: assume that everyone of the left subtree of a node comes before the node, and the right subtree of a node comes after the node. Note that this definition uniquely determines a sequence of nodes from a tree (also known as the in-order, for those who have taken a data structure course). The above example can thus be represented by this tree:

It is very easy to write code to find the \( i \)-th element of any tree and to insert an element:

Note: all codes in this tutorial are designed to minimize ICPC typing time, and thus disregard readability and common coding practice recklessly. I have added some comment to make them more readable, but I still encourage you to write a more 'readable' implementation on your own to make sure you understand the concept :)

struct node { // Father of the node node *f; // Children of the node node *c[2]; // Size of the subtree of the node int size; // Re-caculate the size. void update() { size = 1; for (int t = 0; t < 2; ++t) if (c[t]) size += c[t]->size; } }; // Helper function to decide which way to go and update pos accordingly. // v: direction. 0 - left, 1 - right. // pos: we want to find the pos-th node in the tree. // return size of the left subtree of n. int walk(node *n, int &v, int &pos) { int s = n->c[0] ? n->c[0]->size : 0; // Assign a value to v and update pos simultaneously. if (v = s < pos) pos -= s + 1; return s; } // Find the pos-th node in the tree. node *find(int pos) { node *c = root; int v; // If v == 1, we should go right. // Otherwise, if pos < size of the left subtree of n, we should go left. // Otherwise, pos == size of the left subtree of n, and we should stop. while (pos < walk(c, v, pos) || v) c = c->c[v]; return c; } // Insert node n to position pos. void insert(node *n, int pos) { node *c = root; int v; // Call walk first, and set c = c->c[v] if it is not null. while (walk(c, v, pos), c->c[v] && (c = c->c[v])); // Insert the node. c->c[v] = n, n->f = c; // Update the sizes. while (n) n->update(), n = n->f; }

Now let us take a moment to analyze the complexity of the algorithm. We can infer that the complexity of any operation above is the height of the tree \( O(h) \), and a binary tree can store \( n \) elements in \( \log n \) height, so we have good chances to make the complexity of both finding an element and inserting one \( O(\log n) \).

Unfortunately, using a tree to solve the problem is not so easy. A binary tree can degrade to a 'stick' tree if we do not pay attention. Consider this tree, which also represents the example above:

This type of 'stick' tree has a height of \( n \), which destroys our complexity assumption. Therefore, we need some method to maintain the height of our binary tree as low as \( O(\log n) \) in order for our complexity assumption to hold. This act is called to 'balance' a tree, and there are many ways to do so: AVL, Red-Black Tree, Scapegoat Tree, etc. However, we are going to introduce the only method capable of not only maintaining the balance of our tree, but also keeping track of auxiliary information that allows us to figure out those pesky 'sum of the segment of the sequence' problems - the Splay Tree!

Balancing the Tree

Rotation

To balance a binary tree, we need some tools. One common tool is known as a rotation:

Node \( n \) and \( p \) represents two nodes that we care about, and \( A, B, C \) are three subtrees of these nodes. Notice that both the left and the right tree has an in-order of \( (A, n, B, p, C) \), but they have slightly different height. Also notice that we can change the root of the tree to \( n \) via a rotation (i.e. rotate \( n \) up).

We can try to program this behavior in 10 lines:

// Rotate n up. void rotate(node *n) { // v: is n the left child or the right child? int v = n->f->c[0] == n; node *p = n->f, *m = n->c[v]; if (p->f) p->f->c[p->f->c[1] == p] = n; n->f = p->f, n->c[v] = p; p->f = n, p->c[v ^ 1] = m; if (m) m->f = p; // Update the sizes. p->update(), n->update(); }

The primitive idea of Splay is very simple: bad cases in binary trees occur because the node we access is very far down of the tree (i.e. has a very large height). Therefore, to reduce the cost of accessing this node in the future, we should rotate this node up. In fact, every time we access the node, let us rotate the node to the root!

Unfortunately, This does not quite work. Rotating the bottom node on a 'stick' tree to the root gives us another 'stick'-ish tree, which is not helpful in reducing the complexity:

Therefore, we need to use a new strategy to rotate a node to the top. Enter Splay:

Splay Operation

In a splay operation, we usually rotate node \( n \) up twice back-to-back to reduce its height by 2 (known as Zig-Zag):

However, there is a catch: if node \( n \) and its parent \( m \) are on the same side (i.e. both are the left children or both are the right children), instead of rotating \( n \) up twice, we rotate \( m \) up once first, then \( n \) up once, which also reduces its height by 2, like this (known as Zig-Zig, because two rotations are in the same direction):

Note: In the scenario where the parent of \( n \) is the root, it suffices to rotate \( n \) to the root directly in one rotation (known as Zig).

It is not immediately obvious why Zig-Zig works: the tree looks like a stick before and after the operation! But if we do it on the previous example it becomes more apparent.

Zig-Zig is somehow able to produce a tree with less height after one splay-to-root operation. To analyze why it works, however, we need to introduce a bit of witchery maths:

Complexity - Here be Maths

Note: This part is skippable if you do not want to know why Splay works. I think it is really interesting, though!

The Potential Method

Let us consider how we should argue the complexity of a dynamic rotating tree \( T \) over a sequence of find-and-then-splay-to-top operations \( Q \).

We first observe that finding a node has the same time complexity as splaying a node to the root, because they both access the node's path to the root and nothing else. We will restrict our discussion to the complexity of splaying a node to the root in the following analysis.

One common idea here is the Potential Method, which gives us a battery in addtion to executing the algorithm:

At the end of the execution, we can see that our total execution time is

\[ \text{the number of operations we used} + \text{the difference in battery before and after the execution.} \]

Usually we want our battery to be something indicative of the data structure we are using. In this case we define our battery \( \Phi \) of a tree \( T \) to be:

\[ \begin{align} k&=\text{the number of operations it takes to perform a splay operation}\text{,}\\ \text{size}(n)&=\text{the size of the subtree with root at }n\text{,}\\ \text{rank}(n)&=\log_2\text{size}(n)\text{,}\\ \Phi&=k\sum_{n\in T}\text{rank}(n)\text{.} \end{align} \]

Observe that \( \text{rank}(n) \leq \log|T| \), so \( 0 \leq \Phi \leq k|T|\log|T| \), which means that we have already bound the difference in battery before and after the execution \( \Delta\Phi \) to \( O(|T|\log|T|) \). Now we only need to ensure that each operation takes \( O(\log|T|) \) time, and we will have a very nice complexity of \( O((|Q|+|T|)\log|T|) \).

Bounding the Operation

Let us consider the case of a Zig-Zag:

\[ \begin{align} \text{TIME}(\text{Zig-Zag})&=k+k(\text{rank}'(n)+\text{rank}'(m)+\text{rank}'(l)-\text{rank}(n)-\text{rank}(m)-\text{rank}(l))\\ &=k(1+\text{rank}'(n)+\text{rank}'(m)+\text{rank}'(l)-\text{rank}(n)-\text{rank}(m)-\text{rank}(l))\text{.} \end{align} \]

Since \( \text{rank}(l)=\text{rank}'(n) \) (both are the whole subtree),

\[ \begin{align} \text{TIME}(\text{Zig-Zag})&=k(1+\text{rank}'(m)+\text{rank}'(l)-\text{rank}(n)-\text{rank}(m))\\ &< k(1+\text{rank}'(m)+\text{rank}'(l)-2\text{rank}(n))\\ \end{align} \]

Since

\[ \begin{align} 1+\text{rank}'(m)+\text{rank}'(l)&=\log 2+\log \text{size}'(m)+\log \text{size}'(l)\\ &=\log 2\text{size}'(m)\text{size}'(l)\\ &<\log (\text{size}'(m)+\text{size}'(l))^2\\ &<\log \text{size}'^2(n)\\ &=\log \text{size}'(n)+\log \text{size}'(n)=2\text{rank}'(n)\text{,}\\ \end{align} \]

We demonstrate

\[ \text{TIME}(\text{Zig-Zag})<k(2\text{rank}'(n)-2\text{rank}(n))=2k(\text{rank}'(n)-\text{rank}(n))\text{.} \]

Applying the same reasoning to the Zig-Zig case:

\[ \begin{align} \text{TIME}(\text{Zig-Zig})&=k+k(\text{rank}'(n)+\text{rank}'(m)+\text{rank}'(l)-\text{rank}(n)-\text{rank}(m)-\text{rank}(l))\\ &=k(1+\text{rank}'(n)+\text{rank}'(m)+\text{rank}'(l)-\text{rank}(n)-\text{rank}(m)-\text{rank}(l))\\ &=k(1+\text{rank}'(m)+\text{rank}'(l)-\text{rank}(n)-\text{rank}(m))\\ &< k(1+\text{rank}'(m)+\text{rank}'(l)-2\text{rank}(n))\\ &=k(\log 2\text{size}'(m)\text{size}'(l)-2\text{rank}(n))\text{.} \end{align} \]

And now we are going to use the unique property of the Zig-Zig operation - that

\[ \text{size}(n)+\text{size}'(l)+1=(\text{size}(A)+\text{size}(B)+1)+(\text{size}(C)+\text{size}(D)+1)+1=\text{size}'(n)\text{.} \]

Therefore,

\[ \begin{align} \text{TIME}(\text{Zig-Zig})&<k(\log 2\text{size}'(m)\text{size}'(l)-2\text{rank}(n))\\ &=k(\log 2\text{size}'(m)(\text{size}'(n)-\text{size}(n)-1)-2\text{rank}(n))\\ &<k(\log 2\text{size}'(n)(\text{size}'(n)-\text{size}(n))-2\text{rank}(n))\\ &=k(\log 2\text{size}'(n)(\text{size}'(n)-\text{size}(n))+\text{rank}(n)-3\text{rank}(n))\\ &=k(\log 2\text{size}'(n)(\text{size}'(n)-\text{size}(n))+\log\text{size}(n)-3\text{rank}(n))\\ &=k(\log 2\text{size}'(n)\text{size}(n)(\text{size}'(n)-\text{size}(n))-3\text{rank}(n))\text{.} \end{align} \]

Since

\[ \text{size}(n)(\text{size}'(n)-\text{size}(n)) < \frac{1}{4}\text{size}'^2(n) < \frac{1}{2}\text{size}'^2(n)\text{,} \] \[ \begin{align} \text{TIME}(\text{Zig-Zig})&<k(\log 2\text{size}'(n)\text{size}(n)(\text{size}'(n)-\text{size}(n))-3\text{rank}(n))\\ &<k(\log \text{size}'^3(n)-3\text{rank}(n))=3k(\text{rank}'(n)-\text{rank}(n))\text{.} \end{align} \]

We conclude that

\[ \text{TIME}(\text{Zig-Zag}),\text{TIME}(\text{Zig-Zig})<3k(\text{rank}'(n)-\text{rank}(n))\text{.} \]

Since splaying to the root is a series of Zig-Zag and Zig-Zig operations that starts with a node of rank \( \text{rank}(n) \) and ends with a node of rank \( \text{rank}(\text{root}) = \log|T| \),

\[ \begin{align} \text{TIME}(\text{Splay-to-root})=3k(\text{rank}(\text{root})-\text{rank}(n))=O(\log |T|)\text{.} \end{align} \]

And we have bound the splay-to-root operation to \( O(\log |T|) \) time complexity.

Implementation

With all the maths out of the way we can start actually writing the code to do Splay operations in 10 lines:

// Splay n so that it is under s (or to root if s is null). void splay(node *n, node *s = nullptr) { while (n->f != s) { node *m = n->f, *l = m->f; if (l == s) rotate(n); else if ((l->c[0] == m) == (m->c[0] == n)) rotate(m), rotate(n); else rotate(n), rotate(n); } if (!s) root = n; }

Operations on a Splay Tree

To recap what we have done, we have implemented a Splay Tree that allows

in amortized \( O(\log n) \) time, which is cool because \( O(\log n) \) is very small. We are going to look at some more use cases, and I will provide a concrete example with code.

Sum of a Segment

The most common thing to do in a sequence \( (x_1,x_2,\ldots,x_n) \) is probably to find the sum of a segment \( \sum_{i=l}^{r}x_i \).

It may surprise you that we have already (sort of) implemented in our previoud code. Recall that we have kept track of each node's subtree size with node.size and node.update(). It turns out that we can keep track of each node's subtree sum in the same way. The problem is then to find a subtree that corresponds to \( [x_l,x_r] \).

The solution is quite strightforward: we can splay node \( x_{r+1} \) to root, and node \( x_{l-1} \) to just below \( x_{r+1} \). Then a subtree that corresponds to \( [x_l,x_r] \) will appear:

Note: This assumes that \( x_{l-1} \) and \( x_{r+1} \) exist. In pratice this is trivially ensured by adding a dummy head and a dummy tail to the tree.

Delete a Segment

Deletion in Splay is easy: say we want to delete a segment \( [x_l,x_r] \). We only need to find a subtree that corresponds to \( [x_l,x_r] \), and drop it.

Set a Segment to Some Value

Often times we need to manipulate a segment \( [x_l,x_r] \) in some way: add some value, set the value to something, etc. We will look at a example where we want to set every element in \( [x_l,x_r] \) to some value \( v \).

The way to implement the operation is to maintain a lazy label on every node. A lazy label can either be null or some value. If the label \( c \) is not null, then we treat the corresponding node's subtree as if every node in that subtree has a value of \( c \). Therefore, to apply the operation, we only need to find the subtree of \( [x_l,x_r] \) and give the root a value and a label of \( c \).

Obviously, when we actually try to use a node, we need to make sure that its ancestors do not have any lazy labels, because these labels can interfere with the node's value. Fortunately, we can use push_down to clear the label for every ancestor node, starting from the root:

  1. Apply the operation of setting value on the two children of the node:
    1. Set the value of the two children of the node to \( c \).
    2. Set the label of the two children of the node to \( c \).
  2. Set the label of the node to null.
Then we can rewrite the function that finds a node to apply push_down repeatedly:
// Helper function to decide which way to go and update pos accordingly. // v: direction. 0 - left, 1 - right. // pos: we want to find the pos-th node in the tree. // return size of the left subtree of n. int walk(node *n, int &v, int &pos) { n->push_down(); int s = n->c[0] ? n->c[0]->size : 0; // Assign a value to v and update pos simultaneously. if (v = s < pos) pos -= s + 1; return s; } // Find the node at position pos. If sp is true, splay it. node *find(int pos, int sp = true) { node *c = root; int v; // Account for the dummy head of the tree. ++pos; // If v == 1, we should go right. // Otherwise, if pos < size of the left subtree of n, we should go left. // Otherwise, pos == size of the left subtree of n, and we should stop. while (pos < walk(c, v, pos) || v) c = c->c[v]; if (sp) splay(c); return c; }

Reverse a Segment

Finally we are going to look at an interesting example: reversing a segment \( [x_l,x_r] \) in the sequence to \( (x_r, x_{r-1}, \ldots, x_l) \).

We implement this operation in the same way above: Maintain a label to record if the subtree should be reversed. Then, when doing push_down:

  1. If the label is set then apply the operation of reversing on the two children of the node:
    1. Swap the left child's two children.
    2. Swap the right child's two children.
    3. Set the label of the two children of the node to the reverse of their current labels.
  2. Set the label of the node to null.

A Concrete Example

We are going to look at this problem from POJ 3580.

I prefer a more template-based approach in these large data structure problems: copy-paste everything verbatim from the template and add code only after the template. I have added some comment in the code, so see if you can follow it:

// Template starts here: #include <bits/stdc++.h> namespace allocator { // Array allocator. template <class T, int MAXSIZE> struct array { T v[MAXSIZE], *top; array() : top(v) {} T *alloc(const T &val = T()) { return &(*top++ = val); } void dealloc(T *p) {} }; // Stack-based array allocator. template <class T, int MAXSIZE> struct stack { T v[MAXSIZE]; T *spot[MAXSIZE], **top; stack() { for (int i = 0; i < MAXSIZE; ++i) spot[i] = v + i; top = spot + MAXSIZE; } T *alloc(const T &val = T()) { return &(**--top = val); } void dealloc(T *p) { *top++ = p; } }; } // namespace allocator namespace splay { // Abstract node struct. template <class T> struct node { T *f, *c[2]; int size; node() { f = c[0] = c[1] = nullptr; size = 1; } void push_down() {} void update() { size = 1; for (int t = 0; t < 2; ++t) if (c[t]) size += c[t]->size; } }; // Abstract reversible node struct. template <class T> struct reversible_node : node<T> { int r; reversible_node() : node<T>() { r = 0; } void push_down() { node<T>::push_down(); if (r) { for (int t = 0; t < 2; ++t) if (node<T>::c[t]) node<T>::c[t]->reverse(); r = 0; } } void update() { node<T>::update(); } // Reverse the range of this node. void reverse() { std::swap(node<T>::c[0], node<T>::c[1]); r = r ^ 1; } }; template <class T, int MAXSIZE = 500000, class alloc = allocator::array<T, MAXSIZE + 2>> struct tree { alloc pool; T *root; // Get a new node from the pool. T *new_node(const T &val = T()) { return pool.alloc(val); } tree() { root = new_node(), root->c[1] = new_node(), root->size = 2; root->c[1]->f = root; } // Helper function to rotate node. void rotate(T *n) { int v = n->f->c[0] == n; T *p = n->f, *m = n->c[v]; if (p->f) p->f->c[p->f->c[1] == p] = n; n->f = p->f, n->c[v] = p; p->f = n, p->c[v ^ 1] = m; if (m) m->f = p; p->update(), n->update(); } // Splay n so that it is under s (or to root if s is null). void splay(T *n, T *s = nullptr) { while (n->f != s) { T *m = n->f, *l = m->f; if (l == s) rotate(n); else if ((l->c[0] == m) == (m->c[0] == n)) rotate(m), rotate(n); else rotate(n), rotate(n); } if (!s) root = n; } // Get the size of the tree. int size() { return root->size - 2; } // Helper function to walk down the tree. int walk(T *n, int &v, int &pos) { n->push_down(); int s = n->c[0] ? n->c[0]->size : 0; (v = s < pos) && (pos -= s + 1); return s; } // Insert node n to position pos. void insert(T *n, int pos) { T *c = root; int v; ++pos; while (walk(c, v, pos), c->c[v] && (c = c->c[v])) ; c->c[v] = n, n->f = c, splay(n); } // Find the node at position pos. If sp is true, splay it. T *find(int pos, int sp = true) { T *c = root; int v; ++pos; while ((pos < walk(c, v, pos) || v) && (c = c->c[v])) ; if (sp) splay(c); return c; } // Find the range [posl, posr) on the splay tree. T *find_range(int posl, int posr) { T *r = find(posr), *l = find(posl - 1, false); splay(l, r); if (l->c[1]) l->c[1]->push_down(); return l->c[1]; } // Insert nn of size nn_size to position pos. void insert_range(T **nn, int nn_size, int pos) { T *r = find(pos), *l = find(pos - 1, false), *c = l; splay(l, r); for (int i = 0; i < nn_size; ++i) c->c[1] = nn[i], nn[i]->f = c, c = nn[i]; for (int i = nn_size - 1; i >= 0; --i) nn[i]->update(); l->update(), r->update(), splay(nn[nn_size - 1]); } // Helper function to dealloc a subtree. void dealloc(T *n) { if (!n) return; dealloc(n->c[0]); dealloc(n->c[1]); pool.dealloc(n); } // Remove from position [posl, posr). void erase_range(int posl, int posr) { T *n = find_range(posl, posr); n->f->c[1] = nullptr, n->f->update(), n->f->f->update(), n->f = nullptr; dealloc(n); } }; } // namespace splay // TODO: // 1. Define a node inheriting from splay::node<node> or splay::reversible_node<node>. // 2. Add the values you want to use. // 3. Overload the constructor, push_down, update, (potentially) reverse. // (Do not forget to call the base method.) // 4. Add whatever operations you want. // 5. Define a tree with splay::tree<node, MAXSIZE, allocator::stack<node, MAXSIZE + 2>> t; // 6. Profit. // Template ends here. const int MAXSIZE = 200000; struct node: splay::reversible_node<node> { long long val, val_min, label_add; node(long long v = 0) : splay::reversible_node<node>(), val(v) { val_min = label_add = 0; } // Add v to the subtree. void add(long long v) { val += v; val_min += v; label_add += v; } void push_down() { splay::reversible_node<node>::push_down(); for (int t = 0; t < 2; ++t) if (c[t]) c[t]->add(label_add); label_add = 0; } void update() { splay::reversible_node<node>::update(); val_min = val; for (int t = 0; t < 2; ++t) if (c[t]) val_min = std::min(val_min, c[t]->val_min); } }; splay::tree<node, MAXSIZE, allocator::stack<node, MAXSIZE + 2>> t; int main() { int N, M; scanf("%d", &N); while (N--) { long long u; scanf("%lld", &u); t.insert(t.new_node(node(u)), t.size()); } scanf("%d", &M); while (M--) { char c[10]; scanf("%s", c); if (strcmp(c, "ADD") == 0) { int x, y; long long D; scanf("%d%d%lld", &x, &y, &D); t.find_range(x - 1, y)->add(D); } else if (strcmp(c, "REVERSE") == 0) { int x, y; scanf("%d%d", &x, &y); t.find_range(x - 1, y)->reverse(); } else if (strcmp(c, "REVOLVE") == 0) { int x, y; long long T; scanf("%d%d%lld", &x, &y, &T); T %= (y - x + 1); if (T > 0) { // swap [x - 1, y - T) and [y - T, y) node *right = t.find_range(y - T, y); right->f->c[1] = nullptr, right->f->update(), right->f->f->update(), right->f = nullptr; t.insert(right, x - 1); } } else if (strcmp(c, "INSERT") == 0) { int x; long long P; scanf("%d%lld", &x, &P); t.insert(t.new_node(node(P)), x); } else if (strcmp(c, "DELETE") == 0) { int x; scanf("%d", &x); t.erase_range(x - 1, x); } else if (strcmp(c, "MIN") == 0) { int x, y; scanf("%d%d", &x, &y); printf("%lld\n", t.find_range(x - 1, y)->val_min); } } }

Practice Problems