Skip to content

Add a grouped bnb 4-bit training forward for gpt-oss experts#862

Open
danielhanchen wants to merge 1 commit into
moe-recompute-backwardfrom
gptoss-grouped-4bit-experts
Open

Add a grouped bnb 4-bit training forward for gpt-oss experts#862
danielhanchen wants to merge 1 commit into
moe-recompute-backwardfrom
gptoss-grouped-4bit-experts

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

The gpt-oss 4-bit training path loops over all experts per layer, launching a dequantize and two matmuls per expert even for experts that received no tokens. A profiled step is about 80 percent GPU-idle: 2.27s of CPU work against 0.61s of CUDA, with 33k kernel launches, 2062 per-expert 4-bit matmuls, and 3140 dequantize calls.

Stacked on #838 (the base is moe-recompute-backward, which provides the grouped matmul primitives); the diff here is only the gpt-oss commit.

What this does

Adds a grouped equivalent used when every expert is a plain LoRA-free Linear4bit with a populated quant_state: routing indices are sorted once on GPU, each projection stack is dequantized with a single dequantize_4bit call over one cross-expert QuantState (packed NF4 bytes are elementwise row-major, so concatenating per-expert bytes and fp32 absmax is exact), and the two projections run as torch._grouped_mm calls with grouped bias adds, followed by the identical fp32 index_add combine. With UNSLOTH_MOE_RECOMPUTE=1 the dequantized stacks are rebuilt in backward instead of staying on the tape.

The gate lives in torch_native_forward (the function the experts class actually dispatches) behind a readiness probe (grouped_mm support, no LoRA wrappers, frozen quantized weights), can be disabled with UNSLOTH_GPTOSS_GROUPED=0, and falls back to the per-expert loop on any failure.

Testing

  • gpt-oss-20b-unsloth-bnb-4bit LoRA SFT at 4k tokens: 10194 tok/s grouped vs 2433 with the loop (4.2x); 11238 tok/s at 8k.
  • Per-step losses match the loop path to 0.0022 over 20 deterministic steps.

The gpt-oss 4-bit training path loops over all experts per layer, launching a
dequantize and two matmuls per expert even for experts that received no
tokens. Add a grouped equivalent used when every expert is a plain LoRA-free
Linear4bit with a populated quant_state: routing indices are sorted once on
GPU, each projection stack is dequantized with a single dequantize_4bit call
over one cross-expert QuantState (packed NF4 bytes are elementwise row-major,
so concatenating per-expert bytes and fp32 absmax is exact), and the two
projections run as torch._grouped_mm calls with grouped bias adds, followed
by the identical fp32 index_add combine. With UNSLOTH_MOE_RECOMPUTE=1 the
dequantized stacks are rebuilt in backward instead of staying on the tape.

The path is gated behind a readiness probe (grouped_mm support, no LoRA
wrappers, frozen quantized weights), can be disabled with
UNSLOTH_GPTOSS_GROUPED=0, and falls back to the per-expert loop on any
failure. Per-step losses match the loop path on the deterministic parity
harness; the end-to-end gain is currently modest because this model's step
time is dominated by Python-side launch overhead, which the grouped path
reduces but does not eliminate.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Repo admins can enable using credits for code reviews in their settings.

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant