Write Metal/MPS kernels for PyTorch operators. Use when adding MPS device support to operators, implementing Metal shaders, or porting CUDA kernels to Apple Silicon. Covers native_functions.yaml dispatch, host-side operators, and Metal kernel implementation.
This skill guides you through implementing Metal kernels for PyTorch operators on Apple Silicon.
Important: The goal of this skill is to use native Metal capabilities via the c10/metal/ infrastructure, NOT MPSGraph. Native Metal kernels provide better control, performance, and maintainability.
There are two workflows covered by this skill:
Both workflows involve:
aten/src/ATen/native/native_functions.yamlaten/src/ATen/native/mps/kernels/aten/src/ATen/native/mps/operations/Location: aten/src/ATen/native/native_functions.yaml
Find the operator entry and add MPS dispatch:
# Simple MPS-specific implementation
- func: my_op(Tensor self) -> Tensor
dispatch:
CPU: my_op_cpu
CUDA: my_op_cuda
MPS: my_op_mps
# Shared implementation across devices (preferred for structured kernels)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA, MPS: my_op_out
# Structured kernel (preferred for new ops)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: my_op_out
When migrating an existing operator from MPSGraph to native Metal, consolidate the dispatch entry:
# BEFORE (MPSGraph-based, separate dispatch)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: atan2_out
MPS: atan2_out_mps # Separate MPS implementation
# AFTER (native Metal, shared dispatch via stub)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: atan2_out # MPS now uses the same stub mechanism
Key change: Replace MPS: my_op_out_mps with adding MPS to the shared dispatch line (e.g., CPU, CUDA, MPS: my_op_out).
Dispatch naming conventions:
MPS: function_name_mps - MPS-specific implementation (old MPSGraph pattern)CPU, CUDA, MPS: function_name - Shared stub implementation (native Metal pattern)Location: aten/src/ATen/native/mps/kernels/
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
// Define operation functor
struct my_op_functor {
template <typename T>
inline T operator()(const T x) {
return /* your operation */;
}
};
// Register for supported types
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);
struct my_binary_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return /* your operation */;
}
};
REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);
For binary operations, use the convenience macros defined in BinaryKernel.metal:
// Floating-point types only (float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);
// Integral types with float output (for math ops like atan2, copysign)
// Registers: long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);
// Integral types with same-type output (for bitwise/logical ops)
// Registers: long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);
// Floating-point with opmath precision (for ops needing higher precision)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
Common patterns:
REGISTER_FLOAT_BINARY_OP and REGISTER_INT2FLOAT_BINARY_OPREGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OPREGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OPExample for atan2 (supports both float and int inputs):
struct atan2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::atan2(float(a), float(b)));
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return precise::atan2(float(a), float(b));
}
};
REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);
struct my_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return a + c10::metal::mul(alpha, b);
}
};
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
struct special_functor {
// Floating point types
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return precise::exp(x); // Use precise math
}
// Integral types
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return precise::exp(float(x));
}
// Complex types (float2 for cfloat, half2 for chalf)
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
// x.x = real, x.y = imaginary
return T(/* real */, /* imag */);
}
};
Note on complex types: Complex numbers in Metal are represented as vector types:
c10::complex<float> maps to float2 (x = real, y = imaginary)c10::complex<half> maps to half2Use is_complex_v<T> to specialize for complex types in functors.
utils.h:
opmath_t<T> - Operation math type (half->float)accum_t<T> - Accumulation type for reductionsmax(), min() with NaN propagationspecial_math.h:
precise::exp(), precise::log(), precise::sqrt()precise::sin(), precise::cos(), precise::tan()erf(), erfc(), erfinv()indexing.h:
REGISTER_UNARY_OP(name, in_type, out_type)REGISTER_BINARY_OP(name, in_type, out_type)REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)Location: aten/src/ATen/native/mps/operations/
Choose or create an appropriate file based on operation type:
UnaryKernel.mm - Single input operations via stub dispatchBinaryKernel.mm - Two input operations via stub dispatchUnaryOps.mm / BinaryOps.mm - Legacy MPSGraph implementations (for reference)ReduceOps.mm - Reductions (sum, mean, max, etc.)For structured kernels that use the TensorIterator pattern:
// In BinaryKernel.mm (or appropriate file)
static void my_op_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "my_op"); // "my_op" matches the functor name in .metal
}
// Register the MPS stub - this connects to the dispatch system
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
For unary operations:
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "my_unary");
}
REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)
When migrating from MPSGraph, also remove the old implementation:
Remove from BinaryOps.mm (or UnaryOps.mm):
TORCH_IMPL_FUNC(my_op_out_mps) implementation#include <ATen/ops/my_op_native.h> headerAdd to BinaryKernel.mm (or UnaryKernel.mm):
REGISTER_DISPATCH callAfter making changes, compile to verify everything builds correctly:
cd build && ninja torch_cpu
Basic operator support is already tested by test_output_match in test/test_mps.py. After implementing an operator, enable testing by removing expected failures:
Location: torch/testing/_internal/common_mps.py
Find and remove the operator from skip/xfail lists:
# Remove entries like:
MPS_XFAILLIST = {
"my_op": ..., # Remove this line
}
MPS_SKIPLIST = {
"my_op": ..., # Remove this line
}
Location: torch/testing/_internal/common_methods_invocations.py (or related files)
Remove MPS-specific decorators from the OpInfo:
OpInfo(
"my_op",
# Remove decorators like:
# decorators=[skipMPS, expectedFailureMPS("reason")],
...
)
# Run the specific operator test
python test/test_mps.py -k test_output_match_my_op
# Or run full MPS test suite
python test/test_mps.py
torch.mps.compile_shaderUse torch.mps.compile_shader to JIT-compile and test individual Metal kernels in isolation. This is invaluable for debugging multi-kernel pipelines where you need to verify each stage independently.
import torch
source = '''
#include <metal_stdlib>
using namespace metal;
kernel void my_kernel(
const device float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint tid [[thread_position_in_grid]]) {
output[tid] = input[tid] * 2.0;
}
'''
lib = torch.mps.compile_shader(source)
inp = torch.tensor([1.0, 2.0, 3.0], device='mps')
out = torch.zeros(3, device='mps')
lib.my_kernel(inp, out, threads=[3, 1, 1], group_size=[3, 1, 1])
torch.mps.synchronize()
print(out) # tensor([2., 4., 6.], device='mps:0')
compile_shader uses dispatchThreads semantics (same as mtl_dispatch1DJob in PyTorch):
threads=[N, 1, 1] — total number of threads (NOT threadgroups)group_size=[G, 1, 1] — threads per threadgroupThis differs from the dispatchThreadgroups API used by some host-side code. To match dispatchThreadgroups:MTLSizeMake(num_tgs, num_slices, 1) threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1):
# Equivalent compile_shader call:
lib.kernel(args...,
threads=[num_tgs * TG_SIZE, num_slices, 1],
group_size=[TG_SIZE, 1, 1])
Pass scalar constants as single-element tensors:
slice_size = torch.tensor([1024], dtype=torch.int32, device='mps')
lib.my_kernel(data, output, slice_size, threads=[1024, 1, 1], group_size=[256, 1, 1])
When a pipeline of kernels (e.g., histogram → prefix_sum → scatter) produces wrong results, test each kernel individually and verify its output against a Python/NumPy reference:
# 1. Run GPU kernel
lib.histogram(keys, hist, ..., threads=[N, 1, 1], group_size=[256, 1, 1])
torch.mps.synchronize()
# 2. Compute reference in Python
ref_hist = compute_histogram_cpu(keys.cpu().numpy(), ...)
# 3. Compare
assert np.array_equal(hist.cpu().numpy(), ref_hist), "Histogram mismatch!"
This isolates which kernel in the pipeline is broken, rather than debugging the entire pipeline at once.
threads count — threads is total threads, not threadgroups. For 5 threadgroups of 256, use threads=[1280, 1, 1].compile_shader doesn't support [[threadgroup(N)]] parameters directly. If your kernel needs threadgroup memory, restructure to use threadgroup arrays declared inside the kernel body instead.native_functions.yamlkernels/operations/torch/testing/_internal/common_mps.pyPyTorch深度学习模式与最佳实践,用于构建稳健、高效且可复现的训练流程、模型架构和数据加载。