Code instrumentation for timing workloads. Two scenarios: (1) Training loop — inject manual timing to report per-iteration latency, throughput (samples/sec), and data load time. (2) Standalone kernel/op — write CUDA event timing code with warmup, per-iteration statistics, and anti-pattern avoidance. Also covers NVTX annotation for labeling profiler timelines. NOT for: running or analyzing profiler tools (nsys, ncu, Nsight Systems, Nsight Compute), writing kernels (Triton, CuTe, CUDA), applying optimizations (CUDA Graphs, gradient checkpointing, fusion), or interpreting roofline/SOL% metrics. Triggers: "measure throughput", "benchmark this function", "time my training loop", "samples per second", "NVTX annotate", "instrument my dataloader", "data load time", "kernel timing", "how do I time".
Pick ONE path based on the workload type:
| Workload | Approach | Section |
|---|---|---|
| Training loop | Manual torch.cuda.synchronize() + time.perf_counter() with warmup | Loop Workloads — Manual Timing |
| Single kernel or op | Write CUDA event benchmark (pre-allocate, warmup, event pairs) | Non-Loop Workloads — CUDA Event Benchmarking |
| Add timeline labels for nsys | Use @nvtx.annotate decorator or context manager | NVTX Reference |
time.perf_counter()) include host overhead and miss asynchronous execution.torch.cuda.synchronize() adds 10-50us overhead. Record CUDA events asynchronously, sync once at the end.cute.compile().For training loops and iterative workloads, use manual torch.cuda.synchronize() + time.perf_counter() timing with warmup to measure per-iteration latency, throughput, and data load time.
Read the user's training script, understand the dataloader and loop structure, then inject timing code.
import time
import torch
WARMUP = 5
NUM_ITERS = 30
BATCH_SIZE = 128 # global batch size for throughput calculation
iter_times = []
data_times = []
for i, batch in enumerate(dataloader):
if i >= WARMUP + NUM_ITERS:
break
t_data_end = time.perf_counter()
torch.cuda.synchronize()
t_start = time.perf_counter()
# ... existing training loop body ...
torch.cuda.synchronize()
t_end = time.perf_counter()
if i >= WARMUP:
iter_ms = (t_end - t_start) * 1000
iter_times.append(iter_ms)
if i > 0:
data_times.append((t_data_end - prev_iter_end) * 1000)
print(f"[{i:04d}]: iter {iter_ms:.2f} ms, fps {BATCH_SIZE / (iter_ms / 1000):.2f}")
prev_iter_end = t_end
import statistics
print(f"Average: iter {statistics.mean(iter_times):.2f} ms, "
f"fps {BATCH_SIZE / (statistics.mean(iter_times) / 1000):.2f}")
data / iter > 0.2, data loading is a bottleneck.Manual timing reports aggregate iteration timing — not per-sub-phase breakdown (forward, backward, optimizer). When the user asks where time is spent within compute:
torch.cuda.synchronize() + time.perf_counter() around each sub-phase for a one-off diagnosis, ORnsys profile for timeline visualization.For single kernels, one-shot inference, or standalone operations, write CUDA event benchmarking code directly.
import torch
def benchmark(fn, warmup=50, iters=100):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters # ms per iteration
import torch
import statistics
def benchmark_detailed(fn, warmup=50, iters=100):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
for i in range(iters):
starts[i].record()
fn()
ends[i].record()
torch.cuda.synchronize()
times = [starts[i].elapsed_time(ends[i]) for i in range(iters)]
return {
"mean_ms": statistics.mean(times),
"median_ms": statistics.median(times),
"std_ms": statistics.stdev(times) if len(times) > 1 else 0,
"min_ms": min(times),
"max_ms": max(times),
}
| Anti-Pattern | Problem |
|---|---|
torch.cuda.synchronize() before AND after each iteration | Adds ~10-50us overhead per iteration |
time.perf_counter() for GPU timing | Measures CPU time, misses async GPU execution |
| Missing warmup | First iterations include JIT, clock ramp-up, context init |
| Allocating tensors inside measurement loop | Allocation overhead pollutes timing |
| Reporting only mean | Hides variance, outliers, bimodal distributions |
For additional benchmarking templates (CUDA Graph, CuTe DSL, Triton, Raw CUDA), see references/benchmarking-patterns.md.
NVTX (NVIDIA Tools Extension) adds named annotations to profiler timelines. Use NVTX to label phases (forward, backward, optimizer) for readability in nsys — not for measurement.
import nvtx
# Decorator — annotates every call
@nvtx.annotate("training_step", color="blue")
def training_step():
...
# Context manager — annotates a code block
with nvtx.annotate("data_loading", color="green"):
batch = next(dataloader)
For NVTX domains, categories, payloads, and legacy API details, see references/nvtx-api.md.
device_time vs deprecated cuda_time)PyTorch深度学习模式与最佳实践,用于构建稳健、高效且可复现的训练流程、模型架构和数据加载。