Intermediate
Serving ML Models
Load and serve scikit-learn, PyTorch, and TensorFlow models through FastAPI endpoints with proper lifecycle management and batch inference.
Model Loading with Lifespan
Use FastAPI's lifespan to load models at startup and release resources on shutdown:
Python
from contextlib import asynccontextmanager from fastapi import FastAPI import joblib ml_models = {} @asynccontextmanager async def lifespan(app: FastAPI): # Load models at startup ml_models["classifier"] = joblib.load("models/classifier.pkl") ml_models["vectorizer"] = joblib.load("models/vectorizer.pkl") yield # Cleanup on shutdown ml_models.clear() app = FastAPI(lifespan=lifespan)
Serving scikit-learn
Python
from pydantic import BaseModel import numpy as np class SKLearnInput(BaseModel): features: list[float] class SKLearnOutput(BaseModel): prediction: int probabilities: list[float] @app.post("/sklearn/predict", response_model=SKLearnOutput) async def sklearn_predict(input: SKLearnInput): X = np.array([input.features]) prediction = ml_models["classifier"].predict(X)[0] probs = ml_models["classifier"].predict_proba(X)[0] return SKLearnOutput( prediction=int(prediction), probabilities=probs.tolist() )
Serving PyTorch
Python
import torch from fastapi import UploadFile, File @asynccontextmanager async def lifespan(app): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ml_models["model"] = torch.load("models/model.pt", map_location=device) ml_models["model"].eval() ml_models["device"] = device yield @app.post("/pytorch/predict") async def pytorch_predict(file: UploadFile = File(...)): image = preprocess_image(await file.read()) with torch.no_grad(): output = ml_models["model"](image.to(ml_models["device"])) probs = torch.nn.functional.softmax(output, dim=1) return {"predictions": probs.cpu().tolist()}
Serving TensorFlow
Python
import tensorflow as tf @asynccontextmanager async def lifespan(app): ml_models["tf_model"] = tf.saved_model.load("models/saved_model") yield @app.post("/tensorflow/predict") async def tf_predict(input: SKLearnInput): tensor = tf.constant([input.features]) result = ml_models["tf_model"](tensor) return {"prediction": result.numpy().tolist()}
Batch Inference
Python
class BatchInput(BaseModel): items: list[SKLearnInput] @app.post("/batch/predict") async def batch_predict(batch: BatchInput): X = np.array([item.features for item in batch.items]) predictions = ml_models["classifier"].predict(X) return {"predictions": predictions.tolist()}
Thread safety: scikit-learn models are generally thread-safe for prediction. PyTorch requires
model.eval() and torch.no_grad(). For TensorFlow, ensure you use the same session across requests.What's Next?
Next, we will learn how to stream LLM responses using Server-Sent Events and WebSockets for real-time inference.
Lilly Tech Systems