Advanced

Advanced Data Structures

Advanced data structures are the secret weapon of competitive programmers. They turn O(N) or O(N^2) operations into O(log N) queries, which is the difference between TLE and AC in contests. This lesson covers five essential data structures with ML-context problems and complete implementations.

Problem 1: Segment Tree — Real-Time Metric Range Queries

📝
Problem Statement: Your ML monitoring dashboard receives a stream of metric values (loss, accuracy, etc.). Support two operations: (1) UPDATE: set metric at timestamp i to value v, (2) QUERY: find the minimum metric value in the timestamp range [l, r]. Process Q operations on an array of N timestamps.

Input: N (1 ≤ N ≤ 10^5), Q operations (Q ≤ 10^5)
Output: Answer for each QUERY operation

Approach: Segment Tree with Point Update and Range Min Query

class SegmentTree:
    """Segment tree for range minimum queries with point updates.
    Build: O(N), Query: O(log N), Update: O(log N), Space: O(N)
    """
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)
        self._build(arr, 1, 0, self.n - 1)

    def _build(self, arr, node, start, end):
        if start == end:
            self.tree[node] = arr[start]
            return
        mid = (start + end) // 2
        self._build(arr, 2 * node, start, mid)
        self._build(arr, 2 * node + 1, mid + 1, end)
        self.tree[node] = min(self.tree[2 * node], self.tree[2 * node + 1])

    def update(self, idx, val, node=1, start=0, end=None):
        """Point update: set arr[idx] = val"""
        if end is None:
            end = self.n - 1
        if start == end:
            self.tree[node] = val
            return
        mid = (start + end) // 2
        if idx <= mid:
            self.update(idx, val, 2 * node, start, mid)
        else:
            self.update(idx, val, 2 * node + 1, mid + 1, end)
        self.tree[node] = min(self.tree[2 * node], self.tree[2 * node + 1])

    def query(self, l, r, node=1, start=0, end=None):
        """Range minimum query: min(arr[l..r])"""
        if end is None:
            end = self.n - 1
        if r < start or end < l:
            return float('inf')
        if l <= start and end <= r:
            return self.tree[node]
        mid = (start + end) // 2
        left_min = self.query(l, r, 2 * node, start, mid)
        right_min = self.query(l, r, 2 * node + 1, mid + 1, end)
        return min(left_min, right_min)

# Example: ML metrics monitoring
losses = [0.95, 0.82, 0.71, 0.65, 0.58, 0.52, 0.48, 0.45]
st = SegmentTree(losses)

print(f"Min loss in epochs 0-3: {st.query(0, 3)}")    # 0.65
print(f"Min loss in epochs 4-7: {st.query(4, 7)}")    # 0.45
print(f"Min loss in epochs 2-5: {st.query(2, 5)}")    # 0.52

# Update: epoch 3 had a spike
st.update(3, 0.90)
print(f"After spike, min loss 0-3: {st.query(0, 3)}")  # 0.71

Complexity: Build O(N), Query O(log N), Update O(log N), Space O(N).

Problem 2: BIT/Fenwick Tree — Cumulative Training Statistics

📝
Problem Statement: Track cumulative statistics across training batches. Support: (1) ADD value v to batch i, (2) PREFIX_SUM: compute sum of values from batch 1 to batch i, (3) RANGE_SUM: compute sum of values from batch l to batch r. Process Q operations efficiently.

Input: N batches (1 ≤ N ≤ 10^5), Q operations (Q ≤ 10^5)
Output: Answer for each sum query

Approach: Binary Indexed Tree (Fenwick Tree)

The Fenwick tree stores partial sums in a clever binary scheme. Each index i is responsible for the sum of elements in a range determined by the lowest set bit of i. This gives O(log N) for both update and prefix sum with only N+1 space and very low constant factors.

class FenwickTree:
    """Binary Indexed Tree for prefix sum queries and point updates.
    Update: O(log N), Query: O(log N), Space: O(N)
    Advantages over segment tree: 2x less memory, simpler code, lower constant.
    """
    def __init__(self, n):
        self.n = n
        self.tree = [0] * (n + 1)  # 1-indexed

    def update(self, i, delta):
        """Add delta to position i (1-indexed)"""
        while i <= self.n:
            self.tree[i] += delta
            i += i & (-i)  # Add lowest set bit

    def prefix_sum(self, i):
        """Sum of elements from 1 to i (1-indexed)"""
        s = 0
        while i > 0:
            s += self.tree[i]
            i -= i & (-i)  # Remove lowest set bit
        return s

    def range_sum(self, l, r):
        """Sum of elements from l to r (1-indexed)"""
        return self.prefix_sum(r) - self.prefix_sum(l - 1)

    @staticmethod
    def from_array(arr):
        """Build BIT from array in O(N)"""
        n = len(arr)
        bit = FenwickTree(n)
        for i, val in enumerate(arr):
            bit.update(i + 1, val)
        return bit

# Example: Track training batch statistics
batch_sizes = [32, 64, 32, 128, 64, 32, 64, 128]
bit = FenwickTree.from_array(batch_sizes)

print(f"Total samples processed (batches 1-4): {bit.prefix_sum(4)}")   # 256
print(f"Total samples processed (batches 1-8): {bit.prefix_sum(8)}")   # 544
print(f"Samples in batches 3-6: {bit.range_sum(3, 6)}")               # 256

# Dynamic update: batch 5 was reprocessed with 128 samples
bit.update(5, 64)  # Add 64 more (was 64, now effectively 128)
print(f"After reprocessing, batches 5-8: {bit.range_sum(5, 8)}")       # 352

# Counting inversions (classic CP problem) using BIT
def count_inversions(arr):
    """Count inversions in array: O(N log N)"""
    # Coordinate compression
    sorted_unique = sorted(set(arr))
    rank = {v: i + 1 for i, v in enumerate(sorted_unique)}

    bit = FenwickTree(len(sorted_unique))
    inversions = 0

    for i in range(len(arr) - 1, -1, -1):
        r = rank[arr[i]]
        inversions += bit.prefix_sum(r - 1)
        bit.update(r, 1)

    return inversions

print(f"\nInversions in [3,1,2]: {count_inversions([3, 1, 2])}")  # 2
print(f"Inversions in [5,4,3,2,1]: {count_inversions([5, 4, 3, 2, 1])}")  # 10

Complexity: O(log N) per update and query. O(N) space. Approximately 2x faster than segment tree in practice due to simpler operations and better cache behavior.

💡
When to use BIT vs Segment Tree: Use BIT when you only need prefix sums and point updates (most common case). Use segment tree when you need range updates, range minimum/maximum, or more complex operations. BIT is simpler to code and faster, but less flexible.

Problem 3: Disjoint Set Union — Feature Clustering

📝
Problem Statement: You have N features in a dataset. As you discover correlations between features, you merge them into clusters. Support: (1) UNION(a, b): merge the clusters containing features a and b, (2) FIND(a): return the cluster ID of feature a, (3) SIZE(a): return the number of features in a's cluster. Process Q operations.

Input: N features (1 ≤ N ≤ 10^6), Q operations (Q ≤ 10^6)
Output: Results for FIND and SIZE queries

Approach: Union-Find with Path Compression and Union by Rank

class DSU:
    """Disjoint Set Union with path compression and union by rank.
    Almost O(1) amortized per operation (inverse Ackermann).
    """
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.size = [1] * n
        self.components = n

    def find(self, x):
        """Find root with path compression"""
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        """Union by rank, return True if merged (were different)"""
        rx, ry = self.find(x), self.find(y)
        if rx == ry:
            return False

        # Attach smaller tree under larger tree
        if self.rank[rx] < self.rank[ry]:
            rx, ry = ry, rx
        self.parent[ry] = rx
        self.size[rx] += self.size[ry]
        if self.rank[rx] == self.rank[ry]:
            self.rank[rx] += 1

        self.components -= 1
        return True

    def connected(self, x, y):
        return self.find(x) == self.find(y)

    def get_size(self, x):
        return self.size[self.find(x)]

