01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 // Exercise 2.2.19 (Solution published at http://algs4.cs.princeton.edu/) package algs22; import stdlib.*; /* *********************************************************************** * Compilation: javac Inversions.java * Execution: java Inversions N * * Generate N pseudo-random numbers between 0 and 1 and count * the number of inversions in O(N log N) time. * *************************************************************************/ public class XInversions { // merge and count private static > int merge(T[] a, T[] aux, int lo, int mid, int hi) { int inversions = 0; // copy to aux[] for (int k = lo; k <= hi; k++) { aux[k] = a[k]; } // merge back to a[] int i = lo, j = mid+1; for (int k = lo; k <= hi; k++) { if (i > mid) a[k] = aux[j++]; else if (j > hi) a[k] = aux[i++]; else if (less(aux[j], aux[i])) { a[k] = aux[j++]; inversions += (mid - i + 1); } else a[k] = aux[i++]; } return inversions; } // return the number of inversions in the subarray b[lo..hi] // side effect b[lo..hi] is rearranged in ascending order private static > int count(T[] a, T[] b, T[] aux, int lo, int hi) { int inversions = 0; if (hi <= lo) return 0; int mid = lo + (hi - lo) / 2; inversions += count(a, b, aux, lo, mid); inversions += count(a, b, aux, mid+1, hi); inversions += merge(b, aux, lo, mid, hi); assert inversions == brute(a, lo, hi); return inversions; } // count number of inversions in the array a[] - do not overwrite a[] @SuppressWarnings("unchecked") public static > int count(T[] a) { T[] b = (T[]) new Comparable[a.length]; T[] aux = (T[]) new Comparable[a.length]; for (int i = 0; i < a.length; i++) b[i] = a[i]; int inversions = count(a, b, aux, 0, a.length - 1); return inversions; } // is v < w ? private static > boolean less(T v, T w) { return (v.compareTo(w) < 0); } // count number of inversions in a[lo..hi] via brute force (for debugging only) private static > int brute(T[] a, int lo, int hi) { int inversions = 0; for (int i = lo; i <= hi; i++) for (int j = i + 1; j <= hi; j++) if (less(a[j], a[i])) inversions++; return inversions; } // generate N real numbers between 0 and 1, and mergesort them public static void main(String[] args) { args = new String[] { "20" }; int N = Integer.parseInt(args[0]); Double[] a = new Double[N]; for (int i = 0; i < N; i++) a[i] = Math.random(); StdOut.println(brute(a, 0, N-1)); StdOut.println(count(a)); for (int i = 0; i < N; i++) StdOut.println(a[i]); } }