KAIO v0.2.0: Write GPU kernels in Rust, tensor-core matmul at 92.5% of cuBLAS sgemm

โš“ Rust    ๐Ÿ“… 2026-04-13    ๐Ÿ‘ค surdeus    ๐Ÿ‘๏ธ 4      

surdeus

Hey all, I just published v0.2.0 of KAIO, a Rust-native GPU kernel framework. Wanted to share it here since it's aimed squarely at a gap in the Rust ecosystem.

The short version: you write #[gpu_kernel] fn my_kernel(...) in Rust, the proc macro lowers it to PTX at build time, and the runtime dispatches via the NVIDIA driver. No CUDA toolkit needed to build, no Python, works on Windows and Linux.

The tensor-core matmul hits 92.5% of cuBLAS sgemm at 4096ยฒ on an RTX 4090 (fp16 inputs, fp32 accumulation, cp.async double-buffered). The sync variant lands at 82.3%.

The motivating use case is custom ops, if you're using candle or burn and need a fused activation, a novel attention variant, or a quantization kernel that the framework doesn't support, your options today are CUDA C++ and FFI bindings. Python developers reach for triton.jit. KAIO is an attempt at the Rust equivalent.

Here's what a kernel looks like: the gated SiLU activation from LLaMA/Mistral/Qwen, normally hand-written CUDA:

#[gpu_kernel(block_size = 256)]
fn fused_silu_gate(x: &[f32], gate: &[f32], out: &mut [f32], n: u32) {
    let idx = thread_idx_x() + block_idx_x() * block_dim_x();
    if idx < n {
        let xi = x[idx];
        let sig = 1.0f32 / (1.0f32 + exp(-xi));
        out[idx] = xi * sig * gate[idx];
    }
}

Limitations I want to be upfront about: NVIDIA-only (SM 7.0+), inference-focused (no autograd), pre-1.0 API. The kernel DSL is a Rust subset: no closures, traits, or generics inside kernel bodies. Small-shape matmul trails cuBLAS significantly. The tensor-core comparison is fp16โ†’fp32 vs cuBLAS sgemm f32โ†’f32, a project-local baseline, not a precision-identity claim.

One thing that surprised me during development: the 92.5% didn't come from vectorized loads. The jump came from bank-conflict padding on the shared B-tile plus hoisting (group_id, thread_id_in_group) out of the fragment loaders. The async path benefited disproportionately (+7.4pp) because cp.async was already saturating global bandwidth, the real bottleneck was shared-memory contention at fragment-read time.

Links:

Would love to hear what custom kernels people would want to write first, and whether the DSL subset feels too limiting. Also happy to go deep on any of the PTX generation or tiling work if anyone's curious.

1 post - 1 participant

Read full topic

๐Ÿท๏ธ Rust_feed