Grouped-GEMM MoE forward for transformers<5 ModuleList experts#837
Grouped-GEMM MoE forward for transformers<5 ModuleList experts#837danielhanchen wants to merge 11 commits into
Conversation
On transformers<5 a MoE block keeps experts as a nn.ModuleList and loops over them in Python, launching O(num_experts) tiny matmuls per layer and never calling torch._grouped_mm. This adds the grouped recipe (sort tokens by expert, grouped_mm gate_up, act, grouped_mm down, scatter-add) for frozen bnb-4bit (or plain-frozen) ModuleList experts. Experts stay 4-bit; the dequant stack is rebuilt in backward (recompute) or cached. Gated to known block classes with no shared expert, no LoRA on experts, and torch._grouped_mm support, so everything else is a no-op.
There was a problem hiding this comment.
Code Review
This pull request introduces a Grouped-GEMM MoE forward implementation for transformers < 5 using the nn.ModuleList expert layout, aiming to optimize performance by grouping expert matrix multiplications. The feedback suggests several robustness improvements: dynamically rebuilding the dequantized expert cache if a device or dtype mismatch is detected, using getattr instead of hasattr when checking for base_layer to safely handle falsy values, and performing an identity check during unpatching to ensure only the originally patched function is restored.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if getattr(self, "_cached_gate_up", None) is None: | ||
| with torch.no_grad(): | ||
| self._cached_gate_up = _build_gate_up_stack(experts, spec, dtype) | ||
| self._cached_down = _build_down_stack(experts, spec, dtype) |
There was a problem hiding this comment.
When caching the dequantized expert weights, we should verify that the cached tensors match the current device and dtype of the input hidden_states. If the model is cast (e.g., to float16 or bfloat16) or moved to a different GPU after the first forward pass, reusing the stale cache will trigger device or dtype mismatch runtime errors. Rebuilding the cache dynamically on mismatch prevents these crashes.
| if getattr(self, "_cached_gate_up", None) is None: | |
| with torch.no_grad(): | |
| self._cached_gate_up = _build_gate_up_stack(experts, spec, dtype) | |
| self._cached_down = _build_down_stack(experts, spec, dtype) | |
| cached_gate_up = getattr(self, "_cached_gate_up", None) | |
| if cached_gate_up is None or cached_gate_up.device != dev or cached_gate_up.dtype != dtype: | |
| with torch.no_grad(): | |
| self._cached_gate_up = _build_gate_up_stack(experts, spec, dtype) | |
| self._cached_down = _build_down_stack(experts, spec, dtype) |
| w = getattr(lin, "weight", None) | ||
| if w is None: | ||
| return None | ||
| if getattr(lin, "lora_A", None) is not None or hasattr(lin, "base_layer"): # LoRA on experts -> defer |
There was a problem hiding this comment.
Use a truthy getattr check instead of hasattr when checking for base_layer to safely handle cases where the attribute might be set to None or False, preventing unintended behavior.
| if getattr(lin, "lora_A", None) is not None or hasattr(lin, "base_layer"): # LoRA on experts -> defer | |
| if getattr(lin, "lora_A", None) is not None or getattr(lin, "base_layer", None) is not None: # LoRA on experts -> defer |
References
- Use a truthy
getattrcheck (e.g.,getattr(cfg, "attr", None)) instead ofhasattrwhen you need to preserve falsy values likeNoneorFalse.hasattrreturnsTrueeven if the attribute is set toNoneorFalse, which can lead to unintended overwriting of default behaviors.
| if hasattr(module, "_orig_moe_forward"): | ||
| module.forward = module._orig_moe_forward |
There was a problem hiding this comment.
When unpatching a function, we should perform an identity check to ensure the function being replaced is the one originally patched by our code. This is more robust than relying on the state of other modules and prevents accidentally overwriting a different patch or wrapper that might have been applied to module.forward in the meantime.
if hasattr(module, "_orig_moe_forward"):
if getattr(module.forward, "__func__", None) is grouped_moe_forward:
module.forward = module._orig_moe_forwardReferences
- When unpatching a function, perform an identity check to ensure the function being replaced is the one originally patched by your code. This is more robust than relying on the state of other modules.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c3720b4b7a
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| dtype = hidden_states.dtype | ||
| experts = self.experts | ||
|
|
||
| router_logits = self.gate(hidden_states) |
There was a problem hiding this comment.
Preserve Mixtral router jitter in grouped forward
For MixtralSparseMoeBlock training runs where router_jitter_noise/self.jitter_noise is nonzero, the stock ModuleList forward applies multiplicative jitter to hidden_states before computing router_logits; this replacement goes directly to self.gate(hidden_states), so enabling grouped MoE silently disables the configured router jitter and changes both routing and expert activations for those experiments.
Useful? React with 👍 / 👎.
| e0 = experts[0] | ||
| for name in (g_name, u_name, d_name): | ||
| lin = getattr(e0, name, None) |
There was a problem hiding this comment.
Check every expert before bypassing LoRA paths
Eligibility only inspects experts[0], so a model with adapters or trainable weights attached selectively to a later expert still gets patched; the grouped forward then builds stacks from the base weights via _expert_weight and never applies those later expert adapters/gradients. This matters for selective expert LoRA or partial expert fine-tuning, where the patch would silently ignore the trained expert parameters instead of falling back to the original forward.
Useful? React with 👍 / 👎.
| with torch.no_grad(): | ||
| out = _grouped_mm_fix(x, weight_fn(), offsets) | ||
| return out | ||
|
|
There was a problem hiding this comment.
Avoid saving unused grouped inputs in recompute mode
When recompute=True (the default without gradient checkpointing), this custom Function saves the full grouped input tensor even though backward never uses x to compute dX; it only needs offsets and a rebuilt weight. For long sequences this retains the permuted hidden state and intermediate activations for every patched MoE layer, undermining the memory-safe recompute path and causing avoidable peak-memory pressure.
Useful? React with 👍 / 👎.
| for module in model.modules(): | ||
| _restore_block(module) | ||
| return 0 | ||
| if not _grouped_mm_supported(): |
There was a problem hiding this comment.
Gate patching on each block's execution device
On CUDA hosts, this global support probe can be true even for a model that is still CPU-resident or has MoE blocks offloaded to CPU; enable_grouped_moe then replaces the forward, and the next CPU/offloaded forward calls torch._grouped_mm with CPU tensors instead of falling back to the original loop. This breaks CPU/offload workflows on machines that merely have a supported GPU available, so eligibility needs to include the block/input device rather than only process-wide CUDA support.
Useful? React with 👍 / 👎.
| ok = True | ||
| except Exception: | ||
| ok = False | ||
| _GROUPED_MM_SUPPORTED = ok |
- Do not save the grouped input for backward; only offsets are needed for the frozen base dX, so the recompute path no longer pins the permuted activation. - Runtime fallback to the original forward for CPU/offloaded inputs or experts that gained LoRA after patching, so a process-wide grouped_mm probe or a late adapter attach cannot mis-route. - Preserve Mixtral router jitter in the grouped forward (no-op for models without jitter_noise). - Check every expert for frozen/no-LoRA, not just the first, so selective per-expert adapters fall back. - Rebuild the cached dequant stacks on device/dtype change; restore only our own patch; document the guarded excepts.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3531449d84
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if (not hidden_states.is_cuda) or getattr(lin0, "lora_A", None) is not None \ | ||
| or getattr(lin0, "base_layer", None) is not None: |
There was a problem hiding this comment.
Recheck all expert projections before grouped forward
The eligibility loop now checks every expert, but this runtime guard still only rechecks experts[0]'s gate projection after patching. If adapters are attached after enable_grouped_moe runs and they target up_proj/down_proj or any later expert, this guard stays false and the grouped path rebuilds stacks from base weights via _expert_weight, silently ignoring those adapter outputs/gradients instead of falling back to _orig_moe_forward.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 306014be79
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| is_plain_frozen = (not w.requires_grad) and w.dtype in (torch.bfloat16, torch.float16, torch.float32) | ||
| if not (is_4bit or is_plain_frozen): |
There was a problem hiding this comment.
Fall back for plain fp32 experts
When a known ModuleList MoE is loaded as plain frozen fp32 weights, this eligibility check patches it, but the grouped path later rebuilds fp32 stacks and passes them to torch._grouped_mm. The support probe only validates a bf16 grouped matmul, and PyTorch's grouped GEMM path is documented/implemented for low-precision CUDA inputs, so a non-4bit fp32 Mixtral/Qwen MoE on an otherwise supported GPU can be patched and then fail on the first expert matmul instead of using the original loop. Please either restrict plain-frozen experts to supported compute dtypes or add a runtime fallback before calling grouped_mm.
Useful? React with 👍 / 👎.
| dev = hidden_states.device | ||
| dtype = hidden_states.dtype |
There was a problem hiding this comment.
Respect Linear4bit compute dtype
For 4-bit experts whose bnb_4bit_compute_dtype differs from the activation dtype (for example fp16 hidden states with bf16 Linear4bit compute), this grouped path dequantizes every expert stack to hidden_states.dtype and runs grouped_mm in that dtype. The original bitsandbytes Linear4bit casts inputs to the layer's compute_dtype for the matmul and then casts the result back, so enabling this patch changes expert numerics for those valid quantization configs instead of being accuracy-neutral.
Useful? React with 👍 / 👎.
| from .mxfp4 import * | ||
| from .bitsandbytes import * | ||
| from .moe_utils_bnb4bit import * | ||
| from .moe_grouped_modulelist import * |
There was a problem hiding this comment.
Add the new temporary patch to the import smoke list
Adding this new submodule to the package import chain also adds moe_grouped_modulelist.py to the on-disk temporary_patches inventory, but tests/test_temporary_patches_imports.py has a hard-coded TEMPORARY_PATCHES_SUBMODULES list and its completeness test explicitly fails for any new file absent from that list. Please add this module there, otherwise the import-smoke CI will fail even though the package import itself succeeds.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e1d97c19e2
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| return None | ||
| if getattr(lin, "lora_A", None) is not None or getattr(lin, "base_layer", None) is not None: | ||
| return None | ||
| is_4bit = HAS_BNB and isinstance(w, Params4bit) and getattr(w, "quant_state", None) is not None |
There was a problem hiding this comment.
Require 4-bit experts to be frozen
When a 4-bit expert weight has requires_grad=True (for example, an attempted full expert fine-tune), this branch still marks the block eligible because is_4bit ignores w.requires_grad. The grouped path is explicitly frozen: _expert_weight reads w.data and _GroupedFrozenMM only returns dX, so patched trainable 4-bit expert weights silently receive no useful gradients instead of falling back to the original expert loop.
Useful? React with 👍 / 👎.
| if HAS_BNB and isinstance(w, Params4bit): | ||
| return getattr(getattr(w, "quant_state", None), "dtype", torch.bfloat16) |
There was a problem hiding this comment.
Use Linear4bit compute dtype for the guard
Fresh evidence in this revision is that the new guard reads weight.quant_state.dtype, but bitsandbytes stores bnb_4bit_compute_dtype on the Linear4bit module as compute_dtype. For 4-bit experts loaded with fp16 activations/weights but bf16 or fp32 compute, this returns fp16, lets the grouped path run, and _expert_weight(..., hidden_states.dtype) performs the expert matmuls in fp16 instead of the configured bitsandbytes compute dtype, changing the numerics rather than falling back.
Useful? React with 👍 / 👎.
| # Fall back to the exact original loop unless this is the frozen-expert CUDA path on a | ||
| # grouped_mm-supported dtype that matches the experts' compute dtype (so CPU/offload, | ||
| # LoRA-on-experts, fp32, or a bnb compute_dtype != activation dtype stay bit-identical). | ||
| if (not hidden_states.is_cuda) or _experts_have_lora(experts, spec) \ |
There was a problem hiding this comment.
Reprobe grouped_mm on the input CUDA device
On mixed-GPU hosts where enable_grouped_moe probes a supported GPU as the current device but a device_map sends this MoE block to an older CUDA GPU, this runtime guard only checks hidden_states.is_cuda. _GROUPED_MM_SUPPORTED can therefore stay true from another device, and the later _grouped_mm_fix calls run torch._grouped_mm on an unsupported GPU instead of falling back to _orig_moe_forward; cache support by hidden_states.device or recheck before taking the grouped path.
Useful? React with 👍 / 👎.
|
|
||
| router_logits = self.gate(hidden_states) | ||
| rw, sel = spec[3](self, router_logits, top_k) # exact per-model router | ||
| rw = rw.to(dtype) |
There was a problem hiding this comment.
Preserve the router-logit dtype fix
When the Unsloth compiler has rewritten old-transformers MoE forwards, patch_moe_routing_weights_cast changes the route weight cast to router_logits.dtype for the documented bf16 router-logit dtype fix. This instance forward is intended to win over that compiled-cache class patch, but it casts the softmax weights back to the activation dtype here, so cases with higher-precision router logits regress to bf16/fp16 routing weights instead of matching the patched original forward.
Useful? React with 👍 / 👎.
| final = torch.zeros((T, hidden_dim), dtype=torch.float32, device=dev) | ||
| final.index_add_(0, sorted_tok, down.to(torch.float32)) |
There was a problem hiding this comment.
Preserve Mixtral's accumulation dtype
For old-transformers MixtralSparseMoeBlock, the stock ModuleList loop initializes final_hidden_states in hidden_states.dtype and index-adds expert outputs cast to that same dtype. This grouped replacement always accumulates the routed expert outputs in fp32 before casting back, so bf16/fp16 Mixtral runs no longer match the original loop's reductions when grouped MoE is enabled.
Useful? React with 👍 / 👎.
OLMoE's routed block has the same structure as Qwen3-MoE (gate/up/down SwiGLU experts, softmax(fp32) -> top_k -> optional renorm, no shared expert), so it reuses the same spec and router. transformers>=5 stacks its MoE experts, so no shipped model exercises the ModuleList path there. The new test builds synthetic ModuleList blocks for each spec (Qwen3-MoE / Mixtral / OLMoE) and checks grouped_moe_forward against a plain per-expert reference loop in the default, cache and recompute modes, forward and backward (dX), plus the eligibility and shared-expert bail. Skips where torch._grouped_mm is unsupported.
Adds a round-trip test (enable_grouped_moe swaps in the grouped forward with parity to the reference loop, disable_grouped_moe restores the original and clears the patch attrs) and a loader-wrapper test (wrap_loader_for_grouped_moe is idempotent, passes the loader return through, and enables the grouped path on the returned model). This is the entry point the loader hook wires in.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: dd5a54c300
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # softmax(fp32) -> top_k -> optional renorm via norm_topk_prob) and no shared expert, so | ||
| # the same spec/router applies. Verified against the reference loop in | ||
| # tests/test_moe_grouped_modulelist_parity.py. | ||
| "OlmoeSparseMoeBlock": ("gate_proj", "up_proj", "down_proj", _route_softmax_topk), |
There was a problem hiding this comment.
Keep Olmoe on the tensor-return path
When this entry matches the legacy ModuleList OLMoE layout, enable_grouped_moe will replace OlmoeSparseMoeBlock.forward, but OLMoE's decoder uses the MoE block as hidden_states = self.mlp(hidden_states) and the block returns only the hidden-state tensor. The grouped replacement unconditionally returns (final, router_logits), so eligible OLMoE models will hand a tuple to the residual add instead of a tensor; either omit Olmoe here or make the patched forward preserve that model's return contract.
Useful? React with 👍 / 👎.
The compute-dtype fallback guard read weight.quant_state.dtype (the dequant target dtype), but bitsandbytes casts inputs to Linear4bit.compute_dtype for the matmul. For a 4-bit expert whose compute_dtype differs from the activation dtype the guard could take the fast path and change numerics instead of falling back. Read compute_dtype directly, falling back to the quant_state dtype only when it is unset. Adds a unit test.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0d099a8bdb
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if (not hidden_states.is_cuda) or _experts_have_lora(experts, spec) \ | ||
| or hidden_states.dtype not in (torch.bfloat16, torch.float16) \ | ||
| or _expert_compute_dtype(experts, spec) != hidden_states.dtype: |
There was a problem hiding this comment.
Guard against offloaded expert weights
When a model is loaded with a device map or CPU offload, the block can receive CUDA hidden_states while some expert weights are still on CPU or another CUDA device; this guard only checks the activation device/dtype, so the grouped path proceeds, builds the stacked weights on the experts' device, and later calls _grouped_mm_fix with tensors on different devices instead of falling back to the original expert modules/hooks. This turns the documented CPU/offload no-op case into a runtime device-mismatch crash for offloaded MoE layers.
Useful? React with 👍 / 👎.
| return None | ||
| if getattr(lin, "lora_A", None) is not None or getattr(lin, "base_layer", None) is not None: | ||
| return None | ||
| is_4bit = HAS_BNB and isinstance(w, Params4bit) and getattr(w, "quant_state", None) is not None |
There was a problem hiding this comment.
Reject trainable 4-bit expert weights
The grouped implementation treats expert weights as constants (_GroupedFrozenMM only returns dX, and cache mode stores detached dequantized stacks), but this 4-bit eligibility branch does not require not w.requires_grad. In a 4-bit full-finetuning or explicitly-unfrozen expert setup, the block is patched and expert-weight gradients are silently dropped instead of falling back to the original Linear4bit forward; the plain-linear branch already enforces frozen weights, so the 4-bit branch should too.
Useful? React with 👍 / 👎.
| ok = True | ||
| except Exception: | ||
| ok = False | ||
| _GROUPED_MM_SUPPORTED = ok |
There was a problem hiding this comment.
Cache grouped-mm support per CUDA device
In heterogeneous CUDA setups, the first successful probe is cached globally and then reused for every later layer/device. If enable_grouped_moe probes a supported GPU but an eligible MoE block is placed on an older CUDA device via a device map, grouped_moe_forward only checks hidden_states.is_cuda before calling torch._grouped_mm, so that layer can crash instead of falling back to the original loop.
Useful? React with 👍 / 👎.
enable/auto_enable is re-entrant: a block whose experts gained LoRA is restored to the original loop while a frozen block stays on the grouped path. This is the property the loader relies on when it re-runs auto_enable_grouped_moe after attaching a loaded PEFT adapter.
…rop unused import - Eligibility now requires 4-bit experts to be frozen too (mirrors the plain-Linear branch); the grouped path returns dX only, so a trainable expert would silently get no gradient. In practice Params4bit stores uint8 and cannot require grad, so this is defensive consistency; the reachable trainable case is a plain bf16 expert. - grouped_moe_forward falls back to the original loop when the experts' weights are not on the activation device, so a device_map / CPU offload that splits a block across devices no longer crashes in grouped_mm. - Remove the now-unused _route_softmax_topk test import. Tests cover both guards.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 40a8011c8f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| w = getattr(getattr(experts[0], spec[0], None), "weight", None) | ||
| return getattr(w, "device", device) == device |
There was a problem hiding this comment.
Check every expert projection device
Fresh evidence in this revision is that _experts_on_device was added, but it only checks experts[0]'s gate projection. When a device map or CPU offload leaves any later expert/projection on CPU or another CUDA device while that first weight is on the activation device, this guard passes and _build_gate_up_stack/_build_down_stack then read all projections and try to stack or grouped-mm tensors across devices instead of falling back to the original module path/hooks.
Useful? React with 👍 / 👎.
| if (not hidden_states.is_cuda) or _experts_have_lora(experts, spec) \ | ||
| or hidden_states.dtype not in (torch.bfloat16, torch.float16) \ | ||
| or _expert_compute_dtype(experts, spec) != hidden_states.dtype \ | ||
| or not _experts_on_device(experts, spec, hidden_states.device): |
There was a problem hiding this comment.
Recheck frozen expert weights at forward time
Fresh evidence in this revision is that _block_is_eligible now enforces frozen experts when enable_grouped_moe runs, but this runtime guard does not revalidate requires_grad. If a user unfreezes expert weights after loading/patching for selective or full expert fine-tuning, the grouped path still runs and the recompute/cache paths treat those weights as constants, so expert gradients are silently dropped instead of falling back to _orig_moe_forward.
Useful? React with 👍 / 👎.
| lin = getattr(experts[0], spec[0], None) | ||
| w = getattr(lin, "weight", None) | ||
| if HAS_BNB and isinstance(w, Params4bit): | ||
| compute_dtype = getattr(lin, "compute_dtype", None) | ||
| if compute_dtype is not None: | ||
| return compute_dtype | ||
| return getattr(getattr(w, "quant_state", None), "dtype", torch.bfloat16) | ||
| return getattr(w, "dtype", torch.bfloat16) |
There was a problem hiding this comment.
Validate compute dtype for all projections
This guard only samples experts[0]'s gate projection, but _build_gate_up_stack and _build_down_stack later force every expert projection to hidden_states.dtype. If any later expert or up_proj/down_proj has a different Linear4bit.compute_dtype or low-precision weight dtype, the guard passes and the grouped path changes that projection's matmul dtype instead of falling back to the original Linear/Linear4bit forward.
Useful? React with 👍 / 👎.
…rts[0] The stacks are built from all experts, but the runtime guard only sampled experts[0] for device and compute dtype and never re-checked requires_grad, so a device_map / CPU offload, a partially-unfrozen block, or a mixed compute-dtype block could slip through and crash in grouped_mm or silently drop expert gradients. Consolidate the LoRA / frozen / device / compute-dtype checks into a single _experts_grouped_ready pass over every expert projection; fall back to the original loop on any mismatch. Test covers a mismatch on a later expert.
|
@codex review |
|
Codex Review: Didn't find any major issues. Swish! Reviewed commit: ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
Shorten the verbose module/function docstrings and drop obvious inline comments; code is unchanged (verified comment-only).
Problem
On transformers < 5 a Mixture-of-Experts block (
Qwen3MoeSparseMoeBlock,MixtralSparseMoeBlock, ...) keeps its experts as ann.ModuleListof per-expert MLPs and the decoder loops over every expert in Python. For a 128-expert model that is O(num_experts) tiny matmuls plus a data-dependent sync per layer, so the block is launch/sync bound andtorch._grouped_mmis never called. The v5 stacked-parameter layout already uses a grouped path; the < 5 ModuleList layout does not.Change
New
temporary_patches/moe_grouped_modulelist.pyreplaces the per-expert loop with the grouped recipe:recompute, memory-safe and auto-on when gradient checkpointing is off) or held resident (cache, a big-GPU opt-in). Same bf16 math and float32 accumulation as the loop, so accuracy is neutral.Safety / activation
Activates only when all hold, else the original forward runs unchanged:
Qwen3MoeSparseMoeBlock,MixtralSparseMoeBlock),torch._grouped_mmis supported (CUDA).So transformers v5, non-bnb, LoRA-on-experts, full finetuning, and Mac/MLX/AMD/Intel/CPU are all no-ops. Disable explicitly with
UNSLOTH_MOE_GROUPED=0.Results
Qwen3-30B-A3B QLoRA, 1x B200, seq 2048, gradient checkpointing on, attention LoRA:
~3.3x faster at the same 4-bit memory. Block-level parity vs the loop: output cosine 0.99999, dX cosine 0.99997, router logits bit-identical; per-step training loss matches within bf16 noise.
Note
A small companion change in
unslothwires the loader to callenable_grouped_moeafterfrom_pretrained/get_peft_model(viawrap_loader_for_grouped_moe), gated byUNSLOTH_MOE_GROUPED. That import is wrapped in try/except, so this can land independently.