Advanced

Distributed Patterns

Distributed systems interviews increasingly ask you to implement the building blocks: consistent hashing for sharding, bloom filters for efficient membership testing, circuit breakers for fault tolerance, and retry with exponential backoff. These are the patterns that turn system design diagrams into working code.

Problem 1: Consistent Hashing

Implement a consistent hash ring for distributing keys across nodes. When nodes are added or removed, only K/N keys need to be remapped (where K = total keys, N = total nodes).

import hashlib
import bisect
from typing import List, Optional, Dict


class ConsistentHashRing:
    """Consistent hashing with virtual nodes.

    Maps keys to nodes using a hash ring. Virtual nodes (replicas)
    ensure even distribution across physical nodes.

    Key properties:
    - Adding/removing a node only moves ~K/N keys
    - Virtual nodes prevent hotspots from uneven hashing
    - O(log N) lookup using binary search on sorted ring

    Used by: DynamoDB, Cassandra, Memcached, load balancers
    """

    def __init__(self, num_replicas: int = 150):
        self._num_replicas = num_replicas
        self._ring: List[int] = []       # Sorted list of hash positions
        self._ring_map: Dict[int, str] = {}  # hash position -> node name
        self._nodes: set = set()

    def _hash(self, key: str) -> int:
        """Generate a consistent hash for a key."""
        return int(hashlib.md5(key.encode()).hexdigest(), 16)

    def add_node(self, node: str):
        """Add a node with virtual replicas to the ring."""
        if node in self._nodes:
            return

        self._nodes.add(node)

        for i in range(self._num_replicas):
            virtual_key = f"{node}:replica:{i}"
            hash_val = self._hash(virtual_key)
            bisect.insort(self._ring, hash_val)
            self._ring_map[hash_val] = node

    def remove_node(self, node: str):
        """Remove a node and all its virtual replicas."""
        if node not in self._nodes:
            return

        self._nodes.discard(node)

        for i in range(self._num_replicas):
            virtual_key = f"{node}:replica:{i}"
            hash_val = self._hash(virtual_key)
            self._ring.remove(hash_val)
            del self._ring_map[hash_val]

    def get_node(self, key: str) -> Optional[str]:
        """Find which node a key maps to.

        Walk clockwise on the ring from the key's hash position
        to find the first node.
        """
        if not self._ring:
            return None

        hash_val = self._hash(key)

        # Binary search for the first position >= hash_val
        idx = bisect.bisect_right(self._ring, hash_val)

        # Wrap around if we go past the end
        if idx == len(self._ring):
            idx = 0

        return self._ring_map[self._ring[idx]]

    def get_nodes_for_key(self, key: str, count: int = 3) -> List[str]:
        """Get multiple nodes for replication (walk clockwise)."""
        if not self._ring:
            return []

        hash_val = self._hash(key)
        idx = bisect.bisect_right(self._ring, hash_val)

        nodes = []
        seen = set()
        positions_checked = 0

        while len(nodes) < count and positions_checked < len(self._ring):
            actual_idx = (idx + positions_checked) % len(self._ring)
            node = self._ring_map[self._ring[actual_idx]]

            if node not in seen:
                nodes.append(node)
                seen.add(node)

            positions_checked += 1

        return nodes

    def get_distribution(self) -> Dict[str, int]:
        """Check key distribution across nodes (for testing)."""
        distribution = {node: 0 for node in self._nodes}
        for i in range(10000):
            node = self.get_node(f"key_{i}")
            if node:
                distribution[node] += 1
        return distribution


# ---- Usage ----
ring = ConsistentHashRing(num_replicas=150)

# Add nodes
for node in ["server-1", "server-2", "server-3"]:
    ring.add_node(node)

# Map keys to nodes
for key in ["user:100", "user:200", "user:300", "order:50", "session:abc"]:
    node = ring.get_node(key)
    replicas = ring.get_nodes_for_key(key, count=2)
    print(f"{key} -> primary: {node}, replicas: {replicas}")

# Check distribution
dist = ring.get_distribution()
print(f"\nDistribution across 10,000 keys:")
for node, count in sorted(dist.items()):
    print(f"  {node}: {count} keys ({count/100:.1f}%)")

# Remove a node (only ~1/3 of keys remap)
print(f"\nRemoving server-2...")
ring.remove_node("server-2")
dist = ring.get_distribution()
for node, count in sorted(dist.items()):
    print(f"  {node}: {count} keys ({count/100:.1f}%)")

Problem 2: Bloom Filter

Implement a space-efficient probabilistic data structure for membership testing. A bloom filter can tell you "definitely not in set" or "probably in set" with a configurable false positive rate.

import math
import hashlib
from typing import List


