Add a grouped bnb 4-bit training forward for gpt-oss experts#862
Open
danielhanchen wants to merge 1 commit into
Open
Add a grouped bnb 4-bit training forward for gpt-oss experts#862danielhanchen wants to merge 1 commit into
danielhanchen wants to merge 1 commit into
Conversation
The gpt-oss 4-bit training path loops over all experts per layer, launching a dequantize and two matmuls per expert even for experts that received no tokens. Add a grouped equivalent used when every expert is a plain LoRA-free Linear4bit with a populated quant_state: routing indices are sorted once on GPU, each projection stack is dequantized with a single dequantize_4bit call over one cross-expert QuantState (packed NF4 bytes are elementwise row-major, so concatenating per-expert bytes and fp32 absmax is exact), and the two projections run as torch._grouped_mm calls with grouped bias adds, followed by the identical fp32 index_add combine. With UNSLOTH_MOE_RECOMPUTE=1 the dequantized stacks are rebuilt in backward instead of staying on the tape. The path is gated behind a readiness probe (grouped_mm support, no LoRA wrappers, frozen quantized weights), can be disabled with UNSLOTH_GPTOSS_GROUPED=0, and falls back to the per-expert loop on any failure. Per-step losses match the loop path on the deterministic parity harness; the end-to-end gain is currently modest because this model's step time is dominated by Python-side launch overhead, which the grouped path reduces but does not eliminate.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Contributor
|
Warning Gemini encountered an error creating the review. You can try again by commenting |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The gpt-oss 4-bit training path loops over all experts per layer, launching a dequantize and two matmuls per expert even for experts that received no tokens. A profiled step is about 80 percent GPU-idle: 2.27s of CPU work against 0.61s of CUDA, with 33k kernel launches, 2062 per-expert 4-bit matmuls, and 3140 dequantize calls.
Stacked on #838 (the base is
moe-recompute-backward, which provides the grouped matmul primitives); the diff here is only the gpt-oss commit.What this does
Adds a grouped equivalent used when every expert is a plain LoRA-free
Linear4bitwith a populated quant_state: routing indices are sorted once on GPU, each projection stack is dequantized with a singledequantize_4bitcall over one cross-expert QuantState (packed NF4 bytes are elementwise row-major, so concatenating per-expert bytes and fp32 absmax is exact), and the two projections run astorch._grouped_mmcalls with grouped bias adds, followed by the identical fp32 index_add combine. WithUNSLOTH_MOE_RECOMPUTE=1the dequantized stacks are rebuilt in backward instead of staying on the tape.The gate lives in
torch_native_forward(the function the experts class actually dispatches) behind a readiness probe (grouped_mm support, no LoRA wrappers, frozen quantized weights), can be disabled withUNSLOTH_GPTOSS_GROUPED=0, and falls back to the per-expert loop on any failure.Testing