001package algs52; // section 5.2 002import stdlib.*; 003import algs13.Queue; 004/* *********************************************************************** 005 * Compilation: javac TST.java 006 * Execution: java TST < words.txt 007 * Dependencies: StdIn.java 008 * 009 * Symbol table with string keys, implemented using a ternary search 010 * trie (TST). 011 * 012 * 013 * % java TST < shellsST.txt 014 * by 4 015 * sea 6 016 * sells 1 017 * she 0 018 * shells 3 019 * shore 7 020 * the 5 021 022 * 023 * % java TST 024 * theory the now is the time for all good men 025 026 * Remarks 027 * -------- 028 * - can't use a key that is the empty string "" 029 * 030 *************************************************************************/ 031 032public class TST<V> { 033 private int N; // size 034 private Node<V> root; // root of TST 035 036 private static class Node<V> { 037 public Node() { } 038 public char c; // character 039 public Node<V> left, mid, right; // left, middle, and right subtries 040 public V val; // value associated with string 041 } 042 043 // return number of key-value pairs 044 public int size() { 045 return N; 046 } 047 048 /* ************************************************************ 049 * Is string key in the symbol table? 050 **************************************************************/ 051 public boolean contains(String key) { 052 return get(key) != null; 053 } 054 055 public V get(String key) { 056 if (key == null || key.length() == 0) throw new Error("illegal key"); 057 Node<V> x = get(root, key, 0); 058 if (x == null) return null; 059 return x.val; 060 } 061 062 // return subtrie corresponding to given key 063 private Node<V> get(Node<V> x, String key, int d) { 064 if (key == null || key.length() == 0) throw new Error("illegal key"); 065 if (x == null) return null; 066 char c = key.charAt(d); 067 if (c < x.c) return get(x.left, key, d); 068 else if (c > x.c) return get(x.right, key, d); 069 else if (d < key.length() - 1) return get(x.mid, key, d+1); 070 else return x; 071 } 072 073 074 /* ************************************************************ 075 * Insert string s into the symbol table. 076 **************************************************************/ 077 public void put(String s, V val) { 078 if (!contains(s)) N++; 079 root = put(root, s, val, 0); 080 } 081 082 private Node<V> put(Node<V> x, String s, V val, int d) { 083 char c = s.charAt(d); 084 if (x == null) { 085 x = new Node<>(); 086 x.c = c; 087 } 088 if (c < x.c) x.left = put(x.left, s, val, d); 089 else if (c > x.c) x.right = put(x.right, s, val, d); 090 else if (d < s.length() - 1) x.mid = put(x.mid, s, val, d+1); 091 else x.val = val; 092 return x; 093 } 094 095 096 /* ************************************************************ 097 * Find and return longest prefix of s in TST 098 **************************************************************/ 099 public String longestPrefixOf(String s) { 100 if (s == null || s.length() == 0) return null; 101 int length = 0; 102 Node<V> x = root; 103 int i = 0; 104 while (x != null && i < s.length()) { 105 char c = s.charAt(i); 106 if (c < x.c) x = x.left; 107 else if (c > x.c) x = x.right; 108 else { 109 i++; 110 if (x.val != null) length = i; 111 x = x.mid; 112 } 113 } 114 return s.substring(0, length); 115 } 116 117 // all keys in symbol table 118 public Iterable<String> keys() { 119 Queue<String> queue = new Queue<>(); 120 collect(root, "", queue); 121 return queue; 122 } 123 124 // all keys starting with given prefix 125 public Iterable<String> prefixMatch(String prefix) { 126 Queue<String> queue = new Queue<>(); 127 Node<V> x = get(root, prefix, 0); 128 if (x == null) return queue; 129 if (x.val != null) queue.enqueue(prefix); 130 collect(x.mid, prefix, queue); 131 return queue; 132 } 133 134 // all keys in subtrie rooted at x with given prefix 135 private void collect(Node<V> x, String prefix, Queue<String> queue) { 136 if (x == null) return; 137 collect(x.left, prefix, queue); 138 if (x.val != null) queue.enqueue(prefix + x.c); 139 collect(x.mid, prefix + x.c, queue); 140 collect(x.right, prefix, queue); 141 } 142 143 144 // return all keys matching given wilcard pattern 145 public Iterable<String> wildcardMatch(String pat) { 146 Queue<String> queue = new Queue<>(); 147 collect(root, "", 0, pat, queue); 148 return queue; 149 } 150 151 private void collect(Node<V> x, String prefix, int i, String pat, Queue<String> q) { 152 if (x == null) return; 153 char c = pat.charAt(i); 154 if (c == '.' || c < x.c) collect(x.left, prefix, i, pat, q); 155 if (c == '.' || c == x.c) { 156 if (i == pat.length() - 1 && x.val != null) q.enqueue(prefix + x.c); 157 if (i < pat.length() - 1) collect(x.mid, prefix + x.c, i+1, pat, q); 158 } 159 if (c == '.' || c > x.c) collect(x.right, prefix, i, pat, q); 160 } 161 162 163 164 // test client 165 public static void main(String[] args) { 166 // build symbol table from standard input 167 TST<Integer> st = new TST<>(); 168 for (int i = 0; !StdIn.isEmpty(); i++) { 169 String key = StdIn.readString(); 170 st.put(key, i); 171 } 172 173 174 // print results 175 for (String key : st.keys()) { 176 StdOut.println(key + " " + st.get(key)); 177 } 178 } 179}