class BloomFilter:
    """Bloom filter for probabilistic membership testing.

    Properties:
    - No false negatives: if contains() returns False, the item
      is definitely not in the set
    - Possible false positives: if contains() returns True, the item
      is PROBABLY in the set (with configurable error rate)
    - O(k) time for both add and contains (k = number of hash functions)
    - Space-efficient: much less memory than storing actual elements

    Used by: Chrome (malicious URL check), Cassandra (SSTable lookup),
    Medium (article recommendation deduplication)
    """

    def __init__(self, expected_items: int, false_positive_rate: float = 0.01):
        """
        Args:
            expected_items: Expected number of items to be inserted
            false_positive_rate: Desired false positive probability
        """
        # Calculate optimal bit array size: m = -(n * ln(p)) / (ln(2)^2)
        self._size = self._optimal_size(expected_items, false_positive_rate)
        # Calculate optimal number of hash functions: k = (m/n) * ln(2)
        self._num_hashes = self._optimal_hashes(self._size, expected_items)

        self._bit_array = [False] * self._size
        self._count = 0

    @staticmethod
    def _optimal_size(n: int, p: float) -> int:
        """Calculate optimal bit array size."""
        m = -(n * math.log(p)) / (math.log(2) ** 2)
        return int(m) + 1

    @staticmethod
    def _optimal_hashes(m: int, n: int) -> int:
        """Calculate optimal number of hash functions."""
        k = (m / n) * math.log(2)
        return max(1, int(k) + 1)

    def _get_hash_positions(self, item: str) -> List[int]:
        """Generate k hash positions using double hashing.

        Uses two independent hashes to generate k positions:
        h_i(x) = (h1(x) + i * h2(x)) mod m
        """
        h1 = int(hashlib.md5(item.encode()).hexdigest(), 16)
        h2 = int(hashlib.sha256(item.encode()).hexdigest(), 16)

        positions = []
        for i in range(self._num_hashes):
            pos = (h1 + i * h2) % self._size
            positions.append(pos)

        return positions

    def add(self, item: str):
        """Add an item to the bloom filter."""
        for pos in self._get_hash_positions(item):
            self._bit_array[pos] = True
        self._count += 1

    def contains(self, item: str) -> bool:
        """Check if an item might be in the set.

        Returns False -> item is DEFINITELY NOT in the set
        Returns True  -> item is PROBABLY in the set
        """
        return all(self._bit_array[pos]
                   for pos in self._get_hash_positions(item))

    def estimated_false_positive_rate(self) -> float:
        """Calculate current estimated false positive rate."""
        # p = (1 - e^(-kn/m))^k
        if self._count == 0:
            return 0.0
        exponent = -self._num_hashes * self._count / self._size
        return (1 - math.exp(exponent)) ** self._num_hashes

    @property
    def size_bytes(self) -> int:
        """Approximate memory usage in bytes."""
        return self._size // 8 + 1

    def __len__(self) -> int:
        return self._count

    def __repr__(self) -> str:
        bits_set = sum(self._bit_array)
        return (f"BloomFilter(size={self._size} bits, "
                f"hashes={self._num_hashes}, "
                f"items={self._count}, "
                f"bits_set={bits_set}/{self._size}, "
                f"est_fpr={self.estimated_false_positive_rate():.4f})")


# ---- Usage ----
bf = BloomFilter(expected_items=10000, false_positive_rate=0.01)

# Add items
for i in range(10000):
    bf.add(f"user:{i}")

# Check membership
print(f"user:42 in filter: {bf.contains('user:42')}")     # True
print(f"user:99999 in filter: {bf.contains('user:99999')}")  # False (probably)

# Measure actual false positive rate
false_positives = sum(
    1 for i in range(10000, 20000)
    if bf.contains(f"user:{i}")
)
print(f"False positive rate: {false_positives / 10000:.4f}")
print(f"Memory: {bf.size_bytes:,} bytes vs {10000 * 20:,} bytes (storing strings)")
print(bf)

Problem 3: Circuit Breaker

Implement the circuit breaker pattern to prevent cascading failures. When a service fails repeatedly, the circuit "opens" and immediately rejects requests instead of waiting for timeouts.

from enum import Enum


class CircuitState(Enum):
    CLOSED = "closed"       # Normal: requests pass through
    OPEN = "open"           # Tripped: requests fail immediately
    HALF_OPEN = "half_open" # Testing: allow one request to test recovery


