fix(mlx): preserve inference norm dtypes#844
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the MLX normalization output cast patching logic by moving it from unsloth_zoo/mlx/trainer.py to unsloth_zoo/mlx/utils.py and introducing state snapshotting and restoration capabilities. This ensures that the process-global norm class patch state is properly restored after training runs, even if setup fails mid-patch. Additionally, corresponding tests were added and updated to verify this behavior. Feedback suggests adding an identity check during state restoration to ensure we only unpatch wrappers originally applied by our code, preventing accidental overwriting of third-party patches.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
There was a problem hiding this comment.
Pull request overview
Note
Copilot couldn't run its full agentic review because no GitHub Actions runner was available. Make sure your repository has a runner available to run Copilot's review, or add a copilot-setup-steps.yml file specifying one with the runs-on attribute. See the docs for more details.
This PR fixes an MLX inference performance regression by ensuring normalization parameter dtypes are preserved on inference loads while still preparing fp32 norms for training stability, and it improves the robustness of norm output-cast patch setup/teardown.
Changes:
- Move fp32 norm-parameter upcasting from
from_pretrained()intoMLXTrainer.train()setup so inference preserves loaded dtypes. - Centralize MLX norm output-cast patching/snapshot/restore utilities and broaden custom norm detection.
- Add/extend tests to validate trainer upcasting behavior and robust restoration of prior global norm-cast state.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
unsloth_zoo/mlx/utils.py |
Adds shared utilities for discovering norm modules and snapshotting/restoring global norm output-cast monkey patches. |
unsloth_zoo/mlx/trainer.py |
Switches trainer to snapshot/restore global norm-cast state and applies fp32 norm upcasting during training setup. |
unsloth_zoo/mlx/loader.py |
Stops upcasting norm params on model load and re-scopes fp32 norm prep to training; refactors norm-path detection usage. |
tests/test_mlx_trainer_internals.py |
Adds regression tests for trainer norm upcasting and robust restoration (including failure paths). |
tests/test_mlx_batching_and_decay.py |
Updates tests to use the new shared norm output-cast state helpers and expands custom norm coverage. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Summary
This fixes the MLX inference performance regression introduced by #684. The regression came from applying
_keep_norm_parameters_float32(model)duringFastMLXModel.from_pretrained(), which promoted normalization parameters to fp32 for inference loads as well as training loads.The PR restores the intended split:
MLXTrainer.train()setup.What changed
_keep_norm_parameters_float32(model)fromFastMLXModel.from_pretrained()inference/load paths._keep_norm_parameters_float32(model)fromMLXTrainer.train()setup before_train_inner().unsloth_zoo/mlx/utils.pyand removing unused trainer aliases/static variables.MLXTrainer.train()restores the captured norm-cast snapshot unconditionally infinally, including failures during setup after a partial patch.Trainer behavior
Successful trainer behavior is intended to be unchanged relative to current
main/ #684:_train_inner()runs;cast_norm_output_to_input_dtype=Trueremains the default training behavior;The only intentional broad behavior change is for inference: loading a model for inference no longer upcasts norm parameters to fp32. The unconditional restore path is cleanup hardening for setup failures; it should not change successful training runs.
Architecture-specific norm patches
A few model-specific MLX patches still intentionally run local fp32 norm math and cast the result back to the activation dtype for numerical parity/stability. This is separate from the removed generic inference-time parameter upcast.
Currently this includes:
These special cases may still cast norm computation internally, but they do not store all loaded inference norm parameters as fp32.
Performance validation
Cached GLM-OCR bf16 VLM 1024-token generation showed the #684 regression on
mainand recovery on this PR:mainUnsloth MLXmlx-vlmreferenceThe PR branch is back to stock-level inference throughput and memory for this cached-model sanity check.
Reviewer notes
The norm output cast implementation still uses class-level monkey-patching, so it remains process-global. This PR improves sequential setup/teardown and exception cleanup, but it does not make overlapping trainer/inference work in the same Python process isolated.
No new module was added; the shared helpers remain in
unsloth_zoo/mlx/utils.pybecause both loader and trainer need the logic.Validation
git diff --check_train_inner(), prior norm-cast state restoration, and cleanup after mid-setup patch failure