001package algs32.kdtree; 002import algs12.Point2D; 003import algs13.Queue; 004import stdlib.*; 005 006public class NearestNeighborCorrectnessTest { 007 008 static int NUM_TARGETS = 1000; 009 static int NUM_SIZES = 12; 010 static int NUM_TESTS = 200; 011 static int NUM_POSSIBLE_INIT = 1; 012 static int TREE_SIZE_INIT = 0; 013 static boolean ALLOW_DUPLICATES = true; 014 static boolean SHOW_TREE_ON_FAILURE = true; 015 static boolean STOP_AFTER_FIRST_FAILURE = true; 016 static boolean CATCH_EXCEPTIONS = false; 017 private static boolean passed = true; 018 019 protected static Point2D nearest (KdTree kdtree, Point2D target) { 020 if (!CATCH_EXCEPTIONS) { 021 return kdtree.nearest (target); 022 } else { 023 try { 024 return kdtree.nearest (target); 025 } catch (Throwable e) { 026 if (passed) { 027 passed = false; 028 e.printStackTrace (); 029 } 030 return new Point2D (666, 666); 031 } 032 } 033 } 034 private static boolean showInsertionException = true; 035 protected static boolean insert (KdTree kdtree, Point2D p) { 036 if (!CATCH_EXCEPTIONS) { 037 kdtree.insert (p); 038 return true; 039 } else { 040 try { 041 kdtree.insert (p); 042 return true; 043 } catch (Throwable e) { 044 if (showInsertionException) { 045 showInsertionException = false; 046 e.printStackTrace (); 047 } 048 passed = false; 049 return false; 050 } 051 } 052 } 053 private static double random(int numPossible) { 054 return StdRandom.uniform (numPossible)/(double)numPossible; 055 } 056 public static void main(String[] args) { 057 //StdRandom.setSeed (0); // uncomment to get the same results over and over 058 059 Queue<Point2D> queue = new Queue<> (); 060 for (int i=0; i<NUM_TARGETS; i++) 061 queue.enqueue(new Point2D(random(1000), random(1000))); 062 063 // treeSize and numPossible vary each time around the test loop 064 // trying small trees with few possible values for points to start 065 // doubling the treeSize each time 066 // keeping numPossible a power of 10 so that decimal fractions print nicely 067 int numPossible = NUM_POSSIBLE_INIT; 068 int treeSize = TREE_SIZE_INIT; 069 int numTested = 0; 070 int numPassed = 0; 071 int numTreesAttempted = 0; 072 int numTreesCreated = 0; 073 test: for (int numsize=0; numsize<NUM_SIZES; numsize++) { 074 StdOut.format ("trying treeSize %d\n", treeSize); 075 for (int numtest=0; numtest<NUM_TESTS; numtest++) { 076 PointSET brute = new PointSET(); 077 KdTree kdtree = new KdTree(); 078 079 boolean treeCreated = true; 080 for (int i=0; i<treeSize; i++) { 081 Point2D p = new Point2D(random (numPossible), random (numPossible)); 082 if (ALLOW_DUPLICATES || !brute.contains (p)) { 083 if (!insert(kdtree, p)) treeCreated = false; 084 brute.insert(p); 085 } 086 } 087 numTreesAttempted ++; 088 if (treeCreated) numTreesCreated ++; 089 point: for (Point2D p : queue) { 090 numTested ++; 091 Point2D b = brute.nearest(p); 092 Point2D k = nearest (kdtree, p); 093 if (b==null) { 094 if (k!=null) { 095 printError (treeSize, brute, kdtree, p); 096 if (STOP_AFTER_FIRST_FAILURE) break test; else continue point; 097 } 098 } else if (k==null) { 099 printError (treeSize, brute, kdtree, p); 100 if (STOP_AFTER_FIRST_FAILURE) break test; else continue test; 101 } else if (p.distanceTo(b) - p.distanceTo (k) != 0.0) { 102 printError (treeSize, brute, kdtree, p); 103 if (STOP_AFTER_FIRST_FAILURE) break test; else continue point; 104 } 105 numPassed ++; 106 } 107 } 108 treeSize += (treeSize==0) ? 1 : treeSize; 109 if (numsize % 4==0) numPossible *= 10; 110 } 111 StdOut.format ("#NearestNeighbor %s: %d/%d passed, %d/%d trees created without thrown exception\n", passed ? "passed" : "failed", numPassed, numTested, numTreesCreated, numTreesAttempted); 112 113 } 114 private static void printError (int treeSize, PointSET brute, KdTree kdtree, Point2D p) { 115 if (passed) { 116 passed = false; 117 StdOut.println ("Error!"); 118 //StdOut.println (" treeSize should be " + treeSize); 119 //if (brute.size() != treeSize) StdOut.println (" duplicate points"); 120 StdOut.println (" PointSET = " + brute); 121 StdOut.println (" KdTree = " + kdtree); 122 StdOut.println (" target = " + p); 123 StdOut.println (" PointSET nearest = " + brute.nearest(p)); 124 StdOut.println (" KdTree nearest = " + nearest(kdtree, p)); 125 if (SHOW_TREE_ON_FAILURE) { 126 kdtree.toGraphviz (); 127 kdtree.draw (); 128 } 129 } 130 } 131}