Apply CUDA Graphs to PyTorch workloads — API selection (torch.compile, PyTorch make_graphed_callables, TE make_graphed_callables, MCore CudaGraphManager, FullCudaGraphWrapper, manual torch.cuda.graph), code compatibility, capture workflows, dynamic pattern handling, and troubleshooting. Triggers: CUDA graph, torch.cuda.graph, make_graphed_callables, reduce-overhead, graph capture, graph replay, kernel launch overhead, CudaGraphManager, FullCudaGraphWrapper, full-iteration graph, stream capture.
CUDA Graphs capture a sequence of GPU operations once and replay them with minimal CPU overhead. This skill guides applying CUDA Graphs to PyTorch training and inference workloads using native PyTorch APIs, Transformer Engine, and Megatron-LM.
Reach for this skill when you encounter:
Do NOT use this skill for:
| Dependency | Version | Notes |
|---|---|---|
| PyTorch | >= 1.10 | torch.cuda.graph() available |
| CUDA | >= 11.0 | Graph update APIs |
| GPU | NVIDIA (any) | Required for CUDA |
| Nsight Systems | any | Optional, for profiling |
| APEX | any | Optional, for capturable optimizers |
| Transformer Engine | >= 2.2 | Optional, for FP8-aware graphing |
| Megatron-LM | core >= 0.14.0 | Optional, for CudaGraphManager / FullCudaGraphWrapper |
Choose the API based on your framework and performance needs.
| Situation | API | Workflow |
|---|---|---|
| Quick experiment, unknown graph boundaries | torch.compile(mode="reduce-overhead") | Workflow 2 |
| Training, need autograd, no FP8/PP | torch.cuda.make_graphed_callables() | Workflow 3 |
| Any PyTorch model, FP8 or PP support | TE make_graphed_callables | Workflow 4 |
| Megatron-LM, per-layer, automatic | MCore CudaGraphManager | Workflow 5 |
| Maximum perf, full-iteration capture | MCore FullCudaGraphWrapper | Workflow 6 |
| Full manual control, custom pipelines | torch.cuda.graph() | Workflow 7 |
Decision flowchart:
Strategy: Start with the highest-level API available for your framework. Move to lower-level APIs only if you need more control, hit limitations, or do not achieve the expected performance improvement.
Goal: Determine if CUDA Graphs will benefit your workload before investing effort.
nsys profile --cuda-graph-trace=graph python train.py
with torch.cuda.nvtx.range("forward"):
output = model(input)
Expected result: Identified bottleneck regions with low GPU occupancy between kernels. Proceed to the appropriate workflow from the API Selection Guide.
Goal: Automatic CUDA Graph capture with zero manual effort.
When to use: Quick experiment, unknown graph boundaries, already using
torch.compile.
Steps:
@torch.compile(mode="reduce-overhead"):
@torch.compile(mode="reduce-overhead")
def train_step(model, x, target, criterion):
output = model(x)
loss = criterion(output, target)
loss.backward()
return loss
nsys profile --cuda-graph-trace=graph python train.py
.item(), print(), data-dependent control flow. Fix these or
escalate to Workflow 3+.Trade-offs:
Goal: Training with autograd support. Separate forward/backward graphs.
When to use: Training with custom loops, non-FP8, need autograd.
Steps:
sample_input = torch.randn(batch_size, seq_len, hidden_size, device="cuda")
graphed_model = torch.cuda.make_graphed_callables(
model, (sample_input,), num_warmup_iters=3
)
graphed_model as a drop-in replacement in the training loop:
for data, target in dataloader:
optimizer.zero_grad()
output = graphed_model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
cache_enabled=False:
for data, target in dataloader:
optimizer.zero_grad()
with torch.amp.autocast("cuda", cache_enabled=False):
output = graphed_model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
s = torch.cuda.Stream()
with torch.cuda.stream(s):
model = DistributedDataParallel(model)
torch.cuda.current_stream().wait_stream(s)
graphed_model = torch.cuda.make_graphed_callables(
model, (sample_input,), num_warmup_iters=11
)
Limitations:
sample_args exactly.Goal: Per-callable graphing with FP8 support and pipeline parallelism.
When to use: FP8 training, PP with manual scheduling, non-Megatron models needing FP8, or any PyTorch model that needs FP8-aware CUDA Graphs.
Steps:
from transformer_engine.pytorch.graph import make_graphed_callables
from transformer_engine.pytorch.fp8 import fp8_autocast
sample_args = tuple(
(torch.randn(batch_size, seq_len, hidden_size, device="cuda"),)
for _ in range(num_callables * num_microbatches)
)
# Example: 2 chunks, 3 microbatches
layer_order = [1, 2, 1, 2, 1, 2, -2, -1, -2, -1, -2, -1]
graphed_layers = make_graphed_callables(
tuple(layers),
sample_args=sample_args,
fp8_enabled=True,
fp8_recipe=fp8_recipe,
fp8_weight_caching=True,
_order=layer_order, # None for no PP
)
fp8_autocast during replay:
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
for layer in graphed_layers[start:end]:
x = layer(x, is_first_microbatch=(mb_idx == 0))
# FP8 scaling auto-updated on fp8_autocast exit
optimizer.step()
Key points:
make_graphed_callables()._order: The training loop must execute graphs
in the same interleaved order as specified during capture.fp8_autocast required during replay: Without it, FP8 state is not
properly configured.fp8_weight_caching=True caches FP8 weight
quantization across microbatches; pass is_first_microbatch kwarg to
control when weights are requantized.For full API details, see references/api-te-megatron.md.
Goal: Automatic per-layer graphing for Megatron-LM training.
When to use: Megatron-LM training, especially with PP > 1. Default choice for Megatron users.
Steps:
python pretrain_gpt.py \
--enable-cuda-graph \
--cuda-graph-num-warmup-steps 3
config = TransformerConfig(
enable_cuda_graph=True,
cuda_graph_num_warmup_steps=3,
)
Key points:
TransformerLayer and MambaLayer.fp8_autocast(..., _graph=True) to skip
per-layer amax reduction; reduction happens once after all backward graphs.cuda_graph_share_io_buffers=True to share I/O
buffers between layers (requires no operations between layers).cuda_graph_use_single_mempool=True for shared pool
(higher graph count but may reduce fragmentation).Goal: Maximum performance. Captures forward+backward for all microbatches as a single graph.
When to use: Maximum performance priority, static workloads, Megatron-LM training.
Steps:
python pretrain_gpt.py \
--enable-cuda-graph \
--cuda-graph-scope full_iteration \
--cuda-graph-warmup-steps 1 \
--te-rng-tracker \
--no-check-for-nan-in-loss-and-grad
.item(), no NaN
check, no dynamic control flow).Key points:
--te-rng-tracker required: Standard RNG uses CPU scalars that cannot
be captured; TE RNG uses device tensors compatible with graphs.--no-check-for-nan-in-loss-and-grad mandatory: NaN checking uses
.item() which requires CPU-GPU sync, forbidden during capture.warmup_steps + 1.Goal: Full control over capture and replay. Custom pipelines, full-iteration capture without Megatron.
When to use: Need fine-grained control, non-Megatron full-iteration capture, custom pipelines.
Inference pattern:
static_input = torch.randn(batch_size, *shape, device="cuda")
s = torch.cuda.Stream()
with torch.cuda.stream(s):
for _ in range(3):
_ = model(static_input)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
static_output = model(static_input)
.copy_(), clone outputs:
for data in loader:
static_input.copy_(data)
g.replay()
result = static_output.clone()
Full training pattern (fwd+bwd+optimizer in one graph):
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
static_input = torch.randn(batch_size, *shape, device="cuda")
static_target = torch.randint(0, num_classes, (batch_size,), device="cuda")
# Warmup
s = torch.cuda.Stream()
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad()
with torch.amp.autocast("cuda", cache_enabled=False):
out = model(static_input)
loss = criterion(out, static_target)
loss.backward()
torch.cuda.current_stream().wait_stream(s)
# Capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
optimizer.zero_grad()
with torch.amp.autocast("cuda", cache_enabled=False):
static_output = model(static_input)
static_loss = criterion(static_output, static_target)
static_loss.backward()
# Replay loop
for data, target in loader:
static_input.copy_(data)
static_target.copy_(target)
g.replay()
optimizer.step()
DDP setup:
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
s = torch.cuda.Stream()
with torch.cuda.stream(s):
model = DistributedDataParallel(model)
# 11 warmup iterations for DDP
with torch.cuda.stream(s):
for _ in range(11):
out = model(static_input)
out.sum().backward()
torch.cuda.current_stream().wait_stream(s)
# Capture on the same side stream
with torch.cuda.graph(g):
static_output = model(static_input)
Memory pool sharing for multiple graphs:
g1 = torch.cuda.CUDAGraph()
with torch.cuda.graph(g1):
out1 = model_a(static_in_a)
# Second graph shares first graph's memory pool
g2 = torch.cuda.CUDAGraph()
with torch.cuda.graph(g2, pool=g1.pool()):
out2 = model_b(static_in_b)
Custom RNG registration:
gen = torch.cuda.default_generators[0]
g = torch.cuda.CUDAGraph()
g.register_generator_state(gen)
with torch.cuda.graph(g):
out = model(static_input) # RNG state properly captured
make_graphed_callables (Workflow 3) for larger, fewer graphs.make_graphed_callables (Workflow 4).These principles apply to all workflows. Code inside the captured region must satisfy three constraints.
Only GPU operations are captured. CPU-side code (Python logic, I/O, logging) executes during capture but is eliminated during replay.
Violations:
data = torch.load("file.pt") won't reload on replaytokens = tokenizer.encode(text) won't re-tokenizeprint(f"Step {i}") won't print during replayrandom.randint(0, 10) won't regeneratebuffer.append(tensor) won't populate during replayFix: Move all CPU-side operations outside the graphed region.
No CPU-GPU synchronization inside the graph. The CPU queues work continuously without waiting for GPU results.
Violations:
.item() to get scalar values.cpu() to move tensors for inspectiontorch.cuda.synchronize() or stream.synchronize()print(tensor) (implicitly syncs)Fix: Invoke the perf-torch-sync-free skill for systematic detection and
elimination of sync points. Use torch.cuda.set_sync_debug_mode("warn") to
find hidden syncs.
All operations, control flow, memory addresses, and shapes must be fixed across all replays.
Violations and fixes:
| Dynamic aspect | Fix |
|---|---|
if loss > threshold: | torch.where(condition, a, b) |
input = new_tensor (address changes) | Pre-allocate + .copy_() |
| Python scalars (lr, temperature) | GPU tensor + .fill_() |
| Variable batch size / sequence length | Padding or bucketing |
| MoE / dynamic routing | Partial graphing |
For detailed patterns, see references/patterns-dynamic.md.
Verify every item before attempting capture:
.item(), .cpu(), .numpy(), print(tensor) inside graphtorch.cuda.synchronize() or stream.synchronize()if tensor_value: -- use torch.where() instead.copy_().fill_().clone()d before next replaycache_enabled=False with torch.amp.autocastgraph.register_generator_state()graphsafe_get_state() / graphsafe_set_state() for RNGTORCH_NCCL_ASYNC_ERROR_HANDLING=0, construct on side streamtorch.cuda.current_stream(), not default streamactivation_checkpointing: preserve_rng_state=Falsetorch.compile functions inside manual capture without prior warmupFor the complete checklist with references, see references/patterns-compatibility.md.
Success indicators:
g.replay() completes without errorstorch.allclose)Key metrics:
| Metric | How to Check |
|---|---|
| Correctness | torch.allclose(eager, graphed, rtol=1e-5) |
| Speedup | Wall-clock time comparison |
| GPU utilization | nvidia-smi or Nsight Systems timeline |
| Memory overhead | torch.cuda.memory_summary() |
| Error | Cause | Fix |
|---|---|---|
StreamCaptureUnsupported (900) | Sync op during capture (.item(), .cpu()) | Move sync outside graph |
StreamCaptureInvalidated (901) | Background thread (e.g., pin_memory) | capture_error_mode="thread_local" |
StreamCaptureUnjoined (904) | Side stream didn't rejoin capture stream | capture_stream.wait_stream(side_stream) |
StreamCaptureImplicit (906) | AccumulateGrad on default stream | Warmup on side stream before capture |
| Illegal memory access | Input tensor freed/reassigned | Keep persistent ref, use .copy_() |
| Wrong numerical results | Dynamic behavior frozen at capture | See references/patterns-compatibility.md |
| OOM with multiple graphs | Pools can't share memory | pool=g1.pool() for sequential graphs |
| No speedup | Already GPU-bound or wrong capture scope | Profile with nsys first (Workflow 1) |
| FP8 scaling corruption | TE without fp8_autocast during replay | Wrap with fp8_autocast(enabled=True) |
| PP replay order mismatch | Wrong execution order during replay | Match _order / capture sequence exactly |
| FullCudaGraphWrapper capture fail | NaN check or sync enabled | --no-check-for-nan-in-loss-and-grad |
| RNG failure with FullCudaGraphWrapper | Standard RNG not capturable | --te-rng-tracker |
| DDP capture failure | Async error handling watchdog | TORCH_NCCL_ASYNC_ERROR_HANDLING=0 |
| DDP AccumulateGrad on default stream | DDP constructed on default stream | Construct DDP in side stream context |
| Autocast cache invalidation | Cached cast tensors freed on exit | cache_enabled=False |
For detailed troubleshooting, see references/troubleshooting.md.
Use this 3-tier lookup hierarchy -- start at Tier 1 and escalate only when needed.
You are reading it now. The workflows, compatibility checklist, and error table above cover the most common tasks. Search this file first before going deeper.
The references/ directory beside this file contains distilled reference
material -- API details, patterns, and troubleshooting pages.
How to search:
references/ -- headers are designed to be
grep-friendly.Available references:
references/api-pytorch.md -- PyTorch CUDA Graph APIs (torch.cuda.graph,
make_graphed_callables, torch.compile reduce-overhead)references/api-te-megatron.md -- TE make_graphed_callables,
CudaGraphManager, FullCudaGraphWrapper implementationsreferences/patterns-compatibility.md -- GPU-only, sync-free, and static
principles with full checklistreferences/patterns-dynamic.md -- Dynamic control flow, tensors, scalars,
shapes: workarounds and patternsreferences/troubleshooting.md -- Capture failures, numerical errors,
memory issues, performance issuesIf Tiers 1-2 do not answer the question, consult the original sources:
https://docs.nvidia.com/dl-cuda-graph/latest/index.htmlhttps://docs.pytorch.org/docs/stable/notes/cuda.html
(CUDA Graphs section)https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.htmlhttps://docs.nvidia.com/megatron-core/developer-guide/latest/index.htmlReturn to Tier 2 afterward and consider whether the answer should be distilled into the references directory for next time.
clip_grad_norm_ (PyTorch >= 1.13)