Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/test_mlx_save_export_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

@pytest.fixture(autouse=True, scope="module")
def _install_mlx_torch_shim():
pytest.importorskip("torch")
from mlx_simulation import simulate_mlx_on_torch

simulate_mlx_on_torch()
Expand Down Expand Up @@ -223,6 +224,8 @@ def fake_push_to_hub_gguf(
def test_text_generate_honors_do_sample_false(monkeypatch):
import mlx_lm
import mlx_lm.sample_utils as sample_utils
import torch
from transformers.tokenization_utils_base import to_py_obj
import unsloth_zoo.mlx.loader as loader

calls = {}
Expand Down Expand Up @@ -263,7 +266,12 @@ def fake_stream_generate(_model, tokenizer, prompt, max_tokens=None, **kwargs):
max_length=4,
)

assert isinstance(out, torch.Tensor)
assert out.dtype == torch.long
assert out.tolist() == [[1, 2, 9, 5]]
assert out.shape == (1, 4)
assert out[:, 2:].tolist() == [[9, 5]]
assert to_py_obj(out) == [[1, 2, 9, 5]]
assert calls["sampler"] == {
"temp": 0.0,
"top_p": 0.0,
Expand All @@ -277,7 +285,52 @@ def fake_stream_generate(_model, tokenizer, prompt, max_tokens=None, **kwargs):
assert tokenizer.eos_token_ids == {2}


def test_tokenizer_wrapper_chat_template_return_dict_expands_for_generate():
import unsloth_zoo.mlx.loader as loader

class InnerTokenizer:
def __call__(self, *args, **kwargs):
return {"called": True}

def apply_chat_template(self, *args, tokenize=True, **kwargs):
if tokenize and kwargs.get("return_dict", False):
return {
"input_ids": [1, 2, 3],
"attention_mask": [1, 1, 1],
}
return [1, 2, 3] if tokenize else "rendered"

class TokenizerWrapper:
def __init__(self):
self._tokenizer = InnerTokenizer()

def apply_chat_template(self, *args, tokenize=True, **kwargs):
return [1, 2, 3] if tokenize else "rendered"

tokenizer = TokenizerWrapper()
loader._patch_mlx_tokenizer_call(tokenizer)

encoded = tokenizer.apply_chat_template(
[{"role": "user", "content": "hi"}],
tokenize=True,
return_dict=True,
)

def expand_generate_inputs(**kwargs):
return kwargs

assert expand_generate_inputs(**encoded) == {
"input_ids": [1, 2, 3],
"attention_mask": [1, 1, 1],
}
assert encoded.to("cpu")["input_ids"] == [1, 2, 3]
assert tokenizer.apply_chat_template([], tokenize=False, return_dict=True) == "rendered"
assert tokenizer("hi") == {"called": True}


def test_vlm_generate_hf_kwargs(monkeypatch):
import torch
from transformers.tokenization_utils_base import to_py_obj
import unsloth_zoo.mlx.loader as loader

fake_mlx_vlm = types.ModuleType("mlx_vlm")
Expand Down Expand Up @@ -305,7 +358,11 @@ def fake_stream_generate(_model, _processor, _prompt, max_tokens=None, **batch):
max_new_tokens=1,
)

assert isinstance(out, torch.Tensor)
assert out.dtype == torch.long
assert out.tolist() == [[1, 2]]
assert out.shape == (1, 2)
assert to_py_obj(out) == [[1, 2]]
assert calls[0][0] == 1
assert tuple(calls[0][1]["input_ids"].shape) == (1, 2)
assert tuple(calls[0][1]["mask"].shape) == (1, 2)
Expand Down
80 changes: 69 additions & 11 deletions unsloth_zoo/mlx/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import tempfile
import types
import warnings
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from fnmatch import fnmatch
Expand Down Expand Up @@ -3437,6 +3438,17 @@ def _mlx_apply_attention_mask(prompt_ids, attention_mask):
return [token for token, keep in zip(prompt_ids, mask) if keep != 0]


def _mlx_generate_output(prompt_ids, generated_ids):
"""Build a Transformers-friendly batched generate return value."""
sequences = [list(prompt_ids) + list(generated_ids)]
try:
import torch
return torch.tensor(sequences, dtype=torch.long)
except ImportError:
import numpy as np
return np.asarray(sequences, dtype=np.int64)


def _mlx_eos_token_id_set(eos_token_id):
"""Normalize HF-style eos_token_id values into a set of token ids."""
Comment on lines +3441 to 3449
if eos_token_id is None:
Expand Down Expand Up @@ -3499,7 +3511,6 @@ def _mlx_token_to_int(token):

def _mlx_generate_vlm(self, *args, **kwargs):
"""HF-style VLM generate() shim backed by mlx-vlm stream_generate."""
import mlx.core as mx
from mlx_vlm import stream_generate
from .utils import _to_mx_vlm_batch

Expand Down Expand Up @@ -3614,12 +3625,11 @@ def _mlx_generate_vlm(self, *args, **kwargs):

if streamer is not None:
streamer.end()
return mx.array([prompt_ids + generated_ids])
return _mlx_generate_output(prompt_ids, generated_ids)


def _mlx_generate(self, *args, **kwargs):
"""HF-style text generate() shim backed by mlx-lm stream_generate."""
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler

Expand Down Expand Up @@ -3736,24 +3746,72 @@ def _mlx_generate(self, *args, **kwargs):

if streamer is not None:
streamer.end()
return mx.array([prompt_ids + generated_ids])
return _mlx_generate_output(prompt_ids, generated_ids)


def _mlx_chat_template_batch_encoding(output):
"""Wrap tokenized chat-template output in a HF mapping when requested."""
from transformers import BatchEncoding

if isinstance(output, BatchEncoding):
return output
if isinstance(output, Mapping):
return BatchEncoding(dict(output))
return BatchEncoding({"input_ids": output})


def _patch_mlx_tokenizer_call(tokenizer):
"""Make mlx-lm TokenizerWrapper callable like its wrapped HF tokenizer."""
"""Patch mlx-lm TokenizerWrapper to match HF notebook tokenizer APIs."""
if tokenizer is None:
return
cls = type(tokenizer)
if cls.__name__ != "TokenizerWrapper" or "__call__" in cls.__dict__:
if cls.__name__ != "TokenizerWrapper":
return
if not hasattr(tokenizer, "_tokenizer") or not callable(tokenizer._tokenizer):
if "__call__" not in cls.__dict__:
if hasattr(tokenizer, "_tokenizer") and callable(tokenizer._tokenizer):
def tokenizer_wrapper_call(self, *args, **kwargs):
return self._tokenizer(*args, **kwargs)

tokenizer_wrapper_call._unsloth_mlx_call = True
cls.__call__ = tokenizer_wrapper_call

if getattr(cls, "_unsloth_mlx_apply_chat_template", False):
return
original_apply_chat_template = getattr(cls, "apply_chat_template", None)
if original_apply_chat_template is None:
return

def tokenizer_wrapper_call(self, *args, **kwargs):
return self._tokenizer(*args, **kwargs)
def tokenizer_wrapper_apply_chat_template(self, *args, tokenize=True, **kwargs):
return_dict = bool(kwargs.get("return_dict", False))
inner_tokenizer = getattr(self, "_tokenizer", None)
if (
tokenize
and return_dict
and getattr(self, "_chat_template", None) is None
and hasattr(inner_tokenizer, "apply_chat_template")
):
if "enable_thinking" not in kwargs:
kwargs["enable_thinking"] = getattr(self, "has_thinking", False)
output = inner_tokenizer.apply_chat_template(
*args,
tokenize=tokenize,
**kwargs,
)
return _mlx_chat_template_batch_encoding(output)

output = original_apply_chat_template(
self,
*args,
tokenize=tokenize,
**kwargs,
)
if tokenize and return_dict:
return _mlx_chat_template_batch_encoding(output)
return output

tokenizer_wrapper_call._unsloth_mlx_call = True
cls.__call__ = tokenizer_wrapper_call
tokenizer_wrapper_apply_chat_template._unsloth_mlx_call = True
cls.apply_chat_template = tokenizer_wrapper_apply_chat_template
cls._unsloth_mlx_apply_chat_template = True


def _patch_mlx_saving(model, tokenizer):
Expand Down
Loading