# Example: Feature clustering based on correlation discovery
features = ["age", "income", "height", "weight", "education",
            "experience", "salary", "BMI"]
n = len(features)
dsu = DSU(n)

# Discover correlations
correlations = [
    (0, 1, "age-income"),       # age correlates with income
    (2, 3, "height-weight"),    # height correlates with weight
    (1, 4, "income-education"), # income correlates with education
    (5, 6, "experience-salary"),# experience correlates with salary
    (3, 7, "weight-BMI"),       # weight correlates with BMI
    (1, 6, "income-salary"),    # income correlates with salary
]

for a, b, name in correlations:
    dsu.union(a, b)
    print(f"Merged {name}: {dsu.components} clusters remaining")

print(f"\nFinal clusters:")
clusters = {}
for i in range(n):
    root = dsu.find(i)
    if root not in clusters:
        clusters[root] = []
    clusters[root].append(features[i])

for cluster in clusters.values():
    print(f"  {cluster} (size: {len(cluster)})")

# Cluster 1: [age, income, education, experience, salary] (size 5)
# Cluster 2: [height, weight, BMI] (size 3)

Complexity: Nearly O(1) amortized per operation with both path compression and union by rank (inverse Ackermann function). O(N) space.

Problem 4: Sparse Table — Immutable Range Minimum for Batch Metrics

📝
Problem Statement: Given a fixed array of N training metrics (no updates), answer Q range minimum queries. Each query asks for the minimum value in arr[l..r]. Since there are no updates, we can preprocess for O(1) query time.

Input: N values (1 ≤ N ≤ 10^6), Q queries (Q ≤ 10^6)
Output: Minimum value for each query

Approach: Sparse Table

Sparse table preprocesses all ranges of length 2^k for each starting position. Any range [l, r] can be answered by overlapping two precomputed ranges of length 2^k where k = floor(log2(r - l + 1)). Since min is idempotent (min(a, a) = a), overlapping is safe.

from math import log2, floor

