[Bug] weight_dequant_kernel int32 offset overflow → illegal memory access on FP8 tensors > 2^31 elements
Environment
- unsloth 2026.6.9, unsloth_zoo 2026.6.7
- torch 2.10.0+cu128, triton 3.6.0, CUDA 12.8
- 8x NVIDIA H200 SXM (SM90)
- transformers 5.13.0.dev0 (git main)
- Model: GLM-5.2 FP8 block-quantized (
GlmMoeDsaForCausalLM, 256 experts/layer, weight_block_size=[128,128])
Summary
weight_dequant_kernel in unsloth/kernels/fp8.py computes element offsets in int32:
offs = offs_m[:, None] * N + offs_n[None, :]
tl.arange products default to int32 in Triton. For any 2D tensor with more than 2^31 = 2,147,483,648 elements, offsets wrap negative and the kernel crashes with:
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
When this triggers
The FP8 MoE training path (unsloth_zoo/temporary_patches/moe_utils_fp8.py::_dequantize_full_expert_weights_unsloth) flattens a layer's full stacked expert weight into one 2D tensor before dequantizing. For GLM-5.2:
gate_up_proj stack: [256 experts, 4096, 6144] → flattened [1048576, 6144] = 6.44e9 elements (3x over the int32 limit)
Any FP8 block-quantized MoE with num_experts × rows × cols > 2^31 is affected. Smaller MoEs (Qwen, GLM-4.7-Flash) stay under the limit, which is likely why this hasn't been reported.
Minimal repro
import torch
from unsloth.kernels.fp8 import weight_dequant_block
def probe(M, N, block=128):
x = torch.randn(M, N, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn)
s = torch.rand((M + block - 1)//block, (N + block - 1)//block,
device="cuda", dtype=torch.float32)
y = weight_dequant_block(x, s, block_size=block, dtype=torch.bfloat16)
torch.cuda.synchronize()
print(f"OK M={M} N={N} numel={M*N:,}")
probe(4096, 6144) # 25M elements — passes
probe(400_000, 6144) # 2.46e9 elements > 2^31 — illegal memory access
Fix (verified)
One-line change in weight_dequant_kernel — promote the offset arithmetic to int64:
# before
offs = offs_m[:, None] * N + offs_n[None, :]
# after
offs = offs_m[:, None].to(tl.int64) * N + offs_n[None, :]
After the patch, the repro passes both cases and full GLM-5.2 FP8 LoRA training runs (loss finite, dequant output verified bitwise-deterministic and matching a pure-PyTorch dequant reference to bf16 tolerance, max_abs_diff 4.2e-04 on real expert weights).
The scale-pointer offset (pid_m * n + pid_n) stays comfortably within int32 and needs no change.
[Bug]
weight_dequant_kernelint32 offset overflow → illegal memory access on FP8 tensors > 2^31 elementsEnvironment
GlmMoeDsaForCausalLM, 256 experts/layer,weight_block_size=[128,128])Summary
weight_dequant_kernelinunsloth/kernels/fp8.pycomputes element offsets in int32:tl.arangeproducts default to int32 in Triton. For any 2D tensor with more than 2^31 = 2,147,483,648 elements, offsets wrap negative and the kernel crashes with:When this triggers
The FP8 MoE training path (
unsloth_zoo/temporary_patches/moe_utils_fp8.py::_dequantize_full_expert_weights_unsloth) flattens a layer's full stacked expert weight into one 2D tensor before dequantizing. For GLM-5.2:gate_up_projstack:[256 experts, 4096, 6144]→ flattened[1048576, 6144]= 6.44e9 elements (3x over the int32 limit)Any FP8 block-quantized MoE with
num_experts × rows × cols > 2^31is affected. Smaller MoEs (Qwen, GLM-4.7-Flash) stay under the limit, which is likely why this hasn't been reported.Minimal repro
Fix (verified)
One-line change in
weight_dequant_kernel— promote the offset arithmetic to int64:After the patch, the repro passes both cases and full GLM-5.2 FP8 LoRA training runs (loss finite, dequant output verified bitwise-deterministic and matching a pure-PyTorch dequant reference to bf16 tolerance, max_abs_diff 4.2e-04 on real expert weights).
The scale-pointer offset (
pid_m * n + pid_n) stays comfortably within int32 and needs no change.