Binary Tree Leaves

Question (LC.366)

Given a binary tree, find the all the leaves of the binary tree.

Example

I:        1
         / \
        2   3
       / \     
      4   5
O: [[4,5,3],[2],[1]]

Traverse and Delete

This is most obvious solution from the example given. We can to use a dfs helper removeLeaves() to remove all leaves of the current tree and add them to a list.

public List<List<Integer>> findLeaves(TreeNode root) {
    List<List<Integer>> result = new ArrayList<>();
    if (root == null) {
        return result;
    }
    Set<TreeNode> deleted = new HashSet<>();
    while (!deleted.contains(root)) {
        List<Integer> leaves = new ArrayList<>();
        removeLeaves(root, deleted, leaves);
        result.add(leaves);
    }    
    return result;
}

private void removeLeaves(TreeNode node, Set<TreeNode> deleted, List<Integer> leaves) {
    if ( (node.left == null || deleted.contains(node.left) ) && 
         (node.right == null || deleted.contains(node.right)) ) {
        leaves.add(node.val);
        deleted.add(node);
        return;
    }
    if (node.left != null && !deleted.contains(node.left)) {
        removeLeaves(node.left, deleted, leaves);
    }     
    if (node.right != null && !deleted.contains(node.right)) {
        removeLeaves(node.right, deleted, leaves);
    }
}

Bottom Up with Height

Recall max depth of binary tree. Instead of top down, we can count the height bottom up. In this sense, the "leaves" are really nodes that are grouped by their inverse heights.

public List<List<Integer>> findLeaves(TreeNode root) {
    List<List<Integer>> result = new ArrayList<>();
    if (root == null) {
        return result;
    }
    treeHeight(root, result);
    return result;
}

private int treeHeight(TreeNode root, List<List<Integer>> result) {
    // base
    if (root == null) {
        return -1;
    }
    // divide
    int leftHeight = treeHeight(root.left, result);
    int rightHeight = treeHeight(root.right, result);
    // merge
    int height = Math.max(leftHeight, rightHeight) + 1;
    if (height == result.size()) {
        result.add(new ArrayList<Integer>());
    }
    result.get(height).add(root.val);
    // root.left = null; root.right = null;
    return height;
}

Last updated