Multi-GPU Training Advanced

When a model is too large to fit on a single GPU, or training on one GPU is too slow, you need multi-GPU strategies. This lesson covers the parallelism techniques that enable training across multiple GPUs and multiple nodes, from simple data parallelism to advanced 3D parallelism used for frontier LLMs.

Parallelism Strategies

StrategyWhat's SplitCommunicationBest For
Data ParallelTraining dataGradient sync (AllReduce)Model fits on 1 GPU
Model Parallel (Tensor)Model layers across GPUsActivations between GPUsVery large layers
Pipeline ParallelModel stages across GPUsActivations between stagesDeep models
FSDP / ZeROParameters, gradients, optimizerParameter gather on demandLarge models, memory-efficient
3D ParallelismAll of the above combinedMixedFrontier LLM training

PyTorch FSDP Example

Python
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision

# Wrap model with FSDP for memory-efficient distributed training
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

model = FSDP(
    model,
    mixed_precision=mp_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    use_orig_params=True,
)

# Launch with: torchrun --nproc_per_node=8 train.py

Infrastructure Requirements

  • Intra-node — NVLink provides 600-900 GB/s between GPUs within a node. Use 8-GPU instances for tight coupling.
  • Inter-node — EFA/InfiniBand provides 400-3200 Gbps between nodes. Use placement groups for co-location.
  • Storage — High-throughput shared filesystem (FSx Lustre, Filestore) to avoid data loading bottlenecks.
  • Orchestration — Use torchrun, SageMaker distributed, or Kubernetes with MPI operator.
Scaling Rule: Start with FSDP for models up to ~70B parameters on a single 8-GPU node. Only add multi-node training when you exceed single-node memory. Add tensor/pipeline parallelism for 100B+ parameter models.

Ready for Best Practices?

The final lesson covers GPU utilization optimization, profiling, and production operations.

Next: Best Practices →