001package algs33; 002import stdlib.*; 003import java.util.Iterator; 004import java.util.NoSuchElementException; 005import algs13.Stack; 006/* *********************************************************************** 007 * Compilation: javac RandomizedBST.java 008 * Execution: java RandomizedBST 009 * 010 * Symbol table (map) implemented with a randomized BST. 011 * 012 * 013 *************************************************************************/ 014 015public class XRandomizedBST<K extends Comparable<? super K>, V> implements Iterable<K> { 016 017 private Node<K,V> root; // root of the BST 018 019 // BST helper node data type 020 private static class Node<K,V> { 021 public K key; // key 022 public V val; // associated data 023 public Node<K,V> left, right; // left and right subtrees 024 public int N; // node count of descendents 025 026 public Node(K key, V val) { 027 this.key = key; 028 this.val = val; 029 this.N = 1; 030 } 031 } 032 033 034 /* *********************************************************************** 035 * BST search 036 *************************************************************************/ 037 038 public boolean contains(K key) { 039 return (get(key) != null); 040 } 041 042 // return value associated with the given key 043 // if no such value, return null 044 // if multiple such values, return first one on path from root 045 public V get(K key) { 046 return get(root, key); 047 } 048 049 private V get(Node<K,V> x, K key) { 050 if (x == null) return null; 051 int cmp = key.compareTo(x.key); 052 if (cmp == 0) return x.val; 053 else if (cmp < 0) return get(x.left, key); 054 else return get(x.right, key); 055 } 056 057 058 /* *********************************************************************** 059 * randomized insertion 060 *************************************************************************/ 061 public void put(K key, V val) { 062 root = put(root, key, val); 063 } 064 065 // make new node the root with uniform probability 066 private Node<K,V> put(Node<K,V> x, K key, V val) { 067 if (x == null) return new Node<>(key, val); 068 int cmp = key.compareTo(x.key); 069 if (cmp == 0) { x.val = val; return x; } 070 if (StdRandom.bernoulli(1.0 / (size(x) + 1.0))) return putRoot(x, key, val); 071 if (cmp < 0) x.left = put(x.left, key, val); 072 else x.right = put(x.right, key, val); 073 // (x.N)++; 074 fix(x); 075 return x; 076 } 077 078 079 private Node<K,V> putRoot(Node<K,V> x, K key, V val) { 080 if (x == null) return new Node<>(key, val); 081 int cmp = key.compareTo(x.key); 082 if (cmp == 0) { x.val = val; return x; } 083 else if (cmp < 0) { x.left = putRoot(x.left, key, val); x = rotR(x); } 084 else { x.right = putRoot(x.right, key, val); x = rotL(x); } 085 return x; 086 } 087 088 089 090 /* *********************************************************************** 091 * deletion 092 *************************************************************************/ 093 private Node<K,V> joinLR(Node<K,V> a, Node<K,V> b) { 094 if (a == null) return b; 095 if (b == null) return a; 096 097 if (StdRandom.bernoulli((double) size(a) / (size(a) + size(b)))) { 098 a.right = joinLR(a.right, b); 099 fix(a); 100 return a; 101 } 102 else { 103 b.left = joinLR(a, b.left); 104 fix(b); 105 return b; 106 } 107 } 108 109 private Node<K,V> remove(Node<K,V> x, K key) { 110 if (x == null) return null; 111 int cmp = key.compareTo(x.key); 112 if (cmp == 0) x = joinLR(x.left, x.right); 113 else if (cmp < 0) x.left = remove(x.left, key); 114 else x.right = remove(x.right, key); 115 fix(x); 116 return x; 117 } 118 119 // remove and return value associated with given key; if no such key, return null 120 public V remove(K key) { 121 V val = get(key); 122 root = remove(root, key); 123 return val; 124 } 125 126 /* *********************************************************************** 127 * Selection 128 *************************************************************************/ 129 130 // return the kth largest key 131 public K select(int k) { Node<K,V> x = select(root, k); return x.key; } 132 private Node<K,V> select(Node<K,V> x, int k) { 133 if (x == null) return null; 134 int t = size(x.left); 135 if (t > k) return select(x.left, k); 136 else if (t < k) return select(x.right, k-t-1); 137 else return x; 138 } 139 140 141 142 // return the smallest key 143 public K min() { 144 K key = null; 145 for (Node<K,V> x = root; x != null; x = x.left) 146 key = x.key; 147 return key; 148 } 149 150 // return the largest key 151 public K max() { 152 K key = null; 153 for (Node<K,V> x = root; x != null; x = x.right) 154 key = x.key; 155 return key; 156 } 157 158 // return the smallest key >= query key; if no such key return null 159 public K ceil(K key) { 160 Node<K,V> best = ceil(root, key, null); 161 if (best == null) return null; 162 return best.key; 163 } 164 private Node<K,V> ceil(Node<K,V> x, K key, Node<K,V> best) { 165 if (x == null) return best; 166 else if (eq(key, x.key)) return x; 167 else if (less(key, x.key)) return ceil(x.left, key, x); 168 else return ceil(x.right, key, best); 169 } 170 171 // return the smallest key >= query key; if no such key return null 172 public K ceil2(K key) { 173 Node<K,V> best = null; 174 Node<K,V> x = root; 175 while (x != null) { 176 int cmp = key.compareTo(x.key); 177 if (cmp < 0) { best = x; x = x.left; } 178 else if (cmp > 0) { x = x.right; } 179 else return x.key; 180 } 181 if (best == null) return null; 182 return best.key; 183 } 184 185 186 /* ********************************************************************* 187 * Iterate using inorder traversal using a stack. 188 * Iterating through N elements takes O(N) time. 189 ***********************************************************************/ 190 public Iterator<K> iterator() { return new BSTIterator(root); } 191 192 // an iterator 193 private class BSTIterator implements Iterator<K> { 194 private Stack<Node<K,V>> stack = new Stack<>(); 195 196 public BSTIterator(Node<K,V> x) { 197 while (x != null) { 198 stack.push(x); 199 x = x.left; 200 } 201 } 202 203 public boolean hasNext() { return !stack.isEmpty(); } 204 205 // it's optional and we don't want to support it 206 public void remove() { throw new UnsupportedOperationException(); } 207 208 public K next() { 209 if (!hasNext()) throw new NoSuchElementException(); 210 Node<K,V> x = stack.pop(); 211 K key = x.key; 212 x = x.right; 213 while (x != null) { 214 stack.push(x); 215 x = x.left; 216 } 217 return key; 218 } 219 } 220 221 222 223 224 /* *********************************************************************** 225 * Utility functions. 226 *************************************************************************/ 227 228 // return number of nodes in subtree rooted at x 229 public int size() { return size(root); } 230 private int size(Node<K,V> x) { 231 if (x == null) return 0; 232 else return x.N; 233 } 234 235 // height of tree (empty tree height = 0) 236 public int height() { return height(root); } 237 private int height(Node<K,V> x) { 238 if (x == null) return 0; 239 return 1 + Math.max(height(x.left), height(x.right)); 240 } 241 242 243 /* *********************************************************************** 244 * helper BST functions 245 *************************************************************************/ 246 247 // fix subtree count field 248 private void fix(Node<K,V> x) { 249 if (x == null) return; 250 x.N = 1 + size(x.left) + size(x.right); 251 } 252 253 // right rotate 254 private Node<K,V> rotR(Node<K,V> h) { 255 Node<K,V> x = h.left; 256 h.left = x.right; 257 x.right = h; 258 fix(h); 259 fix(x); 260 return x; 261 } 262 263 // left rotate 264 private Node<K,V> rotL(Node<K,V> h) { 265 Node<K,V> x = h.right; 266 h.right = x.left; 267 x.left = h; 268 fix(h); 269 fix(x); 270 return x; 271 } 272 273 274 /* *********************************************************************** 275 * Debugging functions that test the integrity of the tree 276 *************************************************************************/ 277 278 // check integrity of subtree count fields 279 public boolean check() { return checkCount() && isBST(); } 280 281 // check integrity of count fields 282 private boolean checkCount() { return checkCount(root); } 283 private boolean checkCount(Node<K,V> x) { 284 if (x == null) return true; 285 return checkCount(x.left) && checkCount(x.right) && (x.N == 1 + size(x.left) + size(x.right)); 286 } 287 288 289 // does this tree satisfy the BST property? 290 private boolean isBST() { return isBST(root, min(), max()); } 291 292 // are all the values in the BST rooted at x between min and max, and recursively? 293 private boolean isBST(Node<K,V> x, K min, K max) { 294 if (x == null) return true; 295 if (less(x.key, min) || less(max, x.key)) return false; 296 return isBST(x.left, min, x.key) && isBST(x.right, x.key, max); 297 } 298 299 300 301 /* *********************************************************************** 302 * helper comparison functions 303 *************************************************************************/ 304 305 private boolean less(K k1, K k2) { 306 return k1.compareTo(k2) < 0; 307 } 308 309 private boolean eq(K k1, K k2) { 310 return k1.compareTo(k2) == 0; 311 } 312 313 314 315 /* *********************************************************************** 316 * test client 317 *************************************************************************/ 318 public static void main(String[] args) { 319 XRandomizedBST<String, String> st = new XRandomizedBST<>(); 320 321 // insert some key-value pairs 322 st.put("www.cs.princeton.edu", "128.112.136.11"); 323 st.put("www.cs.princeton.edu", "128.112.136.35"); // overwrite old value 324 st.put("www.princeton.edu", "128.112.130.211"); 325 st.put("www.math.princeton.edu", "128.112.18.11"); 326 st.put("www.yale.edu", "130.132.51.8"); 327 st.put("www.amazon.com", "207.171.163.90"); 328 st.put("www.simpsons.com", "209.123.16.34"); 329 st.put("www.stanford.edu", "171.67.16.120"); 330 st.put("www.google.com", "64.233.161.99"); 331 st.put("www.ibm.com", "129.42.16.99"); 332 st.put("www.apple.com", "17.254.0.91"); 333 st.put("www.slashdot.com", "66.35.250.150"); 334 st.put("www.whitehouse.gov", "204.153.49.136"); 335 st.put("www.espn.com", "199.181.132.250"); 336 st.put("www.snopes.com", "66.165.133.65"); 337 st.put("www.movies.com", "199.181.132.250"); 338 st.put("www.cnn.com", "64.236.16.20"); 339 st.put("www.iitb.ac.in", "202.68.145.210"); 340 341 342 StdOut.println(st.get("www.cs.princeton.edu")); 343 StdOut.println(st.get("www.harvardsucks.com")); 344 StdOut.println(st.get("www.simpsons.com")); 345 StdOut.println(); 346 347 StdOut.println("integrity check: " + st.check()); 348 StdOut.println(); 349 350 StdOut.println("ceil(www.simpsonr.com) = " + st.ceil("www.simpsonr.com")); 351 StdOut.println("ceil(www.simpsons.com) = " + st.ceil("www.simpsons.com")); 352 StdOut.println("ceil(www.simpsont.com) = " + st.ceil("www.simpsont.com")); 353 354 StdOut.println("ceil(www.simpsonr.com) = " + st.ceil2("www.simpsonr.com")); 355 StdOut.println("ceil(www.simpsons.com) = " + st.ceil2("www.simpsons.com")); 356 StdOut.println("ceil(www.simpsont.com) = " + st.ceil2("www.simpsont.com")); 357 StdOut.println(); 358 359 for (int i = 0; i < st.size(); i++) { 360 StdOut.println(i + "th: key " + st.select(i)); 361 } 362 StdOut.println(); 363 364 StdOut.println("min key: " + st.min()); 365 StdOut.println("max key: " + st.max()); 366 StdOut.println("size: " + st.size()); 367 StdOut.println("height: " + st.height()); 368 StdOut.println(); 369 } 370 371}