001package algs35;
002import stdlib.*;
003/* ***********************************************************************
004 *  Compilation:  javac SparseMatrix.java
005 *  Execution:    java SparseMatrix
006 *
007 *  A sparse, square matrix, implementing using two arrays of sparse
008 *  vectors, one representation for the rows and one for the columns.
009 *
010 *  For matrix-matrix product, we might also want to store the
011 *  column representation.
012 *
013 *************************************************************************/
014
015public class XSparseMatrix {
016        private final int N;                 // N-by-N matrix
017        private final SparseVector[] rows;   // the rows, each row is a sparse vector
018
019        // initialize an N-by-N matrix of all 0s
020        public XSparseMatrix(int N) {
021                this.N  = N;
022                rows = new SparseVector[N];
023                for (int i = 0; i < N; i++) rows[i] = new SparseVector(N);
024        }
025
026        // put A[i][j] = value
027        public void put(int i, int j, double value) {
028                if (i < 0 || i >= N) throw new Error("Illegal index");
029                if (j < 0 || j >= N) throw new Error("Illegal index");
030                rows[i].put(j, value);
031        }
032
033        // return A[i][j]
034        public double get(int i, int j) {
035                if (i < 0 || i >= N) throw new Error("Illegal index");
036                if (j < 0 || j >= N) throw new Error("Illegal index");
037                return rows[i].get(j);
038        }
039
040        // return the number of nonzero entries (not the most efficient implementation)
041        public int nnz() {
042                int sum = 0;
043                for (int i = 0; i < N; i++)
044                        sum += rows[i].nnz();
045                return sum;
046        }
047
048        // return the matrix-vector product b = Ax
049        public SparseVector times(SparseVector x) {
050                if (N != x.size()) throw new Error("Dimensions disagree");
051                SparseVector b = new SparseVector(N);
052                for (int i = 0; i < N; i++)
053                        b.put(i, rows[i].dot(x));
054                return b;
055        }
056
057        // return C = A + B
058        public XSparseMatrix plus(XSparseMatrix B) {
059                XSparseMatrix A = this;
060                if (A.N != B.N) throw new Error("Dimensions disagree");
061                XSparseMatrix C = new XSparseMatrix(N);
062                for (int i = 0; i < N; i++)
063                        C.rows[i] = A.rows[i].plus(B.rows[i]);
064                return C;
065        }
066
067
068        // return a string representation
069        public String toString() {
070                String s = "N = " + N + ", nonzeros = " + nnz() + "\n";
071                for (int i = 0; i < N; i++) {
072                        s += i + ": " + rows[i] + "\n";
073                }
074                return s;
075        }
076
077
078        // test client
079        public static void main(String[] args) {
080                XSparseMatrix A = new XSparseMatrix(5);
081                SparseVector x = new SparseVector(5);
082                A.put(0, 0, 1.0);
083                A.put(1, 1, 1.0);
084                A.put(2, 2, 1.0);
085                A.put(3, 3, 1.0);
086                A.put(4, 4, 1.0);
087                A.put(2, 4, 0.3);
088                x.put(0, 0.75);
089                x.put(2, 0.11);
090                StdOut.println("x     : " + x);
091                StdOut.println("A     : " + A);
092                StdOut.println("Ax    : " + A.times(x));
093                StdOut.println("A + A : " + A.plus(A));
094        }
095
096}