Skip to content

fix(mlx): preserve inference norm dtypes#844

Open
Lyxot wants to merge 6 commits into
unslothai:mainfrom
Lyxot:fix/mlx-norm-output-cast
Open

fix(mlx): preserve inference norm dtypes#844
Lyxot wants to merge 6 commits into
unslothai:mainfrom
Lyxot:fix/mlx-norm-output-cast

Conversation

@Lyxot

@Lyxot Lyxot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

Summary

This fixes the MLX inference performance regression introduced by #684. The regression came from applying _keep_norm_parameters_float32(model) during FastMLXModel.from_pretrained(), which promoted normalization parameters to fp32 for inference loads as well as training loads.

The PR restores the intended split:

  • inference loads preserve the model's loaded dtype;
  • training still prepares fp32 norm parameters during MLXTrainer.train() setup.

What changed

  • Stop running _keep_norm_parameters_float32(model) from FastMLXModel.from_pretrained() inference/load paths.
  • Run _keep_norm_parameters_float32(model) from MLXTrainer.train() setup before _train_inner().
  • Keep norm output cast handling trainer-scoped: training can cast fp32 norm outputs back to activation dtype, while the previous process-global patch state is restored after training.
  • Improve norm detection for custom MLX norm modules so loader/trainer behavior covers non-core norm implementations without weakening base RMSNorm/LayerNorm handling.
  • Reduce trainer-owned global/static norm-cast surface by sharing the norm helpers through unsloth_zoo/mlx/utils.py and removing unused trainer aliases/static variables.
  • Make norm output cast cleanup more robust: MLXTrainer.train() restores the captured norm-cast snapshot unconditionally in finally, including failures during setup after a partial patch.

Trainer behavior

Successful trainer behavior is intended to be unchanged relative to current main / #684:

  • norm parameters are still promoted to fp32 before _train_inner() runs;
  • cast_norm_output_to_input_dtype=True remains the default training behavior;
  • trainer setup still temporarily controls the norm-output cast patch and restores the prior process-global state afterward.

The only intentional broad behavior change is for inference: loading a model for inference no longer upcasts norm parameters to fp32. The unconditional restore path is cleanup hardening for setup failures; it should not change successful training runs.

Architecture-specific norm patches

A few model-specific MLX patches still intentionally run local fp32 norm math and cast the result back to the activation dtype for numerical parity/stability. This is separate from the removed generic inference-time parameter upcast.

Currently this includes:

  • Gemma3 text RMSNorm: fp32 RMSNorm math, then cast back to the input activation dtype.
  • Gemma3/SigLIP vision LayerNorm paths: fp32 LayerNorm math for encoder/post LayerNorm, then cast back.
  • Qwen3-VL vision block LayerNorm: fp32 stats/affine with output cast controlled by the same norm-output cast flag used by training.

These special cases may still cast norm computation internally, but they do not store all loaded inference norm parameters as fp32.

Performance validation

Cached GLM-OCR bf16 VLM 1024-token generation showed the #684 regression on main and recovery on this PR:

Path Throughput Peak memory
main Unsloth MLX ~54.4 tok/s ~4.23 GB
this PR, Unsloth MLX ~213.3 tok/s ~3.03 GB
stock mlx-vlm reference ~204.9 tok/s ~3.03 GB

The PR branch is back to stock-level inference throughput and memory for this cached-model sanity check.

Reviewer notes

The norm output cast implementation still uses class-level monkey-patching, so it remains process-global. This PR improves sequential setup/teardown and exception cleanup, but it does not make overlapping trainer/inference work in the same Python process isolated.

No new module was added; the shared helpers remain in unsloth_zoo/mlx/utils.py because both loader and trainer need the logic.

Validation

  • git diff --check
  • Focused MLX norm tests: 5 passed
  • Final fresh review round: 2 clean reviewers
  • Behavioral trainer test covers norm upcast before _train_inner(), prior norm-cast state restoration, and cleanup after mid-setup patch failure

Copilot AI review requested due to automatic review settings July 1, 2026 18:34

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MLX normalization output cast patching logic by moving it from unsloth_zoo/mlx/trainer.py to unsloth_zoo/mlx/utils.py and introducing state snapshotting and restoration capabilities. This ensures that the process-global norm class patch state is properly restored after training runs, even if setup fails mid-patch. Additionally, corresponding tests were added and updated to verify this behavior. Feedback suggests adding an identity check during state restoration to ensure we only unpatch wrappers originally applied by our code, preventing accidental overwriting of third-party patches.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread unsloth_zoo/mlx/utils.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot couldn't run its full agentic review because no GitHub Actions runner was available. Make sure your repository has a runner available to run Copilot's review, or add a copilot-setup-steps.yml file specifying one with the runs-on attribute. See the docs for more details.

This PR fixes an MLX inference performance regression by ensuring normalization parameter dtypes are preserved on inference loads while still preparing fp32 norms for training stability, and it improves the robustness of norm output-cast patch setup/teardown.

Changes:

  • Move fp32 norm-parameter upcasting from from_pretrained() into MLXTrainer.train() setup so inference preserves loaded dtypes.
  • Centralize MLX norm output-cast patching/snapshot/restore utilities and broaden custom norm detection.
  • Add/extend tests to validate trainer upcasting behavior and robust restoration of prior global norm-cast state.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
unsloth_zoo/mlx/utils.py Adds shared utilities for discovering norm modules and snapshotting/restoring global norm output-cast monkey patches.
unsloth_zoo/mlx/trainer.py Switches trainer to snapshot/restore global norm-cast state and applies fp32 norm upcasting during training setup.
unsloth_zoo/mlx/loader.py Stops upcasting norm params on model load and re-scopes fp32 norm prep to training; refactors norm-path detection usage.
tests/test_mlx_trainer_internals.py Adds regression tests for trainer norm upcasting and robust restoration (including failure paths).
tests/test_mlx_batching_and_decay.py Updates tests to use the new shared norm output-cast state helpers and expands custom norm coverage.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread unsloth_zoo/mlx/utils.py
Comment thread unsloth_zoo/mlx/utils.py
Comment thread unsloth_zoo/mlx/utils.py
Comment thread unsloth_zoo/mlx/utils.py
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Repo admins can enable using credits for code reviews in their settings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants