Advanced Data Structures
Advanced data structures are the secret weapon of competitive programmers. They turn O(N) or O(N^2) operations into O(log N) queries, which is the difference between TLE and AC in contests. This lesson covers five essential data structures with ML-context problems and complete implementations.
Problem 1: Segment Tree — Real-Time Metric Range Queries
Input: N (1 ≤ N ≤ 10^5), Q operations (Q ≤ 10^5)
Output: Answer for each QUERY operation
Approach: Segment Tree with Point Update and Range Min Query
class SegmentTree:
"""Segment tree for range minimum queries with point updates.
Build: O(N), Query: O(log N), Update: O(log N), Space: O(N)
"""
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self._build(arr, 1, 0, self.n - 1)
def _build(self, arr, node, start, end):
if start == end:
self.tree[node] = arr[start]
return
mid = (start + end) // 2
self._build(arr, 2 * node, start, mid)
self._build(arr, 2 * node + 1, mid + 1, end)
self.tree[node] = min(self.tree[2 * node], self.tree[2 * node + 1])
def update(self, idx, val, node=1, start=0, end=None):
"""Point update: set arr[idx] = val"""
if end is None:
end = self.n - 1
if start == end:
self.tree[node] = val
return
mid = (start + end) // 2
if idx <= mid:
self.update(idx, val, 2 * node, start, mid)
else:
self.update(idx, val, 2 * node + 1, mid + 1, end)
self.tree[node] = min(self.tree[2 * node], self.tree[2 * node + 1])
def query(self, l, r, node=1, start=0, end=None):
"""Range minimum query: min(arr[l..r])"""
if end is None:
end = self.n - 1
if r < start or end < l:
return float('inf')
if l <= start and end <= r:
return self.tree[node]
mid = (start + end) // 2
left_min = self.query(l, r, 2 * node, start, mid)
right_min = self.query(l, r, 2 * node + 1, mid + 1, end)
return min(left_min, right_min)
# Example: ML metrics monitoring
losses = [0.95, 0.82, 0.71, 0.65, 0.58, 0.52, 0.48, 0.45]
st = SegmentTree(losses)
print(f"Min loss in epochs 0-3: {st.query(0, 3)}") # 0.65
print(f"Min loss in epochs 4-7: {st.query(4, 7)}") # 0.45
print(f"Min loss in epochs 2-5: {st.query(2, 5)}") # 0.52
# Update: epoch 3 had a spike
st.update(3, 0.90)
print(f"After spike, min loss 0-3: {st.query(0, 3)}") # 0.71
Complexity: Build O(N), Query O(log N), Update O(log N), Space O(N).
Problem 2: BIT/Fenwick Tree — Cumulative Training Statistics
Input: N batches (1 ≤ N ≤ 10^5), Q operations (Q ≤ 10^5)
Output: Answer for each sum query
Approach: Binary Indexed Tree (Fenwick Tree)
The Fenwick tree stores partial sums in a clever binary scheme. Each index i is responsible for the sum of elements in a range determined by the lowest set bit of i. This gives O(log N) for both update and prefix sum with only N+1 space and very low constant factors.
class FenwickTree:
"""Binary Indexed Tree for prefix sum queries and point updates.
Update: O(log N), Query: O(log N), Space: O(N)
Advantages over segment tree: 2x less memory, simpler code, lower constant.
"""
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 1) # 1-indexed
def update(self, i, delta):
"""Add delta to position i (1-indexed)"""
while i <= self.n:
self.tree[i] += delta
i += i & (-i) # Add lowest set bit
def prefix_sum(self, i):
"""Sum of elements from 1 to i (1-indexed)"""
s = 0
while i > 0:
s += self.tree[i]
i -= i & (-i) # Remove lowest set bit
return s
def range_sum(self, l, r):
"""Sum of elements from l to r (1-indexed)"""
return self.prefix_sum(r) - self.prefix_sum(l - 1)
@staticmethod
def from_array(arr):
"""Build BIT from array in O(N)"""
n = len(arr)
bit = FenwickTree(n)
for i, val in enumerate(arr):
bit.update(i + 1, val)
return bit
# Example: Track training batch statistics
batch_sizes = [32, 64, 32, 128, 64, 32, 64, 128]
bit = FenwickTree.from_array(batch_sizes)
print(f"Total samples processed (batches 1-4): {bit.prefix_sum(4)}") # 256
print(f"Total samples processed (batches 1-8): {bit.prefix_sum(8)}") # 544
print(f"Samples in batches 3-6: {bit.range_sum(3, 6)}") # 256
# Dynamic update: batch 5 was reprocessed with 128 samples
bit.update(5, 64) # Add 64 more (was 64, now effectively 128)
print(f"After reprocessing, batches 5-8: {bit.range_sum(5, 8)}") # 352
# Counting inversions (classic CP problem) using BIT
def count_inversions(arr):
"""Count inversions in array: O(N log N)"""
# Coordinate compression
sorted_unique = sorted(set(arr))
rank = {v: i + 1 for i, v in enumerate(sorted_unique)}
bit = FenwickTree(len(sorted_unique))
inversions = 0
for i in range(len(arr) - 1, -1, -1):
r = rank[arr[i]]
inversions += bit.prefix_sum(r - 1)
bit.update(r, 1)
return inversions
print(f"\nInversions in [3,1,2]: {count_inversions([3, 1, 2])}") # 2
print(f"Inversions in [5,4,3,2,1]: {count_inversions([5, 4, 3, 2, 1])}") # 10
Complexity: O(log N) per update and query. O(N) space. Approximately 2x faster than segment tree in practice due to simpler operations and better cache behavior.
Problem 3: Disjoint Set Union — Feature Clustering
Input: N features (1 ≤ N ≤ 10^6), Q operations (Q ≤ 10^6)
Output: Results for FIND and SIZE queries
Approach: Union-Find with Path Compression and Union by Rank
class DSU:
"""Disjoint Set Union with path compression and union by rank.
Almost O(1) amortized per operation (inverse Ackermann).
"""
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.size = [1] * n
self.components = n
def find(self, x):
"""Find root with path compression"""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
"""Union by rank, return True if merged (were different)"""
rx, ry = self.find(x), self.find(y)
if rx == ry:
return False
# Attach smaller tree under larger tree
if self.rank[rx] < self.rank[ry]:
rx, ry = ry, rx
self.parent[ry] = rx
self.size[rx] += self.size[ry]
if self.rank[rx] == self.rank[ry]:
self.rank[rx] += 1
self.components -= 1
return True
def connected(self, x, y):
return self.find(x) == self.find(y)
def get_size(self, x):
return self.size[self.find(x)]
# Example: Feature clustering based on correlation discovery
features = ["age", "income", "height", "weight", "education",
"experience", "salary", "BMI"]
n = len(features)
dsu = DSU(n)
# Discover correlations
correlations = [
(0, 1, "age-income"), # age correlates with income
(2, 3, "height-weight"), # height correlates with weight
(1, 4, "income-education"), # income correlates with education
(5, 6, "experience-salary"),# experience correlates with salary
(3, 7, "weight-BMI"), # weight correlates with BMI
(1, 6, "income-salary"), # income correlates with salary
]
for a, b, name in correlations:
dsu.union(a, b)
print(f"Merged {name}: {dsu.components} clusters remaining")
print(f"\nFinal clusters:")
clusters = {}
for i in range(n):
root = dsu.find(i)
if root not in clusters:
clusters[root] = []
clusters[root].append(features[i])
for cluster in clusters.values():
print(f" {cluster} (size: {len(cluster)})")
# Cluster 1: [age, income, education, experience, salary] (size 5)
# Cluster 2: [height, weight, BMI] (size 3)
Complexity: Nearly O(1) amortized per operation with both path compression and union by rank (inverse Ackermann function). O(N) space.
Problem 4: Sparse Table — Immutable Range Minimum for Batch Metrics
Input: N values (1 ≤ N ≤ 10^6), Q queries (Q ≤ 10^6)
Output: Minimum value for each query
Approach: Sparse Table
Sparse table preprocesses all ranges of length 2^k for each starting position. Any range [l, r] can be answered by overlapping two precomputed ranges of length 2^k where k = floor(log2(r - l + 1)). Since min is idempotent (min(a, a) = a), overlapping is safe.
from math import log2, floor
class SparseTable:
"""Sparse table for O(1) range minimum queries on static arrays.
Build: O(N log N), Query: O(1), Space: O(N log N)
"""
def __init__(self, arr):
self.n = len(arr)
self.LOG = floor(log2(self.n)) + 1 if self.n > 0 else 0
self.table = [[0] * self.n for _ in range(self.LOG)]
# Base case: ranges of length 1
self.table[0] = arr[:]
# Fill table: table[k][i] = min of arr[i..i+2^k-1]
for k in range(1, self.LOG):
for i in range(self.n - (1 << k) + 1):
self.table[k][i] = min(
self.table[k-1][i],
self.table[k-1][i + (1 << (k-1))]
)
# Precompute log values for O(1) query
self.log_table = [0] * (self.n + 1)
for i in range(2, self.n + 1):
self.log_table[i] = self.log_table[i // 2] + 1
def query(self, l, r):
"""Range minimum query in O(1)"""
length = r - l + 1
k = self.log_table[length]
return min(self.table[k][l], self.table[k][r - (1 << k) + 1])
# Example: Static training metrics analysis
metrics = [0.95, 0.82, 0.71, 0.65, 0.58, 0.52, 0.48, 0.45,
0.43, 0.42, 0.50, 0.47, 0.44, 0.41, 0.40, 0.39]
st = SparseTable(metrics)
# O(1) range minimum queries
print(f"Min loss epochs 0-7: {st.query(0, 7)}") # 0.45
print(f"Min loss epochs 8-15: {st.query(8, 15)}") # 0.39
print(f"Min loss epochs 4-11: {st.query(4, 11)}") # 0.42
print(f"Min loss overall: {st.query(0, 15)}") # 0.39
# Benchmark: 10^6 queries on 10^6 elements = instant
import time
big_arr = list(range(100000, 0, -1))
big_st = SparseTable(big_arr)
start = time.time()
for i in range(100000):
big_st.query(i % 50000, i % 50000 + 50000)
elapsed = time.time() - start
print(f"\n100K queries on 100K elements: {elapsed:.3f}s")
Complexity: O(N log N) build time and space. O(1) per query. Optimal for static arrays with many queries.
Problem 5: Balanced BST — Dynamic Median Tracking for Online Learning
Input: Q operations (Q ≤ 10^5)
Output: Median after each INSERT
Approach: Two Heaps (Simulating Balanced BST)
Maintain a max-heap for the lower half and a min-heap for the upper half. The median is always at the top of one of the heaps. This is the standard approach in competitive programming since Python does not have a built-in balanced BST.
import heapq
class MedianTracker:
"""Track running median using two heaps.
Insert: O(log N), Median: O(1)
"""
def __init__(self):
self.lo = [] # max-heap (negate values) for lower half
self.hi = [] # min-heap for upper half
def insert(self, val):
"""Insert value and rebalance"""
# Push to max-heap (lower half)
heapq.heappush(self.lo, -val)
# Ensure max of lower <= min of upper
if self.hi and -self.lo[0] > self.hi[0]:
lo_max = -heapq.heappop(self.lo)
hi_min = heapq.heappop(self.hi)
heapq.heappush(self.lo, -hi_min)
heapq.heappush(self.hi, lo_max)
# Balance sizes: lo can have at most 1 more than hi
if len(self.lo) > len(self.hi) + 1:
heapq.heappush(self.hi, -heapq.heappop(self.lo))
elif len(self.hi) > len(self.lo):
heapq.heappush(self.lo, -heapq.heappop(self.hi))
def median(self):
"""Get current median in O(1)"""
if len(self.lo) > len(self.hi):
return -self.lo[0]
return (-self.lo[0] + self.hi[0]) / 2
# Example: Online anomaly detection via median tracking
tracker = MedianTracker()
data_stream = [0.45, 0.82, 0.31, 0.95, 0.67, 0.12, 0.88, 0.54]
print("Online median tracking:")
for val in data_stream:
tracker.insert(val)
median = tracker.median()
print(f" Inserted {val:.2f} -> Median: {median:.3f}")
# Anomaly detection: flag if value > 2 * median
print("\nAnomaly detection:")
tracker2 = MedianTracker()
monitoring = [0.5, 0.48, 0.52, 0.49, 0.51, 2.5, 0.50, 0.47]
for val in monitoring:
tracker2.insert(val)
med = tracker2.median()
flag = " ** ANOMALY **" if val > 2 * med else ""
print(f" Value: {val:.2f}, Median: {med:.3f}{flag}")
Complexity: O(log N) insert, O(1) median query. O(N) space.
Key Takeaways
- Segment trees support O(log N) range queries and point updates — the most versatile data structure in competitive programming.
- Fenwick trees (BIT) are simpler and faster than segment trees for prefix sum operations, using half the memory.
- Disjoint Set Union with path compression and union by rank achieves nearly O(1) per operation — essential for connectivity and clustering problems.
- Sparse tables provide O(1) range minimum queries on static arrays after O(N log N) preprocessing.
- The two-heap median tracker gives O(log N) insert and O(1) median — the standard approach for online median tracking.
Lilly Tech Systems