001// Exercise 2.2.19 (Solution published at http://algs4.cs.princeton.edu/) 002package algs22; 003import stdlib.*; 004/* *********************************************************************** 005 * Compilation: javac Inversions.java 006 * Execution: java Inversions N 007 * 008 * Generate N pseudo-random numbers between 0 and 1 and count 009 * the number of inversions in O(N log N) time. 010 * 011 *************************************************************************/ 012 013public class XInversions { 014 015 // merge and count 016 private static <T extends Comparable<? super T>> int merge(T[] a, T[] aux, int lo, int mid, int hi) { 017 int inversions = 0; 018 019 // copy to aux[] 020 for (int k = lo; k <= hi; k++) { 021 aux[k] = a[k]; 022 } 023 024 // merge back to a[] 025 int i = lo, j = mid+1; 026 for (int k = lo; k <= hi; k++) { 027 if (i > mid) a[k] = aux[j++]; 028 else if (j > hi) a[k] = aux[i++]; 029 else if (less(aux[j], aux[i])) { a[k] = aux[j++]; inversions += (mid - i + 1); } 030 else a[k] = aux[i++]; 031 } 032 return inversions; 033 } 034 035 // return the number of inversions in the subarray b[lo..hi] 036 // side effect b[lo..hi] is rearranged in ascending order 037 private static <T extends Comparable<? super T>> int count(T[] a, T[] b, T[] aux, int lo, int hi) { 038 int inversions = 0; 039 if (hi <= lo) return 0; 040 int mid = lo + (hi - lo) / 2; 041 inversions += count(a, b, aux, lo, mid); 042 inversions += count(a, b, aux, mid+1, hi); 043 inversions += merge(b, aux, lo, mid, hi); 044 assert inversions == brute(a, lo, hi); 045 return inversions; 046 } 047 048 049 // count number of inversions in the array a[] - do not overwrite a[] 050 @SuppressWarnings("unchecked") 051 public static <T extends Comparable<? super T>> int count(T[] a) { 052 T[] b = (T[]) new Comparable[a.length]; 053 T[] aux = (T[]) new Comparable[a.length]; 054 for (int i = 0; i < a.length; i++) b[i] = a[i]; 055 int inversions = count(a, b, aux, 0, a.length - 1); 056 return inversions; 057 } 058 059 060 // is v < w ? 061 private static <T extends Comparable<? super T>> boolean less(T v, T w) { 062 return (v.compareTo(w) < 0); 063 } 064 065 // count number of inversions in a[lo..hi] via brute force (for debugging only) 066 private static <T extends Comparable<? super T>> int brute(T[] a, int lo, int hi) { 067 int inversions = 0; 068 for (int i = lo; i <= hi; i++) 069 for (int j = i + 1; j <= hi; j++) 070 if (less(a[j], a[i])) inversions++; 071 return inversions; 072 } 073 074 075 076 077 // generate N real numbers between 0 and 1, and mergesort them 078 public static void main(String[] args) { 079 args = new String[] { "20" }; 080 081 int N = Integer.parseInt(args[0]); 082 Double[] a = new Double[N]; 083 for (int i = 0; i < N; i++) 084 a[i] = Math.random(); 085 StdOut.println(brute(a, 0, N-1)); 086 StdOut.println(count(a)); 087 for (int i = 0; i < N; i++) 088 StdOut.println(a[i]); 089 } 090}