Kth Smallest Element in a BST

Question (LC.230)

Given the root of a BST, find the kth smallest element.

Example

I: tree = [3,1,4,null,2], k = 1

   3
  / \
 1   4
  \
   2

O: 1

I: tree = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1

O: 3 

Analysis

Find the ith node in the in order traversal.

Code


def inOrderIter(self, root: TreeNode) -> Iterable[TreeNode]:
    
    if root is None:
        return None
    
    for node in self.inOrderIter(root.left):
        yield node
    
    yield root 
    
    for node in self.inOrderIter(root.right):
        yield node 


def kthSmallest(self, root: TreeNode, k: int) -> int:
    
    i = 1
    
    for node in self.inOrderIter(root):
        if i == k:
            return node.val
        i += 1
    
    return -1 

The iterator with stack turns out to be slightly faster.

from collections import deque 


class InOrderIterator:

    def __init__(self, root: TreeNode):
        self.stack = deque()
        self.cur = root 

    def __iter__(self):
        return self 
    
    def __next__(self):
        
        next_node = None
        
        if self.cur is not None or len(self.stack) > 0:
            
            # go left 
            while self.cur is not None:
                self.stack.append(self.cur)
                self.cur = self.cur.left 
            
            # visit mid 
            self.cur = self.stack.pop()
            next_node = self.cur
            
            # do the same to right 
            # if self.cur.right is not None:
            self.cur = self.cur.right 
            
            return next_node
        else:
            raise StopIteration 

class Solution:
    
    
    def kthSmallest(self, root: TreeNode, k: int) -> int:
        
        iterator = InOrderIterator(root)
        
        i = 1
        
        for node in iterator:
            if i == k:
                return node.val
            i += 1
        
        return -1 

Complexity

Time: O(h + k)

Space: O(h)

Follow Up

What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kth smallest routine?

  1. If we know k is a constant

    1. Maintain a separate data structure. Like a max heap of k element.

    2. Deletion is tricky here but O(k) is not terrible if k is small

    3. O(hBST)+O(hheap)O(h_{BST}) + O(h_{heap})insert, O(hBST+O(k))O(h_{BST} + O(k))delete, O(1)O(1)find kth smallest

  2. If k is a variable that is part of the query

    1. The current approach is fine O(h) insert + O(h + k) find kth smallest

    2. We can optimize find kth smallest slightly by maintaining a linked list

    3. We need a mapping between TreeNode to ListNode too for insertion

    4. Then O(h) insert, O(h) delete, O(k) find kth smallest

Last updated