Skip to content

Handle odd shapes and non-float scales in FP8BlockQuantLinear#6848

Open
danielhanchen wants to merge 5 commits into
mainfrom
fp8-blockquant-shape-guards
Open

Handle odd shapes and non-float scales in FP8BlockQuantLinear#6848
danielhanchen wants to merge 5 commits into
mainfrom
fp8-blockquant-shape-guards

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

Small fp8 checkpoints (e.g. tiny test models) break the block-quantized linear in three ways: weight scales stored in a float8 dtype such as float8_e8m0fnu have no triton dtype mapping; activations whose hidden dim is not a multiple of the activation quant block fail act_quant's divisibility assert; and weights whose dims are not multiples of the weight block cannot be tiled by the triton dequant kernel.

What this does

Casts non-float scales to float32 on entry, and when the hidden dim does not divide into the activation block, dequantizes the weight and runs a plain matmul instead of the fp8 block matmul. The dequant goes through a new shape-safe helper that falls back to a torch-native scale expansion when the weight does not tile evenly; backward uses the same helper so the gradient path works for every shape the forward accepts. Full-size checkpoints are unaffected.

Testing

Odd-shape forward+backward smoke test on GPU: forward error at fp8 quantization noise level (2e-3), finite gradients. Regression test for the tiny/e8m0 path included.

@danielhanchen danielhanchen requested a review from Datta0 as a code owner July 3, 2026 18:41

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a helper function '_blockwise_weight_dequant_any_shape' to handle blockwise FP8 weight dequantization for arbitrary shapes, falling back to a torch-native per-block scale expansion when weights do not tile evenly. It also updates 'FP8BlockQuantLinear' to handle cases where the hidden dimension is not divisible by the activation block size. Feedback is provided to handle scalar 'weight_scale' values defensively, as calling 'repeat_interleave' on a scalar would raise an error.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread unsloth/kernels/fp8.py Outdated
Comment on lines +336 to +340
if m % block_size[0] != 0 or n % block_size[1] != 0:
s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m]
s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n]
return (weight.to(torch.float32) * s_full).to(out_dtype)
return weight_dequant(weight, weight_scale).to(out_dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If weight_scale is a scalar (i.e., weight_scale.numel() == 1) and the weight matrix does not tile evenly into block_size, calling repeat_interleave on weight_scale will raise an error (e.g., IndexError or ValueError) because repeat_interleave expects a dimension to repeat along. We should defensively check if weight_scale is a scalar first and directly perform the multiplication, which is shape-safe and avoids the crash.

Suggested change
if m % block_size[0] != 0 or n % block_size[1] != 0:
s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m]
s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n]
return (weight.to(torch.float32) * s_full).to(out_dtype)
return weight_dequant(weight, weight_scale).to(out_dtype)
if weight_scale.numel() == 1:
return (weight.to(torch.float32) * weight_scale).to(out_dtype)
if m % block_size[0] != 0 or n % block_size[1] != 0:
s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m]
s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n]
return (weight.to(torch.float32) * s_full).to(out_dtype)
return weight_dequant(weight, weight_scale).to(out_dtype)

Small fp8 checkpoints (e.g. tiny test models) break the block-quantized
linear in three ways: weight scales stored in a float8 dtype such as
float8_e8m0fnu have no triton dtype mapping; activations whose hidden dim is
not a multiple of the activation quant block fail act_quant's divisibility
assert; and weights whose dims are not multiples of the weight block cannot
be tiled by the triton dequant kernel.

Cast non-float scales to float32 on entry, and when the hidden dim does not
divide into the activation block, dequantize the weight and run a plain
matmul instead of the fp8 block matmul. The dequant goes through a new
shape-safe helper that falls back to a torch-native scale expansion when the
weight does not tile evenly; backward uses the same helper so the gradient
path works for every shape the forward accepts. Full-size checkpoints are
unaffected.
@danielhanchen danielhanchen force-pushed the fp8-blockquant-shape-guards branch from 6d7eb13 to b926de3 Compare July 3, 2026 19:04
@danielhanchen

Copy link
Copy Markdown
Member Author

@codex review

@danielhanchen

Copy link
Copy Markdown
Member Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for handling tiny, non-tileable weights and float8_e8m0fnu scales in the FP8 block-quantized linear layer (FP8BlockQuantLinear) by falling back to a torch-native dequantization and standard matrix multiplication when Triton-based paths are not supported. Feedback on the changes includes: 1) Ensuring weight_dequant_block is called with the correct block_size parameter instead of using weight_dequant which defaults to 128. 2) Utilizing original_weight_scale in the fallback path to avoid unnecessarily triggering the slower repeat_interleave path during both forward and backward passes when the scale is a scalar.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread unsloth/kernels/fp8.py Outdated
s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m]
s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n]
return (weight.to(torch.float32) * s_full).to(out_dtype)
return weight_dequant(weight, weight_scale).to(out_dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The helper function weight_dequant does not accept a block_size parameter and defaults to 128 internally. If FP8BlockQuantLinear is used with a non-default block size (e.g., [64, 64]), calling weight_dequant will ignore the actual block size and use 128, leading to incorrect dequantization. Since weight_scale is guaranteed to be block-quantized at this point, call weight_dequant_block directly and pass block_size[0] to ensure correctness.

Suggested change
return weight_dequant(weight, weight_scale).to(out_dtype)
return weight_dequant_block(weight, weight_scale, block_size=block_size[0], dtype=out_dtype)

Comment thread unsloth/kernels/fp8.py Outdated

if X.shape[-1] % block_size[1] != 0:
# Hidden dim not divisible by the activation block: dequant + plain matmul.
W_deq = _blockwise_weight_dequant_any_shape(weight, weight_scale, block_size, X.dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Passing the expanded weight_scale to _blockwise_weight_dequant_any_shape causes it to execute the slower repeat_interleave path even when the scale was originally a scalar (per-tensor scale). Passing original_weight_scale instead allows the helper to use the fast scalar multiplication path directly.

Suggested change
W_deq = _blockwise_weight_dequant_any_shape(weight, weight_scale, block_size, X.dtype)
W_deq = _blockwise_weight_dequant_any_shape(weight, original_weight_scale, block_size, X.dtype)

Comment thread unsloth/kernels/fp8.py Outdated
# Hidden dim not divisible by the activation block: dequant + plain matmul.
W_deq = _blockwise_weight_dequant_any_shape(weight, weight_scale, block_size, X.dtype)
ctx.weight = weight
ctx.weight_scale = weight_scale

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Saving the expanded weight_scale to ctx.weight_scale in the fallback path causes the backward pass to use the slower repeat_interleave path when the scale was originally a scalar. Saving original_weight_scale instead avoids this overhead and maintains consistency with the normal path.

Suggested change
ctx.weight_scale = weight_scale
ctx.weight_scale = original_weight_scale

@danielhanchen

Copy link
Copy Markdown
Member Author

Fixed the fallback to dequantize with the real block size (it was defaulting to 128) and to keep the original per-tensor scale so a scalar scale stays on the fast path in both forward and backward. The scalar-scale case is already guarded before repeat_interleave, so no change was needed there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant