Advanced

Real-time Inference with Kafka

Deploy ML models as Kafka consumers that produce predictions in real-time, with strategies for scaling, model updates, and monitoring.

Architecture Pattern

In a Kafka-based inference system, models run as consumers that read events, enrich them with features, run predictions, and produce results to an output topic.

Basic Inference Consumer

Python — ML Inference Consumer
from confluent_kafka import Consumer, Producer
import json
import joblib
import numpy as np
import redis
import time

# Load model
model = joblib.load("fraud_model_v2.pkl")
redis_client = redis.Redis(host='localhost', port=6379)

consumer = Consumer({
    'bootstrap.servers': 'localhost:9092',
    'group.id': 'fraud-inference',
    'auto.offset.reset': 'latest',
    'max.poll.interval.ms': 60000
})
consumer.subscribe(['transactions'])

producer = Producer({'bootstrap.servers': 'localhost:9092'})

def get_features(user_id, transaction):
    """Combine real-time features with transaction data."""
    # Get precomputed streaming features from Redis
    cached = redis_client.get(f"features:{user_id}")
    streaming_features = json.loads(cached) if cached else {}

    return np.array([[
        transaction['amount'],
        transaction['merchant_category'],
        streaming_features.get('tx_count_1h', 0),
        streaming_features.get('avg_amount_1h', 0),
        streaming_features.get('unique_merchants_1h', 0),
    ]])

while True:
    msg = consumer.poll(0.1)
    if msg is None or msg.error():
        continue

    start = time.time()
    tx = json.loads(msg.value())

    # Build features and predict
    features = get_features(tx['user_id'], tx)
    fraud_prob = model.predict_proba(features)[0][1]

    # Produce prediction
    result = {
        'transaction_id': tx['transaction_id'],
        'user_id': tx['user_id'],
        'fraud_probability': float(fraud_prob),
        'is_fraud': fraud_prob > 0.5,
        'model_version': 'v2',
        'latency_ms': (time.time() - start) * 1000
    }
    producer.produce('fraud-predictions', value=json.dumps(result).encode())
    producer.poll(0)

Model Hot-Swapping

Python — Hot-swap Models Without Downtime
import threading
import os

class ModelServer:
    def __init__(self, model_path):
        self.model = joblib.load(model_path)
        self.model_path = model_path
        self.lock = threading.Lock()
        self._start_watcher()

    def predict(self, features):
        with self.lock:
            return self.model.predict_proba(features)

    def reload(self, new_path):
        new_model = joblib.load(new_path)  # Load first
        with self.lock:
            self.model = new_model          # Swap atomically
            self.model_path = new_path
        print(f"Model swapped to {new_path}")

    def _start_watcher(self):
        """Watch for new model versions via Kafka topic."""
        def watch():
            c = Consumer({
                'bootstrap.servers': 'localhost:9092',
                'group.id': 'model-updates',
                'auto.offset.reset': 'latest'
            })
            c.subscribe(['model-deployments'])
            while True:
                msg = c.poll(5.0)
                if msg and not msg.error():
                    update = json.loads(msg.value())
                    self.reload(update['model_path'])

        t = threading.Thread(target=watch, daemon=True)
        t.start()

server = ModelServer("fraud_model_v2.pkl")

Scaling Inference

  • Horizontal scaling: Add more consumer instances to the same consumer group. Kafka distributes partitions automatically.
  • Partition count: Set partitions equal to your maximum expected consumer count. You can't have more consumers than partitions.
  • Batch inference: Accumulate messages into mini-batches for GPU inference efficiency.
  • Async processing: Use async I/O for feature lookups to avoid blocking on Redis/DB calls.

Monitoring Predictions

Python — Prediction Monitoring Metrics
from prometheus_client import Counter, Histogram, start_http_server

# Metrics
predictions_total = Counter('ml_predictions_total', 'Total predictions', ['model_version'])
prediction_latency = Histogram('ml_prediction_latency_seconds', 'Prediction latency')
fraud_score = Histogram('ml_fraud_score', 'Fraud probability distribution', buckets=[0.1, 0.3, 0.5, 0.7, 0.9])

start_http_server(8000)  # Expose metrics endpoint

# In your prediction loop:
with prediction_latency.time():
    prob = model.predict_proba(features)[0][1]

predictions_total.labels(model_version='v2').inc()
fraud_score.observe(prob)
Track training-serving skew: Log the feature values used at inference time to a Kafka topic. Compare the distribution of served features against training features to detect data drift early.