Step-by-step tutorial for adding a new lightweight JIT CUDA kernel to sglang's jit_kernel module
This tutorial walks through adding a simple element-wise scale operation as a JIT kernel. We'll implement scale(x, factor) = x * factor to demonstrate the complete workflow.
Add a new operation that scales each element of a tensor by a scalar factor:
x (CUDA) and scalar factor (float, passed at runtime)x * factor (element-wise), allocated internallytorch.float16), BF16 (torch.bfloat16), FP32 (torch.float32)sgl-kernel)jit_kernel): prefer this first for kernels that do not depend on CUTLASS or another large C++ project. It is the default choice for lightweight kernels that benefit from rapid iteration and first-use compilation.sgl-kernel): prefer this when the kernel does depend on CUTLASS or another large C++ project, or when it should live in sgl-kernel/ and participate in the wheel build / torch op registration flow.flashinfer, or on CUTLASS that is already provided through flashinfer, can still be implemented as jit_kernel.python/sglang/jit_kernel/include/sgl_kernel/Always prefer these abstractions over raw CUDA primitives. They provide safety, readability, and consistency with the rest of the codebase.
Important include rule: for every #include <sgl_kernel/...> line, add a short trailing comment explaining why that header is included (for example // For TensorMatcher, SymbolicSize, SymbolicDevice). This matches the current JIT kernel style and keeps include usage self-documenting.
utils.h — Host-side utilities#include <sgl_kernel/utils.h>
host::RuntimeCheck(cond, args...) — Assert a condition at runtime; throws PanicError with file/line info on failure. Prefer this over bare assert.host::Panic(args...) — Unconditionally throw a PanicError with a descriptive message.host::div_ceil(a, b) — Integer ceiling division (a + b - 1) / b.host::irange(n) / host::irange(start, end) — Range views for cleaner loops.host::pointer::offset(ptr, offsets...) — Byte-safe pointer arithmetic on void*. Use this instead of raw casts.utils.cuh — Device-side utilities + LaunchKernel#include <sgl_kernel/utils.cuh>
fp16_t, bf16_t, fp32_t, fp8_e4m3_t, fp8_e5m2_t and their packed variants fp16x2_t, bf16x2_t, fp32x2_t, etc.SGL_DEVICE — Expands to __forceinline__ __device__. Use on all device functions.device::kWarpThreads — Constant 32.device::load_as<T>(ptr, offset) / device::store_as<T>(ptr, val, offset) — Type-safe loads/stores from void*.device::pointer::offset(ptr, offsets...) — Pointer arithmetic on device.host::LaunchKernel(grid, block, device_or_stream [, smem]) — RAII kernel launcher that:
DLDevice via TVM-FFI automatically.operator()(kernel, args...)..enable_pdl(bool) for PDL (Programmatic Dependent Launch, SM90+).host::RuntimeDeviceCheck(cudaError_t) — Check a CUDA error; throw on failure.tensor.h — Tensor validation (TensorMatcher, Symbolic types)#include <sgl_kernel/tensor.h>
This is the primary validation API for all kernel launchers. Use it to validate every tvm::ffi::TensorView argument.
host::SymbolicSize{"name"} — A named symbolic dimension. Call .set_value(n) to pin it, .unwrap() to extract after verification.host::SymbolicDType — Symbolic dtype. Use .set_options<Ts...>() to restrict allowed types.host::SymbolicDevice — Symbolic device. Use .set_options<kDLCUDA>() to restrict to CUDA.host::TensorMatcher({dims...}) — Fluent builder for tensor validation:
.with_dtype<T>() — require a specific C++ type (e.g. fp16_t).with_dtype<T1, T2, ...>() — allow a set of types.with_device<kDLCUDA>(device_sym) — require CUDA and bind the checked device to a SymbolicDevice.with_strides({strides...}) — validate strides (omit to require contiguous).verify(tensor_view) — execute the check; throws PanicError with full context on failure; chainable (verify(a).verify(b) to check multiple tensors with the same shape)Typical pattern:
auto N = SymbolicSize{"num_elements"};
auto device = SymbolicDevice{};
device.set_options<kDLCUDA>();
TensorMatcher({N}) //
.with_dtype<fp16_t>()
.with_device<kDLCUDA>(device)
.verify(dst)
.verify(src); // same shape, dtype, device as dst
const size_t n = N.unwrap();
const DLDevice dev = device.unwrap();
type.cuh — dtype_trait<T> and packed_t<T>#include <sgl_kernel/type.cuh>
dtype_trait<T> — Static trait struct for each scalar type. Provides:
dtype_trait<T>::from(value) — convert from another type (e.g. fp32_t → fp16_t)dtype_trait<T>::abs/sqrt/rsqrt/exp/sin/cos(x) — type-dispatched unary math (primarily for fp32_t)dtype_trait<T>::max/min(x, y) — type-dispatched binary math (primarily for fp32_t)packed_t<T> — Two-element packed alias: packed_t<fp16_t> = fp16x2_t, packed_t<bf16_t> = bf16x2_t, packed_t<fp32_t> = fp32x2_t. Use for vectorized loads/stores.device::cast<To, From>(value) — Type-safe cast using dtype_trait, e.g. cast<fp32x2_t, fp16x2_t>(v).vec.cuh — Vectorized memory access (AlignedVector)#include <sgl_kernel/vec.cuh>
device::AlignedVector<T, N> — Aligned storage for N elements of type T. N must be a power of two, sizeof(T)*N <= 32. Enables vectorized loads/stores for bandwidth efficiency. In terms of API/codegen constraints, the upper bound is 256-bit; in practice, 128-bit is the portable default, while 256-bit vectorization is typically only viable on SM100+ and should be gated by an architecture check when needed.
.load(ptr, offset) — vectorized load from ptr[offset].store(ptr, offset) — vectorized store to ptr[offset].fill(value) — fill all lanesoperator[](i) — element accesstile.cuh — tile::Memory (strided memory access pattern)#include <sgl_kernel/tile.cuh>
tile::Memory<T> is fundamentally a 1D cooperative accessor over a contiguous region.device::tile::Memory<T>::cta(blockDim.x) — Creates a tile accessor where each thread handles tid = threadIdx.x with stride tsize (for cta(blockDim.x), this is blockDim.x). Common for loops over a 1D array..load(ptr, offset) — loads ptr[tid + offset * tsize].store(ptr, val, offset) — stores to ptr[tid + offset * tsize].in_bound(n, offset) — boundary checkFor a 2D tile, either flatten (row, col) into a linear tile index first, or compute the address manually with ptr[row * stride + col] using your thread/block coordinates.
math.cuh — Device math (device::math::)#include <sgl_kernel/math.cuh>
device::math::max/min<T>(a, b) — type-dispatched binary math via dtype_traitdevice::math::abs/sqrt/rsqrt/exp/sin/cos<T>(x) — type-dispatched unary math via dtype_traitwarp.cuh — Warp-level primitives#include <sgl_kernel/warp.cuh>
device::warp::reduce_sum<T>(value) — warp-level sum reduction via __shfl_xor_syncdevice::warp::reduce_max<T>(value) — warp-level max reductioncta.cuh — CTA-level primitives#include <sgl_kernel/cta.cuh>
device::cta::reduce_max<T>(value, smem, min_value) — CTA-wide max using shared memory + warp reduction. Caller is responsible for a __syncthreads() after if the result in smem[0] is needed.atomic.cuh — Atomic operations#include <sgl_kernel/atomic.cuh>
device::atomic::max(float* addr, float value) — float atomic max (handles negative values correctly via bit tricks).runtime.cuh — Occupancy and device info#include <sgl_kernel/runtime.cuh>
host::runtime::get_blocks_per_sm(kernel, block_dim) — max active blocks per SM (occupancy)host::runtime::get_sm_count(device_id) — number of SMs on the devicehost::runtime::get_cc_major(device_id) — compute capability major versionPersistent kernel pattern (cap blocks to SM count × occupancy):
static const uint32_t max_occ = runtime::get_blocks_per_sm(kernel, kBlockSize);
static const uint32_t num_sm = runtime::get_sm_count(device.unwrap().device_id);
const auto num_blocks = std::min(num_sm * max_occ, div_ceil(n, kBlockSize));
LaunchKernel(num_blocks, kBlockSize, device.unwrap())(kernel, params);
.clangd config for better IDE supportpython -m sglang.jit_kernel
jit_kernel/csrc/Create python/sglang/jit_kernel/csrc/elementwise/scale.cuh.
The implementation fully uses the project abstractions described above:
#include <sgl_kernel/tensor.h> // For TensorMatcher, SymbolicSize, SymbolicDevice
#include <sgl_kernel/type.cuh> // For dtype_trait, fp16_t, bf16_t, fp32_t
#include <sgl_kernel/utils.h> // For RuntimeCheck, div_ceil
#include <sgl_kernel/utils.cuh> // For LaunchKernel, SGL_DEVICE
#include <sgl_kernel/vec.cuh> // For AlignedVector
#include <dlpack/dlpack.h>
#include <tvm/ffi/container/tensor.h>
namespace {
// ----------------------------------------------------------------
// Kernel: element-wise scale using vectorized 128-bit loads/stores
// T = fp16_t | bf16_t | fp32_t
// kVecN = number of elements per vector load (e.g. 8 for fp16)
// factor = runtime scale factor
// ----------------------------------------------------------------
template <typename T, int kVecN>
__global__ void scale_kernel(T* __restrict__ dst,
const T* __restrict__ src,
float factor,
uint32_t n_total) {
using vec_t = device::AlignedVector<T, kVecN>;
const uint32_t n_vecs = n_total / kVecN;
// --- vectorised body ---
const uint32_t vec_stride = blockDim.x * gridDim.x;
for (uint32_t vi = blockIdx.x * blockDim.x + threadIdx.x;
vi < n_vecs;
vi += vec_stride) {
vec_t v;
v.load(src, vi);
#pragma unroll
for (int i = 0; i < kVecN; ++i) {
v[i] = static_cast<T>(static_cast<float>(v[i]) * factor);
}
v.store(dst, vi);
}
// --- scalar tail ---
const uint32_t base = n_vecs * kVecN;
const uint32_t scalar_stride = blockDim.x * gridDim.x;
for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
base + i < n_total;
i += scalar_stride) {
dst[base + i] = static_cast<T>(static_cast<float>(src[base + i]) * factor);
}
}
// ----------------------------------------------------------------
// Launcher: validates tensors, selects vector width, launches kernel
// ----------------------------------------------------------------
template <typename T>
void scale(tvm::ffi::TensorView dst, tvm::ffi::TensorView src, float factor) {
using namespace host;
// 1. Validate input tensors with TensorMatcher
SymbolicSize N = {"num_elements"};
SymbolicDevice device_;
device_.set_options<kDLCUDA>();
TensorMatcher({N}) //
.with_dtype<T>()
.with_device<kDLCUDA>(device_)
.verify(dst)
.verify(src); // same shape / dtype / device as dst
const uint32_t n = static_cast<uint32_t>(N.unwrap());
const DLDevice device = device_.unwrap();
RuntimeCheck(n > 0, "scale: num_elements must be > 0, got ", n);
// 2. Choose vector width for 128-bit loads (16 bytes)
// fp16/bf16: 8 elements × 2 bytes = 16 bytes
// fp32: 4 elements × 4 bytes = 16 bytes
constexpr int kVecN = 16 / sizeof(T);
const uint32_t n_work_items = div_ceil(n, static_cast<uint32_t>(kVecN));
// 3. Launch
constexpr uint32_t kBlockSize = 256;
const uint32_t grid = div_ceil(n_work_items, kBlockSize);
LaunchKernel(grid, kBlockSize, device)(
scale_kernel<T, kVecN>,
static_cast<T*>(dst.data_ptr()),
static_cast<const T*>(src.data_ptr()),
factor,
n);
}
} // namespace
Key points:
sgl_kernel/ — not raw CUDA headers for anything already covered// For ... explanation to every #include <sgl_kernel/...> lineTensorMatcher for all tensor validation; never manually check shape/dtype/deviceAlignedVector for vectorised 128-bit loads/stores — significant bandwidth winLaunchKernel — it resolves the stream and checks errors automaticallyRuntimeCheck for runtime assertions with useful error messagesfactor directly unless compile-time specialisation is genuinely requiredfp16_t / bf16_t / fp32_t are the project's type aliases (from utils.cuh)device::cast<To, From> or dtype_trait<T>::from(val) for cross-type conversionsdevice::math:: functions for device math instead of bare __ intrinsicsjit_kernel/Create python/sglang/jit_kernel/scale.py:
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args
if TYPE_CHECKING:
from tvm_ffi.module import Module
@cache_once
def _jit_scale_module(dtype: torch.dtype) -> Module:
"""Compile and cache the JIT scale module for a given dtype."""
args = make_cpp_args(dtype)
return load_jit(
"scale",
*args,
cuda_files=["elementwise/scale.cuh"],
cuda_wrappers=[("scale", f"scale<{args}>")],
)
def scale(src: torch.Tensor, factor: float, out: torch.Tensor | None = None) -> torch.Tensor:
"""
Element-wise scale: dst = src * factor.
Supported dtypes: torch.float16, torch.bfloat16, torch.float32.
Parameters
----------
src : CUDA tensor (FP16 / BF16 / FP32)
factor : scale factor
out : optional pre-allocated output tensor (same shape/dtype as src)
Returns
-------
Scaled tensor (dst = src * factor).
"""
if not src.is_cuda:
raise RuntimeError("src must be a CUDA tensor")
if src.dtype not in (torch.float16, torch.bfloat16, torch.float32):
raise RuntimeError(
f"Unsupported dtype {src.dtype}. Supported: float16, bfloat16, float32"
)
if out is None:
out = torch.empty_like(src)
else:
if out.shape != src.shape:
raise RuntimeError("out shape must match src")
if out.dtype != src.dtype:
raise RuntimeError("out dtype must match src")
if out.device != src.device:
raise RuntimeError("out device must match src")
# Keep the Python wrapper thin, but still enforce the basic preconditions
# that the current JIT/FFI path does not reject safely on its own.
module = _jit_scale_module(src.dtype)
module.scale(out, src, factor)
return out
Key points:
cache_once — not functools.lru_cache (incompatible with torch.compile)load_jit first arg(s) form the unique build marker; same marker = same cached binaryfactor should stay runtime unless the kernel truly needs templatingcuda_wrappers: (export_name, kernel_symbol) — export_name is called from Pythonmake_cpp_args(dtype, ...) converts torch.dtype to C++ type alias:is_cuda, supported dtype, out metadata). In the current JIT/FFI path, invalid tensors are not always rejected safely before launchtorch.dtype | C++ type |
|---|---|
torch.float16 | fp16_t |
torch.bfloat16 | bf16_t |
torch.float32 | fp32_t |
return load_jit(
"scale",
*args,
cuda_files=["elementwise/scale.cuh"],
cuda_wrappers=[("scale", f"scale<{args}>")],
extra_cuda_cflags=["-O3", "--use_fast_math"],
)
If your kernel requires SM90+, raise a clear Python error before calling load_jit:
if torch.cuda.get_device_capability()[0] < 9:
raise RuntimeError("This kernel requires SM90 (Hopper) or later")
JIT kernel tests live under python/sglang/jit_kernel/tests/. CI does not run pytest in that directory directly. The unified runner test/run_suite.py discovers every test_*.py there (and every bench_*.py under benchmark/), collects register_*_ci(...) calls by statically parsing each file’s AST, and executes the selected suite. Every test file must register at least one CUDA entry or the collector fails its sanity check.
test/run_suite.py → PER_COMMIT_SUITES): JIT unit tests use stage-b-kernel-unit-1-gpu-large (see .github/workflows/pr-test-jit-kernel.yml: python3 run_suite.py --hw cuda --suite stage-b-kernel-unit-1-gpu-large).nightly-kernel-1-gpu with --nightly — typically used with SGLANG_JIT_KERNEL_RUN_FULL_TESTS=1 in CI for expanded parameter grids (see python/sglang/jit_kernel/utils.py → should_run_full_tests / get_ci_test_range). Wired in .github/workflows/nightly-test-nvidia.yml (e.g. python3 run_suite.py --hw cuda --suite nightly-kernel-1-gpu --nightly --continue-on-error).Registration pattern (module level, literal est_time and suite strings — required for AST parsing):
from sglang.test.ci.ci_register import register_cuda_ci
register_cuda_ci(est_time=30, suite="stage-b-kernel-unit-1-gpu-large")
# Optional second registration: same file also listed under the nightly kernel suite
# register_cuda_ci(est_time=120, suite="nightly-kernel-1-gpu", nightly=True)
Keep est_time and suite as literal values. run_suite.py collects them from the file AST, so computed values and helper wrappers can break CI discovery.
Use register_cuda_ci(..., disabled="reason") if the file must stay in-tree but should be skipped in CI (e.g. multi-GPU only).
Run like CI (from repo root):
cd test && python3 run_suite.py --hw cuda --suite stage-b-kernel-unit-1-gpu-large
For fast iteration you can still run pytest on a single file locally; CI coverage is via run_suite.py.
Create python/sglang/jit_kernel/tests/test_scale.py:
import pytest
import torch
from sglang.jit_kernel.scale import scale
from sglang.test.ci.ci_register import register_cuda_ci
register_cuda_ci(est_time=30, suite="stage-b-kernel-unit-1-gpu-large")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("size", [1, 127, 128, 1024, 4097]) # cover tail remainder
@pytest.mark.parametrize("factor", [0.5, 1.0, 2.0, 3.0])
def test_scale_correctness(dtype, size, factor):
src = torch.randn(size, dtype=dtype, device="cuda")
out = scale(src, factor)
expected = src * factor
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2)
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_scale_out_param(dtype):
src = torch.randn(1024, dtype=dtype, device="cuda")
out = torch.empty_like(src)
result = scale(src, 2.0, out=out)
assert result is out
torch.testing.assert_close(out, src * 2.0, rtol=1e-2, atol=1e-2)
def test_scale_cpu_error():
src = torch.randn(128, dtype=torch.float16) # CPU tensor
with pytest.raises(RuntimeError, match="CUDA"):
scale(src, 2.0)
def test_scale_unsupported_dtype():
src = torch.randint(0, 10, (128,), dtype=torch.int32, device="cuda")
with pytest.raises(RuntimeError, match="dtype"):
scale(src, 2.0)
if __name__ == "__main__":
import sys
sys.exit(pytest.main([__file__, "-v", "-s"]))
Benchmarks are bench_*.py files under python/sglang/jit_kernel/benchmark/. They are picked up by the same run_suite.py machinery as unit tests. Register them for stage-b-kernel-benchmark-1-gpu-large (PR JIT benchmark job: python3 run_suite.py --hw cuda --suite stage-b-kernel-benchmark-1-gpu-large).
Create python/sglang/jit_kernel/benchmark/bench_scale.py:
import itertools
import torch
import triton
import triton.testing
from sglang.jit_kernel.benchmark.utils import (
DEFAULT_DEVICE,
DEFAULT_DTYPE,
get_benchmark_range,
run_benchmark,
)
from sglang.jit_kernel.scale import scale as jit_scale
from sglang.test.ci.ci_register import register_cuda_ci
register_cuda_ci(est_time=6, suite="stage-b-kernel-benchmark-1-gpu-large")
SIZE_LIST = get_benchmark_range(
full_range=[2**n for n in range(10, 20)], # 1K … 512K elements
ci_range=[4096, 65536],
)
configs = list(itertools.product(SIZE_LIST))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["size"],
x_vals=configs,
line_arg="provider",
line_vals=["jit", "torch"],
line_names=["SGL JIT Kernel", "PyTorch"],
styles=[("blue", "-"), ("red", "--")],
ylabel="us",
plot_name="scale-performance",
args={},
)
)
def benchmark(size: int, provider: str):
src = torch.randn(size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE)
factor = 2.0
if provider == "jit":
fn = lambda: jit_scale(src, factor)
else:
fn = lambda: src * factor
return run_benchmark(fn)
if __name__ == "__main__":
benchmark.run(print_data=True)
Run locally:
python python/sglang/jit_kernel/benchmark/bench_scale.py
Run the benchmark suite the way CI does:
cd test && python3 run_suite.py --hw cuda --suite stage-b-kernel-benchmark-1-gpu-large
No CI registry found in ... from run_suite.py: add a module-level register_cuda_ci(...) with literal est_time and suite (and optional nightly=True); starred args and non-literal values break AST collection.cuh file is under python/sglang/jit_kernel/csrc/; reduce template argument combinationsCUDA_LAUNCH_BLOCKING=1; compute-sanitizer --tool memcheck python ...run_benchmark uses CUDA-graph-based timing by defaultdocs/developer_guide/development_jit_kernel_guide.mdtest/run_suite.py — suite names, discovery of jit_kernel/tests/ and jit_kernel/benchmark/, execution entrypoint for CIpython/sglang/test/ci/ci_register.py — register_cuda_ci and AST registration rulespython/sglang/jit_kernel/utils.py — cache_once, load_jit, make_cpp_args, should_run_full_tests, get_ci_test_rangepython/sglang/jit_kernel/include/sgl_kernel/tensor.h — TensorMatcher, SymbolicSize/DType/Devicepython/sglang/jit_kernel/include/sgl_kernel/utils.cuh — type aliases, LaunchKernel, SGL_DEVICEpython/sglang/jit_kernel/include/sgl_kernel/vec.cuh — AlignedVectorpython/sglang/jit_kernel/include/sgl_kernel/tile.cuh — tile::Memorypython/sglang/jit_kernel/include/sgl_kernel/type.cuh — dtype_trait, packed_t, device::castpython/sglang/jit_kernel/include/sgl_kernel/math.cuh — device::math::python/sglang/jit_kernel/include/sgl_kernel/warp.cuh — warp::reduce_sum/maxpython/sglang/jit_kernel/include/sgl_kernel/cta.cuh — cta::reduce_maxpython/sglang/jit_kernel/include/sgl_kernel/atomic.cuh — atomic::maxpython/sglang/jit_kernel/include/sgl_kernel/runtime.cuh — occupancy / SM count helperspython/sglang/jit_kernel/csrc/add_constant.cuh — minimal runnable referencepython/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh — real example using TensorMatcher + LaunchKernel + tile::Memorypython/sglang/jit_kernel/csrc/elementwise/qknorm.cuh — real example using runtime::get_blocks_per_sm + persistent kernel patternpython/sglang/jit_kernel/benchmark/utils.py — benchmark helperspython/sglang/jit_kernel/csrc/elementwise/scale.cuh # NEW: CUDA kernel
python/sglang/jit_kernel/scale.py # NEW: Python wrapper
python/sglang/jit_kernel/tests/test_scale.py # NEW: Tests
python/sglang/jit_kernel/benchmark/bench_scale.py # NEW: Benchmark
Converted and distributed by TomeVault | Claim this content