001package algs64; // section 6.4
002import stdlib.*;
003/* ***********************************************************************
004 *  Compilation:  javac Hungarian.java
005 *  Execution:    java Hungarian N
006 *  Dependencies: FordFulkerson.java FlowNetwork.java FlowEdge.java
007 *
008 *  Solve an N-by-N assignment problem. Bare-bones implementation:
009 *     - takes N^5 time in worst case.
010 *     - assumes weights are >= 0  (add a large constant if not)
011 *
012 *
013 *********************************************************************/
014
015public class XHungarian {
016        private final int N;              // number of rows and columns
017        private final double[][] weight;  // the N-by-N weight matrix
018        private final double[] x;         // dual variables for rows
019        private final double[] y;         // dual variables for columns
020        private final int[] xy;           // xy[i] = j means i-j is a match
021        private final int[] yx;           // yx[j] = i means i-j is a match
022
023        public XHungarian(double[][] weight) {
024                this.weight = weight;
025                N = weight.length;
026                x = new double[N];
027                y = new double[N];
028                xy = new int[N];
029                yx = new int[N];
030                for (int i = 0; i < N; i++) xy[i] = -1;
031                for (int j = 0; j < N; j++) yx[j] = -1;
032
033                while (true) {
034
035                        // build graph of 0-reduced cost edges
036                        FlowNetwork G = new FlowNetwork(2*N+2);
037                        int s = 2*N, t = 2*N+1;
038                        for (int i = 0; i < N; i++) {
039                                if (xy[i] == -1) G.addEdge(new FlowEdge(s, i, 1.0));
040                                else             G.addEdge(new FlowEdge(s, i, 1.0, 1.0));
041                        }
042                        for (int j = 0; j < N; j++) {
043                                if (yx[j] == -1) G.addEdge(new FlowEdge(N+j, t, 1.0));
044                                else             G.addEdge(new FlowEdge(N+j, t, 1.0, 1.0));
045                        }
046                        for (int i = 0; i < N; i++) {
047                                for (int j = 0; j < N; j++) {
048                                        if (reduced(i, j) == 0) {
049                                                if (xy[i] != j) G.addEdge(new FlowEdge(i, N+j, 1.0));
050                                                else            G.addEdge(new FlowEdge(i, N+j, 1.0, 1.0));
051                                        }
052                                }
053                        }
054
055                        // to make N^4, start from previous solution
056                        FordFulkerson ff = new FordFulkerson(G, s, t);
057
058                        // current matching
059                        for (int i = 0; i < N; i++) xy[i] = -1;
060                        for (int j = 0; j < N; j++) yx[j] = -1;
061                        for (int i = 0; i < N; i++) {
062                                for (FlowEdge e : G.adj(i)) {
063                                        if ((e.from() == i) && (e.flow() > 0)) {
064                                                xy[i] = e.to() - N;
065                                                yx[e.to() - N] = i;
066                                        }
067                                }
068                        }
069
070                        // perfect matching
071                        if (ff.value() == N) break;
072
073                        // find bottleneck weight
074                        double max = Double.POSITIVE_INFINITY;
075                        for (int i = 0; i < N; i++)
076                                for (int j = 0; j < N; j++)
077                                        if (ff.inCut(i) && !ff.inCut(N+j) && (reduced(i, j) < max))
078                                                max = reduced(i, j);
079
080                        // update dual variables
081                        for (int i = 0; i < N; i++)
082                                if (!ff.inCut(i))   x[i] -= max;
083                        for (int j = 0; j < N; j++)
084                                if (!ff.inCut(N+j)) y[j] += max;
085
086                        StdOut.println("value = " + ff.value());
087                }
088                assert check();
089        }
090
091        // reduced cost of i-j
092        private double reduced(int i, int j) {
093                return weight[i][j] - x[i] - y[j];
094        }
095
096        private double weight() {
097                double totalWeight = 0.0;
098                for (int i = 0; i < N; i++) totalWeight += weight[i][xy[i]];
099                return totalWeight;
100        }
101
102        private int sol(int i) {
103                return xy[i];
104        }
105
106
107        // check optimality conditions
108        private boolean check() {
109                // check that xy[] is a permutation
110                boolean[] perm = new boolean[N];
111                for (int i = 0; i < N; i++) {
112                        if (perm[xy[i]]) {
113                                StdOut.println("Not a perfect matching");
114                                return false;
115                        }
116                        perm[xy[i]] = true;
117                }
118
119                // check that all edges in xy[] have 0-reduced cost
120                for (int i = 0; i < N; i++) {
121                        if (reduced(i, xy[i]) != 0) {
122                                StdOut.println("Solution does not have 0 reduced cost");
123                                return false;
124                        }
125                }
126
127                // check that all edges have >= 0 reduced cost
128                for (int i = 0; i < N; i++) {
129                        for (int j = 0; j < N; j++) {
130                                if (reduced(i, j) < 0) {
131                                        StdOut.println("Some edges have negative reduced cost");
132                                        return false;
133                                }
134                        }
135                }
136                return true;
137        }
138
139        public static void main(String[] args) {
140
141                int N = Integer.parseInt(args[0]);
142                double[][] weight = new double[N][N];
143                for (int i = 0; i < N; i++)
144                        for (int j = 0; j < N; j++)
145                                weight[i][j] = StdRandom.random();
146
147                XHungarian assignment = new XHungarian(weight);
148                StdOut.println("weight = " + assignment.weight());
149                for (int i = 0; i < N; i++)
150                        StdOut.println(i + "-" + assignment.sol(i));
151        }
152
153}