Skip to content

Fix MLX notebook generate input/output compatibility#855

Open
Lyxot wants to merge 2 commits into
unslothai:mainfrom
Lyxot:fix/mlx-inference-notebook-compat
Open

Fix MLX notebook generate input/output compatibility#855
Lyxot wants to merge 2 commits into
unslothai:mainfrom
Lyxot:fix/mlx-inference-notebook-compat

Conversation

@Lyxot

@Lyxot Lyxot commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR fixes the MLX-side compatibility gaps that prevented existing supported Unsloth notebooks from running their inference cells unchanged after importing Unsloth.

The changes are intentionally limited to the notebook-facing input/output contracts in the MLX loader:

  • return text and VLM generate(...) results as a batched generated-id sequence
  • prefer a torch.long tensor when torch is importable, with a NumPy int64 fallback when torch is unavailable
  • keep the returned sequence shaped as (1, prompt_length + generated_length) so existing notebook slicing and decode code works
  • make the patched mlx-lm TokenizerWrapper support apply_chat_template(..., return_dict=True) by returning a Hugging Face BatchEncoding
  • preserve the existing callable-tokenizer shim for mlx-lm wrappers that do not define __call__
  • add focused regression coverage for text generation, VLM generation, and chat-template return_dict=True expansion

Why

Some existing notebooks use standard Hugging Face / CUDA-style inference patterns such as:

inputs = tokenizer.apply_chat_template(..., return_tensors="pt", return_dict=True).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=...)
response = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[-1]:])

On the MLX path, the model and tokenizer objects are backed by mlx-lm / mlx-vlm rather than Transformers. Two notebook-visible mismatches showed up:

  1. generate(...) returned an MLX array. That array can be sliced, but Transformers utilities such as to_py_obj(...) do not recognize the mlx array type. Returning torch when available gives notebooks the CUDA-like generated-id container they expect, while the NumPy fallback keeps torch-free MLX installs usable.
  2. mlx-lm's TokenizerWrapper.apply_chat_template(..., return_dict=True) did not consistently return a mapping-like object that can be moved with .to(...) and expanded into model.generate(**inputs).

This PR normalizes those two surfaces without changing notebook code.

Scope

This is deliberately not a broad Transformers generation compatibility layer.

Not included in this PR:

  • no GenerationConfig merge or precedence support
  • no HF logits_processor compatibility policy
  • no return_dict_in_generate output mode
  • no changes to notebook files

The goal is only to make the currently supported notebook inference patterns work unchanged on MLX while keeping the diff small and reviewable. CUDA/ROCm paths are unaffected because these changes are isolated to unsloth_zoo/mlx/loader.py.

Validation

Focused checks run locally:

PYTHONPATH=/Users/long/Github/unsloth/unsloth:/Users/long/Github/unsloth/worktrees/unsloth-zoo-mlx-inference-notebooks:$PYTHONPATH \
  python -m pytest tests/test_mlx_save_export_regressions.py -q

Result:

35 passed

Additional checks:

python -m py_compile unsloth_zoo/mlx/loader.py tests/test_mlx_save_export_regressions.py
git diff --check

Both passed.

A manual blocked-torch smoke check also verified that _mlx_generate_output(...) falls back to a NumPy int64 array when torch import is unavailable.

The regression tests cover:

  • text generate(...) returns a batched torch generated-id tensor with working .shape, slicing, .tolist(), and transformers.to_py_obj(...)
  • VLM generate(...) returns the same notebook-friendly torch sequence shape
  • the MLX-on-torch regression module skips cleanly when torch is unavailable
  • apply_chat_template(..., return_dict=True) returns a BatchEncoding that supports .to(...) and model.generate(**inputs) expansion

Copilot AI review requested due to automatic review settings July 3, 2026 15:55
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Repo admins can enable using credits for code reviews in their settings.

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot couldn't run its full agentic review because no GitHub Actions runner was available. Make sure your repository has a runner available to run Copilot's review, or add a copilot-setup-steps.yml file specifying one with the runs-on attribute. See the docs for more details.

Fixes MLX notebook-facing inference compatibility by normalizing generate(...) return types and improving mlx-lm TokenizerWrapper.apply_chat_template(..., return_dict=True) behavior to better match common Hugging Face notebook patterns.

Changes:

  • Add a shared _mlx_generate_output(...) helper and use it for both text and VLM generate(...) shims.
  • Patch mlx-lm TokenizerWrapper to (a) remain callable when needed and (b) return a BatchEncoding for apply_chat_template(..., return_dict=True).
  • Add regression tests for generation output shape/type and chat-template return_dict=True expansion.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
unsloth_zoo/mlx/loader.py Normalizes MLX generate(...) outputs and patches mlx-lm tokenizer wrapper to return HF-like BatchEncoding for chat templates.
tests/test_mlx_save_export_regressions.py Adds regression assertions for generate output behavior and a new test covering chat-template return_dict=True expansion.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread unsloth_zoo/mlx/loader.py
Comment on lines +3441 to 3449
def _mlx_generate_output(prompt_ids, generated_ids):
"""Build a Transformers-friendly batched generate return value."""
import numpy as np

return np.asarray([list(prompt_ids) + list(generated_ids)], dtype=np.int64)


def _mlx_eos_token_id_set(eos_token_id):
"""Normalize HF-style eos_token_id values into a set of token ids."""
Comment thread unsloth_zoo/mlx/loader.py Outdated
Comment on lines 3444 to 3449

return np.asarray([list(prompt_ids) + list(generated_ids)], dtype=np.int64)


def _mlx_eos_token_id_set(eos_token_id):
"""Normalize HF-style eos_token_id values into a set of token ids."""
Comment on lines 33 to 36
@pytest.fixture(autouse=True, scope="module")
def _install_mlx_torch_shim():
from mlx_simulation import simulate_mlx_on_torch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants