Scaffold a Triton kernel (triton_impl.py) with autotuning config and benchmark integration. Use when: adding Triton implementation, writing Triton kernel, Triton autotuning, triton jit kernel.
Create a Triton kernel implementation for an existing operator benchmark.
scripts/add_operator.py instead — it creates the Triton kernel along with all other files.benchmarks/<op>/triton_impl.py.# The add_operator.py script creates triton_impl.py automatically:
python scripts/add_operator.py <op> --pytorch-op "torch.<op>(x)" \
--triton-body "tl.abs(x)"
benchmarks/<op>/ with config.py and pytorch_impl.pypip install triton<op>_triton (e.g., abs_triton, softmax_triton)BLOCK_SIZE (256, 512, 1024, 2048, 4096) and num_warpstl.int64 before multiplying by a stride:
# BAD — int32 overflow for large shapes
row_start = input_ptr + row * stride_row
# GOOD — cast to int64 first
row_start = input_ptr + row.to(tl.int64) * stride_row
torch.Tensor with same signature as pytorch_impltorch.float16 and torch.float32from benchmarks.<op>.triton_impl import <op>_triton"""Triton <op> implementation."""
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
triton.Config({"BLOCK_SIZE": 4096}, num_warps=8),
],
key=["n_elements"],
)
@triton.jit
def _<op>_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(tl.int64)
mask = offsets < n_elements
x = tl.load(input_ptr + offsets, mask=mask)
result = tl.abs(x) # <-- replace with actual computation
tl.store(output_ptr + offsets, result, mask=mask)
def <op>_triton(x: torch.Tensor) -> torch.Tensor:
assert x.is_cuda, "Input must be a CUDA tensor"
output = torch.empty_like(x)
n_elements = x.numel()
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_<op>_kernel[grid](x.reshape(-1), output.reshape(-1), n_elements)
return output
python -c "
import torch
from benchmarks.<op>.triton_impl import <op>_triton
from benchmarks.<op>.pytorch_impl import <op>_pytorch
x = torch.randn(1024, device='cuda', dtype=torch.float16)
assert torch.allclose(<op>_triton(x), <op>_pytorch(x), atol=1e-2)
print('Triton OK')
"
ruff check benchmarks/<op>/triton_impl.py