001package algs9; // section 6.5
002import stdlib.*;
003/* ***********************************************************************
004 *  Compilation:  javac Simplex.java
005 *  Execution:    java Simplex
006 *
007 *  Given an M-by-N matrix A, an M-length vector b, and an
008 *  N-length vector c, solve the  LP { max cx : Ax <= b, x >= 0 }.
009 *  Assumes that b >= 0 so that x = 0 is a basic feasible solution.
010 *
011 *  Creates an (M+1)-by-(N+M+1) simplex tableaux with the
012 *  RHS in column M+N, the objective function in row M, and
013 *  slack variables in columns M through M+N-1.
014 *
015 *************************************************************************/
016
017public class Simplex {
018        private static final double EPSILON = 1.0E-10;
019        private final double[][] a;   // tableaux
020        private final int M;          // number of constraints
021        private final int N;          // number of original variables
022
023        private final int[] basis;    // basis[i] = basic variable corresponding to row i
024        // only needed to print out solution, not book
025
026        // sets up the simplex tableaux
027        public Simplex(double[][] A, double[] b, double[] c) {
028                M = b.length;
029                N = c.length;
030                a = new double[M+1][N+M+1];
031                for (int i = 0; i < M; i++)
032                        for (int j = 0; j < N; j++)
033                                a[i][j] = A[i][j];
034                for (int i = 0; i < M; i++) a[i][N+i] = 1.0;
035                for (int j = 0; j < N; j++) a[M][j]   = c[j];
036                for (int i = 0; i < M; i++) a[i][M+N] = b[i];
037
038                basis = new int[M];
039                for (int i = 0; i < M; i++) basis[i] = N + i;
040
041                solve();
042
043                // check optimality conditions
044                assert check(A, b, c);
045        }
046
047        // run simplex algorithm starting from initial BFS
048        private void solve() {
049                while (true) {
050
051                        // find entering column q
052                        int q = bland();
053                        if (q == -1) break;  // optimal
054
055                        // find leaving row p
056                        int p = minRatioRule(q);
057                        if (p == -1) throw new Error("Linear program is unbounded");
058
059                        // pivot
060                        pivot(p, q);
061
062                        // update basis
063                        basis[p] = q;
064                }
065        }
066
067        // lowest index of a non-basic column with a positive cost
068        private int bland() {
069                for (int j = 0; j < M + N; j++)
070                        if (a[M][j] > 0) return j;
071                return -1;  // optimal
072        }
073
074        // index of a non-basic column with most positive cost
075        private int dantzig() {
076                int q = 0;
077                for (int j = 1; j < M + N; j++)
078                        if (a[M][j] > a[M][q]) q = j;
079
080                if (a[M][q] <= 0) return -1;  // optimal
081                else return q;
082        }
083
084        // find row p using min ratio rule (-1 if no such row)
085        private int minRatioRule(int q) {
086                int p = -1;
087                for (int i = 0; i < M; i++) {
088                        if (a[i][q] <= 0) continue;
089                        else if (p == -1) p = i;
090                        else if ((a[i][M+N] / a[i][q]) < (a[p][M+N] / a[p][q])) p = i;
091                }
092                return p;
093        }
094
095        // pivot on entry (p, q) using Gauss-Jordan elimination
096        private void pivot(int p, int q) {
097
098                // everything but row p and column q
099                for (int i = 0; i <= M; i++)
100                        for (int j = 0; j <= M + N; j++)
101                                if (i != p && j != q) a[i][j] -= a[p][j] * a[i][q] / a[p][q];
102
103                // zero out column q
104                for (int i = 0; i <= M; i++)
105                        if (i != p) a[i][q] = 0.0;
106
107                // scale row p
108                for (int j = 0; j <= M + N; j++)
109                        if (j != q) a[p][j] /= a[p][q];
110                a[p][q] = 1.0;
111        }
112
113        // return optimal objective value
114        public double value() {
115                return -a[M][M+N];
116        }
117
118        // return primal solution vector
119        public double[] primal() {
120                double[] x = new double[N];
121                for (int i = 0; i < M; i++)
122                        if (basis[i] < N) x[basis[i]] = a[i][M+N];
123                return x;
124        }
125
126        // return dual solution vector
127        public double[] dual() {
128                double[] y = new double[M];
129                for (int i = 0; i < M; i++)
130                        y[i] = -a[M][N+i];
131                return y;
132        }
133
134
135        // is the solution primal feasible?
136        private boolean isPrimalFeasible(double[][] A, double[] b) {
137                double[] x = primal();
138
139                // check that x >= 0
140                for (int j = 0; j < x.length; j++) {
141                        if (x[j] < 0.0) {
142                                StdOut.println("x[" + j + "] = " + x[j] + " is negative");
143                                return false;
144                        }
145                }
146
147                // check that Ax <= b
148                for (int i = 0; i < M; i++) {
149                        double sum = 0.0;
150                        for (int j = 0; j < N; j++) {
151                                sum += A[i][j] * x[j];
152                        }
153                        if (sum > b[i] + EPSILON) {
154                                StdOut.println("not primal feasible");
155                                StdOut.println("b[" + i + "] = " + b[i] + ", sum = " + sum);
156                                return false;
157                        }
158                }
159                return true;
160        }
161
162        // is the solution dual feasible?
163        private boolean isDualFeasible(double[][] A, double[] c) {
164                double[] y = dual();
165
166                // check that y >= 0
167                for (int i = 0; i < y.length; i++) {
168                        if (y[i] < 0.0) {
169                                StdOut.println("y[" + i + "] = " + y[i] + " is negative");
170                                return false;
171                        }
172                }
173
174                // check that yA >= c
175                for (int j = 0; j < N; j++) {
176                        double sum = 0.0;
177                        for (int i = 0; i < M; i++) {
178                                sum += A[i][j] * y[i];
179                        }
180                        if (sum < c[j] - EPSILON) {
181                                StdOut.println("not dual feasible");
182                                StdOut.println("c[" + j + "] = " + c[j] + ", sum = " + sum);
183                                return false;
184                        }
185                }
186                return true;
187        }
188
189        // check that optimal value = cx = yb
190        private boolean isOptimal(double[] b, double[] c) {
191                double[] x = primal();
192                double[] y = dual();
193                double value = value();
194
195                // check that value = cx = yb
196                double value1 = 0.0;
197                for (int j = 0; j < x.length; j++)
198                        value1 += c[j] * x[j];
199                double value2 = 0.0;
200                for (int i = 0; i < y.length; i++)
201                        value2 += y[i] * b[i];
202                if (Math.abs(value - value1) > EPSILON || Math.abs(value - value2) > EPSILON) {
203                        StdOut.println("value = " + value + ", cx = " + value1 + ", yb = " + value2);
204                        return false;
205                }
206
207                return true;
208        }
209
210        private boolean check(double[][]A, double[] b, double[] c) {
211                return isPrimalFeasible(A, b) && isDualFeasible(A, c) && isOptimal(b, c);
212        }
213
214        // print tableaux
215        public void show() {
216                StdOut.println("M = " + M);
217                StdOut.println("N = " + N);
218                for (int i = 0; i <= M; i++) {
219                        for (int j = 0; j <= M + N; j++) {
220                                StdOut.format("%7.2f ", a[i][j]);
221                        }
222                        StdOut.println();
223                }
224                StdOut.println("value = " + value());
225                for (int i = 0; i < M; i++)
226                        if (basis[i] < N) StdOut.println("x_" + basis[i] + " = " + a[i][M+N]);
227                StdOut.println();
228        }
229
230
231        public static void test(double[][] A, double[] b, double[] c) {
232                Simplex lp = new Simplex(A, b, c);
233                StdOut.println("value = " + lp.value());
234                double[] x = lp.primal();
235                for (int i = 0; i < x.length; i++)
236                        StdOut.println("x[" + i + "] = " + x[i]);
237                double[] y = lp.dual();
238                for (int j = 0; j < y.length; j++)
239                        StdOut.println("y[" + j + "] = " + y[j]);
240        }
241
242        public static void test1() {
243                double[][] A = {
244                                { -1,  1,  0 },
245                                {  1,  4,  0 },
246                                {  2,  1,  0 },
247                                {  3, -4,  0 },
248                                {  0,  0,  1 },
249                };
250                double[] c = { 1, 1, 1 };
251                double[] b = { 5, 45, 27, 24, 4 };
252                test(A, b, c);
253        }
254
255
256        // x0 = 12, x1 = 28, opt = 800
257        public static void test2() {
258                double[] c = {  13.0,  23.0 };
259                double[] b = { 480.0, 160.0, 1190.0 };
260                double[][] A = {
261                                {  5.0, 15.0 },
262                                {  4.0,  4.0 },
263                                { 35.0, 20.0 },
264                };
265                test(A, b, c);
266        }
267
268        // unbounded
269        public static void test3() {
270                double[] c = { 2.0, 3.0, -1.0, -12.0 };
271                double[] b = {  3.0,   2.0 };
272                double[][] A = {
273                                { -2.0, -9.0,  1.0,  9.0 },
274                                {  1.0,  1.0, -1.0, -2.0 },
275                };
276                test(A, b, c);
277        }
278
279        // degenerate - cycles if you choose most positive objective function coefficient
280        public static void test4() {
281                double[] c = { 10.0, -57.0, -9.0, -24.0 };
282                double[] b = {  0.0,   0.0,  1.0 };
283                double[][] A = {
284                                { 0.5, -5.5, -2.5, 9.0 },
285                                { 0.5, -1.5, -0.5, 1.0 },
286                                { 1.0,  0.0,  0.0, 0.0 },
287                };
288                test(A, b, c);
289        }
290
291
292
293        // test client
294        public static void main(String[] args) {
295
296                try                           { test1();             }
297                catch (ArithmeticException e) { e.printStackTrace(); }
298                StdOut.println("--------------------------------");
299
300                try                           { test2();             }
301                catch (ArithmeticException e) { e.printStackTrace(); }
302                StdOut.println("--------------------------------");
303
304                try                           { test3();             }
305                catch (ArithmeticException e) { e.printStackTrace(); }
306                StdOut.println("--------------------------------");
307
308                try                           { test4();             }
309                catch (ArithmeticException e) { e.printStackTrace(); }
310                StdOut.println("--------------------------------");
311
312
313                int M = Integer.parseInt(args[0]);
314                int N = Integer.parseInt(args[1]);
315                double[] c = new double[N];
316                double[] b = new double[M];
317                double[][] A = new double[M][N];
318                for (int j = 0; j < N; j++)
319                        c[j] = StdRandom.uniform(1000);
320                for (int i = 0; i < M; i++)
321                        b[i] = StdRandom.uniform(1000);
322                for (int i = 0; i < M; i++)
323                        for (int j = 0; j < N; j++)
324                                A[i][j] = StdRandom.uniform(100);
325                Simplex lp = new Simplex(A, b, c);
326                StdOut.println(lp.value());
327        }
328
329}