diff --git a/tests/test_fp8_tiny_e8m0.py b/tests/test_fp8_tiny_e8m0.py new file mode 100644 index 0000000000..cf49c8c92f --- /dev/null +++ b/tests/test_fp8_tiny_e8m0.py @@ -0,0 +1,123 @@ +"""FP8 block-quant linear must handle tiny / non-tileable weights and e8m0 scales. + +Two things break the triton block path: + * a hidden dim not divisible by the activation block size (tiny test models), + * float8_e8m0fnu weight scales, which have no triton dtype mapping. +The forward falls back to a torch-native blockwise dequant + bf16 matmul; this +test checks that fallback runs finite forward + backward and matches a plain +dequant reference. +""" + +import pytest +import torch + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason = "needs CUDA") + + +def _reference(X, weight, scale, block): + # Expand the per-block scale to full weight shape and dequantize. + m, n = weight.shape + s = scale.to(torch.float32) + s = s.repeat_interleave(block[0], 0)[:m].repeat_interleave(block[1], 1)[:, :n] + W = (weight.to(torch.float32) * s).to(X.dtype) + return X @ W.T + + +def test_tiny_non_tileable_forward_backward_matches_reference(): + from unsloth.kernels.fp8 import FP8BlockQuantLinear + + torch.manual_seed(0) + dev = "cuda" + block = [128, 128] + m, n = 8, 8 # non-tileable, in-dim % 128 != 0 + weight = torch.randn(m, n, device = dev, dtype = torch.bfloat16) # (out=m, in=n) + scale = torch.rand(1, 1, device = dev, dtype = torch.float32) + 0.5 + X = torch.randn(4, n, device = dev, dtype = torch.bfloat16, requires_grad = True) + + out = FP8BlockQuantLinear.apply(X, weight, scale) + assert torch.isfinite(out).all(), "forward produced non-finite values" + + ref = _reference(X.detach(), weight, scale, block) + torch.testing.assert_close(out, ref, atol = 5e-2, rtol = 5e-2) + + out.sum().backward() + assert X.grad is not None and torch.isfinite(X.grad).all(), "backward non-finite" + + +def test_e8m0_scale_is_upcast_and_runs(): + from unsloth.kernels.fp8 import FP8BlockQuantLinear + + if not hasattr(torch, "float8_e8m0fnu"): + pytest.skip("torch build lacks float8_e8m0fnu") + + dev = "cuda" + m, n = 8, 8 + weight = torch.randn(m, n, device = dev, dtype = torch.bfloat16) + scale = (torch.rand(1, 1, device = dev) + 1.0).to(torch.float8_e8m0fnu) + X = torch.randn(4, n, device = dev, dtype = torch.bfloat16, requires_grad = True) + + out = FP8BlockQuantLinear.apply(X, weight, scale) + assert torch.isfinite(out).all() + out.sum().backward() + assert torch.isfinite(X.grad).all() + + +def test_rectangular_block_dequant_matches_reference(): + # Rectangular blocks (block_size[0] != block_size[1]) that tile evenly used to + # route through the triton weight_dequant kernel, which uses a single BLOCK_SIZE + # for both axes and mis-indexes the column scale. Verify the torch expansion path + # now matches the reference for a 64x256 weight with block [64, 128] (scale 1x2). + from unsloth.kernels.fp8 import _blockwise_weight_dequant_any_shape + + torch.manual_seed(0) + dev = "cuda" + block = [64, 128] + m, n = 64, 256 # evenly tiled: 64 % 64 == 0, 256 % 128 == 0 + weight = torch.randn(m, n, device = dev, dtype = torch.bfloat16) + # Distinct per-block column scales expose column mis-indexing. + scale = torch.tensor([[0.5, 3.0]], device = dev, dtype = torch.float32) + + W_deq = _blockwise_weight_dequant_any_shape(weight, scale, block, torch.bfloat16) + + s = scale.repeat_interleave(block[0], 0)[:m].repeat_interleave(block[1], 1)[:, :n] + ref = (weight.to(torch.float32) * s).to(torch.bfloat16) + torch.testing.assert_close(W_deq, ref, atol = 5e-3, rtol = 5e-3) + + +def test_e8m0_scale_preserves_non_default_block_size_attr(): + # An e8m0 scale carrying a non-default block_size attribute must keep it across + # the float32 upcast in forward; otherwise the lookup falls back to [128, 128] + # and a compatible layout is wrongly rejected as incompatible. + from unsloth.kernels.fp8 import FP8BlockQuantLinear + + if not hasattr(torch, "float8_e8m0fnu"): + pytest.skip("torch build lacks float8_e8m0fnu") + + torch.manual_seed(0) + dev = "cuda" + block = [64, 64] + # in-dim 96 is not divisible by block[1]=64 -> forward takes the torch dequant + # fallback (no fp8 matmul kernel). Scale shape (2, 2) validates for [64, 64] but + # not [128, 128] (which expects (1, 1)). + m, n = 128, 96 + weight = torch.randn(m, n, device = dev, dtype = torch.bfloat16) # no block_size attr + scale_f = torch.rand(2, 2, device = dev) + 1.0 + scale = scale_f.to(torch.float8_e8m0fnu) + scale.block_size = block # attribute lives on the scale, not the weight + X = torch.randn(4, n, device = dev, dtype = torch.bfloat16, requires_grad = True) + + # With [128, 128] this raises "not compatible with block size"; success proves + # the [64, 64] attribute survived the e8m0 -> float32 upcast. + out = FP8BlockQuantLinear.apply(X, weight, scale) + assert torch.isfinite(out).all() + + ref = _reference(X.detach(), weight, scale.to(torch.float32), block) + torch.testing.assert_close(out, ref, atol = 5e-2, rtol = 5e-2) + + out.sum().backward() + assert X.grad is not None and torch.isfinite(X.grad).all() + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main([__file__, "-q"])) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index ca608fa01b..80db2f466b 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -327,11 +327,42 @@ def torchao_block_matmul( ) +def _blockwise_weight_dequant_any_shape(weight, weight_scale, block_size, out_dtype): + """Blockwise fp8 weight dequant for any shape: triton when the weight tiles + evenly into block_size, else a torch-native per-block scale expansion.""" + m, n = weight.shape + if weight_scale.dtype not in (torch.float32, torch.float16, torch.bfloat16): + weight_scale = weight_scale.to(torch.float32) # e.g. float8_e8m0fnu scales break triton + if weight_scale.numel() == 1: + # Per-tensor scale: the normal forward stashes the un-expanded scalar, + # which repeat_interleave cannot grow to (m, n). Scale directly. + return (weight.to(torch.float32) * weight_scale.float()).to(out_dtype) + if m % block_size[0] != 0 or n % block_size[1] != 0 or block_size[0] != block_size[1]: + # Uneven tiling, or rectangular blocks. The triton kernel uses a single + # BLOCK_SIZE for both axes and derives the column scale stride from it, so + # it mis-indexes the scale when block_size[0] != block_size[1]. Expand the + # per-block scales in torch, which handles both dimensions independently. + s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m] + s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n] + return (weight.to(torch.float32) * s_full).to(out_dtype) + # Even tiling with square blocks: block-quant dequant with the real block size + # (weight_dequant would silently default to 128 and dequantize wrongly). + return weight_dequant_block(weight, weight_scale, block_size = block_size[0], dtype = out_dtype) + + class FP8BlockQuantLinear(torch.autograd.Function): @staticmethod def forward(ctx, X, weight, weight_scale): m, n = weight.shape + if weight_scale.dtype not in (torch.float32, torch.float16, torch.bfloat16): + # Upcast (e.g. e8m0) returns a fresh tensor and drops any Python + # attribute, so carry block_size across the cast for the lookup below. + _scale_block_size = getattr(weight_scale, "block_size", None) + weight_scale = weight_scale.to(torch.float32) # e8m0 scales break triton dtype mapping + if _scale_block_size is not None: + weight_scale.block_size = _scale_block_size + # Original scale, saved for backward before any transformation original_weight_scale = weight_scale @@ -360,6 +391,18 @@ def forward(ctx, X, weight, weight_scale): if not weight.is_contiguous(): weight = weight.contiguous() + if X.shape[-1] % block_size[1] != 0: + # Hidden dim not divisible by the activation block: dequant + plain matmul. + # Use the original (un-expanded) scale so a scalar per-tensor scale keeps + # the fast scalar path in both forward and backward. + W_deq = _blockwise_weight_dequant_any_shape( + weight, original_weight_scale, block_size, X.dtype + ) + ctx.weight = weight + ctx.weight_scale = original_weight_scale + ctx.block_size = block_size + return torch_matmul(X, W_deq.T).to(X.dtype) + qinput, scale = act_quant(X, block_size[1]) output = fp8_block_matmul( qinput, @@ -371,11 +414,14 @@ def forward(ctx, X, weight, weight_scale): ) ctx.weight = weight ctx.weight_scale = original_weight_scale # Save original for backward + ctx.block_size = block_size return output.to(X.dtype) @staticmethod def backward(ctx, grad_output): - W_deq = weight_dequant(ctx.weight, ctx.weight_scale) + W_deq = _blockwise_weight_dequant_any_shape( + ctx.weight, ctx.weight_scale, ctx.block_size, grad_output.dtype + ) grad_X = torch_matmul(grad_output, W_deq) del W_deq return grad_X, None, None