Module: Mesa
Branch: main
Commit: e8b3006bfdfdcec4116eee90fbf96161420f73cc
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=e8b3006bfdfdcec4116eee90fbf96161420f73cc

Author: Connor Abbott <cwabbo...@gmail.com>
Date:   Mon Feb  6 17:57:27 2023 +0100

util/rb_tree: Add augmented trees and interval trees

An "augmented tree" is a tree with extra data attached which flows from
the leaves to the root. An "interval tree" is a datastructure of
(potentially-overlapping) intervals where, in addition to inserting and
removing intervals, we can quickly lookup all the intervals which
overlap a given interval.

After describing red-black trees, CLRS explains how it's possible to
implement an interval tree using an augmented red-black tree where the
nodes are ordered by interval start and each node also stores the
maximum interval end for its entire subtree.

Implement the interval tree extension described by CLRS. Iterating over
all overlapping intervals is actually an exercise, so we have to solve
the exercise. The recursive solution has been re-written to use the
parent pointers to avoid needing a stack, similarly to rb_tree_first()
and rb_node_next().

For now, we only implement unsigned intervals, but the core algorithms
are all abstracted to allow other types. There's still some boilerplate,
but it's the best that can be done in C.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22071>

---

 src/.clang-format               |   1 +
 src/util/rb_tree.c              | 286 +++++++++++++++++++++++++++++++++++++---
 src/util/rb_tree.h              | 136 ++++++++++++++++++-
 src/util/tests/rb_tree_test.cpp |  87 +++++++++++-
 4 files changed, 486 insertions(+), 24 deletions(-)

diff --git a/src/.clang-format b/src/.clang-format
index 0f0916c369e..802b782ed03 100644
--- a/src/.clang-format
+++ b/src/.clang-format
@@ -62,6 +62,7 @@ ForEachMacros:
   - rb_tree_foreach_rev
   - rb_tree_foreach_rev_safe
   - rb_tree_foreach_safe
+  - uinterval_tree_foreach
 
   - set_foreach
   - set_foreach_remove
diff --git a/src/util/rb_tree.c b/src/util/rb_tree.c
index a3bdd1b4912..98784129b17 100644
--- a/src/util/rb_tree.c
+++ b/src/util/rb_tree.c
@@ -39,6 +39,8 @@
 #include <string.h>
 #include <assert.h>
 
+#include "macros.h"
+
 static bool
 rb_node_is_black(struct rb_node *n)
 {
@@ -118,7 +120,8 @@ rb_tree_splice(struct rb_tree *T, struct rb_node *u, struct 
rb_node *v)
 }
 
 static void
-rb_tree_rotate_left(struct rb_tree *T, struct rb_node *x)
+rb_tree_rotate_left(struct rb_tree *T, struct rb_node *x,
+                    void (*update)(struct rb_node *))
 {
     assert(x && x->right);
 
@@ -129,10 +132,15 @@ rb_tree_rotate_left(struct rb_tree *T, struct rb_node *x)
     rb_tree_splice(T, x, y);
     y->left = x;
     rb_node_set_parent(x, y);
+    if (update) {
+        update(x);
+        update(y);
+    }
 }
 
 static void
-rb_tree_rotate_right(struct rb_tree *T, struct rb_node *y)
+rb_tree_rotate_right(struct rb_tree *T, struct rb_node *y,
+                     void (*update)(struct rb_node *))
 {
     assert(y && y->left);
 
@@ -143,15 +151,23 @@ rb_tree_rotate_right(struct rb_tree *T, struct rb_node *y)
     rb_tree_splice(T, y, x);
     x->right = y;
     rb_node_set_parent(y, x);
+    if (update) {
+        update(y);
+        update(x);
+    }
 }
 
 void
-rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
-                  struct rb_node *node, bool insert_left)
+rb_augmented_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
+                            struct rb_node *node, bool insert_left,
+                            void (*update)(struct rb_node *node))
 {
     /* This sets null children, parent, and a color of red */
     memset(node, 0, sizeof(*node));
 
+    if (update)
+       update(node);
+
     if (parent == NULL) {
         assert(T->root == NULL);
         T->root = node;
@@ -168,6 +184,14 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node 
*parent,
     }
     rb_node_set_parent(node, parent);
 
+    if (update) {
+        struct rb_node *p = parent;
+        while (p) {
+            update(p);
+            p = rb_node_parent(p);
+        }
+    }
+
     /* Now we do the insertion fixup */
     struct rb_node *z = node;
     while (rb_node_is_red(rb_node_parent(z))) {
@@ -185,7 +209,7 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
             } else {
                 if (z == z_p->right) {
                     z = z_p;
-                    rb_tree_rotate_left(T, z);
+                    rb_tree_rotate_left(T, z, update);
                     /* We changed z */
                     z_p = rb_node_parent(z);
                     assert(z == z_p->left || z == z_p->right);
@@ -193,7 +217,7 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
                 }
                 rb_node_set_black(z_p);
                 rb_node_set_red(z_p_p);
-                rb_tree_rotate_right(T, z_p_p);
+                rb_tree_rotate_right(T, z_p_p, update);
             }
         } else {
             struct rb_node *y = z_p_p->left;
@@ -205,7 +229,7 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
             } else {
                 if (z == z_p->left) {
                     z = z_p;
-                    rb_tree_rotate_right(T, z);
+                    rb_tree_rotate_right(T, z, update);
                     /* We changed z */
                     z_p = rb_node_parent(z);
                     assert(z == z_p->left || z == z_p->right);
@@ -213,7 +237,7 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
                 }
                 rb_node_set_black(z_p);
                 rb_node_set_red(z_p_p);
-                rb_tree_rotate_left(T, z_p_p);
+                rb_tree_rotate_left(T, z_p_p, update);
             }
         }
     }
@@ -221,7 +245,8 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
 }
 
 void
