diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index e3e968730..d84357731 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -28,6 +28,7 @@ get_torch_compile_options, UNSLOTH_ENABLE_LOGGING, UNSLOTH_COMPILE_DISABLE, + logger, ) from importlib.metadata import version as importlib_version from ..utils import Version @@ -848,13 +849,138 @@ def __init__(self, config): "down_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False ) + def _grouped_bnb4bit_ready(self): + """True when every expert is a plain (LoRA-free) bnb Linear4bit with a + populated quant_state, so the grouped torch._grouped_mm path is exact.""" + if os.environ.get("UNSLOTH_GPTOSS_GROUPED", "1") == "0": + return False + def _fail(reason): + if UNSLOTH_ENABLE_LOGGING and not getattr(self, "_unsloth_grouped_logged", False): + self._unsloth_grouped_logged = True + logger.info(f"Unsloth: gpt-oss grouped path disabled: {reason}") + return False + try: + import bitsandbytes as bnb + from bitsandbytes.nn import Params4bit + from .moe_utils import _check_torch_grouped_mm_supported + if not _check_torch_grouped_mm_supported(): + return _fail("torch._grouped_mm unsupported") + for lin in list(self.gate_up_projs) + list(self.down_projs): + if hasattr(lin, "lora_A") or hasattr(lin, "base_layer"): + return _fail(f"LoRA-wrapped expert {type(lin).__name__}") + w = getattr(lin, "weight", None) + if not (isinstance(w, Params4bit) and getattr(w, "quant_state", None) is not None): + return _fail(f"expert weight {type(w).__name__} without quant_state") + if w.requires_grad: + return _fail("expert weight requires_grad") + b = getattr(lin, "bias", None) + # Grouped path stacks per-expert biases, so a missing bias would + # break torch.stack; require all present and fall back otherwise. + if b is None: + return _fail("expert bias is None") + if b.requires_grad: + return _fail("expert bias requires_grad") + except Exception as e: + return _fail(f"{type(e).__name__}: {e}") + return True + + def _forward_grouped_bnb4bit(self, hidden_states, router_indices, routing_weights, + batch_size, num_tokens, num_experts, top_k): + """Grouped equivalent of the per-expert loop: one gather, two grouped_mm + calls over dequantized stacks (rebuilt in backward when + UNSLOTH_MOE_RECOMPUTE=1), grouped bias adds, fp32 index_add combine.""" + import bitsandbytes as bnb + from .moe_utils import _base_grouped_mm + + device = hidden_states.device + with torch.no_grad(): + flat_experts = router_indices.flatten() + token_ids = torch.arange(num_tokens, device=device).repeat_interleave(top_k) + sorted_idx = flat_experts.argsort(stable=True) + sorted_tokens = token_ids[sorted_idx] + sorted_experts = flat_experts[sorted_idx] + counts = torch.bincount(flat_experts, minlength=num_experts) + offsets = counts.cumsum(0, dtype=torch.int32) + expert_ids = torch.repeat_interleave( + torch.arange(num_experts, device=device), counts + ) + + recompute = os.environ.get("UNSLOTH_MOE_RECOMPUTE", "0") == "1" + + cached = getattr(self, "_unsloth_grouped_qs", None) + if cached is None: + # One QuantState spanning all experts (packed NF4 is elementwise row-major, + # so per-expert byte + fp32 absmax concat is exact): one dequant launch per stack. + from bitsandbytes.functional import QuantState + + def _absmax_fp32(qs): + # Materialize a QuantState's absmax as flat fp32 (denesting double-quant). + if getattr(qs, "nested", False): + absmax = bnb.functional.dequantize_blockwise(qs.absmax, qs.state2) + return (absmax + qs.offset).float() + return qs.absmax.float() + + def build_qs(projs): + states = [l.weight.quant_state for l in projs] + absmax = torch.cat([_absmax_fp32(qs) for qs in states]) + q0 = states[0] + return QuantState( + absmax=absmax, + shape=torch.Size((len(projs),) + tuple(q0.shape)), + code=q0.code, + blocksize=q0.blocksize, + quant_type=q0.quant_type, + dtype=q0.dtype, + ) + + cached = ( + build_qs(self.gate_up_projs), + build_qs(self.down_projs), + torch.stack([l.bias for l in self.gate_up_projs]), # (E, 2I) + torch.stack([l.bias for l in self.down_projs]), # (E, H) + ) + self._unsloth_grouped_qs = cached + gate_up_qs, down_qs, gate_up_bias, down_bias = cached + + def _stack(projs, quant_state): + # One fused dequant gives (E, out, in); grouped_mm takes the transposed view. + data = torch.cat([l.weight.data.reshape(-1, 1) for l in projs]) + return bnb.functional.dequantize_4bit(data, quant_state).transpose(1, 2) + + xg = hidden_states[sorted_tokens] + gate_up = _base_grouped_mm( + xg, offsets, lambda: _stack(self.gate_up_projs, gate_up_qs), recompute) + gate_up = gate_up + gate_up_bias[expert_ids] + gated = swiglu_torch_forward(gate_up, self.alpha, self.limit) + out = _base_grouped_mm( + gated, offsets, lambda: _stack(self.down_projs, down_qs), recompute) + out = out + down_bias[expert_ids] + + weighted = out.to(torch.float32) * routing_weights[sorted_tokens, sorted_experts, None].to(torch.float32) + next_states = torch.zeros(num_tokens, self.hidden_size, dtype=torch.float32, device=device) + next_states.index_add_(0, sorted_tokens, weighted) + # Grouped path only runs in training; mirror torch_native_forward's + # training branch, which returns fp32 (fp16 NaN protection). + return next_states.view(batch_size, -1, self.hidden_size).to(torch.float32) + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) num_tokens = hidden_states.shape[0] num_experts = routing_weights.shape[1] top_k = router_indices.shape[1] - + + if self.training and self._grouped_bnb4bit_ready(): + try: + return self._forward_grouped_bnb4bit( + hidden_states, router_indices, routing_weights, + batch_size, num_tokens, num_experts, top_k, + ) + except Exception: + if UNSLOTH_ENABLE_LOGGING: + import traceback; traceback.print_exc() + # fall through to the per-expert loop + if self.training: with torch.no_grad(): flat_experts = router_indices.flatten() # [tokens * topk] @@ -1768,7 +1894,24 @@ def torch_native_forward( num_tokens = hidden_states.shape[0] num_experts = routing_weights.shape[1] top_k = router_indices.shape[1] - + + # Grouped bnb-4bit fast path. The class dispatches this module-level + # function (forward is rebound below), so the gate must live here too. + if ( + self.training + and hasattr(self, "_grouped_bnb4bit_ready") + and self._grouped_bnb4bit_ready() + ): + try: + return self._forward_grouped_bnb4bit( + hidden_states, router_indices, routing_weights, + batch_size, num_tokens, num_experts, top_k, + ) + except Exception: + if UNSLOTH_ENABLE_LOGGING: + import traceback; traceback.print_exc() + # fall through to the per-expert loop + if self.training: with torch.no_grad(): flat_experts = router_indices.flatten() # [tokens * topk] diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index 2dadc9160..c9b8c99d8 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -161,11 +161,9 @@ def _manual_grouped_mm( return inputs.new_empty((0, weight.shape[-1])) -# Optional recompute-in-backward for the frozen base expert GEMM. With -# UNSLOTH_MOE_RECOMPUTE=1 the dequantized bf16 expert stack is rebuilt from the 4-bit -# Params4bit in backward (for dX only; the base is frozen and LoRA is a separate additive -# grouped_mm) instead of being pinned on the tape. Output is unchanged; off by default the -# call below is the prior eager path. +# Optional recompute-in-backward for the frozen base expert GEMM +# (UNSLOTH_MOE_RECOMPUTE=1): rebuild the dequantized bf16 stack from the 4-bit +# Params4bit in backward (dX only) instead of pinning it. Output unchanged; off by default. def _moe_recompute_enabled(source) -> bool: