Testing Model Training Functions
Testing the training loop and model behavior. Part of the Unit Testing for ML Pipelines course at AI School by Lilly Tech Systems.
What to Test in Model Training
Testing model training functions is about verifying that the training process works correctly, not that the model achieves a specific accuracy. You are testing the engineering, not the science. Does the training function accept the right inputs? Does it produce a valid model artifact? Does it handle edge cases like empty datasets or single-class data gracefully?
Testing Model Output Properties
After training, verify that the model produces outputs with the expected properties:
import pytest
import numpy as np
from sklearn.ensemble import RandomForestClassifier
def train_classifier(X, y, n_estimators=100, random_state=42):
# Train a RandomForest classifier with validation.
if len(X) == 0:
raise ValueError("Cannot train on empty dataset")
if len(np.unique(y)) < 2:
raise ValueError("Need at least 2 classes for classification")
model = RandomForestClassifier(
n_estimators=n_estimators,
random_state=random_state,
n_jobs=-1
)
model.fit(X, y)
return model
class TestTrainClassifier:
def test_returns_fitted_model(self, sample_features_and_labels):
X, y = sample_features_and_labels
model = train_classifier(X, y)
assert hasattr(model, 'predict')
assert hasattr(model, 'predict_proba')
def test_predictions_shape(self, sample_features_and_labels):
X, y = sample_features_and_labels
model = train_classifier(X, y)
predictions = model.predict(X[:5])
assert predictions.shape == (5,)
def test_probabilities_sum_to_one(self, sample_features_and_labels):
X, y = sample_features_and_labels
model = train_classifier(X, y)
probas = model.predict_proba(X[:10])
for row in probas:
assert abs(sum(row) - 1.0) < 1e-6
def test_predictions_contain_valid_classes(self, sample_features_and_labels):
X, y = sample_features_and_labels
model = train_classifier(X, y)
predictions = model.predict(X)
valid_classes = set(np.unique(y))
assert set(np.unique(predictions)).issubset(valid_classes)
def test_empty_dataset_raises(self):
with pytest.raises(ValueError, match="empty dataset"):
train_classifier(np.array([]).reshape(0, 5), np.array([]))
def test_single_class_raises(self):
X = np.random.randn(100, 5)
y = np.zeros(100)
with pytest.raises(ValueError, match="at least 2 classes"):
train_classifier(X, y)
def test_reproducibility(self, sample_features_and_labels):
X, y = sample_features_and_labels
model1 = train_classifier(X, y, random_state=42)
model2 = train_classifier(X, y, random_state=42)
pred1 = model1.predict(X)
pred2 = model2.predict(X)
np.testing.assert_array_equal(pred1, pred2)
Testing Model Serialization
If your training function saves model artifacts, test the full save-load cycle:
import joblib
def test_model_serialization_roundtrip(sample_features_and_labels, tmp_path):
X, y = sample_features_and_labels
model = train_classifier(X, y)
model_path = tmp_path / "model.pkl"
joblib.dump(model, model_path)
loaded_model = joblib.load(model_path)
original_pred = model.predict(X[:10])
loaded_pred = loaded_model.predict(X[:10])
np.testing.assert_array_equal(original_pred, loaded_pred)
Testing Training with Different Configurations
Use parameterized tests to verify training works across different hyperparameter configurations:
@pytest.mark.parametrize("n_estimators", [10, 50, 100])
@pytest.mark.parametrize("max_depth", [3, 5, None])
def test_training_different_configs(sample_features_and_labels, n_estimators, max_depth):
X, y = sample_features_and_labels
model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
model.fit(X, y)
score = model.score(X, y)
assert score > 0.5, f"Model with n_est={n_estimators}, depth={max_depth} failed"
Testing Training Error Handling
Verify that your training function handles errors gracefully. Test with invalid inputs, corrupted data, and resource constraints. A production training job that fails silently is worse than one that crashes loudly with a clear error message.
Testing Training Callbacks and Logging
If your training function logs metrics or calls callbacks, verify that these are invoked correctly. Use mock objects to capture callback invocations and assert that they were called with the expected arguments at the expected times.
Lilly Tech Systems