We want to find the smallest k elements. We want to use a max heap of size k because each pop want to kick out the local max. We want to define a sense of closeness first. A standard Euclidean distance formula will suffice.
Code
Java 8 syntax will be a lot cleaner
// compute the Euclidean distance
private double eucDist (Point origin, Point pt) {
return Math.pow((origin.x - pt.x), 2) + Math.pow((origin.y - pt.y), 2);
}
public Point[] kClosest(Point[] points, Point origin, int k) {
Point[] result = new Point[k];
// create a comparator with distance function
Comparator<Point> PtComp = (Point pt1, Point pt2) -> {
int closeness = (int) (eucDist(origin, pt1) - eucDist(origin, pt2));
if (closeness != 0) return closeness;
else if (pt1.x != pt2.x) return pt1.x - pt2.x;
else return pt1.y - pt2.y;
};
// create a max heap of size k
// either reverseOrder() or flip pt1 and pt2 would work
Queue<Point> maxHeap = new PriorityQueue<>(k, Collections.reverseOrder(PtComp));
// add points to the max heap if bigger then kick off the max
for (Point pt : points) {
if (maxHeap.size() == k) {
if (eucDist(origin, pt) < eucDist(origin, maxHeap.peek())) {
maxHeap.poll();
maxHeap.offer(pt);
}
} else {
maxHeap.offer(pt);
}
}
// max heap to result array (ascending)
int index = k - 1;
while (!maxHeap.isEmpty()) {
result[index--] = maxHeap.poll();
}
return result;
}
There is a lot potential for bugs you can get by using a non-native max heap in python. It's conceptualized as a max heap but in reality it is a min heap. It's always a good sanity check go through an example to verify.
import math
import heapq
class Solution:
def computeDist(self, point: List[int]) -> float:
# math.sqrt
return math.pow(point[0], 2) + math.pow(point[1], 2)
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
if len(points) == 0 or K == 0:
return []
# maintain a max heap of size K
# pq in python is a min heap, need to reverse the priority value
# the comparison is by min heap unfornately
# go through an example to check
max_heap = []
for point in points:
dist = self.computeDist(point)
if len(max_heap) < K:
heapq.heappush(max_heap, (-dist, point))
# -20, -26
elif -dist > max_heap[0][0]:
heapq.heappop(max_heap)
heapq.heappush(max_heap, (-dist, point))
return [p for (d, p) in max_heap]
Time & Space Complexity
Time O(nlogk) Space O(k)
Memoization
If computing distance is an expensive operation, we can memoization (point, distance) to avoid recomputing the same point.
public Point[] kClosest(Point[] points, Point origin, int k) {
Point[] result = new Point[k];
// precomputing
Map<Point, Integer> dp = new HashMap<>(points.length);
for (Point pt : points) {
dp.put(pt, (int) eucDist(origin, pt));
}
// create a comparator with distance function
Comparator<Point> PtComp = (Point pt1, Point pt2) -> {
int closeness = dp.get(pt1) - dp.get(pt2);
if (closeness != 0) return closeness;
else if (pt1.x != pt2.x) return pt1.x - pt2.x;
else return pt1.y - pt2.y;
};
// create a max heap of size k
Queue<Point> maxHeap = new PriorityQueue<>(k, Collections.reverseOrder(PtComp));
// add points to the max heap if bigger then kick off the max
for (Point pt : points) {
if (maxHeap.size() == k) {
if (dp.get(pt) < dp.get(maxHeap.peek())) {
maxHeap.poll();
maxHeap.offer(pt);
}
} else {
maxHeap.offer(pt);
}
}
// max heap to result array (ascending)
int index = k - 1;
while (!maxHeap.isEmpty()) {
result[index--] = maxHeap.poll();
}
return result;
}
We used space to save time (a constant factor). Time O(nlogk) Space O(n)