Search Range in Binary Search Tree

Question (LC.11)

Given two values k1 and k2 (k1 < k2) and a pointer to the root of a BST. Find all the keys of tree in range [k1, k2] or k1<=x<=k2. Return all the keys in ascending order.

Example

root20
   /  \
  8   22
 / \
4   12
Input: k1 = 10, k2 = 22
Return: [12, 20, 22]

Analysis

Do a binary search. If the root is in the search interval (low < root.val < high), search both left and right subtree. If the root.val < low, search right subtree. If the root.val > high, search left subtree.

Traverse Code

public ArrayList<Integer> searchRange(TreeNode root, int k1, int k2) {
    ArrayList<Integer> results = new ArrayList<>();
    bsearchRange(results, root, k1, k2);
    Collections.sort(results);
    return results;
}

private void bsearchRange(ArrayList<Integer> results, TreeNode root, int low, int high) {
    if (root == null) {
        return;
    }
    if (root.val >= low && root.val <= high) {
        results.add(root.val);
        bsearchRange(results, root.left, low, high);
        bsearchRange(results, root.right, low, high);
    } else if (root.val < low) {
        bsearchRange(results, root.right, low, high);
    } else {
        bsearchRange(results, root.left, low, high);
    }
}

We don't need to sort it. We can take advantaged of the binary search tree property - sorted. We can add in the nodes in a sorted order.

private void bsearchRange(ArrayList<Integer> results, TreeNode root, int low, int high) {
    if (root == null) {
        return;
    }
    if (root.val < low) {
        bsearchRange(results, root.right, low, high);
    } else if (root.val > low && root.val < high) {
        results.add(root.val);
        bsearchRange(results, root.left, low, high);
        bsearchRange(results, root.right, low, high);
    } else {
        bsearchRange(results, root.left, low, high);
    }
}

D&C Code

def searchRange(self, root: TreeNode, k1: int, k2: int) -> List[int]:
    
    # base 
    if root is None:
        return []
    
    
    # divide 
    left_vals = []
    if k1 < root.val:
        left_vals = self.searchRange(root.left, k1, k2)
    
    right_vals = []
    if k2 > root.val:
        right_vals = self.searchRange(root.right, k1, k2)
    
    # merge 
    cur_vals = []
    if root.val >= k1 and root.val <= k2:
        cur_vals = [root.val]
    
    return left_vals + cur_vals + right_vals 

Worst case O(n) but can prune some searches outside of the search interval

Last updated