001package algs44;
002import  stdlib.*;
003
004/* ***********************************************************************
005 *  Compilation:  javac AssignmentProblem.java
006 *  Execution:    java AssignmentProblem N
007 *  Dependencies: DijkstraSP.java DirectedEdge.java
008 *
009 *  Solve an N-by-N assignment problem in N^3 log N time using the
010 *  successive shortest path algorithm.
011 *
012 *  Remark: could use dense version of Dijsktra's algorithm for
013 *  improved theoretical efficiency of N^3, but it doesn't seem to
014 *  help in practice.
015 *
016 *  Assumes N-by-N cost matrix is nonnegative.
017 *
018 *
019 *********************************************************************/
020
021public class AssignmentProblem {
022        private static final int UNMATCHED = -1;
023
024        private int N;              // number of rows and columns
025        private double[][] weight;  // the N-by-N cost matrix
026        private double[] px;        // px[i] = dual variable for row i
027        private double[] py;        // py[j] = dual variable for col j
028        private int[] xy;           // xy[i] = j means i-j is a match
029        private int[] yx;           // yx[j] = i means i-j is a match
030
031
032        public AssignmentProblem(double[][] weight) {
033                N = weight.length;
034                this.weight = new double[N][N];
035                for (int i = 0; i < N; i++)
036                        for (int j = 0; j < N; j++)
037                                this.weight[i][j] = weight[i][j];
038
039                // dual variables
040                px = new double[N];
041                py = new double[N];
042
043                // initial matching is empty
044                xy = new int[N];
045                yx = new int[N];
046                for (int i = 0; i < N; i++) xy[i] = UNMATCHED;
047                for (int j = 0; j < N; j++) yx[j] = UNMATCHED;
048
049                // add N edges to matching
050                for (int k = 0; k < N; k++) {
051                        assert isDualFeasible();
052                        assert isComplementarySlack();
053                        augment();
054                }
055                assert check();
056        }
057
058        // find shortest augmenting path and upate
059        private void augment() {
060
061                // build residual graph
062                EdgeWeightedDigraph G = new EdgeWeightedDigraph(2*N+2);
063                int s = 2*N, t = 2*N+1;
064                for (int i = 0; i < N; i++) {
065                        if (xy[i] == UNMATCHED) G.addEdge(new DirectedEdge(s, i, 0.0));
066                }
067                for (int j = 0; j < N; j++) {
068                        if (yx[j] == UNMATCHED) G.addEdge(new DirectedEdge(N+j, t, py[j]));
069                }
070                for (int i = 0; i < N; i++) {
071                        for (int j = 0; j < N; j++) {
072                                if (xy[i] == j) G.addEdge(new DirectedEdge(N+j, i, 0.0));
073                                else            G.addEdge(new DirectedEdge(i, N+j, reduced(i, j)));
074                        }
075                }
076
077                // compute shortest path from s to every other vertex
078                DijkstraSP spt = new DijkstraSP(G, s);
079
080                // augment along alternating path
081                for (DirectedEdge e : spt.pathTo(t)) {
082                        int i = e.from(), j = e.to() - N;
083                        if (i < N) {
084                                xy[i] = j;
085                                yx[j] = i;
086                        }
087                }
088
089                // update dual variables
090                for (int i = 0; i < N; i++) px[i] += spt.distTo(i);
091                for (int j = 0; j < N; j++) py[j] += spt.distTo(N+j);
092        }
093
094        // reduced cost of i-j
095        private double reduced(int i, int j) {
096                return weight[i][j] + px[i] - py[j];
097        }
098
099        // dual variable for row i
100        public double dualRow(int i) {
101                return px[i];
102        }
103
104        // dual variable for column j
105        public double dualCol(int j) {
106                return py[j];
107        }
108
109        // total weight of min weight perfect matching
110        public double weight() {
111                double total = 0.0;
112                for (int i = 0; i < N; i++) {
113                        if (xy[i] != UNMATCHED)
114                                total += weight[i][xy[i]];
115                }
116                return total;
117        }
118
119        public int sol(int i) {
120                return xy[i];
121        }
122
123        // check that dual variables are feasible
124        private boolean isDualFeasible() {
125                // check that all edges have >= 0 reduced cost
126                for (int i = 0; i < N; i++) {
127                        for (int j = 0; j < N; j++) {
128                                if (reduced(i, j) < 0) {
129                                        StdOut.println("Dual variables are not feasible");
130                                        return false;
131                                }
132                        }
133                }
134                return true;
135        }
136
137        // check that primal and dual variables are complementary slack
138        private boolean isComplementarySlack() {
139
140                // check that all matched edges have 0-reduced cost
141                for (int i = 0; i < N; i++) {
142                        if ((xy[i] != UNMATCHED) && (reduced(i, xy[i]) != 0)) {
143                                StdOut.println("Primal and dual variables are not complementary slack");
144                                return false;
145                        }
146                }
147                return true;
148        }
149
150        // check that primal variables are a perfect matching
151        private boolean isPerfectMatching() {
152
153                // check that xy[] is a perfect matching
154                boolean[] perm = new boolean[N];
155                for (int i = 0; i < N; i++) {
156                        if (perm[xy[i]]) {
157                                StdOut.println("Not a perfect matching");
158                                return false;
159                        }
160                        perm[xy[i]] = true;
161                }
162
163                // check that xy[] and yx[] are inverses
164                for (int j = 0; j < N; j++) {
165                        if (xy[yx[j]] != j) {
166                                StdOut.println("xy[] and yx[] are not inverses");
167                                return false;
168                        }
169                }
170                for (int i = 0; i < N; i++) {
171                        if (yx[xy[i]] != i) {
172                                StdOut.println("xy[] and yx[] are not inverses");
173                                return false;
174                        }
175                }
176
177                return true;
178        }
179
180
181        // check optimality conditions
182        private boolean check() {
183                return isPerfectMatching() && isDualFeasible() && isComplementarySlack();
184        }
185
186        public static void main(String[] args) {
187                In in = new In(args[0]);
188                int N = in.readInt();
189                double[][] weight = new double[N][N];
190                for (int i = 0; i < N; i++) {
191                        for (int j = 0; j < N; j++) {
192                                weight[i][j] = in.readDouble();
193                        }
194                }
195
196                AssignmentProblem assignment = new AssignmentProblem(weight);
197                StdOut.println("weight = " + assignment.weight());
198                for (int i = 0; i < N; i++)
199                        StdOut.println(i + "-" + assignment.sol(i) + "' " + weight[i][assignment.sol(i)]);
200
201                for (int i = 0; i < N; i++)
202                        StdOut.println("px[" + i + "] = " + assignment.dualRow(i));
203                for (int j = 0; j < N; j++)
204                        StdOut.println("py[" + j + "] = " + assignment.dualCol(j));
205                for (int i = 0; i < N; i++)
206                        for (int j = 0; j < N; j++)
207                                StdOut.println("reduced[" + i + "-" + j + "] = " + assignment.reduced(i, j));
208
209        }
210
211}