Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 140 additions & 2 deletions unsloth_zoo/temporary_patches/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_torch_compile_options,
UNSLOTH_ENABLE_LOGGING,
UNSLOTH_COMPILE_DISABLE,
logger,
)
from importlib.metadata import version as importlib_version
from ..utils import Version
Expand Down Expand Up @@ -848,13 +849,133 @@ def __init__(self, config):
"down_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False
)

def _grouped_bnb4bit_ready(self):
"""True when every expert is a plain (LoRA-free) bnb Linear4bit with a
populated quant_state, so the grouped torch._grouped_mm path is exact."""
if os.environ.get("UNSLOTH_GPTOSS_GROUPED", "1") == "0":
return False
def _fail(reason):
if UNSLOTH_ENABLE_LOGGING and not getattr(self, "_unsloth_grouped_logged", False):
self._unsloth_grouped_logged = True
logger.info(f"Unsloth: gpt-oss grouped path disabled: {reason}")
return False
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Params4bit
from .moe_utils import _check_torch_grouped_mm_supported
if not _check_torch_grouped_mm_supported():
return _fail("torch._grouped_mm unsupported")
for lin in list(self.gate_up_projs) + list(self.down_projs):
if hasattr(lin, "lora_A") or hasattr(lin, "base_layer"):
return _fail(f"LoRA-wrapped expert {type(lin).__name__}")
w = getattr(lin, "weight", None)
if not (isinstance(w, Params4bit) and getattr(w, "quant_state", None) is not None):
return _fail(f"expert weight {type(w).__name__} without quant_state")
if w.requires_grad:
return _fail("expert weight requires_grad")
b = getattr(lin, "bias", None)
if b is not None and b.requires_grad:
return _fail("expert bias requires_grad")
except Exception as e:
return _fail(f"{type(e).__name__}: {e}")
return True

def _forward_grouped_bnb4bit(self, hidden_states, router_indices, routing_weights,
batch_size, num_tokens, num_experts, top_k):
"""Grouped equivalent of the per-expert loop: one gather, two grouped_mm
calls over dequantized stacks (rebuilt in backward when
UNSLOTH_MOE_RECOMPUTE=1), grouped bias adds, fp32 index_add combine."""
import bitsandbytes as bnb
from .moe_utils import _base_grouped_mm

device = hidden_states.device
with torch.no_grad():
flat_experts = router_indices.flatten()
token_ids = torch.arange(num_tokens, device=device).repeat_interleave(top_k)
sorted_idx = flat_experts.argsort(stable=True)
sorted_tokens = token_ids[sorted_idx]
sorted_experts = flat_experts[sorted_idx]
counts = torch.bincount(flat_experts, minlength=num_experts)
offsets = counts.cumsum(0, dtype=torch.int32)
expert_ids = torch.repeat_interleave(
torch.arange(num_experts, device=device), counts
)

recompute = os.environ.get("UNSLOTH_MOE_RECOMPUTE", "0") == "1"

cached = getattr(self, "_unsloth_grouped_qs", None)
if cached is None:
# One QuantState spanning all experts (packed NF4 is elementwise
# row-major, so per-expert byte + fp32 absmax concat is exact):
# ONE dequantize_4bit launch per stack instead of num_experts.
from bitsandbytes.functional import QuantState

def _absmax_fp32(qs):
# Materialize a QuantState's absmax as flat fp32 (denesting double-quant).
if getattr(qs, "nested", False):
absmax = bnb.functional.dequantize_blockwise(qs.absmax, qs.state2)
return (absmax + qs.offset).float()
return qs.absmax.float()

def build_qs(projs):
states = [l.weight.quant_state for l in projs]
absmax = torch.cat([_absmax_fp32(qs) for qs in states])
q0 = states[0]
return QuantState(
absmax=absmax,
shape=torch.Size((len(projs),) + tuple(q0.shape)),
code=q0.code,
blocksize=q0.blocksize,
quant_type=q0.quant_type,
dtype=q0.dtype,
)

cached = (
build_qs(self.gate_up_projs),
build_qs(self.down_projs),
torch.stack([l.bias for l in self.gate_up_projs]), # (E, 2I)
torch.stack([l.bias for l in self.down_projs]), # (E, H)
)
self._unsloth_grouped_qs = cached
gate_up_qs, down_qs, gate_up_bias, down_bias = cached

def _stack(projs, quant_state):
# One fused dequant gives (E, out, in); grouped_mm takes the transposed view.
data = torch.cat([l.weight.data.reshape(-1, 1) for l in projs])
return bnb.functional.dequantize_4bit(data, quant_state).transpose(1, 2)

xg = hidden_states[sorted_tokens]
gate_up = _base_grouped_mm(
xg, offsets, lambda: _stack(self.gate_up_projs, gate_up_qs), recompute)
gate_up = gate_up + gate_up_bias[expert_ids]
gated = swiglu_torch_forward(gate_up, self.alpha, self.limit)
out = _base_grouped_mm(
gated, offsets, lambda: _stack(self.down_projs, down_qs), recompute)
out = out + down_bias[expert_ids]

weighted = out.to(torch.float32) * routing_weights[sorted_tokens, sorted_experts, None].to(torch.float32)
next_states = torch.zeros(num_tokens, self.hidden_size, dtype=torch.float32, device=device)
next_states.index_add_(0, sorted_tokens, weighted)
return next_states.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype)

def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
num_tokens = hidden_states.shape[0]
num_experts = routing_weights.shape[1]
top_k = router_indices.shape[1]


if self.training and self._grouped_bnb4bit_ready():
try:
return self._forward_grouped_bnb4bit(
hidden_states, router_indices, routing_weights,
batch_size, num_tokens, num_experts, top_k,
)
except Exception:
if UNSLOTH_ENABLE_LOGGING:
import traceback; traceback.print_exc()
# fall through to the per-expert loop

if self.training:
with torch.no_grad():
flat_experts = router_indices.flatten() # [tokens * topk]
Expand Down Expand Up @@ -1768,7 +1889,24 @@ def torch_native_forward(
num_tokens = hidden_states.shape[0]
num_experts = routing_weights.shape[1]
top_k = router_indices.shape[1]


# Grouped bnb-4bit fast path. The class dispatches this module-level
# function (forward is rebound below), so the gate must live here too.
if (
self.training
and hasattr(self, "_grouped_bnb4bit_ready")
and self._grouped_bnb4bit_ready()
):
try:
return self._forward_grouped_bnb4bit(
hidden_states, router_indices, routing_weights,
batch_size, num_tokens, num_experts, top_k,
)
except Exception:
if UNSLOTH_ENABLE_LOGGING:
import traceback; traceback.print_exc()
# fall through to the per-expert loop

if self.training:
with torch.no_grad():
flat_experts = router_indices.flatten() # [tokens * topk]
Expand Down
Loading