Intermediate

Serving ML Models

Load and serve scikit-learn, PyTorch, and TensorFlow models through FastAPI endpoints with proper lifecycle management and batch inference.

Model Loading with Lifespan

Use FastAPI's lifespan to load models at startup and release resources on shutdown:

Python
from contextlib import asynccontextmanager
from fastapi import FastAPI
import joblib

ml_models = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Load models at startup
    ml_models["classifier"] = joblib.load("models/classifier.pkl")
    ml_models["vectorizer"] = joblib.load("models/vectorizer.pkl")
    yield
    # Cleanup on shutdown
    ml_models.clear()

app = FastAPI(lifespan=lifespan)

Serving scikit-learn

Python
from pydantic import BaseModel
import numpy as np

class SKLearnInput(BaseModel):
    features: list[float]

class SKLearnOutput(BaseModel):
    prediction: int
    probabilities: list[float]

@app.post("/sklearn/predict", response_model=SKLearnOutput)
async def sklearn_predict(input: SKLearnInput):
    X = np.array([input.features])
    prediction = ml_models["classifier"].predict(X)[0]
    probs = ml_models["classifier"].predict_proba(X)[0]
    return SKLearnOutput(
        prediction=int(prediction),
        probabilities=probs.tolist()
    )

Serving PyTorch

Python
import torch
from fastapi import UploadFile, File

@asynccontextmanager
async def lifespan(app):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ml_models["model"] = torch.load("models/model.pt", map_location=device)
    ml_models["model"].eval()
    ml_models["device"] = device
    yield

@app.post("/pytorch/predict")
async def pytorch_predict(file: UploadFile = File(...)):
    image = preprocess_image(await file.read())
    with torch.no_grad():
        output = ml_models["model"](image.to(ml_models["device"]))
        probs = torch.nn.functional.softmax(output, dim=1)
    return {"predictions": probs.cpu().tolist()}

Serving TensorFlow

Python
import tensorflow as tf

@asynccontextmanager
async def lifespan(app):
    ml_models["tf_model"] = tf.saved_model.load("models/saved_model")
    yield

@app.post("/tensorflow/predict")
async def tf_predict(input: SKLearnInput):
    tensor = tf.constant([input.features])
    result = ml_models["tf_model"](tensor)
    return {"prediction": result.numpy().tolist()}

Batch Inference

Python
class BatchInput(BaseModel):
    items: list[SKLearnInput]

@app.post("/batch/predict")
async def batch_predict(batch: BatchInput):
    X = np.array([item.features for item in batch.items])
    predictions = ml_models["classifier"].predict(X)
    return {"predictions": predictions.tolist()}
Thread safety: scikit-learn models are generally thread-safe for prediction. PyTorch requires model.eval() and torch.no_grad(). For TensorFlow, ensure you use the same session across requests.

What's Next?

Next, we will learn how to stream LLM responses using Server-Sent Events and WebSockets for real-time inference.