001package algs33;
002import stdlib.*;
003import java.util.Iterator;
004import java.util.NoSuchElementException;
005import algs13.Stack;
006/* ***********************************************************************
007 *  Compilation:  javac RandomizedBST.java
008 *  Execution:    java RandomizedBST
009 *
010 *  Symbol table (map) implemented with a randomized BST.
011 *
012 *
013 *************************************************************************/
014
015public class XRandomizedBST<K extends Comparable<? super K>, V> implements Iterable<K> {
016
017        private Node<K,V> root;   // root of the BST
018
019        // BST helper node data type
020        private static class Node<K,V> {
021                public K key;          // key
022                public V val;          // associated data
023                public Node<K,V> left, right;   // left and right subtrees
024                public int N;              // node count of descendents
025
026                public Node(K key, V val) {
027                        this.key = key;
028                        this.val = val;
029                        this.N   = 1;
030                }
031        }
032
033
034        /* ***********************************************************************
035         *  BST search
036         *************************************************************************/
037
038        public boolean contains(K key) {
039                return (get(key) != null);
040        }
041
042        // return value associated with the given key
043        // if no such value, return null
044        // if multiple such values, return first one on path from root
045        public V get(K key) {
046                return get(root, key);
047        }
048
049        private V get(Node<K,V> x, K key) {
050                if (x == null) return null;
051                int cmp = key.compareTo(x.key);
052                if      (cmp == 0) return x.val;
053                else if (cmp  < 0) return get(x.left,  key);
054                else               return get(x.right, key);
055        }
056
057
058        /* ***********************************************************************
059         *  randomized insertion
060         *************************************************************************/
061        public void put(K key, V val) {
062                root = put(root, key, val);
063        }
064
065        // make new node the root with uniform probability
066        private Node<K,V> put(Node<K,V> x, K key, V val) {
067                if (x == null) return new Node<>(key, val);
068                int cmp = key.compareTo(x.key);
069                if (cmp == 0) { x.val = val; return x; }
070                if (StdRandom.bernoulli(1.0 / (size(x) + 1.0))) return putRoot(x, key, val);
071                if (cmp < 0) x.left  = put(x.left,  key, val);
072                else         x.right = put(x.right, key, val);
073                // (x.N)++;
074                fix(x);
075                return x;
076        }
077
078
079        private Node<K,V> putRoot(Node<K,V> x, K key, V val) {
080                if (x == null) return new Node<>(key, val);
081                int cmp = key.compareTo(x.key);
082                if      (cmp == 0) { x.val = val; return x; }
083                else if (cmp  < 0) { x.left  = putRoot(x.left,  key, val); x = rotR(x); }
084                else               { x.right = putRoot(x.right, key, val); x = rotL(x); }
085                return x;
086        }
087
088
089
090        /* ***********************************************************************
091         *  deletion
092         *************************************************************************/
093        private Node<K,V> joinLR(Node<K,V> a, Node<K,V> b) {
094                if (a == null) return b;
095                if (b == null) return a;
096
097                if (StdRandom.bernoulli((double) size(a) / (size(a) + size(b))))  {
098                        a.right = joinLR(a.right, b);
099                        fix(a);
100                        return a;
101                }
102                else {
103                        b.left = joinLR(a, b.left);
104                        fix(b);
105                        return b;
106                }
107        }
108
109        private Node<K,V> remove(Node<K,V> x, K key) {
110                if (x == null) return null;
111                int cmp = key.compareTo(x.key);
112                if      (cmp == 0) x = joinLR(x.left, x.right);
113                else if (cmp  < 0) x.left  = remove(x.left,  key);
114                else               x.right = remove(x.right, key);
115                fix(x);
116                return x;
117        }
118
119        // remove and return value associated with given key; if no such key, return null
120        public V remove(K key) {
121                V val = get(key);
122                root = remove(root, key);
123                return val;
124        }
125
126        /* ***********************************************************************
127         *  Selection
128         *************************************************************************/
129
130        // return the kth largest key
131        public K select(int k) { Node<K,V> x = select(root, k); return x.key; }
132        private Node<K,V> select(Node<K,V> x, int k) {
133                if (x == null) return null;
134                int t = size(x.left);
135                if      (t > k) return select(x.left,  k);
136                else if (t < k) return select(x.right, k-t-1);
137                else            return x;
138        }
139
140
141
142        // return the smallest key
143        public K min() {
144                K key = null;
145                for (Node<K,V> x = root; x != null; x = x.left)
146                        key = x.key;
147                return key;
148        }
149
150        // return the largest key
151        public K max() {
152                K key = null;
153                for (Node<K,V> x = root; x != null; x = x.right)
154                        key = x.key;
155                return key;
156        }
157
158        // return the smallest key >= query key; if no such key return null
159        public K ceil(K key) {
160                Node<K,V> best = ceil(root, key, null);
161                if (best == null) return null;
162                return best.key;
163        }
164        private Node<K,V> ceil(Node<K,V> x, K key, Node<K,V> best) {
165                if      (x == null)        return best;
166                else if (eq(key, x.key))   return x;
167                else if (less(key, x.key)) return ceil(x.left,  key, x);
168                else                       return ceil(x.right, key, best);
169        }
170
171        // return the smallest key >= query key; if no such key return null
172        public K ceil2(K key) {
173                Node<K,V> best = null;
174                Node<K,V> x = root;
175                while (x != null) {
176                        int cmp = key.compareTo(x.key);
177                        if      (cmp < 0) { best = x; x = x.left; }
178                        else if (cmp > 0) { x = x.right;          }
179                        else              return x.key;
180                }
181                if (best == null) return null;
182                return best.key;
183        }
184
185
186        /* *********************************************************************
187         *  Iterate using inorder traversal using a stack.
188         *  Iterating through N elements takes O(N) time.
189         ***********************************************************************/
190        public Iterator<K> iterator() { return new BSTIterator(root); }
191
192        // an iterator
193        private class BSTIterator implements Iterator<K> {
194                private Stack<Node<K,V>> stack = new Stack<>();
195
196                public BSTIterator(Node<K,V> x) {
197                        while (x != null) {
198                                stack.push(x);
199                                x = x.left;
200                        }
201                }
202
203                public boolean hasNext()  { return !stack.isEmpty();                    }
204
205                // it's optional and we don't want to support it
206                public void remove()      { throw new UnsupportedOperationException();  }
207
208                public K next() {
209                        if (!hasNext()) throw new NoSuchElementException();
210                        Node<K,V> x = stack.pop();
211                        K key = x.key;
212                        x = x.right;
213                        while (x != null) {
214                                stack.push(x);
215                                x = x.left;
216                        }
217                        return key;
218                }
219        }
220
221
222
223
224        /* ***********************************************************************
225         *  Utility functions.
226         *************************************************************************/
227
228        // return number of nodes in subtree rooted at x
229        public int size() { return size(root); }
230        private int size(Node<K,V> x) {
231                if (x == null) return 0;
232                else           return x.N;
233        }
234
235        // height of tree (empty tree height = 0)
236        public int height() { return height(root); }
237        private int height(Node<K,V> x) {
238                if (x == null) return 0;
239                return 1 + Math.max(height(x.left), height(x.right));
240        }
241
242
243        /* ***********************************************************************
244         *  helper BST functions
245         *************************************************************************/
246
247        // fix subtree count field
248        private void fix(Node<K,V> x) {
249                if (x == null) return;
250                x.N = 1 + size(x.left) + size(x.right);
251        }
252
253        // right rotate
254        private Node<K,V> rotR(Node<K,V> h) {
255                Node<K,V> x = h.left;
256                h.left = x.right;
257                x.right = h;
258                fix(h);
259                fix(x);
260                return x;
261        }
262
263        // left rotate
264        private Node<K,V> rotL(Node<K,V> h) {
265                Node<K,V> x = h.right;
266                h.right = x.left;
267                x.left = h;
268                fix(h);
269                fix(x);
270                return x;
271        }
272
273
274        /* ***********************************************************************
275         *  Debugging functions that test the integrity of the tree
276         *************************************************************************/
277
278        // check integrity of subtree count fields
279        public boolean check() { return checkCount() && isBST(); }
280
281        // check integrity of count fields
282        private boolean checkCount() { return checkCount(root); }
283        private boolean checkCount(Node<K,V> x) {
284                if (x == null) return true;
285                return checkCount(x.left) && checkCount(x.right) && (x.N == 1 + size(x.left) + size(x.right));
286        }
287
288
289        // does this tree satisfy the BST property?
290        private boolean isBST() { return isBST(root, min(), max()); }
291
292        // are all the values in the BST rooted at x between min and max, and recursively?
293        private boolean isBST(Node<K,V> x, K min, K max) {
294                if (x == null) return true;
295                if (less(x.key, min) || less(max, x.key)) return false;
296                return isBST(x.left, min, x.key) && isBST(x.right, x.key, max);
297        }
298
299
300
301        /* ***********************************************************************
302         *  helper comparison functions
303         *************************************************************************/
304
305        private boolean less(K k1, K k2) {
306                return k1.compareTo(k2) < 0;
307        }
308
309        private boolean eq(K k1, K k2) {
310                return k1.compareTo(k2) == 0;
311        }
312
313
314
315        /* ***********************************************************************
316         *  test client
317         *************************************************************************/
318        public static void main(String[] args) {
319                XRandomizedBST<String, String> st = new XRandomizedBST<>();
320
321                // insert some key-value pairs
322                st.put("www.cs.princeton.edu",   "128.112.136.11");
323                st.put("www.cs.princeton.edu",   "128.112.136.35");    // overwrite old value
324                st.put("www.princeton.edu",      "128.112.130.211");
325                st.put("www.math.princeton.edu", "128.112.18.11");
326                st.put("www.yale.edu",           "130.132.51.8");
327                st.put("www.amazon.com",         "207.171.163.90");
328                st.put("www.simpsons.com",       "209.123.16.34");
329                st.put("www.stanford.edu",       "171.67.16.120");
330                st.put("www.google.com",         "64.233.161.99");
331                st.put("www.ibm.com",            "129.42.16.99");
332                st.put("www.apple.com",          "17.254.0.91");
333                st.put("www.slashdot.com",       "66.35.250.150");
334                st.put("www.whitehouse.gov",     "204.153.49.136");
335                st.put("www.espn.com",           "199.181.132.250");
336                st.put("www.snopes.com",         "66.165.133.65");
337                st.put("www.movies.com",         "199.181.132.250");
338                st.put("www.cnn.com",            "64.236.16.20");
339                st.put("www.iitb.ac.in",         "202.68.145.210");
340
341
342                StdOut.println(st.get("www.cs.princeton.edu"));
343                StdOut.println(st.get("www.harvardsucks.com"));
344                StdOut.println(st.get("www.simpsons.com"));
345                StdOut.println();
346
347                StdOut.println("integrity check: " + st.check());
348                StdOut.println();
349
350                StdOut.println("ceil(www.simpsonr.com) = " + st.ceil("www.simpsonr.com"));
351                StdOut.println("ceil(www.simpsons.com) = " + st.ceil("www.simpsons.com"));
352                StdOut.println("ceil(www.simpsont.com) = " + st.ceil("www.simpsont.com"));
353
354                StdOut.println("ceil(www.simpsonr.com) = " + st.ceil2("www.simpsonr.com"));
355                StdOut.println("ceil(www.simpsons.com) = " + st.ceil2("www.simpsons.com"));
356                StdOut.println("ceil(www.simpsont.com) = " + st.ceil2("www.simpsont.com"));
357                StdOut.println();
358
359                for (int i = 0; i < st.size(); i++) {
360                        StdOut.println(i + "th: key  " + st.select(i));
361                }
362                StdOut.println();
363
364                StdOut.println("min key: " + st.min());
365                StdOut.println("max key: " + st.max());
366                StdOut.println("size:    " + st.size());
367                StdOut.println("height:  " + st.height());
368                StdOut.println();
369        }
370
371}