1/*
2 * ctrees.c
3 *
4 *  Author: mozman
5 *  Copyright (c) 2010-2013 by Manfred Moitzi
6 *  License: MIT-License
7 */
8
9#include "ctrees.h"
10#include "stack.h"
11#include <Python.h>
12
13#define LEFT 0
14#define RIGHT 1
15#define KEY(node) (node->key)
16#define VALUE(node) (node->value)
17#define LEFT_NODE(node) (node->link[LEFT])
18#define RIGHT_NODE(node) (node->link[RIGHT])
19#define LINK(node, dir) (node->link[dir])
20#define XDATA(node) (node->xdata)
21#define RED(node) (node->xdata)
22#define BALANCE(node) (node->xdata)
23
24static node_t *
25ct_new_node(PyObject *key, PyObject *value, int xdata)
26{
27	node_t *new_node = PyMem_Malloc(sizeof(node_t));
28	if (new_node != NULL) {
29		KEY(new_node) = key;
30		Py_INCREF(key);
31		VALUE(new_node) = value;
32		Py_INCREF(value);
33		LEFT_NODE(new_node) = NULL;
34		RIGHT_NODE(new_node) = NULL;
35		XDATA(new_node) = xdata;
36	}
37	return new_node;
38}
39
40static void
41ct_delete_node(node_t *node)
42{
43	if (node != NULL) {
44		Py_XDECREF(KEY(node));
45		Py_XDECREF(VALUE(node));
46		LEFT_NODE(node) = NULL;
47		RIGHT_NODE(node) = NULL;
48		PyMem_Free(node);
49	}
50}
51
52extern void
53ct_delete_tree(node_t *root)
54{
55	if (root == NULL)
56		return;
57	if (LEFT_NODE(root) != NULL) {
58		ct_delete_tree(LEFT_NODE(root));
59	}
60	if (RIGHT_NODE(root) != NULL) {
61		ct_delete_tree(RIGHT_NODE(root));
62	}
63	ct_delete_node(root);
64}
65
66static void
67ct_swap_data(node_t *node1, node_t *node2)
68{
69	PyObject *tmp;
70	tmp = KEY(node1);
71	KEY(node1) = KEY(node2);
72	KEY(node2) = tmp;
73	tmp = VALUE(node1);
74	VALUE(node1) = VALUE(node2);
75	VALUE(node2) = tmp;
76}
77
78int
79ct_compare(PyObject *key1, PyObject *key2)
80{
81	int res;
82
83	res = PyObject_RichCompareBool(key1, key2, Py_LT);
84	if (res > 0)
85		return -1;
86	else if (res < 0) {
87		PyErr_SetString(PyExc_TypeError, "invalid type for key");
88		return 0;
89		}
90	/* second compare:
91	+1 if key1 > key2
92	 0 if not -> equal
93	-1 means error, if error, it should happend at the first compare
94	*/
95	return PyObject_RichCompareBool(key1, key2, Py_GT);
96}
97
98extern node_t *
99ct_find_node(node_t *root, PyObject *key)
100{
101	int res;
102	while (root != NULL) {
103		res = ct_compare(key, KEY(root));
104		if (res == 0) /* key found */
105			return root;
106		else {
107			root = LINK(root, (res > 0));
108		}
109	}
110	return NULL; /* key not found */
111}
112
113extern PyObject *
114ct_get_item(node_t *root, PyObject *key)
115{
116	node_t *node;
117	PyObject *tuple;
118
119	node = ct_find_node(root, key);
120	if (node != NULL) {
121		tuple = PyTuple_New(2);
122		PyTuple_SET_ITEM(tuple, 0, KEY(node));
123		PyTuple_SET_ITEM(tuple, 1, VALUE(node));
124		return tuple;
125	}
126	Py_RETURN_NONE;
127}
128
129extern node_t *
130ct_max_node(node_t *root)
131/* get node with largest key */
132{
133	if (root == NULL)
134		return NULL;
135	while (RIGHT_NODE(root) != NULL)
136		root = RIGHT_NODE(root);
137	return root;
138}
139
140extern node_t *
141ct_min_node(node_t *root)
142// get node with smallest key
143{
144	if (root == NULL)
145		return NULL;
146	while (LEFT_NODE(root) != NULL)
147		root = LEFT_NODE(root);
148	return root;
149}
150
151extern int
152ct_bintree_remove(node_t **rootaddr, PyObject *key)
153/* attention: rootaddr is the address of the root pointer */
154{
155	node_t *node, *parent, *replacement;
156	int direction, cmp_res, down_dir;
157
158	node = *rootaddr;
159
160	if (node == NULL)
161		return 0; /* root is NULL */
162	parent = NULL;
163	direction = 0;
164
165	while (1) {
166		cmp_res = ct_compare(key, KEY(node));
167		if (cmp_res == 0) /* key found, remove node */
168		{
169			if ((LEFT_NODE(node) != NULL) && (RIGHT_NODE(node) != NULL)) {
170				/* find replacement node: smallest key in right-subtree */
171				parent = node;
172				direction = RIGHT;
173				replacement = RIGHT_NODE(node);
174				while (LEFT_NODE(replacement) != NULL) {
175					parent = replacement;
176					direction = LEFT;
177					replacement = LEFT_NODE(replacement);
178				}
179				LINK(parent, direction) = RIGHT_NODE(replacement);
180				/* swap places */
181				ct_swap_data(node, replacement);
182				node = replacement; /* delete replacement node */
183			}
184			else {
185				down_dir = (LEFT_NODE(node) == NULL) ? RIGHT : LEFT;
186				if (parent == NULL) /* root */
187				{
188					*rootaddr = LINK(node, down_dir);
189				}
190				else {
191					LINK(parent, direction) = LINK(node, down_dir);
192				}
193			}
194			ct_delete_node(node);
195			return 1; /* remove was success full */
196		}
197		else {
198			direction = (cmp_res < 0) ? LEFT : RIGHT;
199			parent = node;
200			node = LINK(node, direction);
201			if (node == NULL)
202				return 0; /* error key not found */
203		}
204	}
205}
206
207extern int
208ct_bintree_insert(node_t **rootaddr, PyObject *key, PyObject *value)
209/* attention: rootaddr is the address of the root pointer */
210{
211	node_t *parent, *node;
212	int direction, cval;
213	node = *rootaddr;
214	if (node == NULL) {
215		node = ct_new_node(key, value, 0); /* new node is also the root */
216		if (node == NULL)
217			return -1; /* got no memory */
218		*rootaddr = node;
219	}
220	else {
221		direction = LEFT;
222		parent = NULL;
223		while (1) {
224			if (node == NULL) {
225				node = ct_new_node(key, value, 0);
226				if (node == NULL)
227					return -1; /* get no memory */
228				LINK(parent, direction) = node;
229				return 1;
230			}
231			cval = ct_compare(key, KEY(node));
232			if (cval == 0) {
233				/* key exists, replace value object */
234				Py_XDECREF(VALUE(node)); /* release old value object */
235				VALUE(node) = value; /* set new value object */
236				Py_INCREF(value); /* take new value object */
237				return 0;
238			}
239			else {
240				parent = node;
241				direction = (cval < 0) ? LEFT : RIGHT;
242				node = LINK(node, direction);
243			}
244		}
245	}
246	return 1;
247}
248
249static int
250is_red (node_t *node)
251{
252	return (node != NULL) && (RED(node) == 1);
253}
254
255#define rb_new_node(key, value) ct_new_node(key, value, 1)
256
257static node_t *
258rb_single(node_t *root, int dir)
259{
260	node_t *save = root->link[!dir];
261
262	root->link[!dir] = save->link[dir];
263	save->link[dir] = root;
264
265	RED(root) = 1;
266	RED(save) = 0;
267	return save;
268}
269
270static node_t *
271rb_double(node_t *root, int dir)
272{
273	root->link[!dir] = rb_single(root->link[!dir], !dir);
274	return rb_single(root, dir);
275}
276
277#define rb_new_node(key, value) ct_new_node(key, value, 1)
278
279extern int
280rb_insert(node_t **rootaddr, PyObject *key, PyObject *value)
281{
282	node_t *root = *rootaddr;
283
284	if (root == NULL) {
285		/*
286		 We have an empty tree; attach the
287		 new node directly to the root
288		 */
289		root = rb_new_node(key, value);
290		if (root == NULL)
291			return -1; // got no memory
292	}
293	else {
294		node_t head; /* False tree root */
295		node_t *g, *t; /* Grandparent & parent */
296		node_t *p, *q; /* Iterator & parent */
297		int dir = 0;
298		int last = 0;
299		int new_node = 0;
300
301		/* Set up our helpers */
302		t = &head;
303		g = NULL;
304		p = NULL;
305		RIGHT_NODE(t) = root;
306		LEFT_NODE(t) = NULL;
307		q = RIGHT_NODE(t);
308
309		/* Search down the tree for a place to insert */
310		for (;;) {
311			int cmp_res;
312			if (q == NULL) {
313				/* Insert a new node at the first null link */
314				q = rb_new_node(key, value);
315				p->link[dir] = q;
316				new_node = 1;
317				if (q == NULL)
318					return -1; // get no memory
319			}
320			else if (is_red(q->link[0]) && is_red(q->link[1])) {
321				/* Simple red violation: color flip */
322				RED(q) = 1;
323				RED(q->link[0]) = 0;
324				RED(q->link[1]) = 0;
325			}
326
327			if (is_red(q) && is_red(p)) {
328				/* Hard red violation: rotations necessary */
329				int dir2 = (t->link[1] == g);
330
331				if (q == p->link[last])
332					t->link[dir2] = rb_single(g, !last);
333				else
334					t->link[dir2] = rb_double(g, !last);
335			}
336
337			/*  Stop working if we inserted a new node. */
338			if (new_node)
339				break;
340
341			cmp_res = ct_compare(KEY(q), key);
342			if (cmp_res == 0) {       /* key exists?              */
343				Py_XDECREF(VALUE(q)); /* release old value object */
344				VALUE(q) = value;     /* set new value object     */
345				Py_INCREF(value);     /* take new value object    */
346				return 0;
347			}
348			last = dir;
349			dir = (cmp_res < 0);
350
351			/* Move the helpers down */
352			if (g != NULL)
353				t = g;
354
355			g = p;
356			p = q;
357			q = q->link[dir];
358		}
359		/* Update the root (it may be different) */
360		root = head.link[1];
361	}
362
363	/* Make the root black for simplified logic */
364	RED(root) = 0;
365	(*rootaddr) = root;
366	return 1;
367}
368
369extern int
370rb_remove(node_t **rootaddr, PyObject *key)
371{
372	node_t *root = *rootaddr;
373
374	node_t head = { { NULL } }; /* False tree root */
375	node_t *q, *p, *g; /* Helpers */
376	node_t *f = NULL; /* Found item */
377	int dir = 1;
378
379	if (root == NULL)
380		return 0;
381
382	/* Set up our helpers */
383	q = &head;
384	g = p = NULL;
385	RIGHT_NODE(q) = root;
386
387	/*
388	 Search and push a red node down
389	 to fix red violations as we go
390	 */
391	while (q->link[dir] != NULL) {
392		int last = dir;
393		int cmp_res;
394
395		/* Move the helpers down */
396		g = p, p = q;
397		q = q->link[dir];
398
399		cmp_res =  ct_compare(KEY(q), key);
400
401		dir = cmp_res < 0;
402
403		/*
404		 Save the node with matching data and keep
405		 going; we'll do removal tasks at the end
406		 */
407		if (cmp_res == 0)
408			f = q;
409
410		/* Push the red node down with rotations and color flips */
411		if (!is_red(q) && !is_red(q->link[dir])) {
412			if (is_red(q->link[!dir]))
413				p = p->link[last] = rb_single(q, dir);
414			else if (!is_red(q->link[!dir])) {
415				node_t *s = p->link[!last];
416
417				if (s != NULL) {
418					if (!is_red(s->link[!last]) &&
419						!is_red(s->link[last])) {
420						/* Color flip */
421						RED(p) = 0;
422						RED(s) = 1;
423						RED(q) = 1;
424					}
425					else {
426						int dir2 = g->link[1] == p;
427
428						if (is_red(s->link[last]))
429							g->link[dir2] = rb_double(p, last);
430						else if (is_red(s->link[!last]))
431							g->link[dir2] = rb_single(p, last);
432
433						/* Ensure correct coloring */
434						RED(q) = RED(g->link[dir2]) = 1;
435						RED(g->link[dir2]->link[0]) = 0;
436						RED(g->link[dir2]->link[1]) = 0;
437					}
438				}
439			}
440		}
441	}
442
443	/* Replace and remove the saved node */
444	if (f != NULL) {
445		ct_swap_data(f, q);
446		p->link[p->link[1] == q] = q->link[q->link[0] == NULL];
447		ct_delete_node(q);
448	}
449
450	/* Update the root (it may be different) */
451	root = head.link[1];
452
453	/* Make the root black for simplified logic */
454	if (root != NULL)
455		RED(root) = 0;
456	*rootaddr = root;
457	return (f != NULL);
458}
459
460#define avl_new_node(key, value) ct_new_node(key, value, 0)
461#define height(p) ((p) == NULL ? -1 : (p)->xdata)
462#define avl_max(a, b) ((a) > (b) ? (a) : (b))
463
464static node_t *
465avl_single(node_t *root, int dir)
466{
467  node_t *save = root->link[!dir];
468	int rlh, rrh, slh;
469
470	/* Rotate */
471	root->link[!dir] = save->link[dir];
472	save->link[dir] = root;
473
474	/* Update balance factors */
475	rlh = height(root->link[0]);
476	rrh = height(root->link[1]);
477	slh = height(save->link[!dir]);
478
479	BALANCE(root) = avl_max(rlh, rrh) + 1;
480	BALANCE(save) = avl_max(slh, BALANCE(root)) + 1;
481
482	return save;
483}
484
485static node_t *
486avl_double(node_t *root, int dir)
487{
488	root->link[!dir] = avl_single(root->link[!dir], !dir);
489	return avl_single(root, dir);
490}
491
492extern int
493avl_insert(node_t **rootaddr, PyObject *key, PyObject *value)
494{
495	node_t *root = *rootaddr;
496
497	if (root == NULL) {
498		root = avl_new_node(key, value);
499		if (root == NULL)
500			return -1; // got no memory
501	}
502	else {
503		node_t *it, *up[32];
504		int upd[32], top = 0;
505		int done = 0;
506		int cmp_res;
507
508		it = root;
509		/* Search for an empty link, save the path */
510		for (;;) {
511			/* Push direction and node onto stack */
512			cmp_res = ct_compare(KEY(it), key);
513			if (cmp_res == 0) {
514				Py_XDECREF(VALUE(it)); // release old value object
515				VALUE(it) = value; // set new value object
516				Py_INCREF(value); // take new value object
517				return 0;
518			}
519			// upd[top] = it->data < data;
520			upd[top] = (cmp_res < 0);
521			up[top++] = it;
522
523			if (it->link[upd[top - 1]] == NULL)
524				break;
525			it = it->link[upd[top - 1]];
526		}
527
528		/* Insert a new node at the bottom of the tree */
529		it->link[upd[top - 1]] = avl_new_node(key, value);
530		if (it->link[upd[top - 1]] == NULL)
531			return -1; // got no memory
532
533		/* Walk back up the search path */
534		while (--top >= 0 && !done) {
535			// int dir = (cmp_res < 0);
536			int lh, rh, max;
537
538			cmp_res = ct_compare(KEY(up[top]), key);
539
540			lh = height(up[top]->link[upd[top]]);
541			rh = height(up[top]->link[!upd[top]]);
542
543			/* Terminate or rebalance as necessary */
544			if (lh - rh == 0)
545				done = 1;
546			if (lh - rh >= 2) {
547				node_t *a = up[top]->link[upd[top]]->link[upd[top]];
548				node_t *b = up[top]->link[upd[top]]->link[!upd[top]];
549
550				if (height( a ) >= height( b ))
551					up[top] = avl_single(up[top], !upd[top]);
552				else
553					up[top] = avl_double(up[top], !upd[top]);
554
555				/* Fix parent */
556				if (top != 0)
557					up[top - 1]->link[upd[top - 1]] = up[top];
558				else
559					root = up[0];
560				done = 1;
561			}
562			/* Update balance factors */
563			lh = height(up[top]->link[upd[top]]);
564			rh = height(up[top]->link[!upd[top]]);
565			max = avl_max(lh, rh);
566			BALANCE(up[top]) = max + 1;
567		}
568	}
569	(*rootaddr) = root;
570	return 1;
571}
572
573extern int
574avl_remove(node_t **rootaddr, PyObject *key)
575{
576	node_t *root = *rootaddr;
577	int cmp_res;
578
579	if (root != NULL) {
580		node_t *it, *up[32];
581		int upd[32], top = 0;
582
583		it = root;
584		for (;;) {
585			/* Terminate if not found */
586			if (it == NULL)
587				return 0;
588			cmp_res = ct_compare(KEY(it), key);
589			if (cmp_res == 0)
590				break;
591
592			/* Push direction and node onto stack */
593			upd[top] = (cmp_res < 0);
594			up[top++] = it;
595			it = it->link[upd[top - 1]];
596		}
597
598		/* Remove the node */
599		if (it->link[0] == NULL ||
600			it->link[1] == NULL) {
601			/* Which child is not null? */
602			int dir = it->link[0] == NULL;
603
604			/* Fix parent */
605			if (top != 0)
606				up[top - 1]->link[upd[top - 1]] = it->link[dir];
607			else
608				root = it->link[dir];
609
610			ct_delete_node(it);
611		}
612		else {
613			/* Find the inorder successor */
614			node_t *heir = it->link[1];
615
616			/* Save the path */
617			upd[top] = 1;
618			up[top++] = it;
619
620			while ( heir->link[0] != NULL ) {
621				upd[top] = 0;
622				up[top++] = heir;
623				heir = heir->link[0];
624			}
625			/* Swap data */
626			ct_swap_data(it, heir);
627			/* Unlink successor and fix parent */
628			up[top - 1]->link[up[top - 1] == it] = heir->link[1];
629			ct_delete_node(heir);
630		}
631
632		/* Walk back up the search path */
633		while (--top >= 0) {
634			int lh = height(up[top]->link[upd[top]]);
635			int rh = height(up[top]->link[!upd[top]]);
636			int max = avl_max(lh, rh);
637
638			/* Update balance factors */
639			BALANCE(up[top]) = max + 1;
640
641			/* Terminate or rebalance as necessary */
642			if (lh - rh == -1)
643				break;
644			if (lh - rh <= -2) {
645				node_t *a = up[top]->link[!upd[top]]->link[upd[top]];
646				node_t *b = up[top]->link[!upd[top]]->link[!upd[top]];
647
648				if (height(a) <= height(b))
649					up[top] = avl_single(up[top], upd[top]);
650				else
651					up[top] = avl_double(up[top], upd[top]);
652
653				/* Fix parent */
654				if (top != 0)
655					up[top - 1]->link[upd[top - 1]] = up[top];
656				else
657					root = up[0];
658			}
659		}
660	}
661	(*rootaddr) = root;
662	return 1;
663}
664
665extern node_t *
666ct_succ_node(node_t *root, PyObject *key)
667{
668	node_t *succ = NULL;
669	node_t *node = root;
670	int cval;
671
672	while (node != NULL) {
673		cval = ct_compare(key, KEY(node));
674		if (cval == 0)
675			break;
676		else if (cval < 0) {
677			if ((succ == NULL) ||
678				(ct_compare(KEY(node), KEY(succ)) < 0))
679				succ = node;
680			node = LEFT_NODE(node);
681		} else
682			node = RIGHT_NODE(node);
683	}
684	if (node == NULL)
685		return NULL;
686	/* found node of key */
687	if (RIGHT_NODE(node) != NULL) {
688		/* find smallest node of right subtree */
689		node = RIGHT_NODE(node);
690		while (LEFT_NODE(node) != NULL)
691			node = LEFT_NODE(node);
692		if (succ == NULL)
693			succ = node;
694		else if (ct_compare(KEY(node), KEY(succ)) < 0)
695			succ = node;
696	}
697	return succ;
698}
699
700extern node_t *
701ct_prev_node(node_t *root, PyObject *key)
702{
703	node_t *prev = NULL;
704	node_t *node = root;
705	int cval;
706
707	while (node != NULL) {
708		cval = ct_compare(key, KEY(node));
709		if (cval == 0)
710			break;
711		else if (cval < 0)
712			node = LEFT_NODE(node);
713		else {
714			if ((prev == NULL) || (ct_compare(KEY(node), KEY(prev)) > 0))
715				prev = node;
716			node = RIGHT_NODE(node);
717		}
718	}
719	if (node == NULL) /* stay at dead end (None) */
720		return NULL;
721	/* found node of key */
722	if (LEFT_NODE(node) != NULL) {
723		/* find biggest node of left subtree */
724		node = LEFT_NODE(node);
725		while (RIGHT_NODE(node) != NULL)
726			node = RIGHT_NODE(node);
727		if (prev == NULL)
728			prev = node;
729		else if (ct_compare(KEY(node), KEY(prev)) > 0)
730			prev = node;
731	}
732	return prev;
733}
734
735extern node_t *
736ct_floor_node(node_t *root, PyObject *key)
737{
738	node_t *prev = NULL;
739	node_t *node = root;
740	int cval;
741
742	while (node != NULL) {
743		cval = ct_compare(key, KEY(node));
744		if (cval == 0)
745			return node;
746		else if (cval < 0)
747			node = LEFT_NODE(node);
748		else {
749			if ((prev == NULL) || (ct_compare(KEY(node), KEY(prev)) > 0))
750				prev = node;
751			node = RIGHT_NODE(node);
752		}
753	}
754	return prev;
755}
756
757extern node_t *
758ct_ceiling_node(node_t *root, PyObject *key)
759{
760	node_t *succ = NULL;
761	node_t *node = root;
762	int cval;
763
764	while (node != NULL) {
765		cval = ct_compare(key, KEY(node));
766		if (cval == 0)
767			return node;
768		else if (cval < 0) {
769			if ((succ == NULL) ||
770				(ct_compare(KEY(node), KEY(succ)) < 0))
771				succ = node;
772			node = LEFT_NODE(node);
773		} else
774			node = RIGHT_NODE(node);
775	}
776	return succ;
777}
778
779extern int
780ct_index_of(node_t *root, PyObject *key)
781/*
782get index of item <key>, returns -1 if key not found.
783*/
784{
785	node_t *node = root;
786	int index = 0;
787	int go_down = 1;
788	node_stack_t *stack;
789	stack = stack_init(32);
790
791	for (;;) {
792		if ((LEFT_NODE(node) != NULL) && go_down) {
793			stack_push(stack, node);
794			node = LEFT_NODE(node);
795		}
796		else {
797			if (ct_compare(KEY(node), key) == 0) {
798				stack_delete(stack);
799				return index;
800			}
801			index++;
802			if (RIGHT_NODE(node) != NULL) {
803				node = RIGHT_NODE(node);
804				go_down = 1;
805			}
806			else {
807				if (stack_is_empty(stack)) {
808					stack_delete(stack);
809					return -1;
810				}
811				node = stack_pop(stack);
812				go_down = 0;
813			}
814		}
815	}
816}
817
818extern node_t *
819ct_node_at(node_t *root, int index)
820{
821/*
822root -- root node of tree
823index -- index of wanted node
824
825return NULL if index out of range
826*/
827	node_t *node = root;
828	int counter = 0;
829	int go_down = 1;
830	node_stack_t *stack;
831
832	if (index < 0) return NULL;
833
834	stack = stack_init(32);
835
836	for(;;) {
837		if ((LEFT_NODE(node) != NULL) && go_down) {
838			stack_push(stack, node);
839			node = LEFT_NODE(node);
840		}
841		else {
842			if (counter == index) {
843				/* reached wanted index */
844				stack_delete(stack);
845				return node;
846			}
847			counter++;
848			if (RIGHT_NODE(node) != NULL) {
849				node = RIGHT_NODE(node);
850				go_down = 1;
851			}
852			else {
853				if (stack_is_empty(stack)) { /* index out of range */
854					stack_delete(stack);
855					return NULL;
856                }
857				node = stack_pop(stack);
858				go_down = 0;
859			}
860		}
861    }
862}
863