001package algs33;
002import stdlib.*;
003import algs13.Queue;
004/* ***********************************************************************
005 *  Compilation:  javac RedBlackBST.java
006 *  Execution:    java RedBlackBST < input.txt
007 *  Dependencies: StdIn.java StdOut.java
008 *  Data files:   http://algs4.cs.princeton.edu/33balanced/tinyST.txt
009 *
010 *  A symbol table implemented using a left-leaning red-black BST.
011 *  This is the 2-3 version.
012 *
013 *  % more tinyST.txt
014 *  S E A R C H E X A M P L E
015 *
016 *  % java RedBlackBST < tinyST.txt
017 *  A 8
018 *  C 4
019 *  E 12
020 *  H 5
021 *  L 11
022 *  M 9
023 *  P 10
024 *  R 3
025 *  S 0
026 *  X 7
027 *
028 *************************************************************************/
029public class RedBlackBST<K extends Comparable<? super K>, V> {
030
031        private static final boolean RED   = true;
032        private static final boolean BLACK = false;
033
034        private Node<K,V> root;     // root of the BST
035
036        // BST helper node data type
037        private static class Node<K,V> {
038                public K key;         // key
039                public V val;         // associated data
040                public Node<K,V> left, right;  // links to left and right subtrees
041                public boolean color;     // color of parent link
042                public int N;             // subtree count
043
044                public Node(K key, V val, boolean color, int N) {
045                        this.key = key;
046                        this.val = val;
047                        this.color = color;
048                        this.N = N;
049                }
050        }
051
052        /* ***********************************************************************
053         *  Node<K,V> helper methods
054         *************************************************************************/
055        // is node x red; false if x is null ?
056        private boolean isRed(Node<K,V> x) {
057                if (x == null) return false;
058                return (x.color == RED);
059        }
060
061        // number of node in subtree rooted at x; 0 if x is null
062        private int size(Node<K,V> x) {
063                if (x == null) return 0;
064                return x.N;
065        }
066
067
068        /* ***********************************************************************
069         *  Size methods
070         *************************************************************************/
071
072        // return number of key-value pairs in this symbol table
073        public int size() { return size(root); }
074
075        // is this symbol table empty?
076        public boolean isEmpty() { return root == null; }
077
078        /* ***********************************************************************
079         *  Standard BST search
080         *************************************************************************/
081
082        // value associated with the given key; null if no such key
083        public V get(K key) { return get(root, key); }
084
085        // value associated with the given key in subtree rooted at x; null if no such key
086        private V get(Node<K,V> x, K key) {
087                while (x != null) {
088                        int cmp = key.compareTo(x.key);
089                        if      (cmp < 0) x = x.left;
090                        else if (cmp > 0) x = x.right;
091                        else              return x.val;
092                }
093                return null;
094        }
095
096        // is there a key-value pair with the given key?
097        public boolean contains(K key) { return (get(key) != null); }
098
099        // is there a key-value pair with the given key in the subtree rooted at x?
100        private boolean contains(Node<K,V> x, K key) { return (get(x, key) != null); }
101
102        /* ***********************************************************************
103         *  Red-black insertion
104         *************************************************************************/
105
106        // insert the key-value pair; overwrite the old value with the new value
107        // if the key is already present
108        public void put(K key, V val) {
109                root = put(root, key, val);
110                root.color = BLACK;
111                assert check();
112        }
113
114        // insert the key-value pair in the subtree rooted at h
115        private Node<K,V> put(Node<K,V> h, K key, V val) {
116                if (h == null) return new Node<>(key, val, RED, 1);
117
118                int cmp = key.compareTo(h.key);
119                if      (cmp < 0) h.left  = put(h.left,  key, val);
120                else if (cmp > 0) h.right = put(h.right, key, val);
121                else              h.val   = val;
122
123                // fix-up any right-leaning links
124                if (isRed(h.right) && !isRed(h.left))      h = rotateLeft(h);
125                if (isRed(h.left)  &&  isRed(h.left.left)) h = rotateRight(h);
126                if (isRed(h.left)  &&  isRed(h.right))     flipColors(h);
127                h.N = size(h.left) + size(h.right) + 1;
128
129                return h;
130        }
131        /* ***********************************************************************
132         *  Red-black deletion
133         *************************************************************************/
134
135        // delete the key-value pair with the minimum key
136        public void deleteMin() {
137                if (isEmpty()) throw new Error("BST underflow");
138
139                // if both children of root are black, set root to red
140                if (!isRed(root.left) && !isRed(root.right))
141                        root.color = RED;
142
143                root = deleteMin(root);
144                if (!isEmpty()) root.color = BLACK;
145                assert check();
146        }
147
148        // delete the key-value pair with the minimum key rooted at h
149        private Node<K,V> deleteMin(Node<K,V> h) {
150                if (h.left == null)
151                        return null;
152
153                if (!isRed(h.left) && !isRed(h.left.left))
154                        h = moveRedLeft(h);
155
156                h.left = deleteMin(h.left);
157                return balance(h);
158        }
159
160
161        // delete the key-value pair with the maximum key
162        public void deleteMax() {
163                if (isEmpty()) throw new Error("BST underflow");
164
165                // if both children of root are black, set root to red
166                if (!isRed(root.left) && !isRed(root.right))
167                        root.color = RED;
168
169                root = deleteMax(root);
170                if (!isEmpty()) root.color = BLACK;
171                assert check();
172        }
173
174        // delete the key-value pair with the maximum key rooted at h
175        private Node<K,V> deleteMax(Node<K,V> h) {
176                if (isRed(h.left))
177                        h = rotateRight(h);
178
179                if (h.right == null)
180                        return null;
181
182                if (!isRed(h.right) && !isRed(h.right.left))
183                        h = moveRedRight(h);
184
185                h.right = deleteMax(h.right);
186
187                return balance(h);
188        }
189
190        // delete the key-value pair with the given key
191        public void delete(K key) {
192                if (!contains(key)) {
193                        System.err.println("symbol table does not contain " + key);
194                        return;
195                }
196
197                // if both children of root are black, set root to red
198                if (!isRed(root.left) && !isRed(root.right))
199                        root.color = RED;
200
201                root = delete(root, key);
202                if (!isEmpty()) root.color = BLACK;
203                assert check();
204        }
205
206        // delete the key-value pair with the given key rooted at h
207        private Node<K,V> delete(Node<K,V> h, K key) {
208                assert contains(h, key);
209
210                if (key.compareTo(h.key) < 0)  {
211                        if (!isRed(h.left) && !isRed(h.left.left))
212                                h = moveRedLeft(h);
213                        h.left = delete(h.left, key);
214                }
215                else {
216                        if (isRed(h.left))
217                                h = rotateRight(h);
218                        if (key.compareTo(h.key) == 0 && (h.right == null))
219                                return null;
220                        if (!isRed(h.right) && !isRed(h.right.left))
221                                h = moveRedRight(h);
222                        if (key.compareTo(h.key) == 0) {
223                                h.val = get(h.right, min(h.right).key);
224                                h.key = min(h.right).key;
225                                h.right = deleteMin(h.right);
226                        }
227                        else h.right = delete(h.right, key);
228                }
229                return balance(h);
230        }
231
232        /* ***********************************************************************
233         *  red-black tree helper functions
234         *************************************************************************/
235
236        // make a left-leaning link lean to the right
237        private Node<K,V> rotateRight(Node<K,V> h) {
238                assert (h != null) && isRed(h.left);
239                Node<K,V> x = h.left;
240                h.left = x.right;
241                x.right = h;
242                x.color = x.right.color;
243                x.right.color = RED;
244                x.N = h.N;
245                h.N = size(h.left) + size(h.right) + 1;
246                return x;
247        }
248
249        // make a right-leaning link lean to the left
250        private Node<K,V> rotateLeft(Node<K,V> h) {
251                assert (h != null) && isRed(h.right);
252                Node<K,V> x = h.right;
253                h.right = x.left;
254                x.left = h;
255                x.color = x.left.color;
256                x.left.color = RED;
257                x.N = h.N;
258                h.N = size(h.left) + size(h.right) + 1;
259                return x;
260        }
261
262        // flip the colors of a node and its two children
263        private void flipColors(Node<K,V> h) {
264                // h must have opposite color of its two children
265                assert (h != null) && (h.left != null) && (h.right != null);
266                assert (!isRed(h) &&  isRed(h.left) &&  isRed(h.right))
267                || (isRed(h)  && !isRed(h.left) && !isRed(h.right));
268                h.color = !h.color;
269                h.left.color = !h.left.color;
270                h.right.color = !h.right.color;
271        }
272
273        // Assuming that h is red and both h.left and h.left.left
274        // are black, make h.left or one of its children red.
275        private Node<K,V> moveRedLeft(Node<K,V> h) {
276                assert (h != null);
277                assert isRed(h) && !isRed(h.left) && !isRed(h.left.left);
278
279                flipColors(h);
280                if (isRed(h.right.left)) {
281                        h.right = rotateRight(h.right);
282                        h = rotateLeft(h);
283                        // flipColors(h);
284                }
285                return h;
286        }
287
288        // Assuming that h is red and both h.right and h.right.left
289        // are black, make h.right or one of its children red.
290        private Node<K,V> moveRedRight(Node<K,V> h) {
291                assert (h != null);
292                assert isRed(h) && !isRed(h.right) && !isRed(h.right.left);
293                flipColors(h);
294                if (isRed(h.left.left)) {
295                        h = rotateRight(h);
296                        // flipColors(h);
297                }
298                return h;
299        }
300
301        // restore red-black tree invariant
302        private Node<K,V> balance(Node<K,V> h) {
303                assert (h != null);
304
305                if (isRed(h.right))                      h = rotateLeft(h);
306                if (isRed(h.left) && isRed(h.left.left)) h = rotateRight(h);
307                if (isRed(h.left) && isRed(h.right))     flipColors(h);
308
309                h.N = size(h.left) + size(h.right) + 1;
310                return h;
311        }
312
313
314        /* ***********************************************************************
315         *  Utility functions
316         *************************************************************************/
317
318        // height of tree; 0 if empty
319        public int height() { return height(root); }
320        private int height(Node<K,V> x) {
321                if (x == null) return 0;
322                return 1 + Math.max(height(x.left), height(x.right));
323        }
324
325        /* ***********************************************************************
326         *  Ordered symbol table methods.
327         *************************************************************************/
328
329        // the smallest key; null if no such key
330        public K min() {
331                if (isEmpty()) return null;
332                return min(root).key;
333        }
334
335        // the smallest key in subtree rooted at x; null if no such key
336        private Node<K,V> min(Node<K,V> x) {
337                assert x != null;
338                if (x.left == null) return x;
339                else                return min(x.left);
340        }
341
342        // the largest key; null if no such key
343        public K max() {
344                if (isEmpty()) return null;
345                return max(root).key;
346        }
347
348        // the largest key in the subtree rooted at x; null if no such key
349        private Node<K,V> max(Node<K,V> x) {
350                assert x != null;
351                if (x.right == null) return x;
352                else                 return max(x.right);
353        }
354
355        // the largest key less than or equal to the given key
356        public K floor(K key) {
357                Node<K,V> x = floor(root, key);
358                if (x == null) return null;
359                else           return x.key;
360        }
361
362        // the largest key in the subtree rooted at x less than or equal to the given key
363        private Node<K,V> floor(Node<K,V> x, K key) {
364                if (x == null) return null;
365                int cmp = key.compareTo(x.key);
366                if (cmp == 0) return x;
367                if (cmp < 0)  return floor(x.left, key);
368                Node<K,V> t = floor(x.right, key);
369                if (t != null) return t;
370                else           return x;
371        }
372
373        // the smallest key greater than or equal to the given key
374        public K ceiling(K key) {
375                Node<K,V> x = ceiling(root, key);
376                if (x == null) return null;
377                else           return x.key;
378        }
379
380        // the smallest key in the subtree rooted at x greater than or equal to the given key
381        private Node<K,V> ceiling(Node<K,V> x, K key) {
382                if (x == null) return null;
383                int cmp = key.compareTo(x.key);
384                if (cmp == 0) return x;
385                if (cmp > 0)  return ceiling(x.right, key);
386                Node<K,V> t = ceiling(x.left, key);
387                if (t != null) return t;
388                else           return x;
389        }
390
391
392        // the key of rank k
393        public K select(int k) {
394                if (k < 0 || k >= size())  return null;
395                Node<K,V> x = select(root, k);
396                return x.key;
397        }
398
399        // the key of rank k in the subtree rooted at x
400        private Node<K,V> select(Node<K,V> x, int k) {
401                assert x != null;
402                assert k >= 0 && k < size(x);
403                int t = size(x.left);
404                if      (t > k) return select(x.left,  k);
405                else if (t < k) return select(x.right, k-t-1);
406                else            return x;
407        }
408
409        // number of keys less than key
410        public int rank(K key) {
411                return rank(key, root);
412        }
413
414        // number of keys less than key in the subtree rooted at x
415        private int rank(K key, Node<K,V> x) {
416                if (x == null) return 0;
417                int cmp = key.compareTo(x.key);
418                if      (cmp < 0) return rank(key, x.left);
419                else if (cmp > 0) return 1 + size(x.left) + rank(key, x.right);
420                else              return size(x.left);
421        }
422
423        /* *********************************************************************
424         *  Range count and range search.
425         ***********************************************************************/
426
427        // all of the keys, as an Iterable
428        public Iterable<K> keys() {
429                return keys(min(), max());
430        }
431
432        // the keys between lo and hi, as an Iterable
433        public Iterable<K> keys(K lo, K hi) {
434                Queue<K> queue = new Queue<>();
435                // if (isEmpty() || lo.compareTo(hi) > 0) return queue;
436                keys(root, queue, lo, hi);
437                return queue;
438        }
439
440        // add the keys between lo and hi in the subtree rooted at x
441        // to the queue
442        private void keys(Node<K,V> x, Queue<K> queue, K lo, K hi) {
443                if (x == null) return;
444                int cmplo = lo.compareTo(x.key);
445                int cmphi = hi.compareTo(x.key);
446                if (cmplo < 0) keys(x.left, queue, lo, hi);
447                if (cmplo <= 0 && cmphi >= 0) queue.enqueue(x.key);
448                if (cmphi > 0) keys(x.right, queue, lo, hi);
449        }
450
451        // number keys between lo and hi
452        public int size(K lo, K hi) {
453                if (lo.compareTo(hi) > 0) return 0;
454                if (contains(hi)) return rank(hi) - rank(lo) + 1;
455                else              return rank(hi) - rank(lo);
456        }
457
458        /* ***********************************************************************
459         *  Check integrity of red-black BST data structure
460         *************************************************************************/
461        private boolean check() {
462                if (!isBST())            StdOut.format("Not in symmetric order: %s\n", this);
463                if (!isSizeConsistent()) StdOut.format("Subtree counts not consistent: %s\n", this);
464                if (!isRankConsistent()) StdOut.format("Ranks not consistent: %s\n", this);
465                if (!is23())             StdOut.format("Not a 2-3 tree: %s\n", this);
466                if (!isBalanced())       StdOut.format("Not balanced: %s\n", this);
467                return isBST() && isSizeConsistent() && isRankConsistent() && is23() && isBalanced();
468        }
469
470        // does this binary tree satisfy symmetric order?
471        // Note: this test also ensures that data structure is a binary tree since order is strict
472        private boolean isBST() {
473                return isBST(root, null, null);
474        }
475
476        // is the tree rooted at x a BST with all keys strictly between min and max
477        // (if min or max is null, treat as empty constraint)
478        // Credit: Bob Dondero's elegant solution
479        private boolean isBST(Node<K,V> x, K min, K max) {
480                if (x == null) return true;
481                if (min != null && x.key.compareTo(min) <= 0) return false;
482                if (max != null && x.key.compareTo(max) >= 0) return false;
483                return isBST(x.left, min, x.key) && isBST(x.right, x.key, max);
484        }
485
486        // are the size fields correct?
487        private boolean isSizeConsistent() { return isSizeConsistent(root); }
488        private boolean isSizeConsistent(Node<K,V> x) {
489                if (x == null) return true;
490                if (x.N != size(x.left) + size(x.right) + 1) return false;
491                return isSizeConsistent(x.left) && isSizeConsistent(x.right);
492        }
493
494        // check that ranks are consistent
495        private boolean isRankConsistent() {
496                for (int i = 0; i < size(); i++)
497                        if (i != rank(select(i))) return false;
498                for (K key : keys())
499                        if (key.compareTo(select(rank(key))) != 0) return false;
500                return true;
501        }
502
503        // Does the tree have no red right links, and at most one (left)
504        // red links in a row on any path?
505        private boolean is23() { return is23(root); }
506        private boolean is23(Node<K,V> x) {
507                if (x == null) return true;
508                if (isRed(x.right)) return false;
509                if (x != root && isRed(x) && isRed(x.left))
510                        return false;
511                return is23(x.left) && is23(x.right);
512        }
513
514        // do all paths from root to leaf have same number of black edges?
515        private boolean isBalanced() {
516                int black = 0;     // number of black links on path from root to min
517                Node<K,V> x = root;
518                while (x != null) {
519                        if (!isRed(x)) black++;
520                        x = x.left;
521                }
522                return isBalanced(root, black);
523        }
524
525        // does every path from the root to a leaf have the given number of black links?
526        private boolean isBalanced(Node<K,V> x, int black) {
527                if (x == null) return black == 0;
528                if (!isRed(x)) black--;
529                return isBalanced(x.left, black) && isBalanced(x.right, black);
530        }
531
532        /* ***************************************************************************
533         *  Visualization
534         *****************************************************************************/
535        private Iterable<Node<K,V>> levelOrderNodes() {
536                Queue<Node<K,V>> keys = new Queue<>();
537                Queue<Node<K,V>> queue = new Queue<>();
538                queue.enqueue(root);
539                while (!queue.isEmpty()) {
540                        Node<K,V> x = queue.dequeue();
541                        if (x == null) continue;
542                        keys.enqueue(x);
543                        queue.enqueue(x.left);
544                        queue.enqueue(x.right);
545                }
546                return keys;
547        }
548
549        public String toString() {
550                StringBuilder sb = new StringBuilder();
551                for (Node<K,V> n: levelOrderNodes())
552                        sb.append (n.key + (n.color ? "* " : " "));
553                return sb.toString ();
554        }
555
556        public void toGraphviz(String filename) {
557                GraphvizBuilder gb = new GraphvizBuilder ();
558                toGraphviz (gb, null, root);
559                gb.toFileUndirected (filename, "ordering=\"out\"");
560        }
561        private void toGraphviz (GraphvizBuilder gb, Node<K, V> parent, Node<K, V> n) {
562                if (n == null) { gb.addNullEdge (parent); return; }
563                String nodeProperties = n.color ? "color=\"red\"" : "";
564                String edgeProperties = n.color ? "color=\"red\",style=\"bold\"" : "";
565                gb.addLabeledNode (n, n.key.toString (), nodeProperties);
566                if (parent != null) gb.addEdge (parent, n, edgeProperties);
567                toGraphviz (gb, n, n.left);
568                toGraphviz (gb, n, n.right);
569        }
570
571        public void drawTree() {
572                if (root != null) {
573                        StdDraw.setCanvasSize(1200,700);
574                        drawTree(root, .5, 1, .25, 0);
575                }
576        }
577        private void drawTree (Node<K,V> n, double x, double y, double range, int depth) {
578                int CUTOFF = 5;
579                StdDraw.setPenColor (StdDraw.BLACK);
580                StdDraw.text (x, y, n.key.toString ());
581                StdDraw.setPenRadius (.005);
582                if (n.left != null && depth != CUTOFF) {
583                        if (n.left.color == RED) {
584                                StdDraw.setPenRadius (.01);
585                                StdDraw.setPenColor (StdDraw.RED);
586                        }
587                        StdDraw.line (x-range, y-.13, x-.01, y-.01);
588                        drawTree (n.left, x-range, y-.15, range*.5, depth+1);
589                }
590                if (n.right != null && depth != CUTOFF) {
591                        StdDraw.line (x+range, y-.13, x+.01, y-.01);
592                        drawTree (n.right, x+range, y-.15, range*.5, depth+1);
593                }
594        }
595        /* ***************************************************************************
596         *  Test client
597         *****************************************************************************/
598        public static void main(String[] args) {
599                StdIn.fromString ("S E A R C H E X A M P L E");
600                //StdIn.fromString ("D F B  G E A C");
601
602                RedBlackBST<String, Integer> st = new RedBlackBST<>();
603                for (int i = 0; !StdIn.isEmpty(); i++) {
604                        String key = StdIn.readString();
605                        st.put(key, i);
606                }
607                st.toGraphviz ("g.png");
608                for (String s : st.keys())
609                        StdOut.println(s + " " + st.get(s));
610                st.drawTree ();
611        }
612}