Advanced

DP Optimization Techniques

Four advanced techniques that take your DP solutions from good to optimal: space optimization for memory reduction, bitmask DP for subset problems, digit DP for counting problems, and DP on trees for hierarchical structures.

1. Space Optimization

Many 2D DP problems only need the current row and the previous row. By recognizing this dependency pattern, you can reduce space from O(m*n) to O(n) or even O(1).

Pattern: Rolling Array

When dp[i][j] depends only on dp[i-1][...] (previous row), keep only two rows.

# Example: Edit Distance - from O(m*n) to O(n) space
def edit_distance_optimized(word1, word2):
    m, n = len(word1), len(word2)
    if m < n:  # Ensure we use the shorter string for the 1D array
        return edit_distance_optimized(word2, word1)

    prev = list(range(n + 1))
    for i in range(1, m + 1):
        curr = [i] + [0] * n
        for j in range(1, n + 1):
            if word1[i-1] == word2[j-1]:
                curr[j] = prev[j-1]
            else:
                curr[j] = 1 + min(prev[j], curr[j-1], prev[j-1])
        prev = curr
    return prev[n]

# Example: LCS - from O(m*n) to O(min(m,n)) space
def lcs_optimized(text1, text2):
    if len(text1) < len(text2):
        text1, text2 = text2, text1
    n = len(text2)
    prev = [0] * (n + 1)
    for i in range(len(text1)):
        curr = [0] * (n + 1)
        for j in range(n):
            if text1[i] == text2[j]:
                curr[j+1] = prev[j] + 1
            else:
                curr[j+1] = max(prev[j+1], curr[j])
        prev = curr
    return prev[n]

Pattern: Two Variables

When dp[i] depends only on dp[i-1] and dp[i-2], use two variables instead of an array.

# Already shown: Fibonacci, Climbing Stairs, House Robber
# General pattern:
def two_var_template(n):
    prev2 = base_case_0
    prev1 = base_case_1
    for i in range(2, n + 1):
        curr = f(prev1, prev2)  # Your recurrence
        prev2 = prev1
        prev1 = curr
    return prev1
💡
Interview tip: Always mention space optimization as a follow-up improvement. Even if the interviewer does not ask for it, saying "we can reduce space from O(n) to O(1) by using two variables" demonstrates deeper understanding.

2. Bitmask DP

Bitmask DP uses an integer's binary representation to encode which elements of a set have been selected. Each bit position represents an element: 1 means selected, 0 means not selected. This is useful when n is small (typically n <= 20).

Example: Traveling Salesman Problem (TSP)

Problem: Given n cities and distances between them, find the shortest route that visits every city exactly once and returns to the starting city.

State: dp[mask][i] = minimum cost to visit all cities in the bitmask, ending at city i.

# TSP with Bitmask DP: O(2^n * n^2) time, O(2^n * n) space
def tsp(dist):
    n = len(dist)
    ALL_VISITED = (1 << n) - 1  # All bits set
    INF = float('inf')

    # dp[mask][i] = min cost to reach city i having visited cities in mask
    dp = [[INF] * n for _ in range(1 << n)]
    dp[1][0] = 0  # Start at city 0, only city 0 visited (mask = 1)

    for mask in range(1 << n):
        for u in range(n):
            if dp[mask][u] == INF:
                continue
            if not (mask & (1 << u)):
                continue  # u must be in the visited set
            # Try visiting each unvisited city
            for v in range(n):
                if mask & (1 << v):
                    continue  # v already visited
                new_mask = mask | (1 << v)
                dp[new_mask][v] = min(
                    dp[new_mask][v],
                    dp[mask][u] + dist[u][v]
                )

    # Return to starting city
    result = INF
    for u in range(n):
        if dp[ALL_VISITED][u] + dist[u][0] < result:
            result = dp[ALL_VISITED][u] + dist[u][0]
    return result

# Test:
dist = [
    [0, 10, 15, 20],
    [10, 0, 35, 25],
    [15, 35, 0, 30],
    [20, 25, 30, 0]
]
# tsp(dist) = 80 (0->1->3->2->0: 10+25+30+15=80)

Example: Assign Tasks to Workers

Problem: Given n workers and n tasks with a cost matrix, assign each task to exactly one worker to minimize total cost. (Hungarian algorithm solves this in O(n^3), but bitmask DP is simpler for n <= 20.)

# Assignment problem: O(2^n * n) time
def min_cost_assignment(cost):
    n = len(cost)
    dp = [float('inf')] * (1 << n)
    dp[0] = 0

    for mask in range(1 << n):
        # How many workers have been assigned so far?
        worker = bin(mask).count('1')
        if worker >= n:
            continue
        # Try assigning each unassigned task to this worker
        for task in range(n):
            if mask & (1 << task):
                continue  # Task already assigned
            new_mask = mask | (1 << task)
            dp[new_mask] = min(dp[new_mask],
                              dp[mask] + cost[worker][task])

    return dp[(1 << n) - 1]

# Test:
cost = [
    [9, 2, 7, 8],
    [6, 4, 3, 7],
    [5, 8, 1, 8],
    [7, 6, 9, 4]
]
# min_cost_assignment(cost) = 13 (worker0->task1, worker1->task2, worker2->task0... wait no)
# Optimal: worker0=task1(2), worker1=task2(3), worker2=task0(5), worker3=task3(4) = 14
# Actually: worker0=task1(2), worker1=task0(6), worker2=task2(1), worker3=task3(4) = 13

3. Digit DP

Digit DP counts numbers in a range [L, R] that satisfy some property by processing digits from most significant to least significant. The key idea is tracking whether we are still "tight" (bounded by the upper limit) or "free" (any digit is valid).

Example: Count Numbers with Digit Sum <= S

Problem: Count integers in [1, N] whose digits sum to at most S.

# Digit DP template
def count_digit_sum(N, S):
    """Count integers from 1 to N with digit sum <= S."""
    digits = [int(d) for d in str(N)]
    n = len(digits)
    memo = {}

    def dp(pos, digit_sum, tight, started):
        """
        pos: current digit position (0 = most significant)
        digit_sum: sum of digits placed so far
        tight: True if previous digits match N exactly (limits current digit)
        started: True if we have placed a non-zero digit (handles leading zeros)
        """
        if digit_sum > S:
            return 0
        if pos == n:
            return 1 if started else 0

        state = (pos, digit_sum, tight, started)
        if state in memo:
            return memo[state]

        limit = digits[pos] if tight else 9
        result = 0

        for d in range(0, limit + 1):
            result += dp(
                pos + 1,
                digit_sum + d,
                tight and (d == limit),
                started or (d > 0)
            )

        memo[state] = result
        return result

    return dp(0, 0, True, False)

# Count in range [L, R]: count(R) - count(L-1)
def count_in_range(L, R, S):
    return count_digit_sum(R, S) - count_digit_sum(L - 1, S)

# Test: count_digit_sum(100, 5) = 15
# Numbers 1-100 with digit sum <= 5:
# 1,2,3,4,5,10,11,12,13,14,20,21,22,23,30,31,32,40,41,50,100
# Wait, let's verify: 1,2,3,4,5 (5), 10,11,12,13,14 (5), 20,21,22,23 (4), 30,31,32 (3)
# 40,41 (2), 50 (1), 100 (1) = 5+5+4+3+2+1+1 = 21

4. DP on Trees

DP on trees processes subtrees bottom-up, computing results for children before parents. The state typically includes the node and whether we "include" or "exclude" it.

Example: Maximum Independent Set on a Tree

Problem: Select the maximum number of nodes from a tree such that no two selected nodes are adjacent.

# DP on trees: O(n) time, O(n) space
def max_independent_set(adj, root=0):
    """
    adj: adjacency list of an unrooted tree
    Returns maximum independent set size.
    """
    n = len(adj)
    # dp[node][0] = max independent set size in subtree if node is NOT selected
    # dp[node][1] = max independent set size in subtree if node IS selected
    dp = [[0, 0] for _ in range(n)]
    visited = [False] * n

    def dfs(node):
        visited[node] = True
        dp[node][1] = 1  # Select this node

        for child in adj[node]:
            if not visited[child]:
                dfs(child)
                # If node is NOT selected, children can be either
                dp[node][0] += max(dp[child][0], dp[child][1])
                # If node IS selected, children must NOT be selected
                dp[node][1] += dp[child][0]

    dfs(root)
    return max(dp[root][0], dp[root][1])

# Test:
# Tree:     0
#          / \
#         1   2
#        / \
#       3   4
adj = [
    [1, 2],  # 0 -> 1, 2
    [0, 3, 4],  # 1 -> 0, 3, 4
    [0],  # 2 -> 0
    [1],  # 3 -> 1
    [1],  # 4 -> 1
]
# max_independent_set(adj) = 3 (select nodes 2, 3, 4)

Example: House Robber III (Tree Version)

Problem: Nodes in a binary tree represent houses with money. Rob the maximum amount without robbing two directly connected houses.

# Binary tree node
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def rob_tree(root):
    """Returns (rob_this, skip_this) for each subtree."""
    def dfs(node):
        if not node:
            return (0, 0)  # (rob, skip)

        left = dfs(node.left)
        right = dfs(node.right)

        # Rob this node: cannot rob children
        rob = node.val + left[1] + right[1]
        # Skip this node: take the best from each child
        skip = max(left) + max(right)

        return (rob, skip)

    return max(dfs(root))

# Test:
#     3
#    / \
#   2   3
#    \   \
#     3   1
# rob_tree(root) = 7 (rob 3 + 3 + 1)
root = TreeNode(3,
    TreeNode(2, None, TreeNode(3)),
    TreeNode(3, None, TreeNode(1)))
# Result: 7

Example: Tree Diameter

Problem: Find the length of the longest path between any two nodes in a tree.

def tree_diameter(adj):
    """Find diameter of unrooted tree using DP."""
    n = len(adj)
    if n <= 1:
        return 0

    # dp[node] = length of longest path starting from node going downward
    dp = [0] * n
    diameter = [0]  # Use list to allow modification in nested function
    visited = [False] * n

    def dfs(node):
        visited[node] = True
        top_two = [0, 0]  # Two longest paths to children

        for child in adj[node]:
            if not visited[child]:
                dfs(child)
                child_depth = dp[child] + 1
                if child_depth > top_two[0]:
                    top_two[1] = top_two[0]
                    top_two[0] = child_depth
                elif child_depth > top_two[1]:
                    top_two[1] = child_depth

        dp[node] = top_two[0]
        # Diameter through this node = sum of two longest child paths
        diameter[0] = max(diameter[0], top_two[0] + top_two[1])

    dfs(0)
    return diameter[0]

# Test: Linear tree 0-1-2-3-4
adj = [[1], [0, 2], [1, 3], [2, 4], [3]]
# tree_diameter(adj) = 4

When to Use Each Technique

TechniqueUse WhenConstraintTime Impact
Space Optimization dp[i] depends on dp[i-1] only Cannot reconstruct solution path Same time, less space
Bitmask DP Subset selection problems n <= 20 (2^20 = 1M states) O(2^n * n) or O(2^n * n^2)
Digit DP Counting numbers with properties in [L, R] Numbers up to 10^18 O(digits * states)
DP on Trees Optimization on tree structures Input is a tree (n-1 edges) O(n) with DFS

Key Takeaways

  • Space optimization is applicable to most DP problems and should always be mentioned in interviews as a follow-up improvement.
  • Bitmask DP encodes subsets as integers. It is the standard approach for small-n subset problems (TSP, assignment, set cover) where n <= 20.
  • Digit DP processes numbers digit by digit, tracking whether we are still bounded by the upper limit. It handles ranges up to 10^18 efficiently.
  • DP on trees uses DFS to compute subtree results bottom-up. The include/exclude pattern from house robber generalizes to trees.
  • These techniques are composable: you can use bitmask DP on trees, or space-optimize a digit DP solution.