001package algs35; 002import stdlib.*; 003/* *********************************************************************** 004 * Compilation: javac SparseVector.java 005 * Execution: java SparseVector 006 * 007 * A sparse vector, implementing using a symbol table. 008 * 009 * [Not clear we need the instance variable N except for error checking.] 010 * 011 *************************************************************************/ 012 013public class SparseVector { 014 private final int N; // length 015 private final ST<Integer, Double> st; // the vector, represented by index-value pairs 016 017 // initialize the all 0s vector of length N 018 public SparseVector(int N) { 019 this.N = N; 020 this.st = new ST<>(); 021 } 022 023 // put st[i] = value 024 public void put(int i, double value) { 025 if (i < 0 || i >= N) throw new Error("Illegal index"); 026 if (value == 0.0) st.delete(i); 027 else st.put(i, value); 028 } 029 030 // return st[i] 031 public double get(int i) { 032 if (i < 0 || i >= N) throw new Error("Illegal index"); 033 if (st.contains(i)) return st.get(i); 034 else return 0.0; 035 } 036 037 // return the number of nonzero entries 038 public int nnz() { 039 return st.size(); 040 } 041 042 // return the size of the vector 043 public int size() { 044 return N; 045 } 046 047 // return the dot product of this vector with that vector 048 public double dot(SparseVector that) { 049 if (this.N != that.N) throw new Error("Vector lengths disagree"); 050 double sum = 0.0; 051 052 // iterate over the vector with the fewest nonzeros 053 if (this.st.size() <= that.st.size()) { 054 for (int i : this.st.keys()) 055 if (that.st.contains(i)) sum += this.get(i) * that.get(i); 056 } 057 else { 058 for (int i : that.st.keys()) 059 if (this.st.contains(i)) sum += this.get(i) * that.get(i); 060 } 061 return sum; 062 } 063 064 065 // return the dot product of this vector and that array 066 public double dot(double[] that) { 067 double sum = 0.0; 068 for (int i : st.keys()) 069 sum += that[i] * this.get(i); 070 return sum; 071 } 072 073 074 // return the 2-norm 075 public double norm() { 076 SparseVector a = this; 077 return Math.sqrt(a.dot(a)); 078 } 079 080 // return alpha * this 081 public SparseVector scale(double alpha) { 082 SparseVector c = new SparseVector(N); 083 for (int i : this.st.keys()) c.put(i, alpha * this.get(i)); 084 return c; 085 } 086 087 // return this + that 088 public SparseVector plus(SparseVector that) { 089 if (this.N != that.N) throw new Error("Vector lengths disagree"); 090 SparseVector c = new SparseVector(N); 091 for (int i : this.st.keys()) c.put(i, this.get(i)); // c = this 092 for (int i : that.st.keys()) c.put(i, that.get(i) + c.get(i)); // c = c + that 093 return c; 094 } 095 096 // return a string representation 097 public String toString() { 098 String s = ""; 099 for (int i : st.keys()) { 100 s += "(" + i + ", " + st.get(i) + ") "; 101 } 102 return s; 103 } 104 105 106 // test client 107 public static void main(String[] args) { 108 SparseVector a = new SparseVector(10); 109 SparseVector b = new SparseVector(10); 110 a.put(3, 0.50); 111 a.put(9, 0.75); 112 a.put(6, 0.11); 113 a.put(6, 0.00); 114 b.put(3, 0.60); 115 b.put(4, 0.90); 116 StdOut.println("a = " + a); 117 StdOut.println("b = " + b); 118 StdOut.println("a dot b = " + a.dot(b)); 119 StdOut.println("a + b = " + a.plus(b)); 120 } 121 122}