PyTorch Training Throughput: The Patterns That Actually Move the Number

torch.compile, mixed precision, gradient accumulation, DDP vs FSDP, and the profiler — the five levers I reach for before rethinking the model architecture.

By Jovani Pink September 29, 2025 7 min — Platform & AI Engineering

Outcome focus: Cut training wall-clock time and GPU memory pressure by applying compile, AMP, and accumulation patterns in sequence before ever touching model architecture.

The training run that takes twelve hours instead of three is usually not an architecture problem.

Nine times out of ten when I dig into a slow or unstable training loop, the issue is one of five things: the model is not compiled, precision is wrong for the hardware, the effective batch size is too small because memory forced it down, the parallelism strategy is wrong for the model size, or nobody has looked at the profiler output. These are not obscure optimizations. They are table stakes, and they compose — you get the benefit of each on top of the last.

This is the order I apply them.

torch.compile#

torch.compile landed in PyTorch 2.0 and is the cheapest throughput win available. It traces your model's computation graph and emits optimized kernels without you rewriting anything. The code change is one line:

compile.py
model = MyModel().cuda()
model = torch.compile(model)

That is the whole thing. No manual kernel selection, no operator rewriting. But the first forward pass will be slow — PyTorch is tracing and compiling the graph and caching it. Do not benchmark on the first batch. Warm it up, then time it:

compile_warmup.py
# Let the compiler finish before any timing
for _ in range(3):
    _ = model(dummy_input)
 
torch.cuda.synchronize()
start = time.perf_counter()
output = model(real_input)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start

On a trace-friendly model I typically see 20–40% throughput improvement. On models with dynamic shapes or Python control flow inside forward, compile can either fail to trace or fall back to eager silently and you would never know. The fullgraph=True argument makes it strict:

compile_strict.py
model = torch.compile(model, fullgraph=True)

With fullgraph=True, an incomplete trace raises an error instead of quietly degrading. I use it in development to catch graph breaks early. In production, drop it if you have dynamic shapes you cannot eliminate — the fallback is safe, you just lose the speedup.

Mixed precision#

Modern accelerators have dedicated hardware for lower-precision arithmetic. Not using it is leaving performance on the table.

Which dtype to use depends on the hardware, not on preference. bfloat16 on A100s and H100s. fp16 on consumer GPUs (RTX 3090, 4090, and anything below). The reason is numerical range: bfloat16 has the same exponent bits as float32, so large activations do not overflow. fp16 has a much narrower range and will produce NaNs in gradients when activations spike — which they will on deeper models or larger batch sizes.

On A100s, drop the GradScaler entirely:

amp_bfloat16.py
optimizer.zero_grad()
 
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    output = model(batch["input"])
    loss = criterion(output, batch["target"])
 
loss.backward()
optimizer.step()

On consumer hardware, use fp16 with a GradScaler. The scaler monitors for overflow and adjusts the loss scale to keep gradients representable:

amp_fp16.py
scaler = torch.amp.GradScaler()
 
with torch.autocast(device_type="cuda", dtype=torch.float16):
    output = model(batch["input"])
    loss = criterion(output, batch["target"])
 
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

One thing that tripped me up early: the GradScaler silently skips optimizer steps when it detects overflow. The first time this happened I spent two hours convinced my learning rate schedule was broken. Loss looked like it had plateaued; the LR curve was fine. The actual issue was overflow — the scaler was detecting it, skipping updates, and nobody told me. Check scaler.get_scale() across steps if your loss curve stops moving. If the scale value is bouncing or trending toward a floor, overflow is happening far more than it should. The fix is usually reducing the initial scale or, if the hardware supports it, just switching to bfloat16.

Gradient accumulation#

GPU memory caps your batch size. A small batch size makes gradient estimates noisy, which makes training unstable, which forces you to lower the learning rate, which makes everything slower. Gradient accumulation sidesteps that by computing gradients over several mini-batches before applying the update — it is mathematically equivalent to a larger batch and costs throughput instead of memory.

The part people miss: divide the loss by the accumulation steps before calling backward. Skip that division and the gradients from later mini-batches dominate the accumulated sum. You are not averaging anymore — you are summing:

grad_accum.py
accum_steps = 4
optimizer.zero_grad()
 
