Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tests/test_fp8_tiny_e8m0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""FP8 block-quant linear must handle tiny / non-tileable weights and e8m0 scales.

Two things break the triton block path:
* a hidden dim not divisible by the activation block size (tiny test models),
* float8_e8m0fnu weight scales, which have no triton dtype mapping.
The forward falls back to a torch-native blockwise dequant + bf16 matmul; this
test checks that fallback runs finite forward + backward and matches a plain
dequant reference.
"""

import pytest
import torch

pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason = "needs CUDA")


def _reference(X, weight, scale, block):
# Expand the per-block scale to full weight shape and dequantize.
m, n = weight.shape
s = scale.to(torch.float32)
s = s.repeat_interleave(block[0], 0)[:m].repeat_interleave(block[1], 1)[:, :n]
W = (weight.to(torch.float32) * s).to(X.dtype)
return X @ W.T


def test_tiny_non_tileable_forward_backward_matches_reference():
from unsloth.kernels.fp8 import FP8BlockQuantLinear

torch.manual_seed(0)
dev = "cuda"
block = [128, 128]
m, n = 8, 8 # non-tileable, in-dim % 128 != 0
weight = torch.randn(m, n, device = dev, dtype = torch.bfloat16) # (out=m, in=n)
scale = torch.rand(1, 1, device = dev, dtype = torch.float32) + 0.5
X = torch.randn(4, n, device = dev, dtype = torch.bfloat16, requires_grad = True)

out = FP8BlockQuantLinear.apply(X, weight, scale)
assert torch.isfinite(out).all(), "forward produced non-finite values"

ref = _reference(X.detach(), weight, scale, block)
torch.testing.assert_close(out, ref, atol = 5e-2, rtol = 5e-2)

out.sum().backward()
assert X.grad is not None and torch.isfinite(X.grad).all(), "backward non-finite"


def test_e8m0_scale_is_upcast_and_runs():
from unsloth.kernels.fp8 import FP8BlockQuantLinear

if not hasattr(torch, "float8_e8m0fnu"):
pytest.skip("torch build lacks float8_e8m0fnu")

dev = "cuda"
m, n = 8, 8
weight = torch.randn(m, n, device = dev, dtype = torch.bfloat16)
scale = (torch.rand(1, 1, device = dev) + 1.0).to(torch.float8_e8m0fnu)
X = torch.randn(4, n, device = dev, dtype = torch.bfloat16, requires_grad = True)

out = FP8BlockQuantLinear.apply(X, weight, scale)
assert torch.isfinite(out).all()
out.sum().backward()
assert torch.isfinite(X.grad).all()


if __name__ == "__main__":
import sys
sys.exit(pytest.main([__file__, "-q"]))
39 changes: 38 additions & 1 deletion unsloth/kernels/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,33 @@ def torchao_block_matmul(
)


def _blockwise_weight_dequant_any_shape(weight, weight_scale, block_size, out_dtype):
"""Blockwise fp8 weight dequant for any shape: triton when the weight tiles
evenly into block_size, else a torch-native per-block scale expansion."""
m, n = weight.shape
if weight_scale.dtype not in (torch.float32, torch.float16, torch.bfloat16):
weight_scale = weight_scale.to(torch.float32) # e.g. float8_e8m0fnu scales break triton
if weight_scale.numel() == 1:
# Per-tensor scale: the normal forward stashes the un-expanded scalar,
# which repeat_interleave cannot grow to (m, n). Scale directly.
return (weight.to(torch.float32) * weight_scale.float()).to(out_dtype)
if m % block_size[0] != 0 or n % block_size[1] != 0:
s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m]
s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n]
return (weight.to(torch.float32) * s_full).to(out_dtype)
# Even tiling: block-quant dequant with the real block size (weight_dequant
# would silently default to 128 and dequantize wrongly for other sizes).
return weight_dequant_block(weight, weight_scale, block_size = block_size[0], dtype = out_dtype)


class FP8BlockQuantLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, X, weight, weight_scale):
m, n = weight.shape

if weight_scale.dtype not in (torch.float32, torch.float16, torch.bfloat16):
weight_scale = weight_scale.to(torch.float32) # e8m0 scales break triton dtype mapping

# Original scale, saved for backward before any transformation
original_weight_scale = weight_scale

Expand Down Expand Up @@ -360,6 +382,18 @@ def forward(ctx, X, weight, weight_scale):
if not weight.is_contiguous():
weight = weight.contiguous()

if X.shape[-1] % block_size[1] != 0:
# Hidden dim not divisible by the activation block: dequant + plain matmul.
# Use the original (un-expanded) scale so a scalar per-tensor scale keeps
# the fast scalar path in both forward and backward.
W_deq = _blockwise_weight_dequant_any_shape(
weight, original_weight_scale, block_size, X.dtype
)
ctx.weight = weight
ctx.weight_scale = original_weight_scale
ctx.block_size = block_size
return torch_matmul(X, W_deq.T).to(X.dtype)

qinput, scale = act_quant(X, block_size[1])
output = fp8_block_matmul(
qinput,
Expand All @@ -371,11 +405,14 @@ def forward(ctx, X, weight, weight_scale):
)
ctx.weight = weight
ctx.weight_scale = original_weight_scale # Save original for backward
ctx.block_size = block_size
return output.to(X.dtype)

@staticmethod
def backward(ctx, grad_output):
W_deq = weight_dequant(ctx.weight, ctx.weight_scale)
W_deq = _blockwise_weight_dequant_any_shape(
ctx.weight, ctx.weight_scale, ctx.block_size, grad_output.dtype
)
grad_X = torch_matmul(grad_output, W_deq)
del W_deq
return grad_X, None, None
Expand Down
Loading