001package algs35; 002import stdlib.*; 003/* *********************************************************************** 004 * Compilation: javac SparseMatrix.java 005 * Execution: java SparseMatrix 006 * 007 * A sparse, square matrix, implementing using two arrays of sparse 008 * vectors, one representation for the rows and one for the columns. 009 * 010 * For matrix-matrix product, we might also want to store the 011 * column representation. 012 * 013 *************************************************************************/ 014 015public class XSparseMatrix { 016 private final int N; // N-by-N matrix 017 private final SparseVector[] rows; // the rows, each row is a sparse vector 018 019 // initialize an N-by-N matrix of all 0s 020 public XSparseMatrix(int N) { 021 this.N = N; 022 rows = new SparseVector[N]; 023 for (int i = 0; i < N; i++) rows[i] = new SparseVector(N); 024 } 025 026 // put A[i][j] = value 027 public void put(int i, int j, double value) { 028 if (i < 0 || i >= N) throw new Error("Illegal index"); 029 if (j < 0 || j >= N) throw new Error("Illegal index"); 030 rows[i].put(j, value); 031 } 032 033 // return A[i][j] 034 public double get(int i, int j) { 035 if (i < 0 || i >= N) throw new Error("Illegal index"); 036 if (j < 0 || j >= N) throw new Error("Illegal index"); 037 return rows[i].get(j); 038 } 039 040 // return the number of nonzero entries (not the most efficient implementation) 041 public int nnz() { 042 int sum = 0; 043 for (int i = 0; i < N; i++) 044 sum += rows[i].nnz(); 045 return sum; 046 } 047 048 // return the matrix-vector product b = Ax 049 public SparseVector times(SparseVector x) { 050 if (N != x.size()) throw new Error("Dimensions disagree"); 051 SparseVector b = new SparseVector(N); 052 for (int i = 0; i < N; i++) 053 b.put(i, rows[i].dot(x)); 054 return b; 055 } 056 057 // return C = A + B 058 public XSparseMatrix plus(XSparseMatrix B) { 059 XSparseMatrix A = this; 060 if (A.N != B.N) throw new Error("Dimensions disagree"); 061 XSparseMatrix C = new XSparseMatrix(N); 062 for (int i = 0; i < N; i++) 063 C.rows[i] = A.rows[i].plus(B.rows[i]); 064 return C; 065 } 066 067 068 // return a string representation 069 public String toString() { 070 String s = "N = " + N + ", nonzeros = " + nnz() + "\n"; 071 for (int i = 0; i < N; i++) { 072 s += i + ": " + rows[i] + "\n"; 073 } 074 return s; 075 } 076 077 078 // test client 079 public static void main(String[] args) { 080 XSparseMatrix A = new XSparseMatrix(5); 081 SparseVector x = new SparseVector(5); 082 A.put(0, 0, 1.0); 083 A.put(1, 1, 1.0); 084 A.put(2, 2, 1.0); 085 A.put(3, 3, 1.0); 086 A.put(4, 4, 1.0); 087 A.put(2, 4, 0.3); 088 x.put(0, 0.75); 089 x.put(2, 0.11); 090 StdOut.println("x : " + x); 091 StdOut.println("A : " + A); 092 StdOut.println("Ax : " + A.times(x)); 093 StdOut.println("A + A : " + A.plus(A)); 094 } 095 096}