Skip to content

Grouped-GEMM MoE forward for transformers<5 ModuleList experts#837

Open
danielhanchen wants to merge 11 commits into
mainfrom
moe-grouped-modulelist
Open

Grouped-GEMM MoE forward for transformers<5 ModuleList experts#837
danielhanchen wants to merge 11 commits into
mainfrom
moe-grouped-modulelist

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Problem

On transformers < 5 a Mixture-of-Experts block (Qwen3MoeSparseMoeBlock, MixtralSparseMoeBlock, ...) keeps its experts as a nn.ModuleList of per-expert MLPs and the decoder loops over every expert in Python. For a 128-expert model that is O(num_experts) tiny matmuls plus a data-dependent sync per layer, so the block is launch/sync bound and torch._grouped_mm is never called. The v5 stacked-parameter layout already uses a grouped path; the < 5 ModuleList layout does not.

Change

New temporary_patches/moe_grouped_modulelist.py replaces the per-expert loop with the grouped recipe:

route -> sort tokens by expert -> grouped_mm (gate_up) -> act(gate) * up
-> grouped_mm (down) -> router-weight scale -> float32 scatter-add
  • Experts stay quantized. The dequantized bf16 stack is rebuilt in backward (recompute, memory-safe and auto-on when gradient checkpointing is off) or held resident (cache, a big-GPU opt-in). Same bf16 math and float32 accumulation as the loop, so accuracy is neutral.
  • The instance forward is patched after the model is built, so it wins over the compiled-cache class patch and survives cache regeneration.

Safety / activation

Activates only when all hold, else the original forward runs unchanged:

  • a known block class with no shared expert (registry: Qwen3MoeSparseMoeBlock, MixtralSparseMoeBlock),
  • frozen bnb-4bit (or plain-frozen) ModuleList experts with no LoRA attached to them,
  • torch._grouped_mm is supported (CUDA).

So transformers v5, non-bnb, LoRA-on-experts, full finetuning, and Mac/MLX/AMD/Intel/CPU are all no-ops. Disable explicitly with UNSLOTH_MOE_GROUPED=0.

Results

Qwen3-30B-A3B QLoRA, 1x B200, seq 2048, gradient checkpointing on, attention LoRA:

path tok/s peak (active)
loop (current) 204 18.5 GiB
grouped (this change) 664 18.8 GiB

~3.3x faster at the same 4-bit memory. Block-level parity vs the loop: output cosine 0.99999, dX cosine 0.99997, router logits bit-identical; per-step training loss matches within bf16 noise.

Note

A small companion change in unsloth wires the loader to call enable_grouped_moe after from_pretrained / get_peft_model (via wrap_loader_for_grouped_moe), gated by UNSLOTH_MOE_GROUPED. That import is wrapped in try/except, so this can land independently.

