001package algs32.kdtree;
002import algs12.Point2D;
003import algs13.Queue;
004import stdlib.*;
005
006public class NearestNeighborCorrectnessTest {
007
008        static int NUM_TARGETS = 1000;
009        static int NUM_SIZES = 12;
010        static int NUM_TESTS = 200;
011        static int NUM_POSSIBLE_INIT = 1;
012        static int TREE_SIZE_INIT = 0;
013        static boolean ALLOW_DUPLICATES = true;
014        static boolean SHOW_TREE_ON_FAILURE = true;
015        static boolean STOP_AFTER_FIRST_FAILURE = true;
016        static boolean CATCH_EXCEPTIONS = false;
017        private static boolean passed = true;
018
019        protected static Point2D nearest (KdTree kdtree, Point2D target) {
020                if (!CATCH_EXCEPTIONS) {
021                        return kdtree.nearest (target);
022                } else {
023                        try {
024                                return kdtree.nearest (target);
025                        } catch (Throwable e) {
026                                if (passed) {
027                                        passed = false;
028                                        e.printStackTrace ();
029                                }
030                                return new Point2D (666, 666);
031                        }
032                }
033        }
034        private static boolean showInsertionException = true;
035        protected static boolean insert (KdTree kdtree, Point2D p) {
036                if (!CATCH_EXCEPTIONS) {
037                        kdtree.insert (p);
038                        return true;
039                } else {
040                        try {
041                                kdtree.insert (p);
042                                return true;
043                        } catch (Throwable e) {
044                                if (showInsertionException) {
045                                        showInsertionException = false;
046                                        e.printStackTrace ();
047                                }
048                                passed = false;
049                                return false;
050                        }
051                }
052        }
053        private static double random(int numPossible) {
054                return StdRandom.uniform (numPossible)/(double)numPossible;
055        }
056        public static void main(String[] args) {
057                //StdRandom.setSeed (0); // uncomment to get the same results over and over
058
059                Queue<Point2D> queue = new Queue<> ();
060                for (int i=0; i<NUM_TARGETS; i++)
061                        queue.enqueue(new Point2D(random(1000), random(1000)));
062
063                // treeSize and numPossible vary each time around the test loop
064                // trying small trees with few possible values for points to start
065                // doubling the treeSize each time
066                // keeping numPossible a power of 10 so that decimal fractions print nicely
067                int numPossible = NUM_POSSIBLE_INIT;
068                int treeSize = TREE_SIZE_INIT;
069                int numTested = 0;
070                int numPassed = 0;
071                int numTreesAttempted = 0;
072                int numTreesCreated = 0;
073                test: for (int numsize=0; numsize<NUM_SIZES; numsize++) {
074                        StdOut.format ("trying treeSize %d\n", treeSize);
075                        for (int numtest=0; numtest<NUM_TESTS; numtest++) {
076                                PointSET brute = new PointSET();
077                                KdTree kdtree = new KdTree();
078
079                                boolean treeCreated = true;
080                                for (int i=0; i<treeSize; i++) {
081                                        Point2D p = new Point2D(random (numPossible), random (numPossible));
082                                        if (ALLOW_DUPLICATES || !brute.contains (p)) {
083                                                if (!insert(kdtree, p)) treeCreated = false;
084                                                brute.insert(p);
085                                        }
086                                }
087                                numTreesAttempted ++;
088                                if (treeCreated) numTreesCreated ++;
089                                point: for (Point2D p : queue) {
090                                        numTested ++;
091                                        Point2D b = brute.nearest(p);
092                                        Point2D k = nearest (kdtree, p);
093                                        if (b==null) {
094                                                if (k!=null) {
095                                                        printError (treeSize, brute, kdtree, p);
096                                                        if (STOP_AFTER_FIRST_FAILURE) break test; else continue point;
097                                                }
098                                        } else if (k==null) {
099                                                printError (treeSize, brute, kdtree, p);
100                                                if (STOP_AFTER_FIRST_FAILURE) break test; else continue test;
101                                        } else if (p.distanceTo(b) - p.distanceTo (k) != 0.0) {
102                                                printError (treeSize, brute, kdtree, p);
103                                                if (STOP_AFTER_FIRST_FAILURE) break test; else continue point;
104                                        }
105                                        numPassed ++;
106                                }
107                        }
108                        treeSize += (treeSize==0) ? 1 : treeSize;
109                        if (numsize % 4==0) numPossible *= 10;
110                }
111                StdOut.format ("#NearestNeighbor %s: %d/%d passed, %d/%d trees created without thrown exception\n", passed ? "passed" : "failed", numPassed, numTested, numTreesCreated, numTreesAttempted);
112
113        }
114        private static void printError (int treeSize, PointSET brute, KdTree kdtree, Point2D p) {
115                if (passed) {
116                        passed = false;
117                        StdOut.println ("Error!");
118                        //StdOut.println ("  treeSize should be " + treeSize);
119                        //if (brute.size() != treeSize) StdOut.println ("  duplicate points");
120                        StdOut.println ("  PointSET         = " + brute);
121                        StdOut.println ("  KdTree           = " + kdtree);
122                        StdOut.println ("  target           = " + p);
123                        StdOut.println ("  PointSET nearest = " + brute.nearest(p));
124                        StdOut.println ("  KdTree nearest   = " + nearest(kdtree, p));
125                        if (SHOW_TREE_ON_FAILURE) {
126                                kdtree.toGraphviz ();
127                                kdtree.draw ();
128                        }
129                }
130        }
131}