$37
Two progressive pipeline stages that overlap global memory loads, LDS reads, and MFMA compute to hide memory latency. Apply Stage 1 first; verify it is effective before proceeding to Stage 2.
Both stages support CDNA3 and CDNA4. Each stage has one unified code template — the only arch-specific difference is the global load mechanism (one line swap).
# Identify architecture
python3 -c "import torch; props = torch.cuda.get_device_properties(0); print(props.gcnArchName)"
# Find most recent compiled kernel ISA
find ~/.triton/cache -name "*.amdgcn" | xargs ls -lt | head -5
# Count stall signals in the hot loop
grep -c "s_waitcnt vmcnt(0)" <path>.amdgcn # Stage 1 signal: HBM/DMA latency exposed
grep -c "s_waitcnt lgkmcnt(0)" <path>.amdgcn # Stage 2 signal: LDS read latency exposed
gcnArchName | Architecture | async_copy | Global load mechanism |
|---|
gfx942 | CDNA3 | No | buffer_load → VGPR → ds_write |
gfx950 | CDNA4 | Yes | async_copy.buffer_load_to_shared |
MI300X, MI308X, and MI325X are all gfx942 (CDNA3). MI350 is gfx950 (CDNA4).
| Signal in amdgcn | Root cause | Fix |
|---|---|---|
s_waitcnt vmcnt(0) before ds_write (CDNA3) | HBM load latency | Stage 1 |
s_waitcnt vmcnt(0) before MFMA (CDNA4) | DMA latency | Stage 1 |
s_waitcnt lgkmcnt(0) before MFMA | LDS read latency | Stage 2 |
If MFMA utilization is already > 85%, the kernel is compute-bound — skip both stages.
After applying bank-conflict-free LDS layouts, two independent latency sources remain:
ds_read issues to VGPR registers.Stage 1 hides (1). Stage 2 hides (2). When both are applied:
Time → [global load for k+2] ────────────────────────────────────────▶
[ds_read for k+1] ──────────────▶
[MFMA for k] ──────────▶
On CDNA3, Stage 1 holds two full tiles in VGPRs simultaneously. This can reduce occupancy. After Stage 1, check:
# Inspect the compiled ISA
grep "NumVgprs:" <path>.amdgcn
# occupancy = floor(512 / ceil(NumVgprs / 8) / 8) [gfx942 has 512 VGPRs/SIMD, granularity 8]
If VGPRs increased so much that occupancy dropped from 2→1 waves/SIMD and the kernel is slower, revert Stage 1 and document. Stage 2 can still be attempted if lgkmcnt stalls dominate without Stage 1 in place.
Allocate two LDS buffers and pipeline the global load for tile k+1 to run
concurrently with the MFMA for tile k. The key difference from the baseline:
The template is identical for CDNA3 and CDNA4. Only the two marked lines differ.
nBuffers: gl.constexpr = 2
smemA = gl.allocate_shared_memory(
a_ptr.type.element_ty, [nBuffers, BLOCK_M, BLOCK_K], layout=sharedLayoutA
)
smemB = gl.allocate_shared_memory(
b_ptr.type.element_ty, [nBuffers, BLOCK_K, BLOCK_N], layout=sharedLayoutB
)
iterMax = gl.cdiv(K, BLOCK_K)
gl.assume(iterMax > 0)
g_idx = 0
# ── CDNA4 ──────────────────────────────────────────────────────────────────
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemA.index(g_idx), a_base, a_offsets)
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemB.index(g_idx), b_base, b_offsets)
gl.amd.cdna4.async_copy.commit_group()
# ── CDNA3 (replace the three lines above with these four) ──────────────────
vgpr_a = gl.amd.cdna3.buffer_load(ptr=a_ptr, offsets=a_offsets)
vgpr_b = gl.amd.cdna3.buffer_load(ptr=b_ptr, offsets=b_offsets)
smemA.index(g_idx).store(vgpr_a) # ds_write; s_waitcnt vmcnt(0) fires here
smemB.index(g_idx).store(vgpr_b)
# ───────────────────────────────────────────────────────────────────────────
a_base += BLOCK_K * stride_ak
b_base += BLOCK_K * stride_bk
for k in range(0, iterMax - 1):
l_idx = k % 2 # LDS slot holding the tile to compute NOW
g_idx = 1 - l_idx # LDS slot to load the NEXT tile into
# Issue global load for tile k+1 (non-blocking)
# ── CDNA4 ──────────────────────────────────────────────────────────────
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemA.index(g_idx), a_base, a_offsets)
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemB.index(g_idx), b_base, b_offsets)
gl.amd.cdna4.async_copy.commit_group()
gl.amd.cdna4.async_copy.wait_group(1) # allow 1 DMA in-flight while we compute
# ── CDNA3 (replace the four lines above with these two) ────────────────
vgpr_a = gl.amd.cdna3.buffer_load(ptr=a_ptr, offsets=a_offsets)
vgpr_b = gl.amd.cdna3.buffer_load(ptr=b_ptr, offsets=b_offsets)
# ───────────────────────────────────────────────────────────────────────
# LDS read + MFMA for tile k (overlaps with in-flight global load above)
a = smemA.index(l_idx).load(layout=dotOpLayoutA)
b = smemB.index(l_idx).load(layout=dotOpLayoutB)
acc = gl.amd.cdna3.mfma(a, b, acc)
# Write tile k+1 into LDS (vmcnt stall hidden behind MFMA above)
# ── CDNA4: no ds_write needed — async_copy already landed in LDS ───────
# ── CDNA3 (add these two lines after MFMA) ─────────────────────────────
smemA.index(g_idx).store(vgpr_a)
smemB.index(g_idx).store(vgpr_b)
# ───────────────────────────────────────────────────────────────────────
a_base += BLOCK_K * stride_ak
b_base += BLOCK_K * stride_bk
# ── CDNA4 ──────────────────────────────────────────────────────────────────
gl.amd.cdna4.async_copy.wait_group(0)
# ── CDNA3 (no extra wait needed — vmcnt(0) already fired in last loop iter) -
# ───────────────────────────────────────────────────────────────────────────
l_idx = (iterMax - 1) % 2
a = smemA.index(l_idx).load(layout=dotOpLayoutA)
b = smemB.index(l_idx).load(layout=dotOpLayoutB)
acc = gl.amd.cdna3.mfma(a, b, acc)
Run this after implementing Stage 1. Do not proceed to Stage 2 unless Stage 1 passes both checks.
import torch, importlib.util
def load_kernel(path, name):
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec); spec.loader.exec_module(mod)
return mod
baseline = load_kernel("kernel_baseline.py", "baseline")
stage1 = load_kernel("kernel_stage1.py", "stage1")
# Use the actual input shapes from the target workload
x = torch.randn(B, M, K, dtype=torch.bfloat16, device="cuda")
w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
c_ref = baseline.launcher(x, w)
c_new = stage1.launcher(x, w)
assert torch.allclose(c_ref, c_new, atol=1.0, rtol=0), \
f"Stage 1 FAILED: max diff = {(c_ref - c_new).abs().max().item()}"
print("Stage 1 correctness OK")
Use /kernel-perf-analysis in Mode 3 after Stage 1 to measure MFMA efficiency
and confirm that vmcnt stalls are hidden. Mode 3 is triggered by mentioning "ATT",
"trace", "MFMA efficiency", "vmcnt", or "bottleneck":
/kernel-perf-analysis
Kernel file: <absolute path to stage1_kernel.py>
Mode hint: ATT trace MFMA efficiency vmcnt bottleneck
Label: stage1_global_prefetch
The skill collects ATT traces and prints:
MFMA efficiency : 72.4% (target > 80%; was ~57% before Stage 1)
Avg iteration cycles : 210.3
Time distribution : prologue=1.8%, loop=97.1%, epilogue=1.1%
Also check ISA stall counts:
stage1_isa=$(find ~/.triton/cache -name "*.amdgcn" | xargs ls -lt | head -1 | awk '{print $NF}')
echo "VGPRs:"; grep "NumVgprs:" $stage1_isa
echo "vmcnt(0) count:"; grep -c "s_waitcnt vmcnt(0)" $stage1_isa
echo "lgkmcnt(0) count:"; grep -c "s_waitcnt lgkmcnt(0)" $stage1_isa
| Outcome | Action |
|---|---|
Faster and vmcnt(0) count dropped | ✅ Stage 1 succeeded — proceed to Stage 2 |
| Slower, CDNA3, VGPRs increased significantly | ❌ Revert. Check if occupancy dropped (waves/SIMD fell). Document and stop. |
| Slower, CDNA4 | ❌ Revert. Kernel is likely compute-bound or already has good HW prefetch. Stop. |
Same speed, lgkmcnt(0) count is high | ⚠️ Stage 1 neutral — LDS latency dominates. Proceed to Stage 2 anyway. |
Architecture-independent. This stage is purely a register scheduling technique —
no async_copy required.
After Stage 1, ds_read may still stall before MFMA. Fix: issue ds_read for tile
k+1 at the end of iteration k, so by the time iteration k+1 reaches MFMA
the data is already in registers.
wait(1) → ds_read k → lgkmcnt stall → MFMA kMFMA k uses registers pre-loaded at end of iteration k-1 — no stallThis requires extending the prologue to load two tiles and pre-read the first into registers before the loop begins.
The template is identical for CDNA3 and CDNA4. Only the two marked lines differ (same as Stage 1).
iterMax = gl.cdiv(K, BLOCK_K)
gl.assume(iterMax > 1)
## --- Tile 0 → LDS[0] ---
g_idx = 0
# ── CDNA4 ──────────────────────────────────────────────────────────────────
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemA.index(g_idx), a_base, a_offsets)
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemB.index(g_idx), b_base, b_offsets)
gl.amd.cdna4.async_copy.commit_group()
# ── CDNA3 ──────────────────────────────────────────────────────────────────
vgpr_a = gl.amd.cdna3.buffer_load(ptr=a_ptr, offsets=a_offsets)
vgpr_b = gl.amd.cdna3.buffer_load(ptr=b_ptr, offsets=b_offsets)
smemA.index(g_idx).store(vgpr_a)
smemB.index(g_idx).store(vgpr_b)
# ───────────────────────────────────────────────────────────────────────────
a_base += BLOCK_K * stride_ak
b_base += BLOCK_K * stride_bk
## --- Tile 1 → LDS[1] ---
g_idx = 1
# ── CDNA4 ──────────────────────────────────────────────────────────────────
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemA.index(g_idx), a_base, a_offsets)
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemB.index(g_idx), b_base, b_offsets)
gl.amd.cdna4.async_copy.commit_group()
gl.amd.cdna4.async_copy.wait_group(1) # wait for tile 0 only; tile 1 still in-flight
# ── CDNA3 ──────────────────────────────────────────────────────────────────
vgpr_a = gl.amd.cdna3.buffer_load(ptr=a_ptr, offsets=a_offsets)
vgpr_b = gl.amd.cdna3.buffer_load(ptr=b_ptr, offsets=b_offsets)
smemA.index(g_idx).store(vgpr_a) # vmcnt stall fires here for tile 1
smemB.index(g_idx).store(vgpr_b)
# ───────────────────────────────────────────────────────────────────────────
a_base += BLOCK_K * stride_ak
b_base += BLOCK_K * stride_bk
## --- Pre-read tile 0 from LDS[0] into registers ---
# (tile 0 is guaranteed ready; tile 1 DMA/vmcnt may still be in-flight — that's fine)
a = smemA.index(0).load(layout=dotOpLayoutA)
b = smemB.index(0).load(layout=dotOpLayoutB)
for k in range(0, iterMax - 1):
g_idx = k % 2 # LDS slot to write tile k+2 into (was consumed at iter k-1)
l_idx = 1 - g_idx # LDS slot holding tile k+1 (ready to read)
## MFMA on pre-loaded registers — no lgkmcnt stall
acc = gl.amd.cdna3.mfma(a, b, acc)
## Issue global load for tile k+2 (non-blocking; masked on last useful iter)
if k < iterMax - 2:
# ── CDNA4 ──────────────────────────────────────────────────────────
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemA.index(g_idx), a_base, a_offsets)
gl.amd.cdna4.async_copy.buffer_load_to_shared(smemB.index(g_idx), b_base, b_offsets)
gl.amd.cdna4.async_copy.commit_group()
gl.amd.cdna4.async_copy.wait_group(0) # drain — tile k+1 must be ready for ds_read below
# ── CDNA3 ──────────────────────────────────────────────────────────
vgpr_a = gl.amd.cdna3.buffer_load(ptr=a_ptr, offsets=a_offsets)
vgpr_b = gl.amd.cdna3.buffer_load(ptr=b_ptr, offsets=b_offsets)
# ───────────────────────────────────────────────────────────────────
a_base += BLOCK_K * stride_ak
b_base += BLOCK_K * stride_bk
## Write tile k+2 into LDS (vmcnt stall hidden behind MFMA above)
if k < iterMax - 2:
# ── CDNA3 only ─────────────────────────────────────────────────────
smemA.index(g_idx).store(vgpr_a)
smemB.index(g_idx).store(vgpr_b)
# ── CDNA4: async_copy already landed tile k+2 in LDS — nothing to do
## ds_read tile k+1 into registers for next MFMA (overlaps with ds_write above)
a = smemA.index(l_idx).load(layout=dotOpLayoutA)
b = smemB.index(l_idx).load(layout=dotOpLayoutB)
## Final MFMA — a, b were pre-loaded at the end of the last loop iteration
acc = gl.amd.cdna3.mfma(a, b, acc)
Run this after implementing Stage 2. Compare against the Stage 1 kernel (not the original baseline) to isolate Stage 2's contribution.
c_s1 = stage1.launcher(x, w)
c_s2 = stage2.launcher(x, w)
assert torch.allclose(c_s1, c_s2, atol=1.0, rtol=0), \
f"Stage 2 FAILED: max diff = {(c_s1 - c_s2).abs().max().item()}"
print("Stage 2 correctness OK")
Use /kernel-perf-analysis in Mode 3 after Stage 2 to confirm that lgkmcnt
stalls are now hidden and MFMA efficiency improved further. Compare against the
Stage 1 ATT result:
/kernel-perf-analysis
Kernel file: <absolute path to stage2_kernel.py>
Mode hint: ATT trace MFMA efficiency lgkmcnt bottleneck
Label: stage2_local_prefetch
Expected output showing improvement over Stage 1:
MFMA efficiency : 84.1% (was ~72% after Stage 1; target > 80%)
Avg iteration cycles : 178.6
Time distribution : prologue=1.5%, loop=97.9%, epilogue=0.6%
Also check ISA stall counts to confirm lgkmcnt(0) dropped:
stage2_isa=$(find ~/.triton/cache -name "*.amdgcn" | xargs ls -lt | head -1 | awk '{print $NF}')
echo "VGPRs:"; grep "NumVgprs:" $stage2_isa
echo "vmcnt(0) count:"; grep -c "s_waitcnt vmcnt(0)" $stage2_isa
echo "lgkmcnt(0) count:"; grep -c "s_waitcnt lgkmcnt(0)" $stage2_isa
# Confirm lgkmcnt(0) count dropped compared to Stage 1
| Outcome | Action |
|---|---|
MFMA efficiency > 80% and lgkmcnt(0) count dropped | ✅ Stage 2 succeeded — keep it |
| Slower or same speed | ❌ Revert to Stage 1. Compiler already schedules ds_read well, or kernel is MFMA-bound. Document. |
| VGPRs increased, CDNA3 only | Check occupancy. If waves/SIMD dropped, revert. |
| Stage | Hides | Mechanism | Expected Speedup |
|---|---|---|---|
| Stage 1 (CDNA4) | DMA latency (~200–800 cy) | wait_group(1) overlaps DMA with MFMA | 15–40% |
| Stage 1 (CDNA3) | HBM latency (~200–800 cy) | buffer_load into VGPR, vmcnt hidden behind MFMA | 5–25% (VGPR-dependent) |
| Stage 2 (both) | LDS read latency (~40–100 cy) | ds_read one iter ahead; MFMA uses pre-loaded registers | 5–20% additional |