On transformers<5 a MoE block keeps experts as a nn.ModuleList and loops over them in Python, launching O(num_experts) tiny matmuls per layer and never calling torch._grouped_mm. This adds the grouped recipe (sort tokens by expert, grouped_mm gate_up, act, grouped_mm down, scatter-add) for frozen bnb-4bit (or plain-frozen) ModuleList experts. Experts stay 4-bit; the dequant stack is rebuilt in backward (recompute) or cached. Gated to known block classes with no shared expert, no LoRA on experts, and torch._grouped_mm support, so everything else is a no-op.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Grouped-GEMM MoE forward implementation for transformers < 5 using the nn.ModuleList expert layout, aiming to optimize performance by grouping expert matrix multiplications. The feedback suggests several robustness improvements: dynamically rebuilding the dequantized expert cache if a device or dtype mismatch is detected, using getattr instead of hasattr when checking for base_layer to safely handle falsy values, and performing an identity check during unpatching to ensure only the originally patched function is restored.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +210 to +213
if getattr(self, "_cached_gate_up", None) is None:
with torch.no_grad():
self._cached_gate_up = _build_gate_up_stack(experts, spec, dtype)
self._cached_down = _build_down_stack(experts, spec, dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When caching the dequantized expert weights, we should verify that the cached tensors match the current device and dtype of the input hidden_states. If the model is cast (e.g., to float16 or bfloat16) or moved to a different GPU after the first forward pass, reusing the stale cache will trigger device or dtype mismatch runtime errors. Rebuilding the cache dynamically on mismatch prevents these crashes.

Suggested change
if getattr(self, "_cached_gate_up", None) is None:
with torch.no_grad():
self._cached_gate_up = _build_gate_up_stack(experts, spec, dtype)
self._cached_down = _build_down_stack(experts, spec, dtype)
cached_gate_up = getattr(self, "_cached_gate_up", None)
if cached_gate_up is None or cached_gate_up.device != dev or cached_gate_up.dtype != dtype:
with torch.no_grad():
self._cached_gate_up = _build_gate_up_stack(experts, spec, dtype)
self._cached_down = _build_down_stack(experts, spec, dtype)

w = getattr(lin, "weight", None)
if w is None:
return None
if getattr(lin, "lora_A", None) is not None or hasattr(lin, "base_layer"): # LoRA on experts -> defer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use a truthy getattr check instead of hasattr when checking for base_layer to safely handle cases where the attribute might be set to None or False, preventing unintended behavior.

Suggested change
if getattr(lin, "lora_A", None) is not None or hasattr(lin, "base_layer"): # LoRA on experts -> defer
if getattr(lin, "lora_A", None) is not None or getattr(lin, "base_layer", None) is not None: # LoRA on experts -> defer
References
  1. Use a truthy getattr check (e.g., getattr(cfg, "attr", None)) instead of hasattr when you need to preserve falsy values like None or False. hasattr returns True even if the attribute is set to None or False, which can lead to unintended overwriting of default behaviors.

Comment on lines +269 to +270
if hasattr(module, "_orig_moe_forward"):
module.forward = module._orig_moe_forward

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When unpatching a function, we should perform an identity check to ensure the function being replaced is the one originally patched by our code. This is more robust than relying on the state of other modules and prevents accidentally overwriting a different patch or wrapper that might have been applied to module.forward in the meantime.

    if hasattr(module, "_orig_moe_forward"):
        if getattr(module.forward, "__func__", None) is grouped_moe_forward:
            module.forward = module._orig_moe_forward
References
  1. When unpatching a function, perform an identity check to ensure the function being replaced is the one originally patched by your code. This is more robust than relying on the state of other modules.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c3720b4b7a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

dtype = hidden_states.dtype
experts = self.experts

router_logits = self.gate(hidden_states)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve Mixtral router jitter in grouped forward

For MixtralSparseMoeBlock training runs where router_jitter_noise/self.jitter_noise is nonzero, the stock ModuleList forward applies multiplicative jitter to hidden_states before computing router_logits; this replacement goes directly to self.gate(hidden_states), so enabling grouped MoE silently disables the configured router jitter and changes both routing and expert activations for those experiments.

Useful? React with 👍 / 👎.

Comment on lines +247 to +249
e0 = experts[0]
for name in (g_name, u_name, d_name):
lin = getattr(e0, name, None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Check every expert before bypassing LoRA paths

Eligibility only inspects experts[0], so a model with adapters or trainable weights attached selectively to a later expert still gets patched; the grouped forward then builds stacks from the base weights via _expert_weight and never applies those later expert adapters/gradients. This matters for selective expert LoRA or partial expert fine-tuning, where the patch would silently ignore the trained expert parameters instead of falling back to the original forward.

Useful? React with 👍 / 👎.

with torch.no_grad():
out = _grouped_mm_fix(x, weight_fn(), offsets)
return out

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid saving unused grouped inputs in recompute mode

When recompute=True (the default without gradient checkpointing), this custom Function saves the full grouped input tensor even though backward never uses x to compute dX; it only needs offsets and a rebuilt weight. For long sequences this retains the permuted hidden state and intermediate activations for every patched MoE layer, undermining the memory-safe recompute path and causing avoidable peak-memory pressure.

Useful? React with 👍 / 👎.

for module in model.modules():
_restore_block(module)
return 0
if not _grouped_mm_supported():

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Gate patching on each block's execution device

On CUDA hosts, this global support probe can be true even for a model that is still CPU-resident or has MoE blocks offloaded to CPU; enable_grouped_moe then replaces the forward, and the next CPU/offloaded forward calls torch._grouped_mm with CPU tensors instead of falling back to the original loop. This breaks CPU/offload workflows on machines that merely have a supported GPU available, so eligibility needs to include the block/input device rather than only process-wide CUDA support.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/temporary_patches/moe_grouped_modulelist.py Fixed
Comment thread unsloth_zoo/temporary_patches/moe_grouped_modulelist.py Fixed
Comment thread unsloth_zoo/temporary_patches/moe_grouped_modulelist.py Fixed
ok = True
except Exception:
ok = False
_GROUPED_MM_SUPPORTED = ok
- Do not save the grouped input for backward; only offsets are needed for the frozen base dX, so the recompute path no longer pins the permuted activation.
- Runtime fallback to the original forward for CPU/offloaded inputs or experts that gained LoRA after patching, so a process-wide grouped_mm probe or a late adapter attach cannot mis-route.
- Preserve Mixtral router jitter in the grouped forward (no-op for models without jitter_noise).
- Check every expert for frozen/no-LoRA, not just the first, so selective per-expert adapters fall back.
- Rebuild the cached dequant stacks on device/dtype change; restore only our own patch; document the guarded excepts.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 3531449d84

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +181 to +182
if (not hidden_states.is_cuda) or getattr(lin0, "lora_A", None) is not None \
or getattr(lin0, "base_layer", None) is not None:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Recheck all expert projections before grouped forward

The eligibility loop now checks every expert, but this runtime guard still only rechecks experts[0]'s gate projection after patching. If adapters are attached after enable_grouped_moe runs and they target up_proj/down_proj or any later expert, this guard stays false and the grouped path rebuilds stacks from base weights via _expert_weight, silently ignoring those adapter outputs/gradients instead of falling back to _orig_moe_forward.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 306014be79

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +274 to +275
is_plain_frozen = (not w.requires_grad) and w.dtype in (torch.bfloat16, torch.float16, torch.float32)
if not (is_4bit or is_plain_frozen):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Fall back for plain fp32 experts

When a known ModuleList MoE is loaded as plain frozen fp32 weights, this eligibility check patches it, but the grouped path later rebuilds fp32 stacks and passes them to torch._grouped_mm. The support probe only validates a bf16 grouped matmul, and PyTorch's grouped GEMM path is documented/implemented for low-precision CUDA inputs, so a non-4bit fp32 Mixtral/Qwen MoE on an otherwise supported GPU can be patched and then fail on the first expert matmul instead of using the original loop. Please either restrict plain-frozen experts to supported compute dtypes or add a runtime fallback before calling grouped_mm.

Useful? React with 👍 / 👎.

Comment on lines +205 to +206
dev = hidden_states.device
dtype = hidden_states.dtype

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Respect Linear4bit compute dtype

For 4-bit experts whose bnb_4bit_compute_dtype differs from the activation dtype (for example fp16 hidden states with bf16 Linear4bit compute), this grouped path dequantizes every expert stack to hidden_states.dtype and runs grouped_mm in that dtype. The original bitsandbytes Linear4bit casts inputs to the layer's compute_dtype for the matmul and then casts the result back, so enabling this patch changes expert numerics for those valid quantization configs instead of being accuracy-neutral.

Useful? React with 👍 / 👎.

from .mxfp4 import *
from .bitsandbytes import *
from .moe_utils_bnb4bit import *
from .moe_grouped_modulelist import *

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Add the new temporary patch to the import smoke list

Adding this new submodule to the package import chain also adds moe_grouped_modulelist.py to the on-disk temporary_patches inventory, but tests/test_temporary_patches_imports.py has a hard-coded TEMPORARY_PATCHES_SUBMODULES list and its completeness test explicitly fails for any new file absent from that list. Please add this module there, otherwise the import-smoke CI will fail even though the package import itself succeeds.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e1d97c19e2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

return None
if getattr(lin, "lora_A", None) is not None or getattr(lin, "base_layer", None) is not None:
return None
is_4bit = HAS_BNB and isinstance(w, Params4bit) and getattr(w, "quant_state", None) is not None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Require 4-bit experts to be frozen

When a 4-bit expert weight has requires_grad=True (for example, an attempted full expert fine-tune), this branch still marks the block eligible because is_4bit ignores w.requires_grad. The grouped path is explicitly frozen: _expert_weight reads w.data and _GroupedFrozenMM only returns dX, so patched trainable 4-bit expert weights silently receive no useful gradients instead of falling back to the original expert loop.

Useful? React with 👍 / 👎.

Comment on lines +190 to +191
if HAS_BNB and isinstance(w, Params4bit):
return getattr(getattr(w, "quant_state", None), "dtype", torch.bfloat16)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use Linear4bit compute dtype for the guard

Fresh evidence in this revision is that the new guard reads weight.quant_state.dtype, but bitsandbytes stores bnb_4bit_compute_dtype on the Linear4bit module as compute_dtype. For 4-bit experts loaded with fp16 activations/weights but bf16 or fp32 compute, this returns fp16, lets the grouped path run, and _expert_weight(..., hidden_states.dtype) performs the expert matmuls in fp16 instead of the configured bitsandbytes compute dtype, changing the numerics rather than falling back.

Useful? React with 👍 / 👎.

# Fall back to the exact original loop unless this is the frozen-expert CUDA path on a
# grouped_mm-supported dtype that matches the experts' compute dtype (so CPU/offload,
# LoRA-on-experts, fp32, or a bnb compute_dtype != activation dtype stay bit-identical).
if (not hidden_states.is_cuda) or _experts_have_lora(experts, spec) \

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reprobe grouped_mm on the input CUDA device

On mixed-GPU hosts where enable_grouped_moe probes a supported GPU as the current device but a device_map sends this MoE block to an older CUDA GPU, this runtime guard only checks hidden_states.is_cuda. _GROUPED_MM_SUPPORTED can therefore stay true from another device, and the later _grouped_mm_fix calls run torch._grouped_mm on an unsupported GPU instead of falling back to _orig_moe_forward; cache support by hidden_states.device or recheck before taking the grouped path.

Useful? React with 👍 / 👎.


router_logits = self.gate(hidden_states)
rw, sel = spec[3](self, router_logits, top_k) # exact per-model router
rw = rw.to(dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve the router-logit dtype fix

When the Unsloth compiler has rewritten old-transformers MoE forwards, patch_moe_routing_weights_cast changes the route weight cast to router_logits.dtype for the documented bf16 router-logit dtype fix. This instance forward is intended to win over that compiled-cache class patch, but it casts the softmax weights back to the activation dtype here, so cases with higher-precision router logits regress to bf16/fp16 routing weights instead of matching the patched original forward.

Useful? React with 👍 / 👎.

Comment on lines +255 to +256
final = torch.zeros((T, hidden_dim), dtype=torch.float32, device=dev)
final.index_add_(0, sorted_tok, down.to(torch.float32))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve Mixtral's accumulation dtype

For old-transformers MixtralSparseMoeBlock, the stock ModuleList loop initializes final_hidden_states in hidden_states.dtype and index-adds expert outputs cast to that same dtype. This grouped replacement always accumulates the routed expert outputs in fp32 before casting back, so bf16/fp16 Mixtral runs no longer match the original loop's reductions when grouped MoE is enabled.

Useful? React with 👍 / 👎.

OLMoE's routed block has the same structure as Qwen3-MoE (gate/up/down
SwiGLU experts, softmax(fp32) -> top_k -> optional renorm, no shared
expert), so it reuses the same spec and router.

transformers>=5 stacks its MoE experts, so no shipped model exercises the
ModuleList path there. The new test builds synthetic ModuleList blocks for
each spec (Qwen3-MoE / Mixtral / OLMoE) and checks grouped_moe_forward
against a plain per-expert reference loop in the default, cache and
recompute modes, forward and backward (dX), plus the eligibility and
shared-expert bail. Skips where torch._grouped_mm is unsupported.
Adds a round-trip test (enable_grouped_moe swaps in the grouped forward
with parity to the reference loop, disable_grouped_moe restores the
original and clears the patch attrs) and a loader-wrapper test
(wrap_loader_for_grouped_moe is idempotent, passes the loader return
through, and enables the grouped path on the returned model). This is the
entry point the loader hook wires in.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: dd5a54c300

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

# softmax(fp32) -> top_k -> optional renorm via norm_topk_prob) and no shared expert, so
# the same spec/router applies. Verified against the reference loop in
# tests/test_moe_grouped_modulelist_parity.py.
"OlmoeSparseMoeBlock": ("gate_proj", "up_proj", "down_proj", _route_softmax_topk),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep Olmoe on the tensor-return path

When this entry matches the legacy ModuleList OLMoE layout, enable_grouped_moe will replace OlmoeSparseMoeBlock.forward, but OLMoE's decoder uses the MoE block as hidden_states = self.mlp(hidden_states) and the block returns only the hidden-state tensor. The grouped replacement unconditionally returns (final, router_logits), so eligible OLMoE models will hand a tuple to the residual add instead of a tensor; either omit Olmoe here or make the patched forward preserve that model's return contract.

Useful? React with 👍 / 👎.

The compute-dtype fallback guard read weight.quant_state.dtype (the dequant
target dtype), but bitsandbytes casts inputs to Linear4bit.compute_dtype for
the matmul. For a 4-bit expert whose compute_dtype differs from the
activation dtype the guard could take the fast path and change numerics
instead of falling back. Read compute_dtype directly, falling back to the
quant_state dtype only when it is unset. Adds a unit test.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Repo admins can enable using credits for code reviews in their settings.

@danielhanchen

Copy link
Copy Markdown
Member Author

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 0d099a8bdb

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +213 to +215
if (not hidden_states.is_cuda) or _experts_have_lora(experts, spec) \
or hidden_states.dtype not in (torch.bfloat16, torch.float16) \
or _expert_compute_dtype(experts, spec) != hidden_states.dtype:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Guard against offloaded expert weights

When a model is loaded with a device map or CPU offload, the block can receive CUDA hidden_states while some expert weights are still on CPU or another CUDA device; this guard only checks the activation device/dtype, so the grouped path proceeds, builds the stacked weights on the experts' device, and later calls _grouped_mm_fix with tensors on different devices instead of falling back to the original expert modules/hooks. This turns the documented CPU/offload no-op case into a runtime device-mismatch crash for offloaded MoE layers.

Useful? React with 👍 / 👎.

return None
if getattr(lin, "lora_A", None) is not None or getattr(lin, "base_layer", None) is not None:
return None
is_4bit = HAS_BNB and isinstance(w, Params4bit) and getattr(w, "quant_state", None) is not None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject trainable 4-bit expert weights

The grouped implementation treats expert weights as constants (_GroupedFrozenMM only returns dX, and cache mode stores detached dequantized stacks), but this 4-bit eligibility branch does not require not w.requires_grad. In a 4-bit full-finetuning or explicitly-unfrozen expert setup, the block is patched and expert-weight gradients are silently dropped instead of falling back to the original Linear4bit forward; the plain-linear branch already enforces frozen weights, so the 4-bit branch should too.

Useful? React with 👍 / 👎.

ok = True
except Exception:
ok = False
_GROUPED_MM_SUPPORTED = ok

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P3 Badge Cache grouped-mm support per CUDA device

In heterogeneous CUDA setups, the first successful probe is cached globally and then reused for every later layer/device. If enable_grouped_moe probes a supported GPU but an eligible MoE block is placed on an older CUDA device via a device map, grouped_moe_forward only checks hidden_states.is_cuda before calling torch._grouped_mm, so that layer can crash instead of falling back to the original loop.

Useful? React with 👍 / 👎.

enable/auto_enable is re-entrant: a block whose experts gained LoRA is
restored to the original loop while a frozen block stays on the grouped
path. This is the property the loader relies on when it re-runs
auto_enable_grouped_moe after attaching a loaded PEFT adapter.
Comment thread tests/test_moe_grouped_modulelist_parity.py Fixed
…rop unused import

- Eligibility now requires 4-bit experts to be frozen too (mirrors the plain-Linear
  branch); the grouped path returns dX only, so a trainable expert would silently get
  no gradient. In practice Params4bit stores uint8 and cannot require grad, so this is
  defensive consistency; the reachable trainable case is a plain bf16 expert.
- grouped_moe_forward falls back to the original loop when the experts' weights are not
  on the activation device, so a device_map / CPU offload that splits a block across
  devices no longer crashes in grouped_mm.
- Remove the now-unused _route_softmax_topk test import. Tests cover both guards.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 40a8011c8f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +211 to +212
w = getattr(getattr(experts[0], spec[0], None), "weight", None)
return getattr(w, "device", device) == device

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Check every expert projection device

Fresh evidence in this revision is that _experts_on_device was added, but it only checks experts[0]'s gate projection. When a device map or CPU offload leaves any later expert/projection on CPU or another CUDA device while that first weight is on the activation device, this guard passes and _build_gate_up_stack/_build_down_stack then read all projections and try to stack or grouped-mm tensors across devices instead of falling back to the original module path/hooks.

Useful? React with 👍 / 👎.

Comment on lines +222 to +225
if (not hidden_states.is_cuda) or _experts_have_lora(experts, spec) \
or hidden_states.dtype not in (torch.bfloat16, torch.float16) \
or _expert_compute_dtype(experts, spec) != hidden_states.dtype \
or not _experts_on_device(experts, spec, hidden_states.device):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Recheck frozen expert weights at forward time

Fresh evidence in this revision is that _block_is_eligible now enforces frozen experts when enable_grouped_moe runs, but this runtime guard does not revalidate requires_grad. If a user unfreezes expert weights after loading/patching for selective or full expert fine-tuning, the grouped path still runs and the recompute/cache paths treat those weights as constants, so expert gradients are silently dropped instead of falling back to _orig_moe_forward.

Useful? React with 👍 / 👎.

Comment on lines +197 to +204
lin = getattr(experts[0], spec[0], None)
w = getattr(lin, "weight", None)
if HAS_BNB and isinstance(w, Params4bit):
compute_dtype = getattr(lin, "compute_dtype", None)
if compute_dtype is not None:
return compute_dtype
return getattr(getattr(w, "quant_state", None), "dtype", torch.bfloat16)
return getattr(w, "dtype", torch.bfloat16)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Validate compute dtype for all projections

This guard only samples experts[0]'s gate projection, but _build_gate_up_stack and _build_down_stack later force every expert projection to hidden_states.dtype. If any later expert or up_proj/down_proj has a different Linear4bit.compute_dtype or low-precision weight dtype, the guard passes and the grouped path changes that projection's matmul dtype instead of falling back to the original Linear/Linear4bit forward.

Useful? React with 👍 / 👎.

…rts[0]

The stacks are built from all experts, but the runtime guard only sampled
experts[0] for device and compute dtype and never re-checked requires_grad, so a
device_map / CPU offload, a partially-unfrozen block, or a mixed compute-dtype
block could slip through and crash in grouped_mm or silently drop expert
gradients. Consolidate the LoRA / frozen / device / compute-dtype checks into a
single _experts_grouped_ready pass over every expert projection; fall back to the
original loop on any mismatch. Test covers a mismatch on a later expert.
@danielhanchen

Copy link
Copy Markdown
Member Author

@codex review

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex Review: Didn't find any major issues. Swish!

Reviewed commit: 0991307e09

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Shorten the verbose module/function docstrings and drop obvious inline comments;
code is unchanged (verified comment-only).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant