Intermediate

BST Problems

Binary Search Trees maintain sorted order through their structure: every left child is smaller, every right child is larger. This property enables O(log n) search, insertion, and deletion — the same principle behind KD-trees for nearest neighbor search in ML and decision tree splits that partition feature space.

Problem 1: Validate Binary Search Tree

🎯
Problem: Given the root of a binary tree, determine if it is a valid BST (left subtree values < node < right subtree values, for every node).
ML Context: When building decision trees from scratch, you must validate that each split correctly partitions the data. An invalid BST structure means the model's decision boundaries are inconsistent.
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def is_valid_bst(root: TreeNode) -> bool:
    """Validate BST using range tracking.
    Key insight: each node must be within (lower_bound, upper_bound).
    Time: O(n), Space: O(h).
    """
    def validate(node, low=float('-inf'), high=float('inf')):
        if not node:
            return True
        if node.val <= low or node.val >= high:
            return False
        # Left child must be < current, right must be > current
        return (validate(node.left, low, node.val) and
                validate(node.right, node.val, high))

    return validate(root)

def is_valid_bst_inorder(root: TreeNode) -> bool:
    """Validate BST using inorder traversal (must be strictly increasing).
    Time: O(n), Space: O(h).
    """
    stack = []
    prev = float('-inf')
    current = root

    while current or stack:
        while current:
            stack.append(current)
            current = current.left
        current = stack.pop()
        if current.val <= prev:
            return False
        prev = current.val
        current = current.right

    return True

# Test
valid = TreeNode(2, TreeNode(1), TreeNode(3))
invalid = TreeNode(5, TreeNode(1), TreeNode(4, TreeNode(3), TreeNode(6)))
print(is_valid_bst(valid))    # True
print(is_valid_bst(invalid))  # False (4 < 5 but 3 is in right subtree of 5)

Problem 2: Kth Smallest Element in BST

🎯
Problem: Given the root of a BST and an integer k, return the kth smallest value (1-indexed).
ML Context: Finding the kth smallest element is equivalent to finding a specific percentile in sorted data — a common operation when computing feature statistics for normalization or selecting threshold values in decision boundaries.
def kth_smallest(root: TreeNode, k: int) -> int:
    """Iterative inorder traversal, stop at kth element.
    Time: O(h + k), Space: O(h).
    """
    stack = []
    current = root
    count = 0

    while current or stack:
        while current:
            stack.append(current)
            current = current.left
        current = stack.pop()
        count += 1
        if count == k:
            return current.val
        current = current.right

    return -1  # k is larger than tree size

# Test
#       3
#      / \
#     1   4
#      \
#       2
root = TreeNode(3, TreeNode(1, None, TreeNode(2)), TreeNode(4))
print(kth_smallest(root, 1))  # 1
print(kth_smallest(root, 3))  # 3

Problem 3: Lowest Common Ancestor of BST

🎯
Problem: Given a BST and two nodes p and q, find their lowest common ancestor (LCA).
ML Context: In hierarchical classification (taxonomy trees), the LCA of two categories tells you the most specific shared category. This is used in label hierarchy-aware loss functions and ontology-based feature engineering.
def lca_bst(root: TreeNode, p: int, q: int) -> int:
    """Find LCA in BST using the BST property.
    Key insight: LCA is the first node where p and q diverge
    (one goes left, one goes right), or where the node equals p or q.
    Time: O(h), Space: O(1).
    """
    current = root
    while current:
        if p < current.val and q < current.val:
            current = current.left    # Both in left subtree
        elif p > current.val and q > current.val:
            current = current.right   # Both in right subtree
        else:
            return current.val        # Split point = LCA

# General binary tree LCA (works without BST property)
def lca_general(root: TreeNode, p: int, q: int) -> TreeNode:
    """Find LCA in any binary tree.
    Time: O(n), Space: O(h).
    """
    if not root or root.val == p or root.val == q:
        return root
    left = lca_general(root.left, p, q)
    right = lca_general(root.right, p, q)
    if left and right:
        return root  # p and q are in different subtrees
    return left if left else right

# Test
#       6
#      / \
#     2   8
#    / \ / \
#   0  4 7  9
root = TreeNode(6,
    TreeNode(2, TreeNode(0), TreeNode(4)),
    TreeNode(8, TreeNode(7), TreeNode(9))
)
print(lca_bst(root, 2, 8))  # 6
print(lca_bst(root, 2, 4))  # 2

Problem 4: Insert and Delete in BST

🎯
Problem: Implement insert and delete operations for a BST.
ML Context: Online learning algorithms update decision trees incrementally as new data arrives. Insert adds new split points; delete removes splits that are no longer discriminative. Understanding these operations is key to building streaming ML systems.
def insert_bst(root: TreeNode, val: int) -> TreeNode:
    """Insert a value into BST and return the root.
    Time: O(h), Space: O(h) for recursion.
    """
    if not root:
        return TreeNode(val)
    if val < root.val:
        root.left = insert_bst(root.left, val)
    elif val > root.val:
        root.right = insert_bst(root.right, val)
    return root

def delete_bst(root: TreeNode, val: int) -> TreeNode:
    """Delete a value from BST and return the root.
    Three cases:
      1. Leaf node: just remove
      2. One child: replace with child
      3. Two children: replace with inorder successor (smallest in right subtree)
    Time: O(h), Space: O(h).
    """
    if not root:
        return None

    if val < root.val:
        root.left = delete_bst(root.left, val)
    elif val > root.val:
        root.right = delete_bst(root.right, val)
    else:
        # Found the node to delete
        if not root.left:
            return root.right
        if not root.right:
            return root.left
        # Two children: find inorder successor
        successor = root.right
        while successor.left:
            successor = successor.left
        root.val = successor.val
        root.right = delete_bst(root.right, successor.val)

    return root

# Helper to verify with inorder traversal
def inorder(root):
    if not root:
        return []
    return inorder(root.left) + [root.val] + inorder(root.right)

# Test
root = TreeNode(5, TreeNode(3, TreeNode(2), TreeNode(4)), TreeNode(7, TreeNode(6), TreeNode(8)))
print(inorder(root))                    # [2, 3, 4, 5, 6, 7, 8]
root = insert_bst(root, 1)
print(inorder(root))                    # [1, 2, 3, 4, 5, 6, 7, 8]
root = delete_bst(root, 5)
print(inorder(root))                    # [1, 2, 3, 4, 6, 7, 8]

Problem 5: Convert Sorted Array to BST

🎯
Problem: Given an integer array sorted in ascending order, convert it to a height-balanced BST.
ML Context: Building balanced BSTs from sorted data is the foundation of KD-tree construction for nearest neighbor search. KD-trees use the median element at each level (the middle of the sorted array) to partition the space, exactly like this algorithm.
def sorted_array_to_bst(nums: list) -> TreeNode:
    """Convert sorted array to height-balanced BST.
    Key insight: always pick the middle element as root to ensure balance.
    Time: O(n), Space: O(log n) for recursion stack.
    """
    if not nums:
        return None

    mid = len(nums) // 2
    root = TreeNode(nums[mid])
    root.left = sorted_array_to_bst(nums[:mid])
    root.right = sorted_array_to_bst(nums[mid + 1:])
    return root

def sorted_array_to_bst_optimal(nums: list) -> TreeNode:
    """Optimized version avoiding array slicing with index tracking.
    Time: O(n), Space: O(log n).
    """
    def build(left, right):
        if left > right:
            return None
        mid = (left + right) // 2
        node = TreeNode(nums[mid])
        node.left = build(left, mid - 1)
        node.right = build(mid + 1, right)
        return node

    return build(0, len(nums) - 1)

# Test
nums = [-10, -3, 0, 5, 9]
root = sorted_array_to_bst(nums)
print(inorder(root))  # [-10, -3, 0, 5, 9] (sorted = valid BST)

# ML Application: KD-Tree construction sketch
def build_kd_tree(points, depth=0):
    """Simplified KD-tree: alternates splitting axis at each level.
    This is how sklearn.neighbors.KDTree works internally.
    """
    if not points:
        return None
    k = len(points[0])  # number of dimensions
    axis = depth % k
    points.sort(key=lambda p: p[axis])
    mid = len(points) // 2
    return {
        'point': points[mid],
        'left': build_kd_tree(points[:mid], depth + 1),
        'right': build_kd_tree(points[mid + 1:], depth + 1)
    }
💡
BST pattern summary: The key insight for all BST problems is exploiting the sorted property. If you need sorted order, use inorder traversal. If you need to search, follow the BST path (go left if smaller, right if larger). If you need to build, always pick the median to stay balanced.