Intermediate
Model Evaluation
Properly evaluate models using train/test splits, cross-validation, hyperparameter tuning, and model comparison.
Train/Test Split with rsample
R
library(tidymodels) set.seed(42) # Simple random split (75/25) split <- initial_split(mtcars, prop = 0.75) train <- training(split) test <- testing(split) # Stratified split (maintain class proportions) data <- iris |> mutate(Species = factor(Species)) split <- initial_split(data, prop = 0.75, strata = Species) train <- training(split) test <- testing(split)
Cross-Validation
R
# 10-fold cross-validation folds <- vfold_cv(train, v = 10) # Stratified CV folds <- vfold_cv(train, v = 10, strata = Species) # Fit model on each fold rf_spec <- rand_forest(trees = 500) |> set_engine("ranger") |> set_mode("classification") wf <- workflow() |> add_formula(Species ~ .) |> add_model(rf_spec) cv_results <- wf |> fit_resamples( resamples = folds, metrics = metric_set(accuracy, roc_auc) ) # View CV results collect_metrics(cv_results)
Hyperparameter Tuning
R
# Define a model with tunable parameters rf_spec <- rand_forest( trees = 500, mtry = tune(), # Tune this min_n = tune() # Tune this ) |> set_engine("ranger") |> set_mode("classification") wf <- workflow() |> add_formula(Species ~ .) |> add_model(rf_spec) # Grid search tune_results <- wf |> tune_grid( resamples = folds, grid = 20, # Try 20 random combinations metrics = metric_set(accuracy) ) # View best parameters show_best(tune_results, metric = "accuracy") best_params <- select_best(tune_results, metric = "accuracy") # Finalize the workflow with best params final_wf <- wf |> finalize_workflow(best_params) final_fit <- final_wf |> fit(data = train) # Bayesian tuning (smarter search) tune_bayes_results <- wf |> tune_bayes( resamples = folds, iter = 30, metrics = metric_set(accuracy) )
Model Comparison with workflowsets
R
# Compare multiple models at once models <- workflow_set( preproc = list(formula = Species ~ .), models = list( rf = rand_forest() |> set_engine("ranger") |> set_mode("classification"), tree = decision_tree() |> set_engine("rpart") |> set_mode("classification"), log = logistic_reg() |> set_engine("glm") ) ) results <- models |> workflow_map("fit_resamples", resamples = folds) rank_results(results, rank_metric = "accuracy") autoplot(results)
ROC Curves
R
# Get probability predictions probs <- final_fit |> predict(test, type = "prob") results <- test |> bind_cols(probs) # Plot ROC curve results |> roc_curve(truth = Species, .pred_setosa:.pred_virginica) |> autoplot()
Lilly Tech Systems