Guide for implementing CUDA or CPU JIT kernels in mllm-kernel. Use when the user asks to create, add, or implement a new kernel in mllm-kernel.
mllm-kernel uses a JIT (Just-In-Time) compilation system built on tvm_ffi. Kernels are written in C++20 (.cuh for CUDA, .cpp for CPU), validated at runtime via TensorMatcher, and exposed to Python through a @jit decorator. No pre-compilation is needed -- kernels compile on first call and are cached at ~/.cache/mllm_kernel/.
For a kernel named my_kernel:
mllm-kernel/
mllm_kernel/
cuda/
csrc/my_kernel.cuh # CUDA kernel implementation
jit/my_kernel.py # Python JIT wrapper
jit/__init__.py # Add export here
cpu/
csrc/my_kernel.cpp # CPU kernel implementation (Highway SIMD)
include/mllm_kernel/cpu/
my_kernel.hpp # CPU SIMD body (NO #pragma once)
jit/my_kernel.py # Python JIT wrapper
jit/__init__.py # Add export here
tests/test_my_kernel.py # Pytest correctness tests
benchmarks/bench_my_kernel.py # Profiler benchmark vs PyTorch reference
.cuh kernelCreate mllm_kernel/cuda/csrc/my_kernel.cuh:
#pragma once
#include <mllm_kernel/tensor.hpp> // TensorMatcher, SymbolicSize, SymbolicDevice, SymbolicDType
#include <mllm_kernel/utils.hpp> // RuntimeCheck, Panic, div_ceil
#include <mllm_kernel/utils.cuh> // LaunchKernel, fp16_t, bf16_t, PDL helpers
#include <dlpack/dlpack.h>
#include <tvm/ffi/container/tensor.h>
#include <cstdint>
namespace {
// ---------------------------------------------------------------------------
// 1. Parameter struct (trivially copyable, passed to kernel by value)
// ---------------------------------------------------------------------------
struct MyKernelParams {
const float* __restrict__ input;
float* __restrict__ output;
int32_t num_elements;
};
// ---------------------------------------------------------------------------
// 2. CUDA kernel
// ---------------------------------------------------------------------------
__global__ void my_kernel(const MyKernelParams params) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= params.num_elements) return;
params.output[idx] = params.input[idx] * 2.0f;
}
// ---------------------------------------------------------------------------
// 3. Host-side launcher (entry point for TVM FFI binding)
// ---------------------------------------------------------------------------
struct MyKernel {
static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView output) {
using namespace mllm_kernel::host;
// --- Validate tensors ---
SymbolicSize N{"num_elements"};
SymbolicDevice device;
(void)TensorMatcher({N})
.with_dtype<float>()
.with_device<kDLCUDA>(device)
.verify(input);
(void)TensorMatcher({N})
.with_dtype<float>()
.with_device(device)
.verify(output);
const int64_t n = N.unwrap();
RuntimeCheck(n > 0, "num_elements must be positive, got ", n);
// --- Build params ---
MyKernelParams params{
.input = static_cast<const float*>(input.data_ptr()),
.output = static_cast<float*>(output.data_ptr()),
.num_elements = static_cast<int32_t>(n),
};
// --- Launch ---
constexpr int kBlock = 256;
const int grid = static_cast<int>(div_ceil(n, kBlock));
LaunchKernel(grid, kBlock, device.unwrap())(my_kernel, params);
}
};
} // namespace
Key rules:
namespace {} (anonymous namespace).static void run(tvm::ffi::TensorView ...) method.TensorMatcher before reading .data_ptr().data_ptr() returns a GPU pointer.LaunchKernel to launch -- it handles stream resolution and error checking.Create mllm_kernel/cuda/jit/my_kernel.py:
"""JIT wrapper for my_kernel CUDA kernel."""
import torch
from mllm_kernel.jit_utils import jit
@jit(
args=[],
device="cuda",
cuda_files=["my_kernel.cuh"],
cpp_wrappers=[],
cuda_wrappers=[("my_kernel", "MyKernel::run")],
func_name="my_kernel",
)
def _kernel(compiled_module, input: torch.Tensor, output: torch.Tensor) -> None:
compiled_module.my_kernel(input, output)
def my_kernel(input: torch.Tensor) -> torch.Tensor:
"""Double every element in *input*.
Parameters
----------
input : torch.Tensor
1-D float32 tensor on CUDA.
Returns
-------
torch.Tensor
Same shape and dtype as *input*.
"""
output = torch.empty_like(input)
_kernel(input, output)
return output
__init__.pyEdit mllm_kernel/cuda/jit/__init__.py and add:
from mllm_kernel.cuda.jit.my_kernel import my_kernel
.cuhAny time you modify the .cuh file, delete the cached .so:
rm -rf ~/.cache/mllm_kernel/cuda_my_kernel*
The next Python call will trigger recompilation automatically.
When the kernel takes compile-time constants (e.g. block size, dtype), use make_cpp_args:
from mllm_kernel.jit_utils import jit, make_cpp_args
def _make_kernel(block_size: int, use_pdl: bool):
cpp_args = make_cpp_args(block_size, use_pdl) # -> "256, true"
@jit(
args=[block_size, use_pdl],
device="cuda",
cuda_files=["my_kernel.cuh"],
cpp_wrappers=[],
cuda_wrappers=[("my_kernel", f"MyKernel<{cpp_args}>::run")],
func_name="my_kernel",
)
def _kernel(compiled_module, input, output):
compiled_module.my_kernel(input, output)
return _kernel
make_cpp_args converts Python types to C++ literals:
int/float -> string literalbool -> "true" / "false"torch.dtype -> C++ type (torch.float32 -> "fp32_t", torch.float16 -> "fp16_t", torch.bfloat16 -> "bf16_t", torch.int32 -> "int32_t", etc.)CPU kernels use Google Highway for portable SIMD. The key difference: the .hpp body is included multiple times by Highway's foreach_target dispatch, so it must NOT have #pragma once.
.hpp)Create mllm_kernel/cpu/include/mllm_kernel/cpu/my_kernel.hpp:
// NOTE: NO #pragma once -- this file is included multiple times by Highway.
#include <hwy/highway.h>
HWY_BEFORE_NAMESPACE();
namespace mllm_kernel::cpu {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <int Constant>
inline void my_kernel_impl(float* HWY_RESTRICT dst,
const float* HWY_RESTRICT src,
size_t count) {
const hn::ScalableTag<float> d;
const size_t lanes = hn::Lanes(d);
const auto vc = hn::Set(d, static_cast<float>(Constant));
size_t i = 0;
for (; i + lanes <= count; i += lanes) {
const auto v = hn::Load(d, src + i);
hn::Store(hn::Add(v, vc), d, dst + i);
}
for (; i < count; ++i) {
dst[i] = src[i] + static_cast<float>(Constant);
}
}
// Named entry points for HWY_EXPORT
static HWY_NOINLINE HWY_MAYBE_UNUSED void my_kernel_1(float* d, const float* s, size_t n) {
my_kernel_impl<1>(d, s, n);
}
} // namespace HWY_NAMESPACE
} // namespace mllm_kernel::cpu
HWY_AFTER_NAMESPACE();
.cpp sourceCreate mllm_kernel/cpu/csrc/my_kernel.cpp:
#include <mllm_kernel/tensor.hpp>
#include <mllm_kernel/utils.hpp>
#include <tvm/ffi/container/tensor.h>
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "../csrc/my_kernel.cpp"
#include <hwy/foreach_target.h>
#include <mllm_kernel/cpu/my_kernel.hpp>
#if HWY_ONCE
#include <hwy/targets.cc>
#endif
namespace mllm_kernel::cpu {
#if HWY_ONCE
HWY_EXPORT(my_kernel_1);
template <int Constant>
void my_kernel(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {
using namespace mllm_kernel::host;
SymbolicSize N{"num_elements"};
SymbolicDevice device_;
(void)TensorMatcher({N})
.with_dtype<float>()
.with_device<kDLCPU>(device_)
.verify(dst)
.verify(src);
const size_t n = N.unwrap();
auto* dst_ptr = static_cast<float*>(dst.data_ptr());
const auto* src_ptr = static_cast<const float*>(src.data_ptr());
HWY_DYNAMIC_DISPATCH(my_kernel_1)(dst_ptr, src_ptr, n);
}
// Explicit instantiation
template void my_kernel<1>(tvm::ffi::TensorView, tvm::ffi::TensorView);
#endif
} // namespace mllm_kernel::cpu
Create mllm_kernel/cpu/jit/my_kernel.py:
import torch
from mllm_kernel.jit_utils import jit
@jit(
args=1,
device="cpu",
cpp_files=["my_kernel.cpp"],
cpp_wrappers=[("my_kernel", "mllm_kernel::cpu::my_kernel<1>")],
func_name="my_kernel",
)
def _kernel_1(compiled_module, dst, src):
compiled_module.my_kernel(dst, src)
def my_kernel(src: torch.Tensor) -> torch.Tensor:
dst = torch.empty_like(src)
_kernel_1(dst, src)
return dst
Key CPU differences from CUDA:
| Aspect | CUDA | CPU |
|---|---|---|
| Source file | .cuh in cuda/csrc/ | .cpp + .hpp in cpu/csrc/ and cpu/include/ |
| Namespace | Anonymous namespace {} | mllm_kernel::cpu |
| Device check | with_device<kDLCUDA> | with_device<kDLCPU> |
| Launch | LaunchKernel(grid, block, device)(...) | Direct function call via HWY_DYNAMIC_DISPATCH |
| SIMD | CUDA warps | Highway ScalableTag<T> |
| Wrapper fields | cuda_files, cuda_wrappers | cpp_files, cpp_wrappers |
| Wrapper name | "MyKernel::run" | "mllm_kernel::cpu::my_kernel<1>" (fully qualified) |
TensorMatcher validates shape, dtype, device, and strides of tvm::ffi::TensorView arguments.
using namespace mllm_kernel::host;
// Symbolic dimensions -- bind on first .verify(), check consistency on subsequent calls
SymbolicSize B{"batch"}, N{"seq_len"}, D{"dim"};
SymbolicSize Stride0{"stride0"};
SymbolicDType dtype;
SymbolicDevice device;
// Shape [B, N, D], contiguous, float32, on CUDA
(void)TensorMatcher({B, N, D})
.with_dtype<float>(dtype)
.with_device<kDLCUDA>(device)
.verify(tensor_a);
// Shape [B, N, D], same dtype and device (already bound)
(void)TensorMatcher({B, N, D})
.with_dtype(dtype)
.with_device(device)
.verify(tensor_b);
// Shape [B, D] with explicit strides (non-contiguous OK)
(void)TensorMatcher({B, D})
.with_strides({Stride0, 1})
.with_dtype<int32_t>()
.with_device(device)
.verify(indices);
// Multiple acceptable dtypes
SymbolicDType flex_dtype;
(void)TensorMatcher({N})
.with_dtype<float, __half, __nv_bfloat16>(flex_dtype)
.with_device(device)
.verify(mixed_tensor);
// Extract bound values
int64_t batch = B.unwrap();
int64_t dim = D.unwrap();
DLDevice dev = device.unwrap();
using namespace mllm_kernel::host;
// Basic launch (resolves CUDA stream from DLDevice)
DLDevice dev = device.unwrap();
LaunchKernel(grid_dim, block_dim, dev)(kernel_func, param_struct);
// With shared memory
LaunchKernel(grid, block, dev, shared_mem_bytes)(kernel, params);
// With PDL (Programmatic Dependent Launch, sm_90+)
LaunchKernel(grid, block, dev).enable_pdl(true)(kernel, params);
mllm_kernel::host)| Function | Description |
|---|---|
RuntimeCheck(cond, msg...) | Throws PanicError if cond is false |
Panic(msg...) | Always throws (unreachable code) |
div_ceil(a, b) | Integer ceiling division |
dtype_bytes(DLDataType) | Byte size of a DLPack dtype |
CUDA-only (mllm_kernel::device):
| Symbol | Value |
|---|---|
kWarpThreads | 32 |
kFullMask | 0xffffffff |
fp16_t | __half |
bf16_t | __nv_bfloat16 |
Create tests/test_my_kernel.py:
import pytest
import torch
from mllm_kernel.cuda.jit.my_kernel import my_kernel
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@pytest.mark.parametrize("n", [1, 128, 1024, 65536])
def test_my_kernel(n):
x = torch.randn(n, dtype=torch.float32, device="cuda")
result = my_kernel(x)
torch.cuda.synchronize()
expected = x * 2.0
assert torch.allclose(result, expected)
Run:
pytest tests/test_my_kernel.py -v
Create benchmarks/bench_my_kernel.py. Use torch.profiler.profile with ProfilerActivity.CPU and ProfilerActivity.CUDA. Compare the JIT kernel against a naive PyTorch implementation and report speedup.
Run:
python benchmarks/bench_my_kernel.py --num-elements 1000000
.cuh / .cpp + .hpp kernel source createdTensorMatcher validates all tensor arguments (shape, dtype, device)@jit wrapper created with correct cuda_wrappers or cpp_wrappers_kernel)jit/__init__.py.cuh edits (rm -rf ~/.cache/mllm_kernel/cuda_<name>*)@pytest.mark.parametrize and PyTorch referencetorch.profiler (optional but recommended)tensor.data_ptr() returns a GPU pointer for CUDA tensors. Never read its contents in host code. Use TensorMatcher for validation instead..cuh, delete ~/.cache/mllm_kernel/cuda_<kernel_name>*/. The old .so will be reused otherwise.#include <hwy/targets.cc> -- CPU kernels must include this inside #if HWY_ONCE to provide GetChosenTarget for the JIT-built module.#pragma once in Highway .hpp -- Highway's foreach_target includes the file multiple times for different SIMD targets. #pragma once breaks this."MyKernel::run"); CPU uses fully qualified names ("mllm_kernel::cpu::my_kernel<1>").torch.randperm needs a CUDA generator on CUDA; torch.randint only accepts CPU generators. Use separate generators.