001package algs9; // section 9.9
002import stdlib.*;
003/* ***********************************************************************
004 *  Compilation:  javac GaussJordanElimination.java
005 *  Execution:    java GaussJordanElimination N
006 *
007 *  Finds a solutions to Ax = b using Gauss-Jordan elimination with partial
008 *  pivoting. If no solution exists, find a solution to yA = 0, yb != 0,
009 *  which serves as a certificate of infeasibility.
010 *
011 *  % java GaussJordanElimination
012 *  -1.000000
013 *  2.000000
014 *  2.000000
015 *
016 *  3.000000
017 *  -1.000000
018 *  -2.000000
019 *
020 *  System is infeasible
021 *
022 *  -6.250000
023 *  -4.500000
024 *  0.000000
025 *  0.000000
026 *  1.000000
027 *
028 *  System is infeasible
029 *
030 *  -1.375000
031 *  1.625000
032 *  0.000000
033 *
034 *
035 *************************************************************************/
036
037public class XGaussJordanElimination {
038        private static final double EPSILON = 1e-8;
039
040        private final int N;      // N-by-N system
041        private final double[][] a;     // N-by-N+1 augmented matrix
042
043        // Gauss-Jordan elimination with partial pivoting
044        public XGaussJordanElimination(double[][] A, double[] b) {
045                N = b.length;
046
047                // build augmented matrix
048                a = new double[N][N+N+1];
049                for (int i = 0; i < N; i++)
050                        for (int j = 0; j < N; j++)
051                                a[i][j] = A[i][j];
052
053                // only need if you want to find certificate of infeasibility (or compute inverse)
054                for (int i = 0; i < N; i++)
055                        a[i][N+i] = 1.0;
056
057                for (int i = 0; i < N; i++) a[i][N+N] = b[i];
058
059                solve();
060
061                assert check(A, b);
062        }
063
064        private void solve() {
065
066                // Gauss-Jordan elimination
067                for (int p = 0; p < N; p++) {
068                        // show();
069
070                        // find pivot row using partial pivoting
071                        int max = p;
072                        for (int i = p+1; i < N; i++) {
073                                if (Math.abs(a[i][p]) > Math.abs(a[max][p])) {
074                                        max = i;
075                                }
076                        }
077
078                        // exchange row p with row max
079                        swap(p, max);
080
081                        // singular or nearly singular
082                        if (Math.abs(a[p][p]) <= EPSILON) {
083                                continue;
084                                // throw new Error("Matrix is singular or nearly singular");
085                        }
086
087                        // pivot
088                        pivot(p, p);
089                }
090                // show();
091        }
092
093        // swap row1 and row2
094        private void swap(int row1, int row2) {
095                double[] temp = a[row1];
096                a[row1] = a[row2];
097                a[row2] = temp;
098        }
099
100
101        // pivot on entry (p, q) using Gauss-Jordan elimination
102        private void pivot(int p, int q) {
103
104                // everything but row p and column q
105                for (int i = 0; i < N; i++) {
106                        double alpha = a[i][q] / a[p][q];
107                        for (int j = 0; j <= N+N; j++) {
108                                if (i != p && j != q) a[i][j] -= alpha * a[p][j];
109                        }
110                }
111
112                // zero out column q
113                for (int i = 0; i < N; i++)
114                        if (i != p) a[i][q] = 0.0;
115
116                // scale row p (ok to go from q+1 to N, but do this for consistency with simplex pivot)
117                for (int j = 0; j <= N+N; j++)
118                        if (j != q) a[p][j] /= a[p][q];
119                a[p][q] = 1.0;
120        }
121
122        // extract solution to Ax = b
123        public double[] primal() {
124                double[] x = new double[N];
125                for (int i = 0; i < N; i++) {
126                        if (Math.abs(a[i][i]) > EPSILON)
127                                x[i] = a[i][N+N] / a[i][i];
128                        else if (Math.abs(a[i][N+N]) > EPSILON)
129                                return null;
130                }
131                return x;
132        }
133
134        // extract solution to yA = 0, yb != 0
135        public double[] dual() {
136                double[] y = new double[N];
137                for (int i = 0; i < N; i++) {
138                        if ((Math.abs(a[i][i]) <= EPSILON) && (Math.abs(a[i][N+N]) > EPSILON)) {
139                                for (int j = 0; j < N; j++)
140                                        y[j] = a[i][N+j];
141                                return y;
142                        }
143                }
144                return null;
145        }
146
147        // does the system have a solution?
148        public boolean isFeasible() {
149                return primal() != null;
150        }
151
152        // print the tableaux
153        private void show() {
154                for (int i = 0; i < N; i++) {
155                        for (int j = 0; j < N; j++) {
156                                StdOut.format("%8.3f ", a[i][j]);
157                        }
158                        StdOut.format("| ");
159                        for (int j = N; j < N+N; j++) {
160                                StdOut.format("%8.3f ", a[i][j]);
161                        }
162                        StdOut.format("| %8.3f\n", a[i][N+N]);
163                }
164                StdOut.println();
165        }
166
167
168        // check that Ax = b or yA = 0, yb != 0
169        private boolean check(double[][] A, double[] b) {
170
171                // check that Ax = b
172                if (isFeasible()) {
173                        double[] x = primal();
174                        for (int i = 0; i < N; i++) {
175                                double sum = 0.0;
176                                for (int j = 0; j < N; j++) {
177                                        sum += A[i][j] * x[j];
178                                }
179                                if (Math.abs(sum - b[i]) > EPSILON) {
180                                        StdOut.println("not feasible");
181                                        StdOut.format("b[%d] = %8.3f, sum = %8.3f\n", i, b[i], sum);
182                                        return false;
183                                }
184                        }
185                        return true;
186                }
187
188                // or that yA = 0, yb != 0
189                else {
190                        double[] y = dual();
191                        for (int j = 0; j < N; j++) {
192                                double sum = 0.0;
193                                for (int i = 0; i < N; i++) {
194                                        sum += A[i][j] * y[i];
195                                }
196                                if (Math.abs(sum) > EPSILON) {
197                                        StdOut.println("invalid certificate of infeasibility");
198                                        StdOut.format("sum = %8.3f\n", sum);
199                                        return false;
200                                }
201                        }
202                        double sum = 0.0;
203                        for (int i = 0; i < N; i++) {
204                                sum += y[i] * b[i];
205                        }
206                        if (Math.abs(sum) < EPSILON) {
207                                StdOut.println("invalid certificate of infeasibility");
208                                StdOut.format("yb  = %8.3f\n", sum);
209                                return false;
210                        }
211                        return true;
212                }
213        }
214
215
216        public static void test(double[][] A, double[] b) {
217                XGaussJordanElimination gaussian = new XGaussJordanElimination(A, b);
218                if (gaussian.isFeasible()) {
219                        StdOut.println("Solution to Ax = b");
220                        double[] x = gaussian.primal();
221                        for (double element : x) {
222                                StdOut.format("%10.6f\n", element);
223                        }
224                }
225                else {
226                        StdOut.println("Certificate of infeasibility");
227                        double[] y = gaussian.dual();
228                        for (double element : y) {
229                                StdOut.format("%10.6f\n", element);
230                        }
231                }
232                StdOut.println();
233        }
234
235
236        // 3-by-3 nonsingular system
237        public static void test1() {
238                double[][] A = {
239                                { 0, 1,  1 },
240                                { 2, 4, -2 },
241                                { 0, 3, 15 }
242                };
243                double[] b = { 4, 2, 36 };
244                test(A, b);
245        }
246
247        // 3-by-3 nonsingular system
248        public static void test2() {
249                double[][] A = {
250                                {  1, -3,   1 },
251                                {  2, -8,   8 },
252                                { -6,  3, -15 }
253                };
254                double[] b = { 4, -2, 9 };
255                test(A, b);
256        }
257
258        // 5-by-5 singular: no solutions
259        // y = [ -1, 0, 1, 1, 0 ]
260        public static void test3() {
261                double[][] A = {
262                                {  2, -3, -1,  2,  3 },
263                                {  4, -4, -1,  4, 11 },
264                                {  2, -5, -2,  2, -1 },
265                                {  0,  2,  1,  0,  4 },
266                                { -4,  6,  0,  0,  7 },
267                };
268                double[] b = { 4, 4, 9, -6, 5 };
269                test(A, b);
270        }
271
272        // 5-by-5 singluar: infinitely many solutions
273        public static void test4() {
274                double[][] A = {
275                                {  2, -3, -1,  2,  3 },
276                                {  4, -4, -1,  4, 11 },
277                                {  2, -5, -2,  2, -1 },
278                                {  0,  2,  1,  0,  4 },
279                                { -4,  6,  0,  0,  7 },
280                };
281                double[] b = { 4, 4, 9, -5, 5 };
282                test(A, b);
283        }
284
285        // 3-by-3 singular: no solutions
286        // y = [ 1, 0, 1/3 ]
287        public static void test5() {
288                double[][] A = {
289                                {  2, -1,  1 },
290                                {  3,  2, -4 },
291                                { -6,  3, -3 },
292                };
293                double[] b = { 1, 4, 2 };
294                test(A, b);
295        }
296
297        // 3-by-3 singular: infinitely many solutions
298        public static void test6() {
299                double[][] A = {
300                                {  1, -1,  2 },
301                                {  4,  4, -2 },
302                                { -2,  2, -4 },
303                };
304                double[] b = { -3, 1, 6 };
305                test(A, b);
306        }
307
308        // sample client
309        public static void main(String[] args) {
310
311                try                 { test1();             }
312                catch (Exception e) { e.printStackTrace(); }
313                StdOut.println("--------------------------------");
314
315                try                 { test2();             }
316                catch (Exception e) { e.printStackTrace(); }
317                StdOut.println("--------------------------------");
318
319                try                 { test3();             }
320                catch (Exception e) { e.printStackTrace(); }
321                StdOut.println("--------------------------------");
322
323                try                 { test4();             }
324                catch (Exception e) { e.printStackTrace(); }
325                StdOut.println("--------------------------------");
326
327                try                 { test5();             }
328                catch (Exception e) { e.printStackTrace(); }
329                StdOut.println("--------------------------------");
330
331                try                 { test6();             }
332                catch (Exception e) { e.printStackTrace(); }
333                StdOut.println("--------------------------------");
334
335                // N-by-N random system (likely full rank)
336                int N = Integer.parseInt(args[0]);
337                double[][] A = new double[N][N];
338                for (int i = 0; i < N; i++)
339                        for (int j = 0; j < N; j++)
340                                A[i][j] = StdRandom.uniform(1000);
341                double[] b = new double[N];
342                for (int i = 0; i < N; i++)
343                        b[i] = StdRandom.uniform(1000);
344                test(A, b);
345
346                StdOut.println("--------------------------------");
347
348                // N-by-N random system (likely infeasible)
349                A = new double[N][N];
350                for (int i = 0; i < N-1; i++)
351                        for (int j = 0; j < N; j++)
352                                A[i][j] = StdRandom.uniform(1000);
353                for (int i = 0; i < N-1; i++) {
354                        double alpha = StdRandom.uniform(11) - 5.0;
355                        for (int j = 0; j < N; j++) {
356                                A[N-1][j] += alpha * A[i][j];
357                        }
358                }
359                b = new double[N];
360                for (int i = 0; i < N; i++)
361                        b[i] = StdRandom.uniform(1000);
362                test(A, b);
363
364
365        }
366
367}