-rb_tree_remove(struct rb_tree *T, struct rb_node *z)
+rb_augmented_tree_remove(struct rb_tree *T, struct rb_node *z,
+                         void (*update)(struct rb_node *))
 {
     /* x_p is always the parent node of X.  We have to track this
      * separately because x may be NULL.
@@ -260,6 +285,14 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z)
 
     assert(x_p == NULL || x == x_p->left || x == x_p->right);
 
+    if (update) {
+        struct rb_node *p = x_p;
+        while (p) {
+            update(p);
+            p = rb_node_parent(p);
+        }
+    }
+
     if (!y_was_black)
         return;
 
@@ -270,7 +303,7 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z)
             if (rb_node_is_red(w)) {
                 rb_node_set_black(w);
                 rb_node_set_red(x_p);
-                rb_tree_rotate_left(T, x_p);
+                rb_tree_rotate_left(T, x_p, update);
                 assert(x == x_p->left);
                 w = x_p->right;
             }
@@ -281,13 +314,13 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z)
                 if (rb_node_is_black(w->right)) {
                     rb_node_set_black(w->left);
                     rb_node_set_red(w);
-                    rb_tree_rotate_right(T, w);
+                    rb_tree_rotate_right(T, w, update);
                     w = x_p->right;
                 }
                 rb_node_copy_color(w, x_p);
                 rb_node_set_black(x_p);
                 rb_node_set_black(w->right);
-                rb_tree_rotate_left(T, x_p);
+                rb_tree_rotate_left(T, x_p, update);
                 x = T->root;
             }
         } else {
@@ -295,7 +328,7 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z)
             if (rb_node_is_red(w)) {
                 rb_node_set_black(w);
                 rb_node_set_red(x_p);
-                rb_tree_rotate_right(T, x_p);
+                rb_tree_rotate_right(T, x_p, update);
                 assert(x == x_p->right);
                 w = x_p->left;
             }
@@ -306,13 +339,13 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z)
                 if (rb_node_is_black(w->left)) {
                     rb_node_set_black(w->right);
                     rb_node_set_red(w);
-                    rb_tree_rotate_left(T, w);
+                    rb_tree_rotate_left(T, w, update);
                     w = x_p->left;
                 }
                 rb_node_copy_color(w, x_p);
                 rb_node_set_black(x_p);
                 rb_node_set_black(w->left);
-                rb_tree_rotate_right(T, x_p);
+                rb_tree_rotate_right(T, x_p, update);
                 x = T->root;
             }
         }
@@ -378,6 +411,229 @@ rb_node_prev(struct rb_node *node)
     }
 }
 
+/* Return the first node in an interval tree that intersects a given interval
+ * or point. The tests against the interval and the max field are abstracted
+ * via function pointers, so that this works for any type of interval.
+ */
+static struct rb_node *
+rb_node_min_intersecting(struct rb_node *node, void *interval,
+                         int (*cmp_interval)(const struct rb_node *node,
+                                             const void *interval),
+                         bool (*cmp_max)(const struct rb_node *node, 
+                                         const void *interval))
+{
+    if (!cmp_max(node, interval))
+        return NULL;
+
+    while (node) {
+        int cmp = cmp_interval(node, interval);
+
+        /* If the node's interval is entirely to the right of the interval
+         * we're searching for, all of its right descendants are also to the
+         * right and don't intersect so we have to search to the left.
+         */
+        if (cmp > 0) {
+            node = node->left;
+            continue;
+        }
+
+        /* The interval overlaps or is to the left. This must also be true for
+         * its left descendants because their start points are to the left of
+         * node's. We can use the max to tell if there is an interval in its
+         * left descendants which overlaps our interval, in which case we
+         * should descend to the left.
+         */
+        if (node->left && cmp_max(node->left, interval)) {
+            node = node->left;
+            continue;
+        }
+
+        /* Now the only possibilities are the node's interval intersects the
+         * interval or one of its right descendants does.
+         */
+        if (cmp == 0)
+            return node;
+
+        node = node->right;
+        if (node && !cmp_max(node, interval))
+            return NULL;
+    }
+
+    return NULL;
+}
+
+/* Return the next node after "node" that intersects a given interval.
+ *
+ * Because rb_node_min_intersecting() takes O(log n) time and may be called up
+ * to O(log n) times, in addition to the O(log n) crawl up the tree, a naive
+ * runtime analysis would show that this takes O((log n)^2) time, but actually
+ * it's O(log n). Proving this is tricky:
+ *
+ * Call the rightmost node in the tree whose start is before the end of the
+ * interval we're searching for N. All nodes after N in the tree are to the
+ * right of the interval. We'll divide the search into two phases: in phase 1,
+ * "node" is *not* an ancestor of N, and in phase 2 it is. Because we always
+ * crawl up the tree, phase 2 cannot turn back into phase 1, but phase 1 may
+ * be followed by phase 2. We'll prove that the calls to
+ * rb_node_min_intersecting() take O(log n) time in both phases.
+ *
+ * Phase 1: Because "node" is to the left of N and N isn't a descendant of
+ * "node", the start of every interval in "node"'s subtree must be less than
+ * or equal to N's start which is less than or equal to the end of the search
+ * interval. Furthermore, either "node"'s max_end is less than the start of
+ * the interval, in which case rb_node_min_intersecting() immediately returns
+ * NULL, or some descendant has an end equal to "node"'s max_end which is
+ * greater than or equal to the search interval's start, and therefore it
+ * intersects the search interval and rb_node_min_intersecting() must return
+ * non-NULL which causes us to terminate. rb_node_min_intersecting() is called
+ * O(log n) times, with all but the last call taking constant time and the
+ * last call taking O(log n), so the overall runtime is O(log n).
+ *
+ * Phase 2: After the first call to rb_node_min_intersecting, we may crawl up
+ * the tree until we get to a node p where "node", and therefore N, is in p's
+ * left subtree. However this means that p is to the right of N in the tree
+ * and is therefore to the right of the search interval, and the search
+ * terminates on the first iteration of the loop so that
+ * rb_node_min_intersecting() is only called once.
+ */
+static struct rb_node *
+rb_node_next_intersecting(struct rb_node *node,
+                          void *interval,
+                          int (*cmp_interval)(const struct rb_node *node,
+                                              const void *interval),
+                          bool (*cmp_max)(const struct rb_node *node,
+                                          const void *interval))
+{
+    while (true) {
+        /* The first place to search is the node's right subtree. */
+        if (node->right) {
+            struct rb_node *next =
+                rb_node_min_intersecting(node->right, interval, cmp_interval, 
cmp_max);
+            if (next)
+                return next;
+        }
+
+        /* If we don't find a matching interval there, crawl up the tree until
+         * we find an ancestor to the right. This is the next node after the
+         * right subtree which we determined didn't match.
+         */
+        struct rb_node *p = rb_node_parent(node);
+        while (p && node == p->right) {
+            node = p;
+            p = rb_node_parent(node);
+        }
+        assert(p == NULL || node == p->left);
+
+        /* Check if we've searched everything in the tree. */
+        if (!p)
+            return NULL;
+
+        int cmp = cmp_interval(p, interval);
+
+        /* If it intersects, return it. */
+        if (cmp == 0)
+            return p;
+
+        /* If it's to the right of the interval, all following nodes will be
+         * to the right and we can bail early.
+         */
+        if (cmp > 0)
+            return NULL;
+
+        node = p;
+    }
+}
+
+static int
+uinterval_cmp(struct uinterval a, struct uinterval b)
+{
+    if (a.end < b.start)
+        return -1;
+    else if (b.end < a.start)
+        return 1;
+    else
+        return 0;
+}
+
+static int
+uinterval_node_cmp(const struct rb_node *_a, const struct rb_node *_b)
+{
+    const struct uinterval_node *a = rb_node_data(struct uinterval_node, _a, 
node);
+    const struct uinterval_node *b = rb_node_data(struct uinterval_node, _b, 
node);
+
+    return (int) (b->interval.start - a->interval.start);
+}
+
+static int
+uinterval_search_cmp(const struct rb_node *_node, const void *_interval)
+{
+    const struct uinterval_node *node = rb_node_data(struct uinterval_node, 
_node, node);
+    const struct uinterval *interval = _interval;
+
+    return uinterval_cmp(node->interval, *interval);
+}
+
+static bool
+uinterval_max_cmp(const struct rb_node *_node, const void *data)
+{
+    const struct uinterval_node *node = rb_node_data(struct uinterval_node, 
_node, node);
+    const struct uinterval *interval = data;
+
+    return node->max_end >= interval->start;
+}
+
+static void
+uinterval_update_max(struct rb_node *_node)
+{
+    struct uinterval_node *node = rb_node_data(struct uinterval_node, _node, 
node);
+    node->max_end = node->interval.end;
+    if (node->node.left) {
+        struct uinterval_node *left = rb_node_data(struct uinterval_node, 
node->node.left, node);
+        node->max_end = MAX2(node->max_end, left->max_end);
+    }
+    if (node->node.right) {
+        struct uinterval_node *right = rb_node_data(struct uinterval_node, 
node->node.right, node);
+        node->max_end = MAX2(node->max_end, right->max_end);
+    }
+}
+
+void
+uinterval_tree_insert(struct rb_tree *tree, struct uinterval_node *node)
+{
+    rb_augmented_tree_insert(tree, &node->node, uinterval_node_cmp,
+                             uinterval_update_max);
+}
+
+void
+uinterval_tree_remove(struct rb_tree *tree, struct uinterval_node *node)
+{
+    rb_augmented_tree_remove(tree, &node->node, uinterval_update_max);
+}
+
+struct uinterval_node *
+uinterval_tree_first(struct rb_tree *tree, struct uinterval interval)
+{
+    if (!tree->root)
+        return NULL;
+
+    struct rb_node *node =
+        rb_node_min_intersecting(tree->root, &interval, uinterval_search_cmp,
+                                 uinterval_max_cmp);
+
+    return node ? rb_node_data(struct uinterval_node, node, node) : NULL;
+}
+
+struct uinterval_node *
+uinterval_node_next(struct uinterval_node *node,
+                    struct uinterval interval)
+{
+    struct rb_node *next =
+        rb_node_next_intersecting(&node->node, &interval, uinterval_search_cmp,
+                                  uinterval_max_cmp);
+
+    return next ? rb_node_data(struct uinterval_node, next, node) : NULL;
+}
+
 static void
 validate_rb_node(struct rb_node *n, int black_depth)
 {
diff --git a/src/util/rb_tree.h b/src/util/rb_tree.h
index 5e00977b5ba..b5400306fa4 100644
--- a/src/util/rb_tree.h
+++ b/src/util/rb_tree.h
@@ -117,6 +117,36 @@ struct rb_node *rb_node_prev(struct rb_node *node);
 #define rb_node_data(type, node, field) \
     ((type *)(((char *)(node)) - rb_tree_offsetof(type, field, node)))
 
+/** Insert a node into a possibly augmented tree at a particular location
+ *
+ * This function should probably not be used directly as it relies on the
+ * caller to ensure that the parent node is correct.  Use rb_tree_insert
+ * instead.
+ *
+ * If \p update is non-NULL, it will be called for the node being inserted as
+ * well as any nodes which have their children changed and all of their
+ * ancestors. The intent is that each node may contain some augmented data
+ * which is calculated recursively from the node itself and its children, and
+ * \p update should recalculate that data. It's assumed that the function used
+ * to calculate the node data is associative in order to avoid calling it
+ * redundantly after rebalancing the tree.
+ *
+ * \param   T           The red-black tree into which to insert the new node
+ *
+ * \param   parent      The node in the tree that will be the parent of the
+ *                      newly inserted node
+ *
+ * \param   node        The node to insert
+ *
+ * \param   insert_left If true, the new node will be the left child of
+ *                      \p parent, otherwise it will be the right child
+ *
+ * \param   update      The optional function used to calculate per-node data
+ */
+void rb_augmented_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
+                                 struct rb_node *node, bool insert_left,
+                                 void (*update)(struct rb_node *));
+
 /** Insert a node into a tree at a particular location
  *
  * This function should probably not be used directly as it relies on the
@@ -133,20 +163,27 @@ struct rb_node *rb_node_prev(struct rb_node *node);
  * \param   insert_left If true, the new node will be the left child of
  *                      \p parent, otherwise it will be the right child
  */
-void rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
-                       struct rb_node *node, bool insert_left);
+static inline void
+rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent,
+                  struct rb_node *node, bool insert_left)
+{
+   rb_augmented_tree_insert_at(T, parent, node, insert_left, NULL);
+}
 
-/** Insert a node into a tree
+/** Insert a node into a possibly augmented tree
  *
  * \param   T       The red-black tree into which to insert the new node
  *
  * \param   node    The node to insert
  *
  * \param   cmp     A comparison function to use to order the nodes.
+ *
+ * \param   update  Same meaning as in rb_augmented_tree_insert_at()
  */
 static inline void
-rb_tree_insert(struct rb_tree *T, struct rb_node *node,
-               int (*cmp)(const struct rb_node *, const struct rb_node *))
+rb_augmented_tree_insert(struct rb_tree *T, struct rb_node *node,
+                         int (*cmp)(const struct rb_node *, const struct 
rb_node *),
+                         void (*update)(struct rb_node *))
 {
     /* This function is declared inline in the hopes that the compiler can
      * optimize away the comparison function pointer call.
@@ -163,16 +200,47 @@ rb_tree_insert(struct rb_tree *T, struct rb_node *node,
             x = x->right;
     }
 
-    rb_tree_insert_at(T, y, node, left);
+    rb_augmented_tree_insert_at(T, y, node, left, update);
 }
 
+/** Insert a node into a tree
+ *
+ * \param   T       The red-black tree into which to insert the new node
+ *
+ * \param   node    The node to insert
+ *
+ * \param   cmp     A comparison function to use to order the nodes.
+ */
+static inline void
+rb_tree_insert(struct rb_tree *T, struct rb_node *node,
+               int (*cmp)(const struct rb_node *, const struct rb_node *))
+{
+    rb_augmented_tree_insert(T, node, cmp, NULL);
+}
+
+/** Remove a node from a possibly augmented tree
+ *
+ * \param   T       The red-black tree from which to remove the node
+ *
+ * \param   node    The node to remove
+ *
+ * \param   update  Same meaning as in rb_agumented_tree_insert_at()
+ *
+ */
+void rb_augmented_tree_remove(struct rb_tree *T, struct rb_node *z,
+                              void (*update)(struct rb_node *));
+
 /** Remove a node from a tree
  *
  * \param   T       The red-black tree from which to remove the node
  *
  * \param   node    The node to remove
  */
-void rb_tree_remove(struct rb_tree *T, struct rb_node *z);
+static inline void
+rb_tree_remove(struct rb_tree *T, struct rb_node *z)
+{
+    rb_augmented_tree_remove(T, z, NULL);
+}
 
 /** Search the tree for a node
  *
@@ -332,6 +400,60 @@ rb_tree_search_sloppy(struct rb_tree *T, const void *key,
         __node = __prev, \
         __prev = (type *)rb_node_prev_or_null((struct rb_node *)__node))
 
+/** Unsigned interval
+ *
+ * Intervals are closed by convention.
+ */
+struct uinterval {
+   unsigned start, end;
+};
+
+struct uinterval_node {
+   struct rb_node node;
+
+   /* Must be filled in before inserting */
+   struct uinterval interval;
+
+   /* Managed internally by the tree */
+   unsigned max_end;
+};
+
+/** Insert a node into an unsigned interval tree. */
+void uinterval_tree_insert(struct rb_tree *tree, struct uinterval_node *node);
+
+/** Remove a node from an unsigned interval tree. */
+void uinterval_tree_remove(struct rb_tree *tree, struct uinterval_node *node);
+
+/** Get the first node intersecting the given interval. */
+struct uinterval_node *uinterval_tree_first(struct rb_tree *tree,
+                                            struct uinterval interval);
+
+/** Get the next node after \p node intersecting the given interval. */
+struct uinterval_node *uinterval_node_next(struct uinterval_node *node,
+                                           struct uinterval interval);
+
+/** Iterate over the nodes in the tree intersecting the given interval
+ *
+ * The iteration itself should take O(k log n) time, where k is the number of
+ * iterations of the loop and n is the size of the tree.
+ *
+ * \param   type    The type of the containing data structure
+ *
+ * \param   node    The variable name for current node in the iteration;
+ *                  this will be declared as a pointer to \p type
+ *
+ * \param  interval The interval to be tested against.
+ *
+ * \param   T       The red-black tree
+ *
+ * \param   field   The uinterval_node field in containing data structure
+ */
+#define uinterval_tree_foreach(type, iter, interval, T, field) \
+   for (type *iter, *__node = (type *)uinterval_tree_first(T, interval); \
+        __node != NULL && \
+        (iter = rb_node_data(type, (struct uinterval_node *)__node, field), 
true); \
+        __node = (type *)uinterval_node_next((struct uinterval_node *)__node, 
interval))
+
 /** Validate a red-black tree
  *
  * This function walks the tree and validates that this is a valid red-
diff --git a/src/util/tests/rb_tree_test.cpp b/src/util/tests/rb_tree_test.cpp
index 2676bd52b15..60680947637 100644
--- a/src/util/tests/rb_tree_test.cpp
+++ b/src/util/tests/rb_tree_test.cpp
@@ -28,6 +28,8 @@
 #include <gtest/gtest.h>
 #include <limits.h>
 
+#include "macros.h"
+
 /* A list of 100 random numbers from 1 to 100.  The number 30 is explicitly
  * missing from this list.
  */
@@ -46,8 +48,6 @@ int test_numbers[] = {
 
 #define NON_EXISTANT_NUMBER 30
 
-#define ARRAY_SIZE(a) (sizeof(a) / sizeof(*a))
-
 struct rb_test_node {
     int key;
     struct rb_node node;
@@ -283,3 +283,86 @@ TEST(RBTreeTest, FindFirstOfMiddle)
 
     EXPECT_NE(rb_test_node_cmp(prev, n), 0);
 }
+
+struct uinterval_test_node {
+    struct uinterval_node node;
+};
+
+static void
+validate_interval_search(struct rb_tree *tree,
+                         struct uinterval_test_node *nodes,
+                         int first_node, int last_node,
+                         unsigned start,
+                         unsigned end)
+{
+    /* Count the number of intervals intersecting */
+    unsigned actual_count = 0;
+    for (int i = first_node; i <= last_node; i++) {
+        if (nodes[i].node.interval.start <= end &&
+            nodes[i].node.interval.end >= start)
+            actual_count++;
+    }
+
+    /* iterate over matching intervals */
+    struct uinterval interval = { start, end };
+    unsigned max_val = 0;
+    struct uinterval_test_node *prev = NULL;
+    unsigned count = 0;
+    uinterval_tree_foreach (struct uinterval_test_node, n, interval, tree, 
node) {
+        assert(n->node.interval.start <= end &&
+               n->node.interval.end >= start);
+
+        /* Everything should be in increasing order */
+        assert(n->node.interval.start >= max_val);
+        if (n->node.interval.start > max_val) {
+            max_val = n->node.interval.start;
+        } else {
+            /* Things should be stable, i.e., given equal keys, they should
+             * show up in the list in order of insertion.  We insert them
+             * in the order they are in in the array.
+             */
+            assert(prev == NULL || prev < n);
+        }
+
+        prev = n;
+        count++;
+    }
+
+    assert(count == actual_count);
+}
+
+TEST(IntervalTreeTest, InsertAndSearch)
+{
+    struct uinterval_test_node nodes[ARRAY_SIZE(test_numbers) / 2];
+    struct rb_tree tree;
+
+    rb_tree_init(&tree);
+
+    for (unsigned i = 0; 2 * i < ARRAY_SIZE(test_numbers); i++) {
+        nodes[i].node.interval.start = MIN2(test_numbers[2 * i], 
test_numbers[2 * i + 1]);
+        nodes[i].node.interval.end = MAX2(test_numbers[2 * i], test_numbers[2 
* i + 1]);
+        uinterval_tree_insert(&tree, &nodes[i].node);
+        rb_tree_validate(&tree);
+        validate_interval_search(&tree, nodes, 0, i, 0, 100);
+        validate_interval_search(&tree, nodes, 0, i, 0, 50);
+        validate_interval_search(&tree, nodes, 0, i, 50, 100);
+        validate_interval_search(&tree, nodes, 0, i, 0, 2);
+    }
+
+    for (unsigned i = 0; 2 * i < ARRAY_SIZE(test_numbers); i++) {
+        uinterval_tree_remove(&tree, &nodes[i].node);
+        rb_tree_validate(&tree);
+        validate_interval_search(&tree, nodes, i + 1,
+                                 ARRAY_SIZE(test_numbers) / 2 - 1, 
+                                 0, 100);
+        validate_interval_search(&tree, nodes, i + 1,
+                                 ARRAY_SIZE(test_numbers) / 2 - 1, 
+                                 0, 50);
+        validate_interval_search(&tree, nodes, i + 1,
+                                 ARRAY_SIZE(test_numbers) / 2 - 1, 
+                                 50, 100);
+        validate_interval_search(&tree, nodes, i + 1,
+                                 ARRAY_SIZE(test_numbers) / 2 - 1, 
+                                 0, 2);
+    }
+}

Reply via email to