001package algs9; // section 9.9
002import stdlib.*;
003/* ***********************************************************************
004 *  Compilation:  javac GaussianElimination.java
005 *  Execution:    java GaussianElimination
006 *
007 *  Gaussian elimination with partial pivoting.
008 *
009 *  % java GaussianElimination
010 *  -1.0
011 *  2.0
012 *  2.0
013 *
014 *************************************************************************/
015
016public class GaussianElimination {
017        private static final double EPSILON = 1e-10;
018
019        // Gaussian elimination with partial pivoting
020        public static double[] lsolve(double[][] A, double[] b) {
021                int N  = b.length;
022
023                for (int p = 0; p < N; p++) {
024
025                        // find pivot row and swap
026                        int max = p;
027                        for (int i = p + 1; i < N; i++) {
028                                if (Math.abs(A[i][p]) > Math.abs(A[max][p])) {
029                                        max = i;
030                                }
031                        }
032                        double[] temp = A[p]; A[p] = A[max]; A[max] = temp;
033                        double   t    = b[p]; b[p] = b[max]; b[max] = t;
034
035                        // singular or nearly singular
036                        if (Math.abs(A[p][p]) <= EPSILON) {
037                                throw new Error("Matrix is singular or nearly singular");
038                        }
039
040                        // pivot within A and b
041                        for (int i = p + 1; i < N; i++) {
042                                double alpha = A[i][p] / A[p][p];
043                                b[i] -= alpha * b[p];
044                                for (int j = p; j < N; j++) {
045                                        A[i][j] -= alpha * A[p][j];
046                                }
047                        }
048                }
049
050                // back substitution
051                double[] x = new double[N];
052                for (int i = N - 1; i >= 0; i--) {
053                        double sum = 0.0;
054                        for (int j = i + 1; j < N; j++) {
055                                sum += A[i][j] * x[j];
056                        }
057                        x[i] = (b[i] - sum) / A[i][i];
058                }
059                return x;
060        }
061
062
063        // sample client
064        public static void main(String[] args) {
065                int N = 3;
066                double[][] A = {
067                                { 0, 1,  1 },
068                                { 2, 4, -2 },
069                                { 0, 3, 15 }
070                };
071                double[] b = { 4, 2, 36 };
072                double[] x = lsolve(A, b);
073
074
075                // print results
076                for (int i = 0; i < N; i++) {
077                        StdOut.println(x[i]);
078                }
079
080        }
081
082}