001 
002 
003 
004 
005 
006 
007 
008 
009 
010 
011 
012 
013 
014 
015 
016 
017 
018 
019 
020 
021 
022 
023 
024 
025 
026 
027 
028 
029 
030 
031 
032 
033 
034 
035 
036 
037 
038 
039 
040 
041 
042 
043 
044 
045 
046 
047 
048 
049 
050 
051 
052 
053 
054 
055 
056 
057 
058 
059 
060 
061 
062 
063 
064 
065 
066 
067 
068 
069 
070 
071 
072 
073 
074 
075 
076 
077 
078 
079 
080 
081 
082 
083 
084 
085 
086 
087 
088 
089 
090 
091 
092 
093 
094 
095 
096 
097 
098 
099 
100 
101 
102 
103 
104 
105 
106 
107 
108 
109 
110 
111 
112 
113 
114 
115 
116 
117 
118 
119 
120 
121 
122 
 | 
package algs35;
import stdlib.*;
/* ***********************************************************************
 *  Compilation:  javac SparseVector.java
 *  Execution:    java SparseVector
 *
 *  A sparse vector, implementing using a symbol table.
 *
 *  [Not clear we need the instance variable N except for error checking.]
 *
 *************************************************************************/
public class SparseVector {
  private final int N;                   // length
  private final ST<Integer, Double> st;  // the vector, represented by index-value pairs
  // initialize the all 0s vector of length N
  public SparseVector(int N) {
    this.N  = N;
    this.st = new ST<>();
  }
  // put st[i] = value
  public void put(int i, double value) {
    if (i < 0 || i >= N) throw new Error("Illegal index");
    if (value == 0.0) st.delete(i);
    else              st.put(i, value);
  }
  // return st[i]
  public double get(int i) {
    if (i < 0 || i >= N) throw new Error("Illegal index");
    if (st.contains(i)) return st.get(i);
    else                return 0.0;
  }
  // return the number of nonzero entries
  public int nnz() {
    return st.size();
  }
  // return the size of the vector
  public int size() {
    return N;
  }
  // return the dot product of this vector with that vector
  public double dot(SparseVector that) {
    if (this.N != that.N) throw new Error("Vector lengths disagree");
    double sum = 0.0;
    // iterate over the vector with the fewest nonzeros
    if (this.st.size() <= that.st.size()) {
      for (int i : this.st.keys())
        if (that.st.contains(i)) sum += this.get(i) * that.get(i);
    }
    else  {
      for (int i : that.st.keys())
        if (this.st.contains(i)) sum += this.get(i) * that.get(i);
    }
    return sum;
  }
  // return the dot product of this vector and that array
  public double dot(double[] that) {
    double sum = 0.0;
    for (int i : st.keys())
      sum += that[i] * this.get(i);
    return sum;
  }
  // return the 2-norm
  public double norm() {
    SparseVector a = this;
    return Math.sqrt(a.dot(a));
  }
  // return alpha * this
  public SparseVector scale(double alpha) {
    SparseVector c = new SparseVector(N);
    for (int i : this.st.keys()) c.put(i, alpha * this.get(i));
    return c;
  }
  // return this + that
  public SparseVector plus(SparseVector that) {
    if (this.N != that.N) throw new Error("Vector lengths disagree");
    SparseVector c = new SparseVector(N);
    for (int i : this.st.keys()) c.put(i, this.get(i));                // c = this
    for (int i : that.st.keys()) c.put(i, that.get(i) + c.get(i));     // c = c + that
    return c;
  }
  // return a string representation
  public String toString() {
    String s = "";
    for (int i : st.keys()) {
      s += "(" + i + ", " + st.get(i) + ") ";
    }
    return s;
  }
  // test client
  public static void main(String[] args) {
    SparseVector a = new SparseVector(10);
    SparseVector b = new SparseVector(10);
    a.put(3, 0.50);
    a.put(9, 0.75);
    a.put(6, 0.11);
    a.put(6, 0.00);
    b.put(3, 0.60);
    b.put(4, 0.90);
    StdOut.println("a = " + a);
    StdOut.println("b = " + b);
    StdOut.println("a dot b = " + a.dot(b));
    StdOut.println("a + b   = " + a.plus(b));
  }
}
 |