From bb611927e4c5c1f3dc4682068133aa9241b67a43 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 3 Jul 2026 14:22:36 +0000 Subject: [PATCH] Add a grouped bnb 4-bit training forward for gpt-oss experts The gpt-oss 4-bit training path loops over all experts per layer, launching a dequantize and two matmuls per expert even for experts that received no tokens. Add a grouped equivalent used when every expert is a plain LoRA-free Linear4bit with a populated quant_state: routing indices are sorted once on GPU, each projection stack is dequantized with a single dequantize_4bit call over one cross-expert QuantState (packed NF4 bytes are elementwise row-major, so concatenating per-expert bytes and fp32 absmax is exact), and the two projections run as torch._grouped_mm calls with grouped bias adds, followed by the identical fp32 index_add combine. With UNSLOTH_MOE_RECOMPUTE=1 the dequantized stacks are rebuilt in backward instead of staying on the tape. The path is gated behind a readiness probe (grouped_mm support, no LoRA wrappers, frozen quantized weights), can be disabled with UNSLOTH_GPTOSS_GROUPED=0, and falls back to the per-expert loop on any failure. Per-step losses match the loop path on the deterministic parity harness; the end-to-end gain is currently modest because this model's step time is dominated by Python-side launch overhead, which the grouped path reduces but does not eliminate. --- unsloth_zoo/temporary_patches/gpt_oss.py | 142 ++++++++++++++++++++++- 1 file changed, 140 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index e3e968730..d3f0c64cb 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -28,6 +28,7 @@ get_torch_compile_options, UNSLOTH_ENABLE_LOGGING, UNSLOTH_COMPILE_DISABLE, + logger, ) from importlib.metadata import version as importlib_version from ..utils import Version @@ -848,13 +849,133 @@ def __init__(self, config): "down_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False ) + def _grouped_bnb4bit_ready(self): + """True when every expert is a plain (LoRA-free) bnb Linear4bit with a + populated quant_state, so the grouped torch._grouped_mm path is exact.""" + if os.environ.get("UNSLOTH_GPTOSS_GROUPED", "1") == "0": + return False + def _fail(reason): + if UNSLOTH_ENABLE_LOGGING and not getattr(self, "_unsloth_grouped_logged", False): + self._unsloth_grouped_logged = True + logger.info(f"Unsloth: gpt-oss grouped path disabled: {reason}") + return False + try: + import bitsandbytes as bnb + from bitsandbytes.nn import Params4bit + from .moe_utils import _check_torch_grouped_mm_supported + if not _check_torch_grouped_mm_supported(): + return _fail("torch._grouped_mm unsupported") + for lin in list(self.gate_up_projs) + list(self.down_projs): + if hasattr(lin, "lora_A") or hasattr(lin, "base_layer"): + return _fail(f"LoRA-wrapped expert {type(lin).__name__}") + w = getattr(lin, "weight", None) + if not (isinstance(w, Params4bit) and getattr(w, "quant_state", None) is not None): + return _fail(f"expert weight {type(w).__name__} without quant_state") + if w.requires_grad: + return _fail("expert weight requires_grad") + b = getattr(lin, "bias", None) + if b is not None and b.requires_grad: + return _fail("expert bias requires_grad") + except Exception as e: + return _fail(f"{type(e).__name__}: {e}") + return True + + def _forward_grouped_bnb4bit(self, hidden_states, router_indices, routing_weights, + batch_size, num_tokens, num_experts, top_k): + """Grouped equivalent of the per-expert loop: one gather, two grouped_mm + calls over dequantized stacks (rebuilt in backward when + UNSLOTH_MOE_RECOMPUTE=1), grouped bias adds, fp32 index_add combine.""" + import bitsandbytes as bnb + from .moe_utils import _base_grouped_mm + + device = hidden_states.device + with torch.no_grad(): + flat_experts = router_indices.flatten() + token_ids = torch.arange(num_tokens, device=device).repeat_interleave(top_k) + sorted_idx = flat_experts.argsort(stable=True) + sorted_tokens = token_ids[sorted_idx] + sorted_experts = flat_experts[sorted_idx] + counts = torch.bincount(flat_experts, minlength=num_experts) + offsets = counts.cumsum(0, dtype=torch.int32) + expert_ids = torch.repeat_interleave( + torch.arange(num_experts, device=device), counts + ) + + recompute = os.environ.get("UNSLOTH_MOE_RECOMPUTE", "0") == "1" + + cached = getattr(self, "_unsloth_grouped_qs", None) + if cached is None: + # One QuantState spanning all experts (packed NF4 is elementwise + # row-major, so per-expert byte + fp32 absmax concat is exact): + # ONE dequantize_4bit launch per stack instead of num_experts. + from bitsandbytes.functional import QuantState + + def _absmax_fp32(qs): + # Materialize a QuantState's absmax as flat fp32 (denesting double-quant). + if getattr(qs, "nested", False): + absmax = bnb.functional.dequantize_blockwise(qs.absmax, qs.state2) + return (absmax + qs.offset).float() + return qs.absmax.float() + + def build_qs(projs): + states = [l.weight.quant_state for l in projs] + absmax = torch.cat([_absmax_fp32(qs) for qs in states]) + q0 = states[0] + return QuantState( + absmax=absmax, + shape=torch.Size((len(projs),) + tuple(q0.shape)), + code=q0.code, + blocksize=q0.blocksize, + quant_type=q0.quant_type, + dtype=q0.dtype, + ) + + cached = ( + build_qs(self.gate_up_projs), + build_qs(self.down_projs), + torch.stack([l.bias for l in self.gate_up_projs]), # (E, 2I) + torch.stack([l.bias for l in self.down_projs]), # (E, H) + ) + self._unsloth_grouped_qs = cached + gate_up_qs, down_qs, gate_up_bias, down_bias = cached + + def _stack(projs, quant_state): + # One fused dequant gives (E, out, in); grouped_mm takes the transposed view. + data = torch.cat([l.weight.data.reshape(-1, 1) for l in projs]) + return bnb.functional.dequantize_4bit(data, quant_state).transpose(1, 2) + + xg = hidden_states[sorted_tokens] + gate_up = _base_grouped_mm( + xg, offsets, lambda: _stack(self.gate_up_projs, gate_up_qs), recompute) + gate_up = gate_up + gate_up_bias[expert_ids] + gated = swiglu_torch_forward(gate_up, self.alpha, self.limit) + out = _base_grouped_mm( + gated, offsets, lambda: _stack(self.down_projs, down_qs), recompute) + out = out + down_bias[expert_ids] + + weighted = out.to(torch.float32) * routing_weights[sorted_tokens, sorted_experts, None].to(torch.float32) + next_states = torch.zeros(num_tokens, self.hidden_size, dtype=torch.float32, device=device) + next_states.index_add_(0, sorted_tokens, weighted) + return next_states.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) num_tokens = hidden_states.shape[0] num_experts = routing_weights.shape[1] top_k = router_indices.shape[1] - + + if self.training and self._grouped_bnb4bit_ready(): + try: + return self._forward_grouped_bnb4bit( + hidden_states, router_indices, routing_weights, + batch_size, num_tokens, num_experts, top_k, + ) + except Exception: + if UNSLOTH_ENABLE_LOGGING: + import traceback; traceback.print_exc() + # fall through to the per-expert loop + if self.training: with torch.no_grad(): flat_experts = router_indices.flatten() # [tokens * topk] @@ -1768,7 +1889,24 @@ def torch_native_forward( num_tokens = hidden_states.shape[0] num_experts = routing_weights.shape[1] top_k = router_indices.shape[1] - + + # Grouped bnb-4bit fast path. The class dispatches this module-level + # function (forward is rebound below), so the gate must live here too. + if ( + self.training + and hasattr(self, "_grouped_bnb4bit_ready") + and self._grouped_bnb4bit_ready() + ): + try: + return self._forward_grouped_bnb4bit( + hidden_states, router_indices, routing_weights, + batch_size, num_tokens, num_experts, top_k, + ) + except Exception: + if UNSLOTH_ENABLE_LOGGING: + import traceback; traceback.print_exc() + # fall through to the per-expert loop + if self.training: with torch.no_grad(): flat_experts = router_indices.flatten() # [tokens * topk]