001package algs52; // section 5.2
002import stdlib.*;
003import algs13.Queue;
004/* ***********************************************************************
005 *  Compilation:  javac TrieST.java
006 *  Execution:    java TrieST < words.txt
007 *  Dependencies: StdIn.java
008 *
009 *  A string symbol table for ASCII strings, implemented using a 256-way trie.
010 *
011 *  % java TrieST < shellsST.txt
012 *  by 4
013 *  sea 6
014 *  sells 1
015 *  she 0
016 *  shells 3
017 *  shore 7
018 *  the 5
019 *
020 *************************************************************************/
021
022public class TrieST<V> {
023        private static final int R = 256;        // extended ASCII
024
025        private Node<V> root = new Node<>();
026
027        private static class Node<V> {
028                public Node() { }
029                public V val;
030                @SuppressWarnings("unchecked")
031                public final Node<V>[] next = new Node[R];
032        }
033
034        /* **************************************************
035         * Is the key in the symbol table?
036         ****************************************************/
037        public boolean contains(String key) {
038                return get(key) != null;
039        }
040
041        public V get(String key) {
042                Node<V> x = get(root, key, 0);
043                if (x == null) return null;
044                return x.val;
045        }
046
047        private Node<V> get(Node<V> x, String key, int d) {
048                if (x == null) return null;
049                if (d == key.length()) return x;
050                char c = key.charAt(d);
051                return get(x.next[c], key, d+1);
052        }
053
054        /* **************************************************
055         * Insert key-value pair into the symbol table.
056         ****************************************************/
057        public void put(String key, V val) {
058                root = put(root, key, val, 0);
059        }
060
061        private Node<V> put(Node<V> x, String key, V val, int d) {
062                if (x == null) x = new Node<>();
063                if (d == key.length()) {
064                        x.val = val;
065                        return x;
066                }
067                char c = key.charAt(d);
068                x.next[c] = put(x.next[c], key, val, d+1);
069                return x;
070        }
071
072        // find the key that is the longest prefix of s
073        public String longestPrefixOf(String query) {
074                int length = longestPrefixOf(root, query, 0, 0);
075                return query.substring(0, length);
076        }
077
078        // find the key in the subtrie rooted at x that is the longest
079        // prefix of the query string, starting at the dth character
080        private int longestPrefixOf(Node<V> x, String query, int d, int length) {
081                if (x == null) return length;
082                if (x.val != null) length = d;
083                if (d == query.length()) return length;
084                char c = query.charAt(d);
085                return longestPrefixOf(x.next[c], query, d+1, length);
086        }
087
088
089        public Iterable<String> keys() {
090                return keysWithPrefix("");
091        }
092
093        public Iterable<String> keysWithPrefix(String prefix) {
094                Queue<String> queue = new Queue<>();
095                Node<V> x = get(root, prefix, 0);
096                collect(x, prefix, queue);
097                return queue;
098        }
099
100        private void collect(Node<V> x, String key, Queue<String> queue) {
101                if (x == null) return;
102                if (x.val != null) queue.enqueue(key);
103                for (int c = 0; c < R; c++)
104                        collect(x.next[c], key + (char) c, queue);
105        }
106
107
108        public Iterable<String> keysThatMatch(String pat) {
109                Queue<String> q = new Queue<>();
110                collect(root, "", pat, q);
111                return q;
112        }
113
114        private void collect(Node<V> x, String prefix, String pat, Queue<String> q) {
115                if (x == null) return;
116                if (prefix.length() == pat.length() && x.val != null) q.enqueue(prefix);
117                if (prefix.length() == pat.length()) return;
118                char next = pat.charAt(prefix.length());
119                for (int c = 0; c < R; c++)
120                        if (next == '.' || next == c)
121                                collect(x.next[c], prefix + (char) c, pat, q);
122        }
123
124        public void delete(String key) {
125                root = delete(root, key, 0);
126        }
127
128        private Node<V> delete(Node<V> x, String key, int d) {
129                if (x == null) return null;
130                if (d == key.length()) x.val = null;
131                else {
132                        char c = key.charAt(d);
133                        x.next[c] = delete(x.next[c], key, d+1);
134                }
135                if (x.val != null) return x;
136                for (int c = 0; c < R; c++)
137                        if (x.next[c] != null)
138                                return x;
139                return null;
140        }
141
142
143        // test client
144        public static void main(String[] args) {
145                StdIn.fromFile("data/shellsST.txt");
146
147                // build symbol table from standard input
148                TrieST<Integer> st = new TrieST<>();
149                for (int i = 0; !StdIn.isEmpty(); i++) {
150                        String key = StdIn.readString();
151                        st.put(key, i);
152                }
153
154                // print results
155                for (String key : st.keys()) {
156                        StdOut.println(key + " " + st.get(key));
157                }
158        }
159}