DeepSpeed ZeRO stages trade memory savings for communication overhead by partitioning optimizer states, gradients, and parameters across GPUs.

Distributed training in PyTorch typically relies on Distributed Data Parallel (DDP), which replicates the full model, gradients, and optimizer states on every GPU. This redundancy ensures simple implementation but limits the maximum model size to the memory of a single device. DeepSpeed is a deep learning optimization library that introduces Zero Redundancy Optimizer (ZeRO) to remove this redundancy. ZeRO shards the training states across multiple GPUs, allowing the cluster to treat aggregate memory as a single pool.

The mechanism relies on partitioning specific tensors during the forward and backward passes. Each stage of ZeRO removes a layer of redundancy, reducing memory usage per GPU at the cost of increased inter-GPU communication. The choice between stages is not a performance optimization but a capacity decision. A model that fits in 40GB of VRAM on 8 GPUs with DDP might fit on 8 GPUs with ZeRO-1, but a 100GB model will require ZeRO-3. The tradeoff is strictly between memory footprint and network bandwidth utilization.

The memory baseline for Adam optimizers

Understanding the memory savings requires the baseline math for a standard Adam optimizer in PyTorch. A single parameter in FP32 consumes 4 bytes. The Adam optimizer maintains two state variables per parameter (momentum and variance), each also in FP32, consuming 8 bytes. Additionally, a master copy of the weights is often kept in FP32 for precision during updates, adding 4 bytes. This totals 12 bytes of optimizer state per parameter.

Gradients add another 4 bytes per parameter during the backward pass. In a standard DDP setup, every GPU holds the full model (4 bytes), full optimizer states (12 bytes), and full gradients (4 bytes). This sums to 20 bytes per parameter per GPU. For a 70 billion parameter model, this requires 1.4 terabytes of VRAM. A single A100 GPU with 80GB cannot hold this state.

DeepSpeed ZeRO reduces this footprint by ensuring each GPU only stores a fraction of the total states. The reduction is linear with the number of GPUs. If 8 GPUs are used, each GPU stores roughly 1/8th of the states. However, the communication cost to reconstruct these states during the forward and backward passes increases. The system must gather the necessary shards from other GPUs before computing a layer, introducing latency.

The three stages of partitioning

ZeRO defines three distinct stages of optimization, each removing a specific set of redundant data. Stage 1 partitions only the optimizer states. Stage 2 partitions optimizer states and gradients. Stage 3 partitions optimizer states, gradients, and parameters.

StagePartitioned ComponentsMemory ReductionCommunication Overhead
ZeRO-1Optimizer StatesLow
ZeRO-2Optimizer States + GradientsMedium
ZeRO-3Optimizer States + Gradients + ParametersHigh

The configuration for ZeRO is passed to the DeepSpeed engine via a JSON file, which is then mounted into the training container. The zero_optimization block defines the stage and specific settings. The following JSON snippet demonstrates a ZeRO-2 configuration, which is the most common starting point for large models.

{
  "zero_optimization": {
    "stage": 2,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "allgather_partitions": true,
    "allgather_bucket_size": 2e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 2e8,
    "contiguous_gradients": true
  }
}

In this configuration, the stage field is set to 2. The offload_optimizer block allows moving optimizer states to CPU memory, further reducing GPU pressure at the cost of PCIe bandwidth. The allgather_bucket_size controls how much data is gathered in a single communication step. Larger buckets reduce the number of communication rounds but increase the peak memory required for the gather operation.

ZeRO-3 behaves differently by partitioning parameters themselves. This requires the model to be reloaded from shards during the forward pass. This is equivalent to Fully Sharded Data Parallel (FSDP) in PyTorch. The cpu_offload feature in ZeRO-3 allows parameters to reside on CPU, enabling training of models larger than total cluster GPU memory. This is not a speed optimization; it is a capacity enablement.

Communication volume and network topology

Partitioning states necessitates communication. In DDP, gradients are reduced via All-Reduce. In ZeRO, parameters or states are gathered via All-Gather. All-Gather is more expensive than All-Reduce because every GPU must receive data from every other GPU to reconstruct the full tensor.

ZeRO-1 requires All-Gather for optimizer states. This happens once per step. ZeRO-2 requires All-Gather for gradients. This happens during the backward pass. ZeRO-3 requires All-Gather for parameters. This happens during the forward pass. The communication volume scales with the number of layers in the model. For a transformer model, this means communication occurs at every layer, not just at the end of the step.

The NCCL library handles the actual communication between GPUs. NCCL optimizes for ring topologies and NVLink. If the cluster uses standard Ethernet without NVLink, the communication overhead of ZeRO-3 can dominate the training time. The NCCL_DEBUG environment variable can be set to INFO to trace these communication patterns in the logs.

export NCCL_DEBUG=INFO

This output reveals the size of the buffers and the time taken for each collective operation. High latency on the All-Gather calls indicates a network bottleneck. If the network is 100GbE, ZeRO-3 training will be significantly slower than DDP on the same hardware. If the network is InfiniBand or NVLink, the overhead is negligible compared to the memory savings.

Failure modes and debugging

The most common failure mode is Out Of Memory (OOM) during the gather phase. Even if the shard fits on the GPU, the temporary buffer required to reconstruct the full tensor for computation may exceed available VRAM. This is distinct from the model OOM that occurs during DDP. A ZeRO-3 OOM often happens silently during the forward pass, causing the process to crash without a clear stack trace in the application code.

Another failure mode is NCCL timeout. Because ZeRO-3 increases the number of communication rounds, the synchronization barrier is more sensitive to stragglers. If one GPU is slower due to thermal throttling or disk I/O contention, the entire cluster waits. The default NCCL timeout is 30 minutes. If the gather operation takes longer than this, the job fails with a NCCL_TIMEOUT error.

Memory offloading to CPU mitigates GPU OOM but introduces PCIe bandwidth contention. The offload_optimizer setting moves states to system RAM. If the system RAM is slower than GPU VRAM, the training step time increases. Monitoring nvidia-smi shows GPU memory usage, but it does not show CPU memory pressure. The operator must monitor system memory usage separately to ensure the offloaded states do not exhaust RAM.

Decision frame

The choice between ZeRO stages is not about speed but about capacity. ZeRO-1 is the default for models that fit in DDP but benefit from slight memory relief. ZeRO-2 is the standard for large models where gradient memory is the bottleneck. ZeRO-3 is required only when the model parameters themselves exceed the total GPU memory of the node. The tradeoff is strictly between memory footprint and network bandwidth utilization. If the cluster has high-bandwidth interconnects, ZeRO-3 is viable. If the cluster relies on standard Ethernet, ZeRO-3 will likely bottleneck on communication latency. The next time a training job fails with OOM, the decision is not to add more GPUs but to check if the ZeRO stage matches the network topology.