001package algs91; // section 6.5 002import stdlib.*; 003/* *********************************************************************** 004 * Compilation: javac Simplex.java 005 * Execution: java Simplex 006 * 007 * Given an M-by-N matrix A, an M-length vector b, and an 008 * N-length vector c, solve the LP { max cx : Ax <= b, x >= 0 }. 009 * Assumes that b >= 0 so that x = 0 is a basic feasible solution. 010 * 011 * Creates an (M+1)-by-(N+M+1) simplex tableaux with the 012 * RHS in column M+N, the objective function in row M, and 013 * slack variables in columns M through M+N-1. 014 * 015 *************************************************************************/ 016 017public class Simplex { 018 private static final double EPSILON = 1.0E-10; 019 private final double[][] a; // tableaux 020 private final int M; // number of constraints 021 private final int N; // number of original variables 022 023 private final int[] basis; // basis[i] = basic variable corresponding to row i 024 // only needed to print out solution, not book 025 026 // sets up the simplex tableaux 027 public Simplex(double[][] A, double[] b, double[] c) { 028 M = b.length; 029 N = c.length; 030 a = new double[M+1][N+M+1]; 031 for (int i = 0; i < M; i++) 032 for (int j = 0; j < N; j++) 033 a[i][j] = A[i][j]; 034 for (int i = 0; i < M; i++) a[i][N+i] = 1.0; 035 for (int j = 0; j < N; j++) a[M][j] = c[j]; 036 for (int i = 0; i < M; i++) a[i][M+N] = b[i]; 037 038 basis = new int[M]; 039 for (int i = 0; i < M; i++) basis[i] = N + i; 040 041 solve(); 042 043 // check optimality conditions 044 assert check(A, b, c); 045 } 046 047 // run simplex algorithm starting from initial BFS 048 private void solve() { 049 while (true) { 050 051 // find entering column q 052 int q = bland(); 053 if (q == -1) break; // optimal 054 055 // find leaving row p 056 int p = minRatioRule(q); 057 if (p == -1) throw new Error("Linear program is unbounded"); 058 059 // pivot 060 pivot(p, q); 061 062 // update basis 063 basis[p] = q; 064 } 065 } 066 067 // lowest index of a non-basic column with a positive cost 068 private int bland() { 069 for (int j = 0; j < M + N; j++) 070 if (a[M][j] > 0) return j; 071 return -1; // optimal 072 } 073 074 // index of a non-basic column with most positive cost 075 private int dantzig() { 076 int q = 0; 077 for (int j = 1; j < M + N; j++) 078 if (a[M][j] > a[M][q]) q = j; 079 080 if (a[M][q] <= 0) return -1; // optimal 081 else return q; 082 } 083 084 // find row p using min ratio rule (-1 if no such row) 085 private int minRatioRule(int q) { 086 int p = -1; 087 for (int i = 0; i < M; i++) { 088 if (a[i][q] <= 0) continue; 089 else if (p == -1) p = i; 090 else if ((a[i][M+N] / a[i][q]) < (a[p][M+N] / a[p][q])) p = i; 091 } 092 return p; 093 } 094 095 // pivot on entry (p, q) using Gauss-Jordan elimination 096 private void pivot(int p, int q) { 097 098 // everything but row p and column q 099 for (int i = 0; i <= M; i++) 100 for (int j = 0; j <= M + N; j++) 101 if (i != p && j != q) a[i][j] -= a[p][j] * a[i][q] / a[p][q]; 102 103 // zero out column q 104 for (int i = 0; i <= M; i++) 105 if (i != p) a[i][q] = 0.0; 106 107 // scale row p 108 for (int j = 0; j <= M + N; j++) 109 if (j != q) a[p][j] /= a[p][q]; 110 a[p][q] = 1.0; 111 } 112 113 // return optimal objective value 114 public double value() { 115 return -a[M][M+N]; 116 } 117 118 // return primal solution vector 119 public double[] primal() { 120 double[] x = new double[N]; 121 for (int i = 0; i < M; i++) 122 if (basis[i] < N) x[basis[i]] = a[i][M+N]; 123 return x; 124 } 125 126 // return dual solution vector 127 public double[] dual() { 128 double[] y = new double[M]; 129 for (int i = 0; i < M; i++) 130 y[i] = -a[M][N+i]; 131 return y; 132 } 133 134 135 // is the solution primal feasible? 136 private boolean isPrimalFeasible(double[][] A, double[] b) { 137 double[] x = primal(); 138 139 // check that x >= 0 140 for (int j = 0; j < x.length; j++) { 141 if (x[j] < 0.0) { 142 StdOut.println("x[" + j + "] = " + x[j] + " is negative"); 143 return false; 144 } 145 } 146 147 // check that Ax <= b 148 for (int i = 0; i < M; i++) { 149 double sum = 0.0; 150 for (int j = 0; j < N; j++) { 151 sum += A[i][j] * x[j]; 152 } 153 if (sum > b[i] + EPSILON) { 154 StdOut.println("not primal feasible"); 155 StdOut.println("b[" + i + "] = " + b[i] + ", sum = " + sum); 156 return false; 157 } 158 } 159 return true; 160 } 161 162 // is the solution dual feasible? 163 private boolean isDualFeasible(double[][] A, double[] c) { 164 double[] y = dual(); 165 166 // check that y >= 0 167 for (int i = 0; i < y.length; i++) { 168 if (y[i] < 0.0) { 169 StdOut.println("y[" + i + "] = " + y[i] + " is negative"); 170 return false; 171 } 172 } 173 174 // check that yA >= c 175 for (int j = 0; j < N; j++) { 176 double sum = 0.0; 177 for (int i = 0; i < M; i++) { 178 sum += A[i][j] * y[i]; 179 } 180 if (sum < c[j] - EPSILON) { 181 StdOut.println("not dual feasible"); 182 StdOut.println("c[" + j + "] = " + c[j] + ", sum = " + sum); 183 return false; 184 } 185 } 186 return true; 187 } 188 189 // check that optimal value = cx = yb 190 private boolean isOptimal(double[] b, double[] c) { 191 double[] x = primal(); 192 double[] y = dual(); 193 double value = value(); 194 195 // check that value = cx = yb 196 double value1 = 0.0; 197 for (int j = 0; j < x.length; j++) 198 value1 += c[j] * x[j]; 199 double value2 = 0.0; 200 for (int i = 0; i < y.length; i++) 201 value2 += y[i] * b[i]; 202 if (Math.abs(value - value1) > EPSILON || Math.abs(value - value2) > EPSILON) { 203 StdOut.println("value = " + value + ", cx = " + value1 + ", yb = " + value2); 204 return false; 205 } 206 207 return true; 208 } 209 210 private boolean check(double[][]A, double[] b, double[] c) { 211 return isPrimalFeasible(A, b) && isDualFeasible(A, c) && isOptimal(b, c); 212 } 213 214 // print tableaux 215 public void show() { 216 StdOut.println("M = " + M); 217 StdOut.println("N = " + N); 218 for (int i = 0; i <= M; i++) { 219 for (int j = 0; j <= M + N; j++) { 220 StdOut.format("%7.2f ", a[i][j]); 221 } 222 StdOut.println(); 223 } 224 StdOut.println("value = " + value()); 225 for (int i = 0; i < M; i++) 226 if (basis[i] < N) StdOut.println("x_" + basis[i] + " = " + a[i][M+N]); 227 StdOut.println(); 228 } 229 230 231 public static void test(double[][] A, double[] b, double[] c) { 232 Simplex lp = new Simplex(A, b, c); 233 StdOut.println("value = " + lp.value()); 234 double[] x = lp.primal(); 235 for (int i = 0; i < x.length; i++) 236 StdOut.println("x[" + i + "] = " + x[i]); 237 double[] y = lp.dual(); 238 for (int j = 0; j < y.length; j++) 239 StdOut.println("y[" + j + "] = " + y[j]); 240 } 241 242 public static void test1() { 243 double[][] A = { 244 { -1, 1, 0 }, 245 { 1, 4, 0 }, 246 { 2, 1, 0 }, 247 { 3, -4, 0 }, 248 { 0, 0, 1 }, 249 }; 250 double[] c = { 1, 1, 1 }; 251 double[] b = { 5, 45, 27, 24, 4 }; 252 test(A, b, c); 253 } 254 255 256 // x0 = 12, x1 = 28, opt = 800 257 public static void test2() { 258 double[] c = { 13.0, 23.0 }; 259 double[] b = { 480.0, 160.0, 1190.0 }; 260 double[][] A = { 261 { 5.0, 15.0 }, 262 { 4.0, 4.0 }, 263 { 35.0, 20.0 }, 264 }; 265 test(A, b, c); 266 } 267 268 // unbounded 269 public static void test3() { 270 double[] c = { 2.0, 3.0, -1.0, -12.0 }; 271 double[] b = { 3.0, 2.0 }; 272 double[][] A = { 273 { -2.0, -9.0, 1.0, 9.0 }, 274 { 1.0, 1.0, -1.0, -2.0 }, 275 }; 276 test(A, b, c); 277 } 278 279 // degenerate - cycles if you choose most positive objective function coefficient 280 public static void test4() { 281 double[] c = { 10.0, -57.0, -9.0, -24.0 }; 282 double[] b = { 0.0, 0.0, 1.0 }; 283 double[][] A = { 284 { 0.5, -5.5, -2.5, 9.0 }, 285 { 0.5, -1.5, -0.5, 1.0 }, 286 { 1.0, 0.0, 0.0, 0.0 }, 287 }; 288 test(A, b, c); 289 } 290 291 292 293 // test client 294 public static void main(String[] args) { 295 296 try { test1(); } 297 catch (ArithmeticException e) { e.printStackTrace(); } 298 StdOut.println("--------------------------------"); 299 300 try { test2(); } 301 catch (ArithmeticException e) { e.printStackTrace(); } 302 StdOut.println("--------------------------------"); 303 304 try { test3(); } 305 catch (ArithmeticException e) { e.printStackTrace(); } 306 StdOut.println("--------------------------------"); 307 308 try { test4(); } 309 catch (ArithmeticException e) { e.printStackTrace(); } 310 StdOut.println("--------------------------------"); 311 312 313 int M = Integer.parseInt(args[0]); 314 int N = Integer.parseInt(args[1]); 315 double[] c = new double[N]; 316 double[] b = new double[M]; 317 double[][] A = new double[M][N]; 318 for (int j = 0; j < N; j++) 319 c[j] = StdRandom.uniform(1000); 320 for (int i = 0; i < M; i++) 321 b[i] = StdRandom.uniform(1000); 322 for (int i = 0; i < M; i++) 323 for (int j = 0; j < N; j++) 324 A[i][j] = StdRandom.uniform(100); 325 Simplex lp = new Simplex(A, b, c); 326 StdOut.println(lp.value()); 327 } 328 329}