001package algs35;
002import stdlib.*;
003/* ***********************************************************************
004 *  Compilation:  javac SparseVector.java
005 *  Execution:    java SparseVector
006 *
007 *  A sparse vector, implementing using a symbol table.
008 *
009 *  [Not clear we need the instance variable N except for error checking.]
010 *
011 *************************************************************************/
012
013public class SparseVector {
014        private final int N;                   // length
015        private final ST<Integer, Double> st;  // the vector, represented by index-value pairs
016
017        // initialize the all 0s vector of length N
018        public SparseVector(int N) {
019                this.N  = N;
020                this.st = new ST<>();
021        }
022
023        // put st[i] = value
024        public void put(int i, double value) {
025                if (i < 0 || i >= N) throw new Error("Illegal index");
026                if (value == 0.0) st.delete(i);
027                else              st.put(i, value);
028        }
029
030        // return st[i]
031        public double get(int i) {
032                if (i < 0 || i >= N) throw new Error("Illegal index");
033                if (st.contains(i)) return st.get(i);
034                else                return 0.0;
035        }
036
037        // return the number of nonzero entries
038        public int nnz() {
039                return st.size();
040        }
041
042        // return the size of the vector
043        public int size() {
044                return N;
045        }
046
047        // return the dot product of this vector with that vector
048        public double dot(SparseVector that) {
049                if (this.N != that.N) throw new Error("Vector lengths disagree");
050                double sum = 0.0;
051
052                // iterate over the vector with the fewest nonzeros
053                if (this.st.size() <= that.st.size()) {
054                        for (int i : this.st.keys())
055                                if (that.st.contains(i)) sum += this.get(i) * that.get(i);
056                }
057                else  {
058                        for (int i : that.st.keys())
059                                if (this.st.contains(i)) sum += this.get(i) * that.get(i);
060                }
061                return sum;
062        }
063
064
065        // return the dot product of this vector and that array
066        public double dot(double[] that) {
067                double sum = 0.0;
068                for (int i : st.keys())
069                        sum += that[i] * this.get(i);
070                return sum;
071        }
072
073
074        // return the 2-norm
075        public double norm() {
076                SparseVector a = this;
077                return Math.sqrt(a.dot(a));
078        }
079
080        // return alpha * this
081        public SparseVector scale(double alpha) {
082                SparseVector c = new SparseVector(N);
083                for (int i : this.st.keys()) c.put(i, alpha * this.get(i));
084                return c;
085        }
086
087        // return this + that
088        public SparseVector plus(SparseVector that) {
089                if (this.N != that.N) throw new Error("Vector lengths disagree");
090                SparseVector c = new SparseVector(N);
091                for (int i : this.st.keys()) c.put(i, this.get(i));                // c = this
092                for (int i : that.st.keys()) c.put(i, that.get(i) + c.get(i));     // c = c + that
093                return c;
094        }
095
096        // return a string representation
097        public String toString() {
098                String s = "";
099                for (int i : st.keys()) {
100                        s += "(" + i + ", " + st.get(i) + ") ";
101                }
102                return s;
103        }
104
105
106        // test client
107        public static void main(String[] args) {
108                SparseVector a = new SparseVector(10);
109                SparseVector b = new SparseVector(10);
110                a.put(3, 0.50);
111                a.put(9, 0.75);
112                a.put(6, 0.11);
113                a.put(6, 0.00);
114                b.put(3, 0.60);
115                b.put(4, 0.90);
116                StdOut.println("a = " + a);
117                StdOut.println("b = " + b);
118                StdOut.println("a dot b = " + a.dot(b));
119                StdOut.println("a + b   = " + a.plus(b));
120        }
121
122}