ONLY for OpenAI Triton (@triton.jit) kernel development. NEVER use for CUDA C++ kernels, TileIR, or profiling tools (ncu, nsys). The user's request must involve Triton explicitly. Covers Triton-specific patterns: fused elementwise, reductions (softmax, LayerNorm, RMSNorm), tiled GEMM with triton.autotune, and flash attention. Workflow: design, write, verify (with fast-path for explicit requests).
kernel_fn, reference_fn, and get_inputs() exports for companion scripts.scripts/verify_kernel.py to validate against the reference.Transcendental functions (tl.exp, tl.log, tl.math.erf, tl.math.tanh) require fp32 inputs.
# WRONG -- compilation error or wrong results with fp16/bf16:
result = tl.exp(x)
# CORRECT -- cast to fp32, compute, cast back:
x_fp32 = x.to(tl.float32)
result = tl.exp(x_fp32).to(x.dtype)
Rule: any math function beyond basic arithmetic (+, -, *, /) requires fp32 cast in, original dtype cast out.
Additional precision constraints:
tl.sigmoid() is unavailable in some Triton versions. Use 1.0 / (1.0 + tl.exp(-x_fp32)).x.dtype before tl.store -- mismatches cause "Type mismatch, store Float32 to Float16".tl.float32 accumulators for tl.dot.tl.dot uses TF32 by default for fp32 inputs (same as torch.mm). Do NOT add input_precision="ieee" — it is 3-8x slower. TF32 is the correct default. If verification fails due to TF32 precision (~0.01-0.1 abs diff), ensure reference_fn also uses TF32 (plain torch.mm, no allow_tf32=False).Never call .item() in kernel wrappers. It forces a CPU-GPU sync (~50-100us per call).
| Pitfall | Fix |
|---|---|
tensor.item() for seed | x.data_ptr() % (2**31) |
torch.randint(...).item() | Use tensor metadata for pseudo-random seed |
| Allocating output every call | Accept pre-allocated output as parameter |
| Python loops calling kernel | Batch operations |
Triton uses C semantics (round toward zero) for // and %, NOT Python semantics (round toward negative infinity). This only matters when operands can be negative.
| Expression | Python | Triton/C |
|---|---|---|
-7 // 2 | -4 | -3 |
-7 % 2 | 1 | -1 |
Safe pattern: Ensure all index/offset values are non-negative. If negative values are possible, use (idx % BLOCK + BLOCK) % BLOCK.
See references/concepts-semantics.md for full rules and scalar-only exception.
triton.cdiv(n_elements, BLOCK_SIZE). With autotune, grid must be a lambda: lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),).@triton.autotune (outermost) -> @triton.heuristics -> @triton.jit (innermost).reset_to_zero: Required for autotune on kernels that accumulate (e.g., matmul output). Without it, later configs see leftover values from earlier trials.Fast path: If the user explicitly requests a Triton kernel (e.g., "Write a Triton kernel for X", "Implement softmax in Triton"), start at Phase 2. Only use Phase 0-1 when the request is ambiguous about whether Triton is appropriate.
Skip this phase if the user explicitly asks for a Triton kernel. Only use when the request is ambiguous (e.g., "optimize this operation").
Triton wins when 2+ operations can share registers instead of writing/reading global memory. Quick rules:
| Pattern | Decision |
|---|---|
Single element-wise op (relu, sigmoid) | SKIP — PyTorch already optimal |
| Standalone matmul | SKIP — cuBLAS is optimized |
| Standard attention | SKIP — Use FlashAttention |
| Element-wise chain (2+ ops), reduction, matmul + epilogue | USE TRITON |
If SKIP, recommend the alternative and STOP. See references/operator-routing.md for edge cases.
From the user's request, identify: (1) operation type, (2) parallelization strategy, (3) input shapes and dtypes.
Pick the skeleton below that matches your operation. These skeletons are sufficient for element-wise, reduction, matmul, and fusion kernels — do NOT read reference files for these common patterns. Only consult references/ when implementing uncommon patterns (grouped GEMM, TMA, extern functions) or debugging issues.
Element-wise skeleton (GELU, dropout, fused ops on flat tensors):
@triton.jit
def kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# ... compute ...
tl.store(out_ptr + offsets, result, mask=mask)
Row-wise skeleton (softmax, LayerNorm, RMSNorm — one program per row):
@triton.jit
def kernel(x_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
x = tl.load(x_ptr + row_idx * n_cols + col_offsets, mask=mask, other=0.0)
# ... reduce / normalize ...
tl.store(out_ptr + row_idx * n_cols + col_offsets, result, mask=mask)
Tiled matmul skeleton (GEMM with 2D tiling, grouped ordering, and autotune):
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(0)
num_m_blocks = tl.cdiv(M, BLOCK_M)
num_n_blocks = tl.cdiv(N, BLOCK_N)
# Grouped ordering for L2 cache locality
num_pid_in_group = GROUP_SIZE_M * num_n_blocks
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_m_blocks - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
offs_k += BLOCK_K
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
Create an output directory, then write the kernel file to {output_dir}/kernel.py.
The kernel file MUST include:
@triton.jit decorated kernel function@triton.autotune for production kernels (see references/api-core.md)kernel_fn — alias to the wrapper functionreference_fn(*args) — PyTorch reference with identical signatureget_inputs() — returns list of fresh CUDA tensors for testing/benchmarkingConcise example (fused GELU + dropout):
import triton
import triton.language as tl
import torch
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
],
key=['n_elements'],
)
@triton.jit
def fused_gelu_dropout_kernel(
x_ptr, out_ptr, n_elements, p, seed,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x_fp32 = x.to(tl.float32)
x = (0.5 * x_fp32 * (1.0 + tl.math.erf(x_fp32 * 0.7071067811865476))).to(x.dtype)
random = tl.rand(seed, offsets)
x = tl.where(random > p, x / (1.0 - p), 0.0)
tl.store(out_ptr + offsets, x, mask=mask)
def fused_gelu_dropout_triton(x: torch.Tensor, p: float = 0.1) -> torch.Tensor:
n_elements = x.numel()
out = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
seed = (x.data_ptr() % (2**31)) ^ n_elements # sync-free seed
fused_gelu_dropout_kernel[grid](x, out, n_elements, p, seed)
return out
# --- Fixed contract (companion scripts rely on these names) ---
kernel_fn = fused_gelu_dropout_triton
def reference_fn(x, p=0.1):
torch.manual_seed((x.data_ptr() % (2**31)) ^ x.numel())
return torch.nn.functional.dropout(
torch.nn.functional.gelu(x), p, training=True
)
def get_inputs():
return [torch.randn(128 * 1024 * 1024, device="cuda")]
For more patterns (SiLU+mul, RMSNorm, linear+GELU, add+LayerNorm), see references/patterns-fusion.md. For GEMM patterns, see references/patterns-gemm.md.
Run the companion verification script:
python scripts/verify_kernel.py {output_dir}/kernel.py --rtol 1e-3 --atol 1e-3
Output:
{"correct": true, "max_abs_diff": 1.2e-7, "max_rel_diff": 3.4e-6, "details": "..."}
Stop if correct: false. Fix the kernel before benchmarking.
Tolerance guide:
| Dtype | rtol | atol | Notes |
|---|---|---|---|
| float16 | 1e-3 | 1e-3 | |
| bfloat16 | 1e-2 | 1e-2 | |
| float32 | 1e-5 | 1e-5 | Element-wise ops |
| float32 (matmul) | 1e-2 | 1e-1 | TF32 accumulation order differs between Triton tiles and cuBLAS |
Only benchmark if the user explicitly requests performance numbers. Skip this phase for correctness-focused requests.
python scripts/benchmark_kernel.py {output_dir}/kernel.py
Output:
{"kernel_time_ms": 0.45, "reference_time_ms": 1.23, "speedup": 2.73, "warmup_iters": 10, "benchmark_iters": 40}
The skeletons and principles above cover element-wise, reduction, matmul, and fusion kernels. Do NOT read reference files for these common patterns.
Only consult references/ when:
tl.* operationHow to search: Grep for your keyword across references/. Read only the file Grep points to.
| File | When to use |
|---|---|
references/api-core.md | Unfamiliar triton.autotune / triton.Config options |
references/api-language.md | Unfamiliar tl.* operations |
references/patterns-gemm.md | Grouped GEMM, persistent matmul, TMA, MX formats |
references/patterns-advanced.md | Flash attention details, backward passes, libdevice |
references/troubleshooting.md | Debug ops, interpreter mode, env vars |
| Error / Symptom | Cause | Fix |
|---|---|---|
| "Type mismatch, store Float32 to Float16" | Missing .to(x.dtype) before store | Cast fp32 result back |
BLOCK_SIZE is not a constexpr | Block size passed as runtime value | Add : tl.constexpr annotation |
shape mismatch in binary op | Tensor shapes don't broadcast | Check with tl.static_print; use [:, None] / [None, :] |
| Large diffs everywhere | Wrong dtype in tl.load | Check load dtype matches input |
| Matmul 3-8x slower than expected | input_precision="ieee" on tl.dot | Remove it; use TF32 default. Ensure reference_fn also uses TF32 |
| Matmul ~0.01-0.1 abs diff vs reference | TF32 vs IEEE mismatch | Use same precision in both kernel and reference (TF32 for both) |
| Diffs at boundaries | Missing mask | Add mask to all load/store ops |
| Random diffs | Race condition | Check atomics and ordering |
| NaN/Inf | Division by zero or fp16 overflow | Guard with epsilon; use tl.float32 accumulator |
grid must be a tuple | Grid lambda returns int, not tuple | Return (value,) with trailing comma |
expected constexpr in tl.arange | Non-constexpr argument | Both args of tl.arange(start, end) must be constexpr |
triton.OutOfResources | Register/shared memory pressure | Reduce BLOCK_SIZE or num_stages |
| Kernel not updating after edit | Stale compilation cache | rm -rf ~/.triton/cache/ |
| Mismatched results vs PyTorch | C integer division semantics | Triton uses truncation; see references/concepts-semantics.md |
For extended error table, interpreter mode issues, and environment variables, see references/troubleshooting.md.
Stop and report failure if: