Fix Gemma3N/Gemma4 gradient_checkpointing_enable crash on positional kwargs#6826
Fix Gemma3N/Gemma4 gradient_checkpointing_enable crash on positional kwargs#6826oobabooga wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the _gc_enable_reentrant function in unsloth/models/vision.py to accept gradient_checkpointing_kwargs as an explicit parameter. This ensures compatibility when gradient checkpointing is re-enabled positionally by TRL or via keyword arguments by HF/PEFT. There are no review comments, so no additional feedback is provided.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Fixes #4886.
GRPO training on a Gemma3N/Gemma4 model crashes with:
Problem
For Gemma3N/Gemma4,
post_patch_modelreplacesmodel.gradient_checkpointing_enablewith a wrapper that forcesuse_reentrant=True(needed to avoid the audio-conformer stride assertion from #4629). The wrapper's signature wasdef _gc_enable_reentrant(**kwargs), keyword-only.HF's real method is
gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None), and TRL'sdisable_gradient_checkpointingre-enables it positionally:model.gradient_checkpointing_enable(gradient_checkpointing_kwargs). That positional call hit the keyword-only wrapper and raised the TypeError during GRPO generation.Fix
Accept
gradient_checkpointing_kwargsas a positional-or-keyword arg, matching HF's signature, and forceuse_reentrant=Trueregardless of call style. Thedict(...)copy also stops the wrapper mutating the caller's config in place.As a side effect this restores PEFT's capability probe:
prepare_model_for_kbit_trainingchecks"gradient_checkpointing_kwargs" in inspect.signature(...).parameters, which the old**kwargs-only signature returned False for, so PEFT was silently dropping the kwargs. The named param now makes it True.Scope
On current
main, the GRPO path that triggers this is usually masked:patch_trl_disable_gradient_checkpointingswaps TRL'sdisable_gradient_checkpointingfor a no-op, so the positional re-enable never runs. The fix still matters for installs where that no-op is absent (the reporter is on an older unsloth) or where itstrl.*module rebind misses the call site, for PEFT's introspection above, and for any direct positional caller, since HF's own signature accepts one.Verification
Reproduced end-to-end on a real
unsloth/gemma-3n-E2B-itload with TRL 1.0.0. With TRL's realdisable_gradient_checkpointingactive, one GRPO step onmaincrashes with the exact traceback above at_generate_and_score_completions, while the same step on this branch invokes the positional re-enable (confirmed by instrumentation) and completes.Also exercised the wrapper directly across every caller's call style:
use_reentrant=Trueuse_reentrant=True