class CircuitBreaker:
    """Circuit breaker for fault tolerance.

    States:
    CLOSED  -> Normal operation. Track failure count.
    OPEN    -> Service is down. Reject all requests immediately.
    HALF_OPEN -> Allow one test request. Success -> CLOSED, Failure -> OPEN.

    Transitions:
    CLOSED -> OPEN: when failure_count >= failure_threshold
    OPEN -> HALF_OPEN: after recovery_timeout seconds
    HALF_OPEN -> CLOSED: on successful test request
    HALF_OPEN -> OPEN: on failed test request
    """

    def __init__(self, failure_threshold: int = 5,
                 recovery_timeout: float = 30.0,
                 success_threshold: int = 1):
        """
        Args:
            failure_threshold: Failures before opening circuit
            recovery_timeout: Seconds before trying half-open
            success_threshold: Successes in half-open before closing
        """
        self._failure_threshold = failure_threshold
        self._recovery_timeout = recovery_timeout
        self._success_threshold = success_threshold

        self._state = CircuitState.CLOSED
        self._failure_count = 0
        self._success_count = 0
        self._last_failure_time = 0.0
        self._total_requests = 0
        self._total_failures = 0

    @property
    def state(self) -> CircuitState:
        """Current state (may transition from OPEN to HALF_OPEN)."""
        if (self._state == CircuitState.OPEN and
                time.time() - self._last_failure_time >= self._recovery_timeout):
            self._state = CircuitState.HALF_OPEN
            self._success_count = 0
        return self._state

    def allow_request(self) -> bool:
        """Check if a request should be allowed through."""
        current_state = self.state
        self._total_requests += 1

        if current_state == CircuitState.CLOSED:
            return True
        elif current_state == CircuitState.HALF_OPEN:
            return True  # Allow test request
        else:  # OPEN
            return False

    def record_success(self):
        """Record a successful request."""
        if self._state == CircuitState.HALF_OPEN:
            self._success_count += 1
            if self._success_count >= self._success_threshold:
                self._state = CircuitState.CLOSED
                self._failure_count = 0
        elif self._state == CircuitState.CLOSED:
            self._failure_count = 0  # Reset on success

    def record_failure(self):
        """Record a failed request."""
        self._total_failures += 1
        self._last_failure_time = time.time()

        if self._state == CircuitState.HALF_OPEN:
            # Test request failed -> back to OPEN
            self._state = CircuitState.OPEN

        elif self._state == CircuitState.CLOSED:
            self._failure_count += 1
            if self._failure_count >= self._failure_threshold:
                self._state = CircuitState.OPEN

    def execute(self, func, *args, **kwargs):
        """Execute a function with circuit breaker protection.

        Returns (success: bool, result_or_error: Any)
        """
        if not self.allow_request():
            return False, CircuitOpenError(
                f"Circuit is OPEN. Retry after "
                f"{self._recovery_timeout}s"
            )

        try:
            result = func(*args, **kwargs)
            self.record_success()
            return True, result
        except Exception as e:
            self.record_failure()
            return False, e

    def get_stats(self) -> Dict:
        return {
            "state": self.state.value,
            "failure_count": self._failure_count,
            "total_requests": self._total_requests,
            "total_failures": self._total_failures,
            "error_rate": self._total_failures / max(1, self._total_requests),
        }


class CircuitOpenError(Exception):
    pass


# ---- Usage ----
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=5.0)

call_count = 0

def unreliable_service():
    """Simulated service that fails after 3 calls."""
    global call_count
    call_count += 1
    if call_count <= 5:
        raise ConnectionError("Service unavailable")
    return "Success!"

# Try calling the service
for i in range(8):
    success, result = cb.execute(unreliable_service)
    state = cb.state.value
    if success:
        print(f"Call {i+1}: SUCCESS -> {result} [{state}]")
    else:
        print(f"Call {i+1}: FAILED -> {type(result).__name__} [{state}]")

print(f"\nStats: {cb.get_stats()}")

Problem 4: Retry with Exponential Backoff

Implement a retry mechanism with exponential backoff, jitter, and configurable retry policies. This is essential for building resilient clients that interact with unreliable services.

import random
from dataclasses import dataclass
from typing import Callable, Optional, Tuple, Set, Type
from enum import Enum


class BackoffStrategy(Enum):
    EXPONENTIAL = "exponential"   # 2^attempt * base
    LINEAR = "linear"             # attempt * base
    CONSTANT = "constant"         # Always base


class JitterType(Enum):
    NONE = "none"
    FULL = "full"          # random(0, backoff)
    EQUAL = "equal"        # backoff/2 + random(0, backoff/2)
    DECORRELATED = "decorrelated"  # random(base, prev_backoff * 3)


@dataclass
class RetryConfig:
    max_retries: int = 3
    base_delay: float = 1.0
    max_delay: float = 60.0
    strategy: BackoffStrategy = BackoffStrategy.EXPONENTIAL
    jitter: JitterType = JitterType.FULL
    retryable_exceptions: Set[Type[Exception]] = None

    def __post_init__(self):
        if self.retryable_exceptions is None:
            self.retryable_exceptions = {Exception}


