Handle odd shapes and non-float scales in FP8BlockQuantLinear#6848
Handle odd shapes and non-float scales in FP8BlockQuantLinear#6848danielhanchen wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a helper function '_blockwise_weight_dequant_any_shape' to handle blockwise FP8 weight dequantization for arbitrary shapes, falling back to a torch-native per-block scale expansion when weights do not tile evenly. It also updates 'FP8BlockQuantLinear' to handle cases where the hidden dimension is not divisible by the activation block size. Feedback is provided to handle scalar 'weight_scale' values defensively, as calling 'repeat_interleave' on a scalar would raise an error.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if m % block_size[0] != 0 or n % block_size[1] != 0: | ||
| s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m] | ||
| s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n] | ||
| return (weight.to(torch.float32) * s_full).to(out_dtype) | ||
| return weight_dequant(weight, weight_scale).to(out_dtype) |
There was a problem hiding this comment.
If weight_scale is a scalar (i.e., weight_scale.numel() == 1) and the weight matrix does not tile evenly into block_size, calling repeat_interleave on weight_scale will raise an error (e.g., IndexError or ValueError) because repeat_interleave expects a dimension to repeat along. We should defensively check if weight_scale is a scalar first and directly perform the multiplication, which is shape-safe and avoids the crash.
| if m % block_size[0] != 0 or n % block_size[1] != 0: | |
| s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m] | |
| s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n] | |
| return (weight.to(torch.float32) * s_full).to(out_dtype) | |
| return weight_dequant(weight, weight_scale).to(out_dtype) | |
| if weight_scale.numel() == 1: | |
| return (weight.to(torch.float32) * weight_scale).to(out_dtype) | |
| if m % block_size[0] != 0 or n % block_size[1] != 0: | |
| s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m] | |
| s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n] | |
| return (weight.to(torch.float32) * s_full).to(out_dtype) | |
| return weight_dequant(weight, weight_scale).to(out_dtype) |
Small fp8 checkpoints (e.g. tiny test models) break the block-quantized linear in three ways: weight scales stored in a float8 dtype such as float8_e8m0fnu have no triton dtype mapping; activations whose hidden dim is not a multiple of the activation quant block fail act_quant's divisibility assert; and weights whose dims are not multiples of the weight block cannot be tiled by the triton dequant kernel. Cast non-float scales to float32 on entry, and when the hidden dim does not divide into the activation block, dequantize the weight and run a plain matmul instead of the fp8 block matmul. The dequant goes through a new shape-safe helper that falls back to a torch-native scale expansion when the weight does not tile evenly; backward uses the same helper so the gradient path works for every shape the forward accepts. Full-size checkpoints are unaffected.
6d7eb13 to
b926de3
Compare
for more information, see https://pre-commit.ci
|
@codex review |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for handling tiny, non-tileable weights and float8_e8m0fnu scales in the FP8 block-quantized linear layer (FP8BlockQuantLinear) by falling back to a torch-native dequantization and standard matrix multiplication when Triton-based paths are not supported. Feedback on the changes includes: 1) Ensuring weight_dequant_block is called with the correct block_size parameter instead of using weight_dequant which defaults to 128. 2) Utilizing original_weight_scale in the fallback path to avoid unnecessarily triggering the slower repeat_interleave path during both forward and backward passes when the scale is a scalar.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| s_full = weight_scale.repeat_interleave(block_size[0], 0)[:m] | ||
| s_full = s_full.repeat_interleave(block_size[1], 1)[:, :n] | ||
| return (weight.to(torch.float32) * s_full).to(out_dtype) | ||
| return weight_dequant(weight, weight_scale).to(out_dtype) |
There was a problem hiding this comment.
The helper function weight_dequant does not accept a block_size parameter and defaults to 128 internally. If FP8BlockQuantLinear is used with a non-default block size (e.g., [64, 64]), calling weight_dequant will ignore the actual block size and use 128, leading to incorrect dequantization. Since weight_scale is guaranteed to be block-quantized at this point, call weight_dequant_block directly and pass block_size[0] to ensure correctness.
| return weight_dequant(weight, weight_scale).to(out_dtype) | |
| return weight_dequant_block(weight, weight_scale, block_size=block_size[0], dtype=out_dtype) |
|
|
||
| if X.shape[-1] % block_size[1] != 0: | ||
| # Hidden dim not divisible by the activation block: dequant + plain matmul. | ||
| W_deq = _blockwise_weight_dequant_any_shape(weight, weight_scale, block_size, X.dtype) |
There was a problem hiding this comment.
Passing the expanded weight_scale to _blockwise_weight_dequant_any_shape causes it to execute the slower repeat_interleave path even when the scale was originally a scalar (per-tensor scale). Passing original_weight_scale instead allows the helper to use the fast scalar multiplication path directly.
| W_deq = _blockwise_weight_dequant_any_shape(weight, weight_scale, block_size, X.dtype) | |
| W_deq = _blockwise_weight_dequant_any_shape(weight, original_weight_scale, block_size, X.dtype) |
| # Hidden dim not divisible by the activation block: dequant + plain matmul. | ||
| W_deq = _blockwise_weight_dequant_any_shape(weight, weight_scale, block_size, X.dtype) | ||
| ctx.weight = weight | ||
| ctx.weight_scale = weight_scale |
There was a problem hiding this comment.
Saving the expanded weight_scale to ctx.weight_scale in the fallback path causes the backward pass to use the slower repeat_interleave path when the scale was originally a scalar. Saving original_weight_scale instead avoids this overhead and maintains consistency with the normal path.
| ctx.weight_scale = weight_scale | |
| ctx.weight_scale = original_weight_scale |
for more information, see https://pre-commit.ci
|
Fixed the fallback to dequantize with the real block size (it was defaulting to 128) and to keep the original per-tensor scale so a scalar scale stays on the fast path in both forward and backward. The scalar-scale case is already guarded before |
Summary
Small fp8 checkpoints (e.g. tiny test models) break the block-quantized linear in three ways: weight scales stored in a float8 dtype such as
float8_e8m0fnuhave no triton dtype mapping; activations whose hidden dim is not a multiple of the activation quant block fail act_quant's divisibility assert; and weights whose dims are not multiples of the weight block cannot be tiled by the triton dequant kernel.What this does
Casts non-float scales to float32 on entry, and when the hidden dim does not divide into the activation block, dequantizes the weight and runs a plain matmul instead of the fp8 block matmul. The dequant goes through a new shape-safe helper that falls back to a torch-native scale expansion when the weight does not tile evenly; backward uses the same helper so the gradient path works for every shape the forward accepts. Full-size checkpoints are unaffected.
Testing
Odd-shape forward+backward smoke test on GPU: forward error at fp8 quantization noise level (2e-3), finite gradients. Regression test for the tiny/e8m0 path included.