class SparseTable:
    """Sparse table for O(1) range minimum queries on static arrays.
    Build: O(N log N), Query: O(1), Space: O(N log N)
    """
    def __init__(self, arr):
        self.n = len(arr)
        self.LOG = floor(log2(self.n)) + 1 if self.n > 0 else 0
        self.table = [[0] * self.n for _ in range(self.LOG)]

        # Base case: ranges of length 1
        self.table[0] = arr[:]

        # Fill table: table[k][i] = min of arr[i..i+2^k-1]
        for k in range(1, self.LOG):
            for i in range(self.n - (1 << k) + 1):
                self.table[k][i] = min(
                    self.table[k-1][i],
                    self.table[k-1][i + (1 << (k-1))]
                )

        # Precompute log values for O(1) query
        self.log_table = [0] * (self.n + 1)
        for i in range(2, self.n + 1):
            self.log_table[i] = self.log_table[i // 2] + 1

    def query(self, l, r):
        """Range minimum query in O(1)"""
        length = r - l + 1
        k = self.log_table[length]
        return min(self.table[k][l], self.table[k][r - (1 << k) + 1])

# Example: Static training metrics analysis
metrics = [0.95, 0.82, 0.71, 0.65, 0.58, 0.52, 0.48, 0.45,
           0.43, 0.42, 0.50, 0.47, 0.44, 0.41, 0.40, 0.39]

st = SparseTable(metrics)

# O(1) range minimum queries
print(f"Min loss epochs 0-7:  {st.query(0, 7)}")    # 0.45
print(f"Min loss epochs 8-15: {st.query(8, 15)}")   # 0.39
print(f"Min loss epochs 4-11: {st.query(4, 11)}")   # 0.42
print(f"Min loss overall:     {st.query(0, 15)}")   # 0.39

# Benchmark: 10^6 queries on 10^6 elements = instant
import time
big_arr = list(range(100000, 0, -1))
big_st = SparseTable(big_arr)
start = time.time()
for i in range(100000):
    big_st.query(i % 50000, i % 50000 + 50000)
elapsed = time.time() - start
print(f"\n100K queries on 100K elements: {elapsed:.3f}s")

Complexity: O(N log N) build time and space. O(1) per query. Optimal for static arrays with many queries.

Problem 5: Balanced BST — Dynamic Median Tracking for Online Learning

📝
Problem Statement: In online learning, you receive data points one at a time and need to track the running median for anomaly detection. Support: (1) INSERT a new value, (2) MEDIAN: return the current median of all inserted values. Both operations must be O(log N).

Input: Q operations (Q ≤ 10^5)
Output: Median after each INSERT

Approach: Two Heaps (Simulating Balanced BST)

Maintain a max-heap for the lower half and a min-heap for the upper half. The median is always at the top of one of the heaps. This is the standard approach in competitive programming since Python does not have a built-in balanced BST.

import heapq

class MedianTracker:
    """Track running median using two heaps.
    Insert: O(log N), Median: O(1)
    """
    def __init__(self):
        self.lo = []  # max-heap (negate values) for lower half
        self.hi = []  # min-heap for upper half

    def insert(self, val):
        """Insert value and rebalance"""
        # Push to max-heap (lower half)
        heapq.heappush(self.lo, -val)

        # Ensure max of lower <= min of upper
        if self.hi and -self.lo[0] > self.hi[0]:
            lo_max = -heapq.heappop(self.lo)
            hi_min = heapq.heappop(self.hi)
            heapq.heappush(self.lo, -hi_min)
            heapq.heappush(self.hi, lo_max)

        # Balance sizes: lo can have at most 1 more than hi
        if len(self.lo) > len(self.hi) + 1:
            heapq.heappush(self.hi, -heapq.heappop(self.lo))
        elif len(self.hi) > len(self.lo):
            heapq.heappush(self.lo, -heapq.heappop(self.hi))

    def median(self):
        """Get current median in O(1)"""
        if len(self.lo) > len(self.hi):
            return -self.lo[0]
        return (-self.lo[0] + self.hi[0]) / 2

# Example: Online anomaly detection via median tracking
tracker = MedianTracker()
data_stream = [0.45, 0.82, 0.31, 0.95, 0.67, 0.12, 0.88, 0.54]

print("Online median tracking:")
for val in data_stream:
    tracker.insert(val)
    median = tracker.median()
    print(f"  Inserted {val:.2f} -> Median: {median:.3f}")

# Anomaly detection: flag if value > 2 * median
print("\nAnomaly detection:")
tracker2 = MedianTracker()
monitoring = [0.5, 0.48, 0.52, 0.49, 0.51, 2.5, 0.50, 0.47]
for val in monitoring:
    tracker2.insert(val)
    med = tracker2.median()
    flag = " ** ANOMALY **" if val > 2 * med else ""
    print(f"  Value: {val:.2f}, Median: {med:.3f}{flag}")

Complexity: O(log N) insert, O(1) median query. O(N) space.

💡
Contest tip: The two-heap median trick is one of the most frequently tested data structure problems. In Python, remember that heapq only provides min-heap. For max-heap, negate all values. For balanced BST operations (order statistics, rank queries), use the SortedList from the sortedcontainers library, which is allowed on most competitive programming judges that support Python.

Key Takeaways

  • Segment trees support O(log N) range queries and point updates — the most versatile data structure in competitive programming.
  • Fenwick trees (BIT) are simpler and faster than segment trees for prefix sum operations, using half the memory.
  • Disjoint Set Union with path compression and union by rank achieves nearly O(1) per operation — essential for connectivity and clustering problems.
  • Sparse tables provide O(1) range minimum queries on static arrays after O(N log N) preprocessing.
  • The two-heap median tracker gives O(log N) insert and O(1) median — the standard approach for online median tracking.