Skip to content

Optional gate gradient identity for the grouped MoE combine#858

Open
danielhanchen wants to merge 1 commit into
mainfrom
moe-gategrad-identity
Open

Optional gate gradient identity for the grouped MoE combine#858
danielhanchen wants to merge 1 commit into
mainfrom
moe-gategrad-identity

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

In the grouped MoE combine, differentiating out = down(inter) * gate pins the down-projection output on the autograd tape solely to form the routing-weight gradient dGate = <dOut, Y>. That tensor is large at long sequence lengths.

What this does

With UNSLOTH_MOE_GATEGRAD=1 (default off), forward_native_grouped_mm derives the gate gradient from the inner-product identity dGate = <A, dA> / gate over the pre-down-projection activation instead, so Y never has to stay on the tape. The identity is exact for any linear down projection (frozen base, the recompute path, additive LoRA) but not for a bias added after the down matmul, so the path auto-disables when a down-projection bias is present.

Testing

  • 4 CPU unit tests check the identity against plain autograd in fp32, with and without a LoRA term on the down weight, plus forward-identity and env gating.
  • Full-model loss parity on Qwen3-30B-A3B: max per-step deviation 0.0063 over 20 deterministic steps.
  • Throughput neutral; peak memory down 1.41 GiB (6 percent) at 16k tokens, with the saving growing with sequence length.

With UNSLOTH_MOE_GATEGRAD=1, forward_native_grouped_mm derives the routing
weight gradient from the inner product identity dGate = (A dot dA) / gate over
the pre-down-projection activation instead of differentiating the post-down
multiply. The down projection output then never has to stay on the autograd
tape solely for the gate gradient, which lowers peak training memory at long
sequence lengths (about 1.4 GiB, 6 percent, at 16k tokens on a 30B A3B MoE)
with unchanged throughput and per-step losses.

The identity is exact for any linear down projection (frozen base, recompute
path, additive LoRA) but not for a bias added after the down matmul, so the
path auto-disables when a down projection bias is present. Off by default.
@danielhanchen danielhanchen requested a review from Datta0 as a code owner July 3, 2026 18:40

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optional gate gradient optimization (_MoEGateGradIdentity) for MoE layers, enabled via the UNSLOTH_MOE_GATEGRAD environment variable. This optimization derives the routing-weight gradient using an inner-product identity, preventing the down-projection output from being cached on the autograd tape. The feedback highlights a potential runtime crash in the backward pass if grad_inter is None (e.g., when the base model is frozen) and suggests a safer division-by-zero check. Additionally, it is recommended to use .reshape(-1) instead of .view(-1) on top_k_weights to avoid potential RunTimeErrors with non-contiguous tensors.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +1029 to +1034
@staticmethod
def backward(ctx, grad_inter):
inter, gate = ctx.saved_tensors
dgate = (inter.to(torch.float32) * grad_inter.to(torch.float32)).sum(dim=-1)
dgate = dgate / gate.to(torch.float32).clamp_min(1e-12)
return grad_inter, dgate.to(gate.dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the base model is frozen (e.g., in PEFT/LoRA where only the down-projection LoRA or the gating network is trained), inter might not require gradients. In such cases, grad_inter will be None during the backward pass, causing a runtime crash (AttributeError: 'NoneType' object has no attribute 'to') when attempting to cast it. Additionally, clamping the denominator to 1e-12 can artificially shrink the gradient for very small but valid gate values (e.g., 1e-15 in bfloat16/float32). Using torch.where with a safer threshold avoids this underestimation while preventing division by zero.

    @staticmethod
    def backward(ctx, grad_inter):
        if grad_inter is None:
            return None, None
        inter, gate = ctx.saved_tensors
        dgate = (inter.to(torch.float32) * grad_inter.to(torch.float32)).sum(dim=-1)
        safe_gate = gate.to(torch.float32)
        dgate = torch.where(safe_gate > 1e-20, dgate / safe_gate, 0.0)
        return grad_inter, dgate.to(gate.dtype)

)
permuted_weights = None
if _gategrad:
permuted_weights = top_k_weights.view(-1)[sorted_indices]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using .view(-1) on top_k_weights can raise a RuntimeError if the tensor is non-contiguous (e.g., if it is a slice or transposed view). Using .reshape(-1) is safer as it automatically handles non-contiguous tensors by copying them if necessary.

Suggested change
permuted_weights = top_k_weights.view(-1)[sorted_indices]
permuted_weights = top_k_weights.reshape(-1)[sorted_indices]

Comment on lines +1295 to +1296
flat_weights = top_k_weights.view(-1)
permuted_weights = flat_weights[sorted_indices]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using .view(-1) on top_k_weights can raise a RuntimeError if the tensor is non-contiguous. Using .reshape(-1) is safer and more robust.

Suggested change
flat_weights = top_k_weights.view(-1)
permuted_weights = flat_weights[sorted_indices]
flat_weights = top_k_weights.reshape(-1)
permuted_weights = flat_weights[sorted_indices]

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c0215d744b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

def backward(ctx, grad_inter):
inter, gate = ctx.saved_tensors
dgate = (inter.to(torch.float32) * grad_inter.to(torch.float32)).sum(dim=-1)
dgate = dgate / gate.to(torch.float32).clamp_min(1e-12)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Divide by the signed routing weight

When UNSLOTH_MOE_GATEGRAD=1 is used with routed weights that are not guaranteed positive, this clamp no longer implements the stated identity. In particular, Gemma4 folds the unconstrained router.per_expert_scale parameter into top_k_weights (gemma4_moe.py:267 and gemma4_moe.py:287), so a negative learned scale makes grad_inter contain the signed gate but this division uses 1e-12 instead of the negative gate; because the later multiply is detached, this is the only gradient path for that weight and the router/per-expert-scale gradient gets the wrong magnitude/sign. Use the actual signed gate denominator, or disable this path when any routed weight is non-positive/too small.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant