Skip to content

Commit

Permalink
ported avl trees (tsearch) from musl
Browse files Browse the repository at this point in the history
  • Loading branch information
stsp committed Oct 11, 2024
1 parent 60500be commit e64bca9
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 0 deletions.
6 changes: 6 additions & 0 deletions include/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ void * lfind(const void *_key, void *_base, size_t *_nelp, size_t _width,
void * lsearch(const void *_key, void *_base, size_t *_nelp, size_t _width,
int(*_compar)(const void *, const void *));

typedef enum { preorder, postorder, endorder, leaf } VISIT;
void *tdelete(const void *, void **, int(*)(const void *, const void *));
void *tfind(const void *, void *const *, int(*)(const void *, const void *));
void *tsearch(const void *, void **, int (*)(const void *, const void *));
void twalk(const void *, void (*)(const void *, VISIT, int));

#ifndef _POSIX_SOURCE

/* qelem type for insque and remque. */
Expand Down
1 change: 1 addition & 0 deletions src/libc/posix/search/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ TOP=../..
SRC += _hash2v.c
SRC += hcreate.c
SRC += lfind.c
SRC += tsearch_avl.c

include $(TOP)/../makefile.inc
204 changes: 204 additions & 0 deletions src/libc/posix/search/tsearch_avl.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#include <stdlib.h>
#include <search.h>

/*
avl tree implementation using recursive functions
the height of an n node tree is less than 1.44*log2(n+2)-1
(so the max recursion depth in case of a tree with 2^32 nodes is 45)
*/

struct node {
const void *key;
struct node *left;
struct node *right;
int height;
};

static int delta(struct node *n) {
return (n->left ? n->left->height:0) - (n->right ? n->right->height:0);
}

static void updateheight(struct node *n) {
n->height = 0;
if (n->left && n->left->height > n->height)
n->height = n->left->height;
if (n->right && n->right->height > n->height)
n->height = n->right->height;
n->height++;
}

static struct node *rotl(struct node *n) {
struct node *r = n->right;
n->right = r->left;
r->left = n;
updateheight(n);
updateheight(r);
return r;
}

static struct node *rotr(struct node *n) {
struct node *l = n->left;
n->left = l->right;
l->right = n;
updateheight(n);
updateheight(l);
return l;
}

static struct node *balance(struct node *n) {
int d = delta(n);

if (d < -1) {
if (delta(n->right) > 0)
n->right = rotr(n->right);
return rotl(n);
} else if (d > 1) {
if (delta(n->left) < 0)
n->left = rotl(n->left);
return rotr(n);
}
updateheight(n);
return n;
}

static struct node *find(struct node *n, const void *k,
int (*cmp)(const void *, const void *))
{
int c;

if (!n)
return 0;
c = cmp(k, n->key);
if (c == 0)
return n;
if (c < 0)
return find(n->left, k, cmp);
else
return find(n->right, k, cmp);
}

static struct node *insert(struct node *n, const void *k,
int (*cmp)(const void *, const void *), struct node **found)
{
struct node *r;
int c;

if (!n) {
n = malloc(sizeof *n);
if (n) {
n->key = k;
n->left = n->right = 0;
n->height = 1;
}
*found = n;
return n;
}
c = cmp(k, n->key);
if (c == 0) {
*found = n;
return 0;
}
r = insert(c < 0 ? n->left : n->right, k, cmp, found);
if (r) {
if (c < 0)
n->left = r;
else
n->right = r;
r = balance(n);
}
return r;
}

static struct node *remove_rightmost(struct node *n, struct node **rightmost)
{
if (!n->right) {
*rightmost = n;
return n->left;
}
n->right = remove_rightmost(n->right, rightmost);
return balance(n);
}

static struct node *remove(struct node **n, const void *k,
int (*cmp)(const void *, const void *), struct node *parent)
{
int c;

if (!*n)
return 0;
c = cmp(k, (*n)->key);
if (c == 0) {
struct node *r = *n;
if (r->left) {
r->left = remove_rightmost(r->left, n);
(*n)->left = r->left;
(*n)->right = r->right;
*n = balance(*n);
} else
*n = r->right;
free(r);
return parent;
}
if (c < 0)
parent = remove(&(*n)->left, k, cmp, *n);
else
parent = remove(&(*n)->right, k, cmp, *n);
if (parent)
*n = balance(*n);
return parent;
}

void *tdelete(const void *restrict key, void **restrict rootp,
int(*compar)(const void *, const void *))
{
if (!rootp)
return 0;
struct node *n = *rootp;
struct node *ret;
/* last argument is arbitrary non-null pointer
which is returned when the root node is deleted */
ret = remove(&n, key, compar, n);
*rootp = n;
return ret;
}

void *tfind(const void *key, void *const *rootp,
int(*compar)(const void *, const void *))
{
if (!rootp)
return 0;
return find(*rootp, key, compar);
}

void *tsearch(const void *key, void **rootp,
int (*compar)(const void *, const void *))
{
struct node *update;
struct node *ret;
if (!rootp)
return 0;
update = insert(*rootp, key, compar, &ret);
if (update)
*rootp = update;
return ret;
}

static void walk(const struct node *r, void (*action)(const void *, VISIT, int), int d)
{
if (r == 0)
return;
if (r->left == 0 && r->right == 0)
action(r, leaf, d);
else {
action(r, preorder, d);
walk(r->left, action, d+1);
action(r, postorder, d);
walk(r->right, action, d+1);
action(r, endorder, d);
}
}

void twalk(const void *root, void (*action)(const void *, VISIT, int))
{
walk(root, action, 0);
}

0 comments on commit e64bca9

Please sign in to comment.