Skip to content

IMA crash, with the probe repro #6830

Description

@xvdev09

[Bug] weight_dequant_kernel int32 offset overflow → illegal memory access on FP8 tensors > 2^31 elements

Environment

  • unsloth 2026.6.9, unsloth_zoo 2026.6.7
  • torch 2.10.0+cu128, triton 3.6.0, CUDA 12.8
  • 8x NVIDIA H200 SXM (SM90)
  • transformers 5.13.0.dev0 (git main)
  • Model: GLM-5.2 FP8 block-quantized (GlmMoeDsaForCausalLM, 256 experts/layer, weight_block_size=[128,128])

Summary

weight_dequant_kernel in unsloth/kernels/fp8.py computes element offsets in int32:

offs = offs_m[:, None] * N + offs_n[None, :]

tl.arange products default to int32 in Triton. For any 2D tensor with more than 2^31 = 2,147,483,648 elements, offsets wrap negative and the kernel crashes with:

RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

When this triggers

The FP8 MoE training path (unsloth_zoo/temporary_patches/moe_utils_fp8.py::_dequantize_full_expert_weights_unsloth) flattens a layer's full stacked expert weight into one 2D tensor before dequantizing. For GLM-5.2:

  • gate_up_proj stack: [256 experts, 4096, 6144] → flattened [1048576, 6144] = 6.44e9 elements (3x over the int32 limit)

Any FP8 block-quantized MoE with num_experts × rows × cols > 2^31 is affected. Smaller MoEs (Qwen, GLM-4.7-Flash) stay under the limit, which is likely why this hasn't been reported.

Minimal repro

import torch
from unsloth.kernels.fp8 import weight_dequant_block

def probe(M, N, block=128):
    x = torch.randn(M, N, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn)
    s = torch.rand((M + block - 1)//block, (N + block - 1)//block,
                   device="cuda", dtype=torch.float32)
    y = weight_dequant_block(x, s, block_size=block, dtype=torch.bfloat16)
    torch.cuda.synchronize()
    print(f"OK M={M} N={N} numel={M*N:,}")

probe(4096, 6144)      # 25M elements — passes
probe(400_000, 6144)   # 2.46e9 elements > 2^31 — illegal memory access

Fix (verified)

One-line change in weight_dequant_kernel — promote the offset arithmetic to int64:

# before
offs = offs_m[:, None] * N + offs_n[None, :]
# after
offs = offs_m[:, None].to(tl.int64) * N + offs_n[None, :]

After the patch, the repro passes both cases and full GLM-5.2 FP8 LoRA training runs (loss finite, dequant output verified bitwise-deterministic and matching a pure-PyTorch dequant reference to bf16 tolerance, max_abs_diff 4.2e-04 on real expert weights).

The scale-pointer offset (pid_m * n + pid_n) stays comfortably within int32 and needs no change.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions