Optional gate gradient identity for the grouped MoE combine#858
Optional gate gradient identity for the grouped MoE combine#858danielhanchen wants to merge 1 commit into
Conversation
With UNSLOTH_MOE_GATEGRAD=1, forward_native_grouped_mm derives the routing weight gradient from the inner product identity dGate = (A dot dA) / gate over the pre-down-projection activation instead of differentiating the post-down multiply. The down projection output then never has to stay on the autograd tape solely for the gate gradient, which lowers peak training memory at long sequence lengths (about 1.4 GiB, 6 percent, at 16k tokens on a 30B A3B MoE) with unchanged throughput and per-step losses. The identity is exact for any linear down projection (frozen base, recompute path, additive LoRA) but not for a bias added after the down matmul, so the path auto-disables when a down projection bias is present. Off by default.
There was a problem hiding this comment.
Code Review
This pull request introduces an optional gate gradient optimization (_MoEGateGradIdentity) for MoE layers, enabled via the UNSLOTH_MOE_GATEGRAD environment variable. This optimization derives the routing-weight gradient using an inner-product identity, preventing the down-projection output from being cached on the autograd tape. The feedback highlights a potential runtime crash in the backward pass if grad_inter is None (e.g., when the base model is frozen) and suggests a safer division-by-zero check. Additionally, it is recommended to use .reshape(-1) instead of .view(-1) on top_k_weights to avoid potential RunTimeErrors with non-contiguous tensors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| @staticmethod | ||
| def backward(ctx, grad_inter): | ||
| inter, gate = ctx.saved_tensors | ||
| dgate = (inter.to(torch.float32) * grad_inter.to(torch.float32)).sum(dim=-1) | ||
| dgate = dgate / gate.to(torch.float32).clamp_min(1e-12) | ||
| return grad_inter, dgate.to(gate.dtype) |
There was a problem hiding this comment.
If the base model is frozen (e.g., in PEFT/LoRA where only the down-projection LoRA or the gating network is trained), inter might not require gradients. In such cases, grad_inter will be None during the backward pass, causing a runtime crash (AttributeError: 'NoneType' object has no attribute 'to') when attempting to cast it. Additionally, clamping the denominator to 1e-12 can artificially shrink the gradient for very small but valid gate values (e.g., 1e-15 in bfloat16/float32). Using torch.where with a safer threshold avoids this underestimation while preventing division by zero.
@staticmethod
def backward(ctx, grad_inter):
if grad_inter is None:
return None, None
inter, gate = ctx.saved_tensors
dgate = (inter.to(torch.float32) * grad_inter.to(torch.float32)).sum(dim=-1)
safe_gate = gate.to(torch.float32)
dgate = torch.where(safe_gate > 1e-20, dgate / safe_gate, 0.0)
return grad_inter, dgate.to(gate.dtype)| ) | ||
| permuted_weights = None | ||
| if _gategrad: | ||
| permuted_weights = top_k_weights.view(-1)[sorted_indices] |
There was a problem hiding this comment.
Using .view(-1) on top_k_weights can raise a RuntimeError if the tensor is non-contiguous (e.g., if it is a slice or transposed view). Using .reshape(-1) is safer as it automatically handles non-contiguous tensors by copying them if necessary.
| permuted_weights = top_k_weights.view(-1)[sorted_indices] | |
| permuted_weights = top_k_weights.reshape(-1)[sorted_indices] |
| flat_weights = top_k_weights.view(-1) | ||
| permuted_weights = flat_weights[sorted_indices] |
There was a problem hiding this comment.
Using .view(-1) on top_k_weights can raise a RuntimeError if the tensor is non-contiguous. Using .reshape(-1) is safer and more robust.
| flat_weights = top_k_weights.view(-1) | |
| permuted_weights = flat_weights[sorted_indices] | |
| flat_weights = top_k_weights.reshape(-1) | |
| permuted_weights = flat_weights[sorted_indices] |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c0215d744b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def backward(ctx, grad_inter): | ||
| inter, gate = ctx.saved_tensors | ||
| dgate = (inter.to(torch.float32) * grad_inter.to(torch.float32)).sum(dim=-1) | ||
| dgate = dgate / gate.to(torch.float32).clamp_min(1e-12) |
There was a problem hiding this comment.
Divide by the signed routing weight
When UNSLOTH_MOE_GATEGRAD=1 is used with routed weights that are not guaranteed positive, this clamp no longer implements the stated identity. In particular, Gemma4 folds the unconstrained router.per_expert_scale parameter into top_k_weights (gemma4_moe.py:267 and gemma4_moe.py:287), so a negative learned scale makes grad_inter contain the signed gate but this division uses 1e-12 instead of the negative gate; because the later multiply is detached, this is the only gradient path for that weight and the router/per-expert-scale gradient gets the wrong magnitude/sign. Use the actual signed gate denominator, or disable this path when any routed weight is non-positive/too small.
Useful? React with 👍 / 👎.
Summary
In the grouped MoE combine, differentiating
out = down(inter) * gatepins the down-projection output on the autograd tape solely to form the routing-weight gradientdGate = <dOut, Y>. That tensor is large at long sequence lengths.What this does
With
UNSLOTH_MOE_GATEGRAD=1(default off),forward_native_grouped_mmderives the gate gradient from the inner-product identitydGate = <A, dA> / gateover the pre-down-projection activation instead, so Y never has to stay on the tape. The identity is exact for any linear down projection (frozen base, the recompute path, additive LoRA) but not for a bias added after the down matmul, so the path auto-disables when a down-projection bias is present.Testing