001package algs52; // section 5.2
002import stdlib.*;
003import algs13.Queue;
004/* ***********************************************************************
005 *  Compilation:  javac TST.java
006 *  Execution:    java TST < words.txt
007 *  Dependencies: StdIn.java
008 *
009 *  Symbol table with string keys, implemented using a ternary search
010 *  trie (TST).
011 *
012 *
013 *  % java TST < shellsST.txt
014 *  by 4
015 *  sea 6
016 *  sells 1
017 *  she 0
018 *  shells 3
019 *  shore 7
020 *  the 5
021
022 *
023 *  % java TST
024 *  theory the now is the time for all good men
025
026 *  Remarks
027 *  --------
028 *    - can't use a key that is the empty string ""
029 *
030 *************************************************************************/
031
032public class TST<V> {
033        private int N;       // size
034        private Node<V> root;   // root of TST
035
036        private static class Node<V> {
037                public Node() { }
038                public char c;                 // character
039                public Node<V> left, mid, right;  // left, middle, and right subtries
040                public V val;              // value associated with string
041        }
042
043        // return number of key-value pairs
044        public int size() {
045                return N;
046        }
047
048        /* ************************************************************
049         * Is string key in the symbol table?
050         **************************************************************/
051        public boolean contains(String key) {
052                return get(key) != null;
053        }
054
055        public V get(String key) {
056                if (key == null || key.length() == 0) throw new Error("illegal key");
057                Node<V> x = get(root, key, 0);
058                if (x == null) return null;
059                return x.val;
060        }
061
062        // return subtrie corresponding to given key
063        private Node<V> get(Node<V> x, String key, int d) {
064                if (key == null || key.length() == 0) throw new Error("illegal key");
065                if (x == null) return null;
066                char c = key.charAt(d);
067                if      (c < x.c)              return get(x.left,  key, d);
068                else if (c > x.c)              return get(x.right, key, d);
069                else if (d < key.length() - 1) return get(x.mid,   key, d+1);
070                else                           return x;
071        }
072
073
074        /* ************************************************************
075         * Insert string s into the symbol table.
076         **************************************************************/
077        public void put(String s, V val) {
078                if (!contains(s)) N++;
079                root = put(root, s, val, 0);
080        }
081
082        private Node<V> put(Node<V> x, String s, V val, int d) {
083                char c = s.charAt(d);
084                if (x == null) {
085                        x = new Node<>();
086                        x.c = c;
087                }
088                if      (c < x.c)             x.left  = put(x.left,  s, val, d);
089                else if (c > x.c)             x.right = put(x.right, s, val, d);
090                else if (d < s.length() - 1)  x.mid   = put(x.mid,   s, val, d+1);
091                else                          x.val   = val;
092                return x;
093        }
094
095
096        /* ************************************************************
097         * Find and return longest prefix of s in TST
098         **************************************************************/
099        public String longestPrefixOf(String s) {
100                if (s == null || s.length() == 0) return null;
101                int length = 0;
102                Node<V> x = root;
103                int i = 0;
104                while (x != null && i < s.length()) {
105                        char c = s.charAt(i);
106                        if      (c < x.c) x = x.left;
107                        else if (c > x.c) x = x.right;
108                        else {
109                                i++;
110                                if (x.val != null) length = i;
111                                x = x.mid;
112                        }
113                }
114                return s.substring(0, length);
115        }
116
117        // all keys in symbol table
118        public Iterable<String> keys() {
119                Queue<String> queue = new Queue<>();
120                collect(root, "", queue);
121                return queue;
122        }
123
124        // all keys starting with given prefix
125        public Iterable<String> prefixMatch(String prefix) {
126                Queue<String> queue = new Queue<>();
127                Node<V> x = get(root, prefix, 0);
128                if (x == null) return queue;
129                if (x.val != null) queue.enqueue(prefix);
130                collect(x.mid, prefix, queue);
131                return queue;
132        }
133
134        // all keys in subtrie rooted at x with given prefix
135        private void collect(Node<V> x, String prefix, Queue<String> queue) {
136                if (x == null) return;
137                collect(x.left,  prefix,       queue);
138                if (x.val != null) queue.enqueue(prefix + x.c);
139                collect(x.mid,   prefix + x.c, queue);
140                collect(x.right, prefix,       queue);
141        }
142
143
144        // return all keys matching given wilcard pattern
145        public Iterable<String> wildcardMatch(String pat) {
146                Queue<String> queue = new Queue<>();
147                collect(root, "", 0, pat, queue);
148                return queue;
149        }
150
151        private void collect(Node<V> x, String prefix, int i, String pat, Queue<String> q) {
152                if (x == null) return;
153                char c = pat.charAt(i);
154                if (c == '.' || c < x.c) collect(x.left, prefix, i, pat, q);
155                if (c == '.' || c == x.c) {
156                        if (i == pat.length() - 1 && x.val != null) q.enqueue(prefix + x.c);
157                        if (i < pat.length() - 1) collect(x.mid, prefix + x.c, i+1, pat, q);
158                }
159                if (c == '.' || c > x.c) collect(x.right, prefix, i, pat, q);
160        }
161
162
163
164        // test client
165        public static void main(String[] args) {
166                // build symbol table from standard input
167                TST<Integer> st = new TST<>();
168                for (int i = 0; !StdIn.isEmpty(); i++) {
169                        String key = StdIn.readString();
170                        st.put(key, i);
171                }
172
173
174                // print results
175                for (String key : st.keys()) {
176                        StdOut.println(key + " " + st.get(key));
177                }
178        }
179}