for step, batch in enumerate(dataloader):
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        output = model(batch["input"])
        # Without this division, the gradient is 4x what it should be
        loss = criterion(output, batch["target"]) / accum_steps
 
    loss.backward()
 
    if (step + 1) % accum_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

With fp16 and a GradScaler, the clip has to happen after unscale_ and before scaler.step(). If you clip scaled gradients, the max norm threshold is meaningless — you are comparing against values inflated by the scaler's current scale factor:

grad_accum_scaler.py
scaler = torch.amp.GradScaler()
accum_steps = 4
optimizer.zero_grad()
 
for step, batch in enumerate(dataloader):
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        loss = criterion(model(batch["input"]), batch["target"]) / accum_steps
 
    scaler.scale(loss).backward()
 
    if (step + 1) % accum_steps == 0:
        scaler.unscale_(optimizer)  # must happen before the clip
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

DDP vs FSDP#

Multi-GPU training adds communication overhead that the rest of your optimization work cannot hide, so the strategy choice matters.

DistributedDataParallel replicates the full model on every GPU and syncs gradients after each backward pass. Lower communication overhead per step, simpler to reason about, and the right default when your model fits comfortably on one GPU:

ddp.py
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
 
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
 
model = MyModel().cuda()
model = DDP(model, device_ids=[local_rank])

FullyShardedDataParallel shards parameters, gradients, and optimizer state across all GPUs. Each GPU holds only a fraction of the model at any time; parameters are gathered when needed for forward and backward and then resharded. Per-GPU memory drops roughly proportional to the number of GPUs, which is what makes large model training feasible at all:

fsdp.py
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
 
model = MyModel().cuda()
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    use_orig_params=True,  # required for torch.compile compatibility
)

FSDP issues all_gather calls to reconstitute parameters during forward and backward, then reshards after the update. On NVLink-connected GPUs, that communication is fast and the memory savings dominate. On PCIe-only machines or across nodes over Ethernet, FSDP can be slower than DDP because the communication volume is higher even though the memory footprint is lower.

I reach for DDP first. If the model does not fit on a single GPU after AMP and gradient checkpointing, then FSDP. I have watched people jump to FSDP on a two-GPU machine to chase "maximum scalability" and end up with slower training than a single DDP run.

Profile before you guess#

Every recommendation above comes with an asterisk: the actual gain depends on your model shape, your data, and your hardware. The profiler resolves the asterisk.

profiler.py
from torch.profiler import profile, ProfilerActivity, schedule
 
prof = profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1, warmup=2, active=5, repeat=1),
    record_shapes=True,
    profile_memory=True,
    with_stack=False,  # adds overhead; enable only to trace a specific op
)
 
prof.start()
for step, batch in enumerate(dataloader):
    if step >= 8:
        break
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        loss = criterion(model(batch["input"]), batch["target"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    prof.step()
prof.stop()
 
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
prof.export_chrome_trace("trace.json")

Open trace.json in chrome://tracing or Perfetto UI. What to look for: GPU idle time between kernels (data pipeline bottleneck), long host-to-device copies (pin_memory is off or batches are being assembled on GPU), and individual ops that are disproportionately expensive relative to what they should cost (usually means a fused kernel is available but not firing).

The profiler output has killed more "we need a better architecture" conversations than I can count. Repeatedly the answer has been DataLoader running with num_workers=0, pin_memory missing, or a .item() call inside the training loop forcing a GPU sync on every single step.

That last one is easy to miss. Any .item() or .numpy() call on a CUDA tensor inside the hot path forces a full GPU-CPU sync — the training loop stalls and waits for the GPU to flush before continuing. Computing a metric with .item() inside the loop does this on every step. Move metric computation out of the hot path, accumulate with tensor ops, and call .item() once per epoch or once per logging interval.

The sequence matters#

Apply these in order. Compile without AMP leaves throughput on the table. AMP without the right dtype for your hardware introduces silent failures that look like training instability. Both without gradient accumulation means your effective batch size is whatever fits in GPU memory, not what training stability requires.

Profile last. The profiler tells you where the time is actually going — which is only useful once the easy wins are already in and you are looking for what remains.

Start from the top, measure after each change, and let the numbers tell you when to stop.

Back to all writing
On this page
  1. torch.compile
  2. Mixed precision
  3. Gradient accumulation
  4. DDP vs FSDP
  5. Profile before you guess
  6. The sequence matters