Skip to content
93 changes: 91 additions & 2 deletions unsloth_zoo/mlx/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten
from mlx_lm.tuner.lora import LoRALinear

_PAD_MULTIPLE = 32
SUPPORTED_MLX_OPTIMIZERS = ("adafactor", "adamw", "adam", "sgd", "muon", "lion")
Expand All @@ -58,6 +59,9 @@
make_vlm_cce_loss_fn,
make_vlm_baseline_loss_fn,
create_batches,
create_preference_batches,
make_orpo_loss_fn,
make_dpo_loss_fn,
create_ordered_batches,
iterate_training_batches,
create_vlm_batches,
Expand Down Expand Up @@ -502,6 +506,10 @@ class MLXTrainingConfig:

# Eval
eval_steps: int = 0 # 0 = disabled
loss_type: str = "sft" # "sft" or "orpo"
orpo_beta: float = 0.1 # ORPO odds-ratio weight (TRL default)
dpo_beta: float = 0.1 # DPO beta (TRL default)
reference_free: bool = False # DPO: drop the reference term if True

# SFT-specific (from SFTConfig, for API compat)
dataset_text_field: str = "text"
Expand Down Expand Up @@ -537,6 +545,20 @@ class MLXTrainingConfig:
vlm_chat_template: object = None # Unsloth template name/tuple or raw Jinja string


@dataclass
class MLXORPOConfig(MLXTrainingConfig):
"""ORPO config mirroring TRL's ORPOConfig. Presets loss_type='orpo';
tune orpo_beta (inherited). Use with MLXORPOTrainer."""
loss_type: str = "orpo"


@dataclass
class MLXDPOConfig(MLXTrainingConfig):
"""DPO config mirroring TRL's DPOConfig. Presets loss_type='dpo';
tune dpo_beta / reference_free (inherited). Use with MLXDPOTrainer."""
loss_type: str = "dpo"


class MLXTrainer:
"""MLX-native trainer for Apple Silicon, mirroring SFTTrainer's constructor API."""

Expand Down Expand Up @@ -1226,7 +1248,20 @@ def _train_inner(self):
)
print("Unsloth: Using VLM standard cross-entropy loss.")
else:
if use_cce:
if getattr(args, "loss_type", "sft") == "orpo":
_ob = getattr(args, "orpo_beta", 0.1)
loss_fn = make_orpo_loss_fn(beta=_ob)
print("Unsloth: Using ORPO loss (beta=" + str(_ob) + ").")
elif getattr(args, "loss_type", "sft") == "dpo":
_db = getattr(args, "dpo_beta", 0.1)
_rf = bool(getattr(args, "reference_free", False))
_lora_mods = [mod for _, mod in tree_flatten(
model, is_leaf=lambda x: isinstance(x, LoRALinear))
if isinstance(mod, LoRALinear)]
loss_fn = make_dpo_loss_fn(beta=_db, lora_mods=_lora_mods, reference_free=_rf)
print("Unsloth: Using DPO loss (beta=" + str(_db) +
(", reference_free" if _rf else "") + ").")
elif use_cce:
loss_fn = make_cce_loss_fn(model)
cce_backend = getattr(loss_fn, "_unsloth_cce_backend", "unknown")
print(f"Unsloth: Using CCE loss ({cce_backend}) for memory-efficient training.")
Expand Down Expand Up @@ -1618,7 +1653,10 @@ def step_fn(batch_data, prev_state, do_update):
# Prepare eval batches
eval_batches = None
text_completion_only_loss = _text_completion_only_loss_arg(args)
if args.eval_steps > 0 and self.eval_dataset is not None:
if (getattr(args, "loss_type", "sft") in ("orpo", "dpo")
and args.eval_steps > 0 and self.eval_dataset is not None):
print(f"Unsloth: eval is not yet supported for {args.loss_type}; skipping eval.")
elif args.eval_steps > 0 and self.eval_dataset is not None:
# Use pre-built labeled eval batches if available
_labeled_eval = getattr(self, '_eval_batches_labeled', None)
if _labeled_eval is not None:
Expand Down Expand Up @@ -2048,6 +2086,27 @@ def _prepare_data(self, is_vlm):
)
text_completion_only_loss = _text_completion_only_loss_arg(args)

