001// Exercise 2.2.19 (Solution published at http://algs4.cs.princeton.edu/)
002package algs22;
003import stdlib.*;
004/* ***********************************************************************
005 *  Compilation:  javac Inversions.java
006 *  Execution:    java Inversions N
007 *
008 *  Generate N pseudo-random numbers between 0 and 1 and count
009 *  the number of inversions in O(N log N) time.
010 *
011 *************************************************************************/
012
013public class XInversions {
014
015        // merge and count
016        private static <T extends Comparable<? super T>> int merge(T[] a, T[] aux, int lo, int mid, int hi) {
017                int inversions = 0;
018
019                // copy to aux[]
020                for (int k = lo; k <= hi; k++) {
021                        aux[k] = a[k];
022                }
023
024                // merge back to a[]
025                int i = lo, j = mid+1;
026                for (int k = lo; k <= hi; k++) {
027                        if      (i > mid)                a[k] = aux[j++];
028                        else if (j > hi)                 a[k] = aux[i++];
029                        else if (less(aux[j], aux[i])) { a[k] = aux[j++]; inversions += (mid - i + 1); }
030                        else                             a[k] = aux[i++];
031                }
032                return inversions;
033        }
034
035        // return the number of inversions in the subarray b[lo..hi]
036        // side effect b[lo..hi] is rearranged in ascending order
037        private static <T extends Comparable<? super T>> int count(T[] a, T[] b, T[] aux, int lo, int hi) {
038                int inversions = 0;
039                if (hi <= lo) return 0;
040                int mid = lo + (hi - lo) / 2;
041                inversions += count(a, b, aux, lo, mid);
042                inversions += count(a, b, aux, mid+1, hi);
043                inversions += merge(b, aux, lo, mid, hi);
044                assert inversions == brute(a, lo, hi);
045                return inversions;
046        }
047
048
049        // count number of inversions in the array a[] - do not overwrite a[]
050        @SuppressWarnings("unchecked")
051        public static <T extends Comparable<? super T>> int count(T[] a) {
052                T[] b   = (T[]) new Comparable[a.length];
053                T[] aux = (T[]) new Comparable[a.length];
054                for (int i = 0; i < a.length; i++) b[i] = a[i];
055                int inversions = count(a, b, aux, 0, a.length - 1);
056                return inversions;
057        }
058
059
060        // is v < w ?
061        private static <T extends Comparable<? super T>> boolean less(T v, T w) {
062                return (v.compareTo(w) < 0);
063        }
064
065        // count number of inversions in a[lo..hi] via brute force (for debugging only)
066        private static <T extends Comparable<? super T>> int brute(T[] a, int lo, int hi) {
067                int inversions = 0;
068                for (int i = lo; i <= hi; i++)
069                        for (int j = i + 1; j <= hi; j++)
070                                if (less(a[j], a[i])) inversions++;
071                return inversions;
072        }
073
074
075
076
077        // generate N real numbers between 0 and 1, and mergesort them
078        public static void main(String[] args) {
079                args = new String[] { "20" };
080
081                int N = Integer.parseInt(args[0]);
082                Double[] a = new Double[N];
083                for (int i = 0; i < N; i++)
084                        a[i] = Math.random();
085                StdOut.println(brute(a, 0, N-1));
086                StdOut.println(count(a));
087                for (int i = 0; i < N; i++)
088                        StdOut.println(a[i]);
089        }
090}