Intermediate

Testing Model Training Functions

Testing the training loop and model behavior. Part of the Unit Testing for ML Pipelines course at AI School by Lilly Tech Systems.

What to Test in Model Training

Testing model training functions is about verifying that the training process works correctly, not that the model achieves a specific accuracy. You are testing the engineering, not the science. Does the training function accept the right inputs? Does it produce a valid model artifact? Does it handle edge cases like empty datasets or single-class data gracefully?

Testing Model Output Properties

After training, verify that the model produces outputs with the expected properties:

import pytest
import numpy as np
from sklearn.ensemble import RandomForestClassifier

def train_classifier(X, y, n_estimators=100, random_state=42):
    # Train a RandomForest classifier with validation.
    if len(X) == 0:
        raise ValueError("Cannot train on empty dataset")
    if len(np.unique(y)) < 2:
        raise ValueError("Need at least 2 classes for classification")

    model = RandomForestClassifier(
        n_estimators=n_estimators,
        random_state=random_state,
        n_jobs=-1
    )
    model.fit(X, y)
    return model

class TestTrainClassifier:
    def test_returns_fitted_model(self, sample_features_and_labels):
        X, y = sample_features_and_labels
        model = train_classifier(X, y)
        assert hasattr(model, 'predict')
        assert hasattr(model, 'predict_proba')

    def test_predictions_shape(self, sample_features_and_labels):
        X, y = sample_features_and_labels
        model = train_classifier(X, y)
        predictions = model.predict(X[:5])
        assert predictions.shape == (5,)

    def test_probabilities_sum_to_one(self, sample_features_and_labels):
        X, y = sample_features_and_labels
        model = train_classifier(X, y)
        probas = model.predict_proba(X[:10])
        for row in probas:
            assert abs(sum(row) - 1.0) < 1e-6

    def test_predictions_contain_valid_classes(self, sample_features_and_labels):
        X, y = sample_features_and_labels
        model = train_classifier(X, y)
        predictions = model.predict(X)
        valid_classes = set(np.unique(y))
        assert set(np.unique(predictions)).issubset(valid_classes)

    def test_empty_dataset_raises(self):
        with pytest.raises(ValueError, match="empty dataset"):
            train_classifier(np.array([]).reshape(0, 5), np.array([]))

    def test_single_class_raises(self):
        X = np.random.randn(100, 5)
        y = np.zeros(100)
        with pytest.raises(ValueError, match="at least 2 classes"):
            train_classifier(X, y)

    def test_reproducibility(self, sample_features_and_labels):
        X, y = sample_features_and_labels
        model1 = train_classifier(X, y, random_state=42)
        model2 = train_classifier(X, y, random_state=42)
        pred1 = model1.predict(X)
        pred2 = model2.predict(X)
        np.testing.assert_array_equal(pred1, pred2)
💡
Best practice: Always test reproducibility by verifying that training with the same random seed produces identical results. This catches issues with uncontrolled randomness in your training pipeline.

Testing Model Serialization

If your training function saves model artifacts, test the full save-load cycle:

import joblib

def test_model_serialization_roundtrip(sample_features_and_labels, tmp_path):
    X, y = sample_features_and_labels
    model = train_classifier(X, y)

    model_path = tmp_path / "model.pkl"
    joblib.dump(model, model_path)
    loaded_model = joblib.load(model_path)

    original_pred = model.predict(X[:10])
    loaded_pred = loaded_model.predict(X[:10])
    np.testing.assert_array_equal(original_pred, loaded_pred)

Testing Training with Different Configurations

Use parameterized tests to verify training works across different hyperparameter configurations:

@pytest.mark.parametrize("n_estimators", [10, 50, 100])
@pytest.mark.parametrize("max_depth", [3, 5, None])
def test_training_different_configs(sample_features_and_labels, n_estimators, max_depth):
    X, y = sample_features_and_labels
    model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
    model.fit(X, y)
    score = model.score(X, y)
    assert score > 0.5, f"Model with n_est={n_estimators}, depth={max_depth} failed"

Testing Training Error Handling

Verify that your training function handles errors gracefully. Test with invalid inputs, corrupted data, and resource constraints. A production training job that fails silently is worse than one that crashes loudly with a clear error message.

Testing Training Callbacks and Logging

If your training function logs metrics or calls callbacks, verify that these are invoked correctly. Use mock objects to capture callback invocations and assert that they were called with the expected arguments at the expected times.

Important: Training tests should be fast. Do not train full models in unit tests. Use small datasets, few iterations, and simple configurations. Save full training tests for integration test suites that run less frequently.