class RetryWithBackoff:
    """Retry mechanism with configurable backoff and jitter.

    Jitter is critical to avoid the "thundering herd" problem where
    all clients retry at the same time after a service recovers.

    AWS recommendation: Use "Full Jitter" for best performance.
    """

    def __init__(self, config: Optional[RetryConfig] = None):
        self._config = config or RetryConfig()
        self._attempt_history: List[Dict] = []

    def _calculate_delay(self, attempt: int, prev_delay: float) -> float:
        """Calculate delay with backoff strategy and jitter."""
        config = self._config

        # Base backoff
        if config.strategy == BackoffStrategy.EXPONENTIAL:
            backoff = config.base_delay * (2 ** attempt)
        elif config.strategy == BackoffStrategy.LINEAR:
            backoff = config.base_delay * (attempt + 1)
        else:  # CONSTANT
            backoff = config.base_delay

        # Cap at max delay
        backoff = min(backoff, config.max_delay)

        # Apply jitter
        if config.jitter == JitterType.NONE:
            delay = backoff
        elif config.jitter == JitterType.FULL:
            delay = random.uniform(0, backoff)
        elif config.jitter == JitterType.EQUAL:
            delay = backoff / 2 + random.uniform(0, backoff / 2)
        elif config.jitter == JitterType.DECORRELATED:
            delay = random.uniform(config.base_delay, prev_delay * 3)
            delay = min(delay, config.max_delay)

        return delay

    def _is_retryable(self, exception: Exception) -> bool:
        """Check if an exception should trigger a retry."""
        return any(isinstance(exception, exc_type)
                   for exc_type in self._config.retryable_exceptions)

    def execute(self, func: Callable, *args, **kwargs) -> Tuple[bool, Any]:
        """Execute a function with retry logic.

        Returns (success: bool, result_or_last_error: Any)
        """
        self._attempt_history = []
        prev_delay = self._config.base_delay
        last_error = None

        for attempt in range(self._config.max_retries + 1):
            try:
                result = func(*args, **kwargs)

                self._attempt_history.append({
                    "attempt": attempt,
                    "status": "success",
                    "timestamp": time.time(),
                })
                return True, result

            except Exception as e:
                last_error = e

                if not self._is_retryable(e):
                    self._attempt_history.append({
                        "attempt": attempt,
                        "status": "non_retryable_error",
                        "error": str(e),
                        "timestamp": time.time(),
                    })
                    return False, e

                if attempt < self._config.max_retries:
                    delay = self._calculate_delay(attempt, prev_delay)
                    prev_delay = delay

                    self._attempt_history.append({
                        "attempt": attempt,
                        "status": "retry",
                        "error": str(e),
                        "delay": delay,
                        "timestamp": time.time(),
                    })

                    # In a real implementation, you would sleep here:
                    # time.sleep(delay)
                else:
                    self._attempt_history.append({
                        "attempt": attempt,
                        "status": "exhausted",
                        "error": str(e),
                        "timestamp": time.time(),
                    })

        return False, last_error

    def get_history(self) -> List[Dict]:
        return self._attempt_history


# ---- Usage ----
config = RetryConfig(
    max_retries=5,
    base_delay=1.0,
    max_delay=30.0,
    strategy=BackoffStrategy.EXPONENTIAL,
    jitter=JitterType.FULL,
    retryable_exceptions={ConnectionError, TimeoutError},
)

retrier = RetryWithBackoff(config)

attempt_counter = 0

def flaky_api_call():
    """Fails first 3 times, succeeds on 4th."""
    global attempt_counter
    attempt_counter += 1
    if attempt_counter <= 3:
        raise ConnectionError(f"Connection refused (attempt {attempt_counter})")
    return {"status": "ok", "data": [1, 2, 3]}

success, result = retrier.execute(flaky_api_call)
print(f"Success: {success}")
print(f"Result: {result}")
print(f"\nRetry history:")
for entry in retrier.get_history():
    delay = f", delay={entry.get('delay', 0):.2f}s" if 'delay' in entry else ""
    error = f", error={entry.get('error', '')}" if 'error' in entry else ""
    print(f"  Attempt {entry['attempt']}: {entry['status']}{error}{delay}")
💡
Interview tip: When asked about retry logic, always mention jitter. Without jitter, all clients retry at the same time after a failure, causing a "thundering herd" that can bring the service down again. AWS's recommended approach is "Full Jitter": delay = random(0, base * 2^attempt). This one detail shows deep operational experience.

Key Takeaways

💡
  • Consistent hashing uses virtual nodes for even distribution and O(log N) lookup via binary search
  • Bloom filters trade a configurable false positive rate for massive space savings (no false negatives)
  • Circuit breakers have three states: CLOSED (normal), OPEN (reject), HALF_OPEN (test recovery)
  • Exponential backoff with jitter prevents thundering herd after service recovery
  • Always mention the trade-offs: consistent hashing vs. modular hashing, bloom filter vs. hash set
  • These patterns appear in both coding interviews and system design discussions