K Closest Points

Question (LI.612)

Given a list of points and an origin in a 2D coordinate, find k points that are the closest to the origin.

Example

I: [(4,6), (4,7), (4,4), (2,5), (1,1)], (0,0), 3
O: [(1,1), (2,5), (4,4)]

Approach

We want to find the smallest k elements. We want to use a max heap of size k because each pop want to kick out the local max. We want to define a sense of closeness first. A standard Euclidean distance formula will suffice.

Code

Java 8 syntax will be a lot cleaner

// compute the Euclidean distance
private double eucDist (Point origin, Point pt) {
    return Math.pow((origin.x - pt.x), 2) + Math.pow((origin.y - pt.y), 2);
}

public Point[] kClosest(Point[] points, Point origin, int k) {
    Point[] result = new Point[k];
    // create a comparator with distance function 
    Comparator<Point> PtComp = (Point pt1, Point pt2) -> {
        int closeness = (int) (eucDist(origin, pt1) - eucDist(origin, pt2));
        if (closeness != 0) return closeness;
        else if (pt1.x != pt2.x) return pt1.x - pt2.x;
        else return pt1.y - pt2.y;          
    };
    // create a max heap of size k 
    // either reverseOrder() or flip pt1 and pt2 would work
    Queue<Point> maxHeap = new PriorityQueue<>(k, Collections.reverseOrder(PtComp));
    // add points to the max heap if bigger then kick off the max
    for (Point pt : points) {
        if (maxHeap.size() == k) {
            if (eucDist(origin, pt) < eucDist(origin, maxHeap.peek())) {
                maxHeap.poll();
                maxHeap.offer(pt);
            }
        } else {
            maxHeap.offer(pt);
        }
    }
    // max heap to result array (ascending)
    int index = k - 1;
    while (!maxHeap.isEmpty()) {
        result[index--] = maxHeap.poll();
    }
    return result;
}

There is a lot potential for bugs you can get by using a non-native max heap in python. It's conceptualized as a max heap but in reality it is a min heap. It's always a good sanity check go through an example to verify.

import math
import heapq


class Solution:
    
    def computeDist(self, point: List[int]) -> float:
    
        # math.sqrt 
        return math.pow(point[0], 2) + math.pow(point[1], 2)
        
    
    def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:

        if len(points) == 0 or K == 0:
            return []


        # maintain a max heap of size K
        # pq in python is a min heap, need to reverse the priority value
        # the comparison is by min heap unfornately 
        # go through an example to check 
        max_heap = []

        for point in points:

            dist = self.computeDist(point)
            
            if len(max_heap) < K:
                heapq.heappush(max_heap, (-dist, point))
            # -20, -26 
            elif -dist > max_heap[0][0]:
                heapq.heappop(max_heap)
                heapq.heappush(max_heap, (-dist, point))

        return [p for (d, p) in max_heap]

Time & Space Complexity

Time O(nlogk) Space O(k)

Memoization

If computing distance is an expensive operation, we can memoization (point, distance) to avoid recomputing the same point.

public Point[] kClosest(Point[] points, Point origin, int k) {
    Point[] result = new Point[k];
    // precomputing
    Map<Point, Integer> dp = new HashMap<>(points.length);
    for (Point pt : points) {
        dp.put(pt, (int) eucDist(origin, pt));
    }
    // create a comparator with distance function 
    Comparator<Point> PtComp = (Point pt1, Point pt2) -> {
        int closeness = dp.get(pt1) - dp.get(pt2);
        if (closeness != 0) return closeness;
        else if (pt1.x != pt2.x) return pt1.x - pt2.x;
        else return pt1.y - pt2.y;          
    };
    // create a max heap of size k
    Queue<Point> maxHeap = new PriorityQueue<>(k, Collections.reverseOrder(PtComp));
    // add points to the max heap if bigger then kick off the max
    for (Point pt : points) {
        if (maxHeap.size() == k) {
            if (dp.get(pt) < dp.get(maxHeap.peek())) {
                maxHeap.poll();
                maxHeap.offer(pt);
            }
        } else {
            maxHeap.offer(pt);
        }
    }
    // max heap to result array (ascending)
    int index = k - 1;
    while (!maxHeap.isEmpty()) {
        result[index--] = maxHeap.poll();
    }
    return result;
}

We used space to save time (a constant factor). Time O(nlogk) Space O(n)

Last updated