if getattr(args, "loss_type", "sft") in ("orpo", "dpo"):
if is_vlm:
raise ValueError(
f"{args.loss_type.upper()} is not yet supported for VLM models on MLX."
)
batches = create_preference_batches(
dataset=self.train_dataset,
tokenizer=self.tokenizer,
batch_size=args.per_device_train_batch_size,
max_seq_length=args.max_seq_length,
num_batches=total_batches_needed,
dataset_order=(
"sequential"
if getattr(args, "preserve_dataset_order", False)
else getattr(args, "dataset_order", "default")
),
seed=getattr(args, "seed", None),
append_eos=bool(getattr(args, "append_eos", True)),
)
return batches, None

if is_vlm:
_vlm_mask_fn = getattr(self, '_vlm_response_mask_fn', None)
vlm_dataset_order = (
Expand Down Expand Up @@ -2292,6 +2351,36 @@ def save_model(self, output_dir=None):
save_merged_model(self.model, self.tokenizer, output_dir)


class MLXORPOTrainer(MLXTrainer):
"""ORPO trainer mirroring TRL's ORPOTrainer. Forces loss_type='orpo' so
the class is authoritative regardless of the config passed."""
def __init__(self, model, tokenizer, train_dataset, eval_dataset=None,
dataset_text_field=None, max_seq_length=None, packing=None,
data_collator=None, args=None, formatting_func=None, processor=None):
if args is None:
args = MLXORPOConfig()
elif getattr(args, "loss_type", "sft") != "orpo":
args.loss_type = "orpo"
super().__init__(model, tokenizer, train_dataset, eval_dataset,
dataset_text_field, max_seq_length, packing,
data_collator, args, formatting_func, processor)


class MLXDPOTrainer(MLXTrainer):
"""DPO trainer mirroring TRL's DPOTrainer. Forces loss_type='dpo' so
the class is authoritative regardless of the config passed."""
def __init__(self, model, tokenizer, train_dataset, eval_dataset=None,
dataset_text_field=None, max_seq_length=None, packing=None,
data_collator=None, args=None, formatting_func=None, processor=None):
if args is None:
args = MLXDPOConfig()
elif getattr(args, "loss_type", "sft") != "dpo":
args.loss_type = "dpo"
super().__init__(model, tokenizer, train_dataset, eval_dataset,
dataset_text_field, max_seq_length, packing,
data_collator, args, formatting_func, processor)


def _create_labeled_batches(dataset, tokenizer, mask_fn, batch_size,
max_seq_length, formatting_func=None,
dataset_text_field="text", num_batches=None,
Expand Down
211 changes: 211 additions & 0 deletions unsloth_zoo/mlx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,217 @@ def loss_fn(model, batch, lengths, labels=None):
return loss_fn


# ============================================================================
# ORPO (Odds Ratio Preference Optimization) — text models.
# Mirrors TRL/CUDA's concatenated-forward: chosen and rejected are stacked into
# one batch (chosen block then rejected block) and run through a single forward,
# then split. Loss follows the ORPO paper: L = L_SFT + beta * L_OR, where
# L_OR = -log(sigmoid(log_odds_chosen - log_odds_rejected)) and the odds use the
# full p/(1-p) form (not a simplified log-prob ratio). Built on this module's
# own length-mask convention (see make_baseline_loss_fn), not ported from TRL
# (no TRL on MLX) or copied from third-party MLX projects.
# ============================================================================
def make_orpo_loss_fn(beta=0.1):
"""Create an ORPO loss function over a concatenated [chosen; rejected] batch.

Signature matches make_baseline_loss_fn: (model, batch, lengths, labels=None)
-> (loss, ntoks), so it drops into the trainer's text step path unchanged.

Expects ``batch`` shape (2B, L) where rows [0:B] are chosen and rows [B:2B]
are rejected, paired by index (produced by ``create_preference_batches``).
``lengths`` is (2B, 2) with per-row [response_start, seq_end). The odds-ratio
term scores response tokens only; the SFT/NLL term (TRL chosen_nll_loss) is
computed over the full prompt+response span (all non-pad chosen tokens).
"""
def loss_fn(model, batch, lengths, labels=None):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs)
steps = mx.arange(1, targets.shape[1] + 1)
ce_tok = nn.losses.cross_entropy(logits, targets)
mask = mx.logical_and(
steps >= lengths[:, 0:1], steps < lengths[:, 1:]
).astype(mx.float32)
logp_tok = -ce_tok * mask
ntok_row = mask.sum(axis=1)
logp = logp_tok.sum(axis=1) / mx.maximum(ntok_row, mx.array(1.0))
B = batch.shape[0] // 2
logp_c, logp_r = logp[:B], logp[B:]
# SFT term: TRL chosen_nll_loss. Pooled token-mean cross-entropy over the
# full prompt+response span (all non-pad chosen tokens), matching TRL's
# nn.CrossEntropyLoss default reduction. This is NOT the response-only,
# length-normalized logp used for the odds-ratio term below.
nll_mask = (steps < lengths[:, 1:]).astype(mx.float32)[:B]
sft = (ce_tok[:B] * nll_mask).sum() / mx.maximum(nll_mask.sum(), mx.array(1.0))
# Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized.
val_c = mx.maximum(-mx.expm1(logp_c), mx.array(1e-12, logp_c.dtype))
val_r = mx.maximum(-mx.expm1(logp_r), mx.array(1e-12, logp_r.dtype))
log_odds = (logp_c - mx.log(val_c)) - (logp_r - mx.log(val_r))
or_loss = -mx.mean(nn.log_sigmoid(log_odds))
loss = sft + beta * or_loss
Comment on lines +547 to +552

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In low-precision training (such as float16 or bfloat16), computing odds-ratio terms directly can lead to numerical instability. Specifically, the constant 1e-12 underflows to 0.0 in float16 (where the minimum positive subnormal is 5.96e-8). If logp_c is 0.0 (e.g., if a token has a probability of 1.0), -mx.expm1(logp_c) becomes 0.0, and mx.maximum(0.0, 1e-12) will evaluate to 0.0 due to underflow. This results in mx.log(0.0) = -inf, causing NaN gradients and training crashes.

To prevent this, perform the odds-ratio calculations in float32 and cast the final or_loss back to the original dtype.

Suggested change
# Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized.
val_c = mx.maximum(-mx.expm1(logp_c), mx.array(1e-12, logp_c.dtype))
val_r = mx.maximum(-mx.expm1(logp_r), mx.array(1e-12, logp_r.dtype))
log_odds = (logp_c - mx.log(val_c)) - (logp_r - mx.log(val_r))
or_loss = -mx.mean(nn.log_sigmoid(log_odds))
loss = sft + beta * or_loss
# Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized in float32.
logp_c_f32 = logp_c.astype(mx.float32)
logp_r_f32 = logp_r.astype(mx.float32)
val_c = mx.maximum(-mx.expm1(logp_c_f32), mx.array(1e-12, dtype=mx.float32))
val_r = mx.maximum(-mx.expm1(logp_r_f32), mx.array(1e-12, dtype=mx.float32))
log_odds = (logp_c_f32 - mx.log(val_c)) - (logp_r_f32 - mx.log(val_r))
or_loss = -mx.mean(nn.log_sigmoid(log_odds))
loss = sft + beta * or_loss.astype(sft.dtype)

return loss, mask.sum()
return loss_fn


def make_dpo_loss_fn(beta=0.1, lora_mods=None, reference_free=False):
"""Create a DPO (Direct Preference Optimization) loss function.

Operates on a concatenated [chosen; rejected] batch (same layout as
make_orpo_loss_fn / create_preference_batches): rows [0:B] chosen,
[B:2B] rejected. Signature matches the other loss fns:
(model, batch, lengths, labels=None) -> (loss, ntoks).

DPO compares the policy's per-response log-probs against a frozen
reference. For LoRA models the reference is the base model: obtained by
temporarily zeroing the LoRA scales (adapters off), running the reference
forward under stop_gradient, then restoring the scales in a finally block.
Mirrors TRL's disable-adapter approach (no second model copy). With
reference_free=True the reference term is dropped, matching TRL.

``lora_mods`` is the list of LoRALinear modules to toggle; collected once
by the trainer at setup.
"""
_mods = list(lora_mods) if lora_mods is not None else []
if not _mods and not reference_free:
raise ValueError(
"Unsloth: DPO with a reference model is not yet supported for full "
"fine-tuning on MLX — the reference is obtained by disabling LoRA "
"adapters, but this model has none. Use a LoRA/PEFT model, or pass "
"reference_free=True to train without a reference (TRL-style "
"reference-free DPO)."
)

def _row_logp_and_mask(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs)
steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(
steps >= lengths[:, 0:1], steps < lengths[:, 1:]
).astype(mx.float32)
logp_tok = -nn.losses.cross_entropy(logits, targets) * mask
return logp_tok.sum(axis=1), mask.sum()

def loss_fn(model, batch, lengths, labels=None):
B = batch.shape[0] // 2
pol, ntoks = _row_logp_and_mask(model, batch, lengths)
pol_c, pol_r = pol[:B], pol[B:]

if reference_free or not _mods:
ref_c = mx.zeros(pol_c.shape)
ref_r = mx.zeros(pol_r.shape)
else:
saved = [md.scale for md in _mods]
try:
for md in _mods:
md.scale = 0.0
ref, _ = _row_logp_and_mask(model, batch, lengths)
ref = mx.stop_gradient(ref)
finally:
for md, s in zip(_mods, saved):
md.scale = s
ref_c, ref_r = ref[:B], ref[B:]

logits = beta * ((pol_c - ref_c) - (pol_r - ref_r))
loss = -mx.mean(nn.log_sigmoid(logits))
return loss, ntoks
return loss_fn

def create_preference_batches(dataset, tokenizer, batch_size, max_seq_length,
prompt_key="prompt", chosen_key="chosen",
rejected_key="rejected", pad_to_multiple=32,
num_batches=None, dataset_order="default",
preserve_dataset_order=False, seed=None,
append_eos=True):
"""Build concatenated [chosen; rejected] preference batches for ORPO/DPO.

Each example contributes ``prompt + chosen`` and ``prompt + rejected``.
Pairs are grouped into batches of ``batch_size`` PAIRS, and every row in a
batch is padded to that batch's max length, rounded up to
``pad_to_multiple`` (Apple-Silicon padding).

``dataset_order`` controls how pairs are ordered before batching (mirrors
the SFT/VLM builders so preference runs can match CUDA/TRL parity):
"default" length-sort by ``max(len(chosen), len(rejected))`` — least
padding / best throughput; the historical behavior.
"sequential" keep dataset order — matches CUDA ``SequentialSampler``.
"torch_randperm" seeded permutation — matches CUDA ``RandomSampler``.
``preserve_dataset_order=True`` forces "sequential" (Studio wiring).
``seed`` seeds the "torch_randperm" order.

``append_eos`` (default True, matching TRL) appends the tokenizer's EOS id
to each chosen/rejected completion, guarded to avoid a double EOS (mirrors
TRL's ``add_eos_token_if_needed``). EOS is appended before truncation to
``max_seq_length`` (same contract as the SFT path). Because EOS lands after
the prompt boundary it falls inside the ``[response_start, seq_end)`` loss
span, so it is trained on — as in TRL.

Returns a list of ``(batch, lengths, None)`` tuples:
batch: (2B, L) int32 — rows [0:B] chosen, [B:2B] rejected, paired by index
lengths: (2B, 2) — per row [response_start, seq_end)
"""
order_mode = "sequential" if preserve_dataset_order else dataset_order
if order_mode not in ("default", "sequential", "torch_randperm"):
raise ValueError(
f"Unsloth MLX: unsupported preference dataset_order={order_mode!r}. "
"Expected 'default', 'sequential', or 'torch_randperm'."
)
hf = getattr(tokenizer, "_tokenizer", tokenizer)
eos_id = hf.eos_token_id
pad_id = eos_id if eos_id is not None else 0

def _with_eos(ids):
# TRL appends EOS to the completion, avoiding a double EOS
# (add_eos_token_if_needed). Append before max_seq_length truncation,
# matching the SFT path's append-then-truncate contract.
if append_eos and eos_id is not None and (not ids or ids[-1] != eos_id):
ids = ids + [eos_id]
return ids[:max_seq_length]

rows = []
for ex in dataset:
if prompt_key not in ex or chosen_key not in ex or rejected_key not in ex:
raise ValueError(
f"ORPO requires '{prompt_key}', '{chosen_key}', '{rejected_key}' "
f"columns; got {sorted(ex.keys())}."
)
prompt = ex[prompt_key]
p_ids = hf.encode(prompt)
c_ids = _with_eos(hf.encode(prompt + ex[chosen_key]))
r_ids = _with_eos(hf.encode(prompt + ex[rejected_key]))
pe = min(len(p_ids), len(c_ids), len(r_ids))
rows.append((pe, c_ids, r_ids))
Comment on lines +673 to +684

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of create_orpo_batches directly encodes prompt and prompt + ex[chosen_key] as raw strings. This has two major issues:

  1. Broken Chat Templates / Control Tokens: It bypasses the model's chat template. For chat/instruction models, prompts and responses must be formatted with control tokens (e.g., <|im_start|>, [INST], etc.). Without this, the model trains on raw text and will fail to learn the correct chat format.
  2. Type Errors with Message Lists: Many preference datasets (e.g., UltraFeedback) store prompt and chosen/rejected as lists of message dicts (conversations). Directly doing prompt + ex[chosen_key] will raise a TypeError when trying to concatenate a list and a string, or when passing a list of dicts to hf.encode.

To fix this, we should normalize the inputs into message lists, apply the tokenizer's chat template (if available) to format them with control tokens, and then encode the resulting formatted strings.

    for ex in dataset:
        if prompt_key not in ex or chosen_key not in ex or rejected_key not in ex:
            raise ValueError(
                f"ORPO requires '{prompt_key}', '{chosen_key}', '{rejected_key}' "
                f"columns; got {sorted(ex.keys())}."
            )
        prompt = ex[prompt_key]
        chosen = ex[chosen_key]
        rejected = ex[rejected_key]

        if isinstance(prompt, str):
            prompt_messages = [{"role": "user", "content": prompt}]
        else:
            prompt_messages = prompt

        if isinstance(chosen, str):
            chosen_messages = [{"role": "assistant", "content": chosen}]
        else:
            chosen_messages = chosen

        if isinstance(rejected, str):
            rejected_messages = [{"role": "assistant", "content": rejected}]
        else:
            rejected_messages = rejected

        apply_tmpl = getattr(tokenizer, "apply_chat_template", None) or getattr(hf, "apply_chat_template", None)
        if apply_tmpl is not None:
            prompt_str = apply_tmpl(prompt_messages, tokenize=False, add_generation_prompt=True)
            chosen_str = apply_tmpl(prompt_messages + chosen_messages, tokenize=False)
            rejected_str = apply_tmpl(prompt_messages + rejected_messages, tokenize=False)
        else:
            prompt_str = prompt if isinstance(prompt, str) else ""
            chosen_str = prompt_str + (chosen if isinstance(chosen, str) else "")
            rejected_str = prompt_str + (rejected if isinstance(rejected, str) else "")

        p_ids = hf.encode(prompt_str)
        c_ids = hf.encode(chosen_str)[:max_seq_length]
        r_ids = hf.encode(rejected_str)[:max_seq_length]
        pe = min(len(p_ids), len(c_ids), len(r_ids))
        rows.append((pe, c_ids, r_ids))


# rows are collected in dataset order above; reorder per the requested mode.
if order_mode == "default":
rows.sort(key=lambda t: max(len(t[1]), len(t[2])))
elif order_mode == "torch_randperm":
order = _torch_randperm_order(len(rows), _normalize_seed(seed))
rows = [rows[i] for i in order]
# "sequential": leave rows in dataset order (CUDA SequentialSampler parity).

out = []
for i in range(0, len(rows), batch_size):
chunk = rows[i:i + batch_size]
Lmax = max(max(len(c), len(r)) for _, c, r in chunk)
if pad_to_multiple:
Lmax = ((Lmax + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple
chosen_rows, rejected_rows, lengths = [], [], []
for pe, c, r in chunk:
chosen_rows.append(c + [pad_id] * (Lmax - len(c)))
lengths.append([pe, len(c)])
for pe, c, r in chunk:
rejected_rows.append(r + [pad_id] * (Lmax - len(r)))
lengths.append([pe, len(r)])
batch = mx.array(chosen_rows + rejected_rows, dtype=mx.int32)
lengths_arr = mx.array(lengths)
out.append((batch, lengths_arr, None))
if num_batches is not None and len(out) >= num_batches:
break
mx.eval([b for b, _, _ in out] + [l for _, l, _ in out])
return out


def make_baseline_loss_fn():
"""Create a standard cross-entropy loss function (full logits via LM head).

Expand Down
Loading