From 599f0512176daa21b308d84a5f99e7526dbc8131 Mon Sep 17 00:00:00 2001 From: Bardia Koopah Date: Wed, 24 Jun 2026 12:28:19 -0700 Subject: [PATCH 1/7] Add ORPO (loss_type='orpo') for text models to MLXTrainer --- unsloth_zoo/mlx/trainer.py | 29 ++++++++++- unsloth_zoo/mlx/utils.py | 102 +++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 70572bd51..ba4ceba49 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -58,6 +58,8 @@ make_vlm_cce_loss_fn, make_vlm_baseline_loss_fn, create_batches, + create_orpo_batches, + make_orpo_loss_fn, create_ordered_batches, iterate_training_batches, create_vlm_batches, @@ -502,6 +504,8 @@ class MLXTrainingConfig: # Eval eval_steps: int = 0 # 0 = disabled + loss_type: str = "sft" # "sft" or "orpo" + orpo_beta: float = 0.1 # ORPO odds-ratio weight (TRL default) # SFT-specific (from SFTConfig, for API compat) dataset_text_field: str = "text" @@ -1226,7 +1230,11 @@ def _train_inner(self): ) print("Unsloth: Using VLM standard cross-entropy loss.") else: - if use_cce: + if getattr(args, "loss_type", "sft") == "orpo": + _ob = getattr(args, "orpo_beta", 0.1) + loss_fn = make_orpo_loss_fn(beta=_ob) + print("Unsloth: Using ORPO loss (beta=" + str(_ob) + ").") + elif use_cce: loss_fn = make_cce_loss_fn(model) cce_backend = getattr(loss_fn, "_unsloth_cce_backend", "unknown") print(f"Unsloth: Using CCE loss ({cce_backend}) for memory-efficient training.") @@ -1618,7 +1626,10 @@ def step_fn(batch_data, prev_state, do_update): # Prepare eval batches eval_batches = None text_completion_only_loss = _text_completion_only_loss_arg(args) - if args.eval_steps > 0 and self.eval_dataset is not None: + if (getattr(args, "loss_type", "sft") == "orpo" + and args.eval_steps > 0 and self.eval_dataset is not None): + print("Unsloth: eval is not yet supported for ORPO; skipping eval.") + elif args.eval_steps > 0 and self.eval_dataset is not None: # Use pre-built labeled eval batches if available _labeled_eval = getattr(self, '_eval_batches_labeled', None) if _labeled_eval is not None: @@ -2048,6 +2059,20 @@ def _prepare_data(self, is_vlm): ) text_completion_only_loss = _text_completion_only_loss_arg(args) + if getattr(args, "loss_type", "sft") == "orpo": + if is_vlm: + raise ValueError( + "ORPO is not yet supported for VLM models on MLX." + ) + batches = create_orpo_batches( + dataset=self.train_dataset, + tokenizer=self.tokenizer, + batch_size=args.per_device_train_batch_size, + max_seq_length=args.max_seq_length, + num_batches=total_batches_needed, + ) + return batches, None + if is_vlm: _vlm_mask_fn = getattr(self, '_vlm_response_mask_fn', None) vlm_dataset_order = ( diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index a079c5af6..a4b6f99ba 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -502,6 +502,108 @@ def loss_fn(model, batch, lengths, labels=None): return loss_fn +# ============================================================================ +# ORPO (Odds Ratio Preference Optimization) — text models. +# Mirrors TRL/CUDA's concatenated-forward: chosen and rejected are stacked into +# one batch (chosen block then rejected block) and run through a single forward, +# then split. Loss follows the ORPO paper: L = L_SFT + beta * L_OR, where +# L_OR = -log(sigmoid(log_odds_chosen - log_odds_rejected)) and the odds use the +# full p/(1-p) form (not a simplified log-prob ratio). Built on this module's +# own length-mask convention (see make_baseline_loss_fn), not ported from TRL +# (no TRL on MLX) or copied from third-party MLX projects. +# ============================================================================ +def make_orpo_loss_fn(beta=0.1): + """Create an ORPO loss function over a concatenated [chosen; rejected] batch. + + Signature matches make_baseline_loss_fn: (model, batch, lengths, labels=None) + -> (loss, ntoks), so it drops into the trainer's text step path unchanged. + + Expects ``batch`` shape (2B, L) where rows [0:B] are chosen and rows [B:2B] + are rejected, paired by index (produced by ``create_orpo_batches``). + ``lengths`` is (2B, 2) with per-row [response_start, seq_end); only response + tokens are scored (the prompt is masked). + """ + def loss_fn(model, batch, lengths, labels=None): + inputs = batch[:, :-1] + targets = batch[:, 1:] + logits = model(inputs) + steps = mx.arange(1, targets.shape[1] + 1) + mask = mx.logical_and( + steps >= lengths[:, 0:1], steps < lengths[:, 1:] + ).astype(mx.float32) + logp_tok = -nn.losses.cross_entropy(logits, targets) * mask + ntok_row = mask.sum(axis=1) + logp = logp_tok.sum(axis=1) / mx.maximum(ntok_row, mx.array(1.0)) + B = batch.shape[0] // 2 + logp_c, logp_r = logp[:B], logp[B:] + # SFT term: NLL on chosen (length-normalized). + sft = -mx.mean(logp_c) + # Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized. + val_c = mx.maximum(-mx.expm1(logp_c), mx.array(1e-12, logp_c.dtype)) + val_r = mx.maximum(-mx.expm1(logp_r), mx.array(1e-12, logp_r.dtype)) + log_odds = (logp_c - mx.log(val_c)) - (logp_r - mx.log(val_r)) + or_loss = -mx.mean(nn.log_sigmoid(log_odds)) + loss = sft + beta * or_loss + return loss, mask.sum() + return loss_fn + + +def create_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, + prompt_key="prompt", chosen_key="chosen", + rejected_key="rejected", pad_to_multiple=32, + num_batches=None): + """Build concatenated [chosen; rejected] preference batches for ORPO. + + Each example contributes ``prompt + chosen`` and ``prompt + rejected``. + Pairs are length-sorted (by max of the two), grouped into batches of + ``batch_size`` PAIRS, and every row in a batch is padded to that batch's + max length, rounded up to ``pad_to_multiple`` (Apple-Silicon padding). + + Returns a list of ``(batch, lengths, None)`` tuples: + batch: (2B, L) int32 — rows [0:B] chosen, [B:2B] rejected, paired by index + lengths: (2B, 2) — per row [response_start, seq_end) + """ + hf = getattr(tokenizer, "_tokenizer", tokenizer) + pad_id = hf.eos_token_id if hf.eos_token_id is not None else 0 + + rows = [] + for ex in dataset: + if prompt_key not in ex or chosen_key not in ex or rejected_key not in ex: + raise ValueError( + f"ORPO requires '{prompt_key}', '{chosen_key}', '{rejected_key}' " + f"columns; got {sorted(ex.keys())}." + ) + prompt = ex[prompt_key] + p_ids = hf.encode(prompt) + c_ids = hf.encode(prompt + ex[chosen_key])[:max_seq_length] + r_ids = hf.encode(prompt + ex[rejected_key])[:max_seq_length] + pe = min(len(p_ids), len(c_ids), len(r_ids)) + rows.append((pe, c_ids, r_ids)) + + rows.sort(key=lambda t: max(len(t[1]), len(t[2]))) + + out = [] + for i in range(0, len(rows), batch_size): + chunk = rows[i:i + batch_size] + Lmax = max(max(len(c), len(r)) for _, c, r in chunk) + if pad_to_multiple: + Lmax = ((Lmax + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple + chosen_rows, rejected_rows, lengths = [], [], [] + for pe, c, r in chunk: + chosen_rows.append(c + [pad_id] * (Lmax - len(c))) + lengths.append([pe, len(c)]) + for pe, c, r in chunk: + rejected_rows.append(r + [pad_id] * (Lmax - len(r))) + lengths.append([pe, len(r)]) + batch = mx.array(chosen_rows + rejected_rows, dtype=mx.int32) + lengths_arr = mx.array(lengths) + out.append((batch, lengths_arr, None)) + if num_batches is not None and len(out) >= num_batches: + break + mx.eval([b for b, _, _ in out] + [l for _, l, _ in out]) + return out + + def make_baseline_loss_fn(): """Create a standard cross-entropy loss function (full logits via LM head). From 2e3c22e9ffd1ed635dd175f08d695a319e4f68da Mon Sep 17 00:00:00 2001 From: Bardia Koopah Date: Wed, 24 Jun 2026 13:10:46 -0700 Subject: [PATCH 2/7] Add DPO (loss_type='dpo') with live LoRA-disable reference; share preference collator with ORPO --- unsloth_zoo/mlx/trainer.py | 25 +++++++++++---- unsloth_zoo/mlx/utils.py | 62 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index ba4ceba49..2a3e817c3 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -47,6 +47,7 @@ import mlx.nn as nn import mlx.optimizers as optim from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten +from mlx_lm.tuner.lora import LoRALinear _PAD_MULTIPLE = 32 SUPPORTED_MLX_OPTIMIZERS = ("adafactor", "adamw", "adam", "sgd", "muon", "lion") @@ -58,8 +59,9 @@ make_vlm_cce_loss_fn, make_vlm_baseline_loss_fn, create_batches, - create_orpo_batches, + create_preference_batches, make_orpo_loss_fn, + make_dpo_loss_fn, create_ordered_batches, iterate_training_batches, create_vlm_batches, @@ -506,6 +508,8 @@ class MLXTrainingConfig: eval_steps: int = 0 # 0 = disabled loss_type: str = "sft" # "sft" or "orpo" orpo_beta: float = 0.1 # ORPO odds-ratio weight (TRL default) + dpo_beta: float = 0.1 # DPO beta (TRL default) + reference_free: bool = False # DPO: drop the reference term if True # SFT-specific (from SFTConfig, for API compat) dataset_text_field: str = "text" @@ -1234,6 +1238,15 @@ def _train_inner(self): _ob = getattr(args, "orpo_beta", 0.1) loss_fn = make_orpo_loss_fn(beta=_ob) print("Unsloth: Using ORPO loss (beta=" + str(_ob) + ").") + elif getattr(args, "loss_type", "sft") == "dpo": + _db = getattr(args, "dpo_beta", 0.1) + _rf = bool(getattr(args, "reference_free", False)) + _lora_mods = [mod for _, mod in tree_flatten( + model, is_leaf=lambda x: isinstance(x, LoRALinear)) + if isinstance(mod, LoRALinear)] + loss_fn = make_dpo_loss_fn(beta=_db, lora_mods=_lora_mods, reference_free=_rf) + print("Unsloth: Using DPO loss (beta=" + str(_db) + + (", reference_free" if _rf else "") + ").") elif use_cce: loss_fn = make_cce_loss_fn(model) cce_backend = getattr(loss_fn, "_unsloth_cce_backend", "unknown") @@ -1626,9 +1639,9 @@ def step_fn(batch_data, prev_state, do_update): # Prepare eval batches eval_batches = None text_completion_only_loss = _text_completion_only_loss_arg(args) - if (getattr(args, "loss_type", "sft") == "orpo" + if (getattr(args, "loss_type", "sft") in ("orpo", "dpo") and args.eval_steps > 0 and self.eval_dataset is not None): - print("Unsloth: eval is not yet supported for ORPO; skipping eval.") + print(f"Unsloth: eval is not yet supported for {args.loss_type}; skipping eval.") elif args.eval_steps > 0 and self.eval_dataset is not None: # Use pre-built labeled eval batches if available _labeled_eval = getattr(self, '_eval_batches_labeled', None) @@ -2059,12 +2072,12 @@ def _prepare_data(self, is_vlm): ) text_completion_only_loss = _text_completion_only_loss_arg(args) - if getattr(args, "loss_type", "sft") == "orpo": + if getattr(args, "loss_type", "sft") in ("orpo", "dpo"): if is_vlm: raise ValueError( - "ORPO is not yet supported for VLM models on MLX." + f"{args.loss_type.upper()} is not yet supported for VLM models on MLX." ) - batches = create_orpo_batches( + batches = create_preference_batches( dataset=self.train_dataset, tokenizer=self.tokenizer, batch_size=args.per_device_train_batch_size, diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index a4b6f99ba..5e134c0de 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -519,7 +519,7 @@ def make_orpo_loss_fn(beta=0.1): -> (loss, ntoks), so it drops into the trainer's text step path unchanged. Expects ``batch`` shape (2B, L) where rows [0:B] are chosen and rows [B:2B] - are rejected, paired by index (produced by ``create_orpo_batches``). + are rejected, paired by index (produced by ``create_preference_batches``). ``lengths`` is (2B, 2) with per-row [response_start, seq_end); only response tokens are scored (the prompt is masked). """ @@ -548,11 +548,67 @@ def loss_fn(model, batch, lengths, labels=None): return loss_fn -def create_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, +def make_dpo_loss_fn(beta=0.1, lora_mods=None, reference_free=False): + """Create a DPO (Direct Preference Optimization) loss function. + + Operates on a concatenated [chosen; rejected] batch (same layout as + make_orpo_loss_fn / create_preference_batches): rows [0:B] chosen, + [B:2B] rejected. Signature matches the other loss fns: + (model, batch, lengths, labels=None) -> (loss, ntoks). + + DPO compares the policy's per-response log-probs against a frozen + reference. For LoRA models the reference is the base model: obtained by + temporarily zeroing the LoRA scales (adapters off), running the reference + forward under stop_gradient, then restoring the scales in a finally block. + Mirrors TRL's disable-adapter approach (no second model copy). With + reference_free=True the reference term is dropped, matching TRL. + + ``lora_mods`` is the list of LoRALinear modules to toggle; collected once + by the trainer at setup. + """ + _mods = list(lora_mods) if lora_mods is not None else [] + + def _row_logp_and_mask(model, batch, lengths): + inputs = batch[:, :-1] + targets = batch[:, 1:] + logits = model(inputs) + steps = mx.arange(1, targets.shape[1] + 1) + mask = mx.logical_and( + steps >= lengths[:, 0:1], steps < lengths[:, 1:] + ).astype(mx.float32) + logp_tok = -nn.losses.cross_entropy(logits, targets) * mask + return logp_tok.sum(axis=1), mask.sum() + + def loss_fn(model, batch, lengths, labels=None): + B = batch.shape[0] // 2 + pol, ntoks = _row_logp_and_mask(model, batch, lengths) + pol_c, pol_r = pol[:B], pol[B:] + + if reference_free or not _mods: + ref_c = mx.zeros(pol_c.shape) + ref_r = mx.zeros(pol_r.shape) + else: + saved = [md.scale for md in _mods] + try: + for md in _mods: + md.scale = 0.0 + ref, _ = _row_logp_and_mask(model, batch, lengths) + ref = mx.stop_gradient(ref) + finally: + for md, s in zip(_mods, saved): + md.scale = s + ref_c, ref_r = ref[:B], ref[B:] + + logits = beta * ((pol_c - ref_c) - (pol_r - ref_r)) + loss = -mx.mean(nn.log_sigmoid(logits)) + return loss, ntoks + return loss_fn + +def create_preference_batches(dataset, tokenizer, batch_size, max_seq_length, prompt_key="prompt", chosen_key="chosen", rejected_key="rejected", pad_to_multiple=32, num_batches=None): - """Build concatenated [chosen; rejected] preference batches for ORPO. + """Build concatenated [chosen; rejected] preference batches for ORPO/DPO. Each example contributes ``prompt + chosen`` and ``prompt + rejected``. Pairs are length-sorted (by max of the two), grouped into batches of From 89108ff06c34222b45a5c0933e50470f29f66b05 Mon Sep 17 00:00:00 2001 From: Bardia Koopah Date: Thu, 25 Jun 2026 11:35:02 -0700 Subject: [PATCH 3/7] Add MLXDPOTrainer/MLXORPOTrainer + configs (TRL-style API) over loss_type; flag kept for back-compat --- unsloth_zoo/mlx/trainer.py | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 2a3e817c3..7c451db7e 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -545,6 +545,20 @@ class MLXTrainingConfig: vlm_chat_template: object = None # Unsloth template name/tuple or raw Jinja string +@dataclass +class MLXORPOConfig(MLXTrainingConfig): + """ORPO config mirroring TRL's ORPOConfig. Presets loss_type='orpo'; + tune orpo_beta (inherited). Use with MLXORPOTrainer.""" + loss_type: str = "orpo" + + +@dataclass +class MLXDPOConfig(MLXTrainingConfig): + """DPO config mirroring TRL's DPOConfig. Presets loss_type='dpo'; + tune dpo_beta / reference_free (inherited). Use with MLXDPOTrainer.""" + loss_type: str = "dpo" + + class MLXTrainer: """MLX-native trainer for Apple Silicon, mirroring SFTTrainer's constructor API.""" @@ -2330,6 +2344,36 @@ def save_model(self, output_dir=None): save_merged_model(self.model, self.tokenizer, output_dir) +class MLXORPOTrainer(MLXTrainer): + """ORPO trainer mirroring TRL's ORPOTrainer. Forces loss_type='orpo' so + the class is authoritative regardless of the config passed.""" + def __init__(self, model, tokenizer, train_dataset, eval_dataset=None, + dataset_text_field=None, max_seq_length=None, packing=None, + data_collator=None, args=None, formatting_func=None, processor=None): + if args is None: + args = MLXORPOConfig() + elif getattr(args, "loss_type", "sft") != "orpo": + args.loss_type = "orpo" + super().__init__(model, tokenizer, train_dataset, eval_dataset, + dataset_text_field, max_seq_length, packing, + data_collator, args, formatting_func, processor) + + +class MLXDPOTrainer(MLXTrainer): + """DPO trainer mirroring TRL's DPOTrainer. Forces loss_type='dpo' so + the class is authoritative regardless of the config passed.""" + def __init__(self, model, tokenizer, train_dataset, eval_dataset=None, + dataset_text_field=None, max_seq_length=None, packing=None, + data_collator=None, args=None, formatting_func=None, processor=None): + if args is None: + args = MLXDPOConfig() + elif getattr(args, "loss_type", "sft") != "dpo": + args.loss_type = "dpo" + super().__init__(model, tokenizer, train_dataset, eval_dataset, + dataset_text_field, max_seq_length, packing, + data_collator, args, formatting_func, processor) + + def _create_labeled_batches(dataset, tokenizer, mask_fn, batch_size, max_seq_length, formatting_func=None, dataset_text_field="text", num_batches=None, From 3f1477183c338851ab0655c4a75ee861b59bf0c1 Mon Sep 17 00:00:00 2001 From: Bardia Koopah Date: Tue, 30 Jun 2026 22:39:49 -0700 Subject: [PATCH 4/7] fix(mlx): ORPO SFT/NLL term over prompt+response to match TRL The MLX ORPO loss computed its SFT/NLL term from the response-only, length-normalized log-prob (the same quantity used for the odds-ratio term). TRL's chosen_nll_loss is instead a pooled token-mean cross-entropy over the full prompt+response span (all non-pad chosen tokens, matching nn.CrossEntropyLoss default reduction). The narrower span made the SFT signal trivially easy, so ORPO saturated to ~0 loss on toy data while CUDA did not. Compute the NLL term over prompt+response for the chosen rows; the odds-ratio term (response-only, length-normalized) is unchanged. --- unsloth_zoo/mlx/utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 5e134c0de..1c8fd207d 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -520,24 +520,30 @@ def make_orpo_loss_fn(beta=0.1): Expects ``batch`` shape (2B, L) where rows [0:B] are chosen and rows [B:2B] are rejected, paired by index (produced by ``create_preference_batches``). - ``lengths`` is (2B, 2) with per-row [response_start, seq_end); only response - tokens are scored (the prompt is masked). + ``lengths`` is (2B, 2) with per-row [response_start, seq_end). The odds-ratio + term scores response tokens only; the SFT/NLL term (TRL chosen_nll_loss) is + computed over the full prompt+response span (all non-pad chosen tokens). """ def loss_fn(model, batch, lengths, labels=None): inputs = batch[:, :-1] targets = batch[:, 1:] logits = model(inputs) steps = mx.arange(1, targets.shape[1] + 1) + ce_tok = nn.losses.cross_entropy(logits, targets) mask = mx.logical_and( steps >= lengths[:, 0:1], steps < lengths[:, 1:] ).astype(mx.float32) - logp_tok = -nn.losses.cross_entropy(logits, targets) * mask + logp_tok = -ce_tok * mask ntok_row = mask.sum(axis=1) logp = logp_tok.sum(axis=1) / mx.maximum(ntok_row, mx.array(1.0)) B = batch.shape[0] // 2 logp_c, logp_r = logp[:B], logp[B:] - # SFT term: NLL on chosen (length-normalized). - sft = -mx.mean(logp_c) + # SFT term: TRL chosen_nll_loss. Pooled token-mean cross-entropy over the + # full prompt+response span (all non-pad chosen tokens), matching TRL's + # nn.CrossEntropyLoss default reduction. This is NOT the response-only, + # length-normalized logp used for the odds-ratio term below. + nll_mask = (steps < lengths[:, 1:]).astype(mx.float32)[:B] + sft = (ce_tok[:B] * nll_mask).sum() / mx.maximum(nll_mask.sum(), mx.array(1.0)) # Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized. val_c = mx.maximum(-mx.expm1(logp_c), mx.array(1e-12, logp_c.dtype)) val_r = mx.maximum(-mx.expm1(logp_r), mx.array(1e-12, logp_r.dtype)) From 72bb6c3cd0c1e2013357190715e569bfb33511a8 Mon Sep 17 00:00:00 2001 From: Bardia Koopah Date: Thu, 2 Jul 2026 09:50:49 -0700 Subject: [PATCH 5/7] MLX ORPO/DPO: make preference-batch row order configurable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit create_preference_batches unconditionally length-sorted every pair by max(len(chosen), len(rejected)), which minimizes padding but diverges from CUDA/TRL — HF Trainer feeds preference data via SequentialSampler (dataset order) or a seeded RandomSampler (torch.randperm), never length sorted. Identical-seed MLX and CUDA runs therefore saw examples in different order from step 1, producing different batches and curves. Add dataset_order ("default" | "sequential" | "torch_randperm") plus a preserve_dataset_order shortcut and seed, mirroring the existing SFT/VLM builders, and wire them at the trainer call site. "default" keeps the historical length-sort so existing runs are byte-identical; "sequential" (or preserve_dataset_order=True) reproduces CUDA SequentialSampler order and "torch_randperm" reproduces RandomSampler order for parity testing. --- unsloth_zoo/mlx/trainer.py | 6 ++++++ unsloth_zoo/mlx/utils.py | 32 +++++++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 7c451db7e..c29fa9c7d 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -2097,6 +2097,12 @@ def _prepare_data(self, is_vlm): batch_size=args.per_device_train_batch_size, max_seq_length=args.max_seq_length, num_batches=total_batches_needed, + dataset_order=( + "sequential" + if getattr(args, "preserve_dataset_order", False) + else getattr(args, "dataset_order", "default") + ), + seed=getattr(args, "seed", None), ) return batches, None diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 1c8fd207d..84db07ec8 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -613,18 +613,34 @@ def loss_fn(model, batch, lengths, labels=None): def create_preference_batches(dataset, tokenizer, batch_size, max_seq_length, prompt_key="prompt", chosen_key="chosen", rejected_key="rejected", pad_to_multiple=32, - num_batches=None): + num_batches=None, dataset_order="default", + preserve_dataset_order=False, seed=None): """Build concatenated [chosen; rejected] preference batches for ORPO/DPO. Each example contributes ``prompt + chosen`` and ``prompt + rejected``. - Pairs are length-sorted (by max of the two), grouped into batches of - ``batch_size`` PAIRS, and every row in a batch is padded to that batch's - max length, rounded up to ``pad_to_multiple`` (Apple-Silicon padding). + Pairs are grouped into batches of ``batch_size`` PAIRS, and every row in a + batch is padded to that batch's max length, rounded up to + ``pad_to_multiple`` (Apple-Silicon padding). + + ``dataset_order`` controls how pairs are ordered before batching (mirrors + the SFT/VLM builders so preference runs can match CUDA/TRL parity): + "default" length-sort by ``max(len(chosen), len(rejected))`` — least + padding / best throughput; the historical behavior. + "sequential" keep dataset order — matches CUDA ``SequentialSampler``. + "torch_randperm" seeded permutation — matches CUDA ``RandomSampler``. + ``preserve_dataset_order=True`` forces "sequential" (Studio wiring). + ``seed`` seeds the "torch_randperm" order. Returns a list of ``(batch, lengths, None)`` tuples: batch: (2B, L) int32 — rows [0:B] chosen, [B:2B] rejected, paired by index lengths: (2B, 2) — per row [response_start, seq_end) """ + order_mode = "sequential" if preserve_dataset_order else dataset_order + if order_mode not in ("default", "sequential", "torch_randperm"): + raise ValueError( + f"Unsloth MLX: unsupported preference dataset_order={order_mode!r}. " + "Expected 'default', 'sequential', or 'torch_randperm'." + ) hf = getattr(tokenizer, "_tokenizer", tokenizer) pad_id = hf.eos_token_id if hf.eos_token_id is not None else 0 @@ -642,7 +658,13 @@ def create_preference_batches(dataset, tokenizer, batch_size, max_seq_length, pe = min(len(p_ids), len(c_ids), len(r_ids)) rows.append((pe, c_ids, r_ids)) - rows.sort(key=lambda t: max(len(t[1]), len(t[2]))) + # rows are collected in dataset order above; reorder per the requested mode. + if order_mode == "default": + rows.sort(key=lambda t: max(len(t[1]), len(t[2]))) + elif order_mode == "torch_randperm": + order = _torch_randperm_order(len(rows), _normalize_seed(seed)) + rows = [rows[i] for i in order] + # "sequential": leave rows in dataset order (CUDA SequentialSampler parity). out = [] for i in range(0, len(rows), batch_size): From 203919e12a3a7ad0a5c64bcb638803a54bd954fb Mon Sep 17 00:00:00 2001 From: Bardia Koopah Date: Thu, 2 Jul 2026 10:09:45 -0700 Subject: [PATCH 6/7] MLX ORPO/DPO: append EOS to preference completions (TRL parity) create_preference_batches encoded prompt+chosen / prompt+rejected with no EOS, while TRL appends the tokenizer's EOS to every completion (DPO tokenize_row unconditionally; ORPO via add_eos_token_if_needed). Missing EOS changes the trained token sequence, logprobs, and loss: the model never has to predict the completion's end token, so the ORPO loss sits systematically below the CUDA/TRL reference. Add an append_eos flag (default True, matching TRL) that appends the EOS id to each chosen/rejected completion, guarded against a double EOS (mirrors add_eos_token_if_needed) and applied before max_seq_length truncation (same append-then-truncate contract as the SFT path). The EOS lands after the prompt boundary, inside the [response_start, seq_end) loss span, so it is trained on as in TRL. Verified: the appended EOS has mask==1.0 at its target position, and the untrained-model ORPO loss over a fixed set rises (+0.256) once completions carry the end token. --- unsloth_zoo/mlx/trainer.py | 1 + unsloth_zoo/mlx/utils.py | 25 +++++++++++++++++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index c29fa9c7d..c7afbb039 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -2103,6 +2103,7 @@ def _prepare_data(self, is_vlm): else getattr(args, "dataset_order", "default") ), seed=getattr(args, "seed", None), + append_eos=bool(getattr(args, "append_eos", True)), ) return batches, None diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 84db07ec8..ffd7b3ed0 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -614,7 +614,8 @@ def create_preference_batches(dataset, tokenizer, batch_size, max_seq_length, prompt_key="prompt", chosen_key="chosen", rejected_key="rejected", pad_to_multiple=32, num_batches=None, dataset_order="default", - preserve_dataset_order=False, seed=None): + preserve_dataset_order=False, seed=None, + append_eos=True): """Build concatenated [chosen; rejected] preference batches for ORPO/DPO. Each example contributes ``prompt + chosen`` and ``prompt + rejected``. @@ -631,6 +632,13 @@ def create_preference_batches(dataset, tokenizer, batch_size, max_seq_length, ``preserve_dataset_order=True`` forces "sequential" (Studio wiring). ``seed`` seeds the "torch_randperm" order. + ``append_eos`` (default True, matching TRL) appends the tokenizer's EOS id + to each chosen/rejected completion, guarded to avoid a double EOS (mirrors + TRL's ``add_eos_token_if_needed``). EOS is appended before truncation to + ``max_seq_length`` (same contract as the SFT path). Because EOS lands after + the prompt boundary it falls inside the ``[response_start, seq_end)`` loss + span, so it is trained on — as in TRL. + Returns a list of ``(batch, lengths, None)`` tuples: batch: (2B, L) int32 — rows [0:B] chosen, [B:2B] rejected, paired by index lengths: (2B, 2) — per row [response_start, seq_end) @@ -642,7 +650,16 @@ def create_preference_batches(dataset, tokenizer, batch_size, max_seq_length, "Expected 'default', 'sequential', or 'torch_randperm'." ) hf = getattr(tokenizer, "_tokenizer", tokenizer) - pad_id = hf.eos_token_id if hf.eos_token_id is not None else 0 + eos_id = hf.eos_token_id + pad_id = eos_id if eos_id is not None else 0 + + def _with_eos(ids): + # TRL appends EOS to the completion, avoiding a double EOS + # (add_eos_token_if_needed). Append before max_seq_length truncation, + # matching the SFT path's append-then-truncate contract. + if append_eos and eos_id is not None and (not ids or ids[-1] != eos_id): + ids = ids + [eos_id] + return ids[:max_seq_length] rows = [] for ex in dataset: @@ -653,8 +670,8 @@ def create_preference_batches(dataset, tokenizer, batch_size, max_seq_length, ) prompt = ex[prompt_key] p_ids = hf.encode(prompt) - c_ids = hf.encode(prompt + ex[chosen_key])[:max_seq_length] - r_ids = hf.encode(prompt + ex[rejected_key])[:max_seq_length] + c_ids = _with_eos(hf.encode(prompt + ex[chosen_key])) + r_ids = _with_eos(hf.encode(prompt + ex[rejected_key])) pe = min(len(p_ids), len(c_ids), len(r_ids)) rows.append((pe, c_ids, r_ids)) From 7eda39468679dfb3d96e3d5d574dd3c2f38c99b9 Mon Sep 17 00:00:00 2001 From: Bardia Koopah Date: Thu, 2 Jul 2026 10:23:36 -0700 Subject: [PATCH 7/7] fail loud on full-FT DPO reference instead of silent collapse (DPO portion) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DPO's reference is computed by disabling LoRA adapters. When a model has no LoRA modules (full fine-tuning), the disable-adapter path has nothing to toggle, so DPO silently fell back to reference-free — dropping the reference term and training a different objective than requested, with no warning. Raise a clear ValueError at loss-fn construction when there are no LoRA mods and a reference is actually requested (not reference_free). The message points users to a LoRA/PEFT model or to reference_free=True if they genuinely want reference-free DPO. Legitimate reference-free configs pass through unchanged. This is the DPO portion of the full-FT fail-loud guard; the matching GRPO KL guard lands with the GRPO branch (make_grpo_loss_fn does not exist here). A real full-FT reference (a frozen second model) remains a separate maintainer design decision. --- unsloth_zoo/mlx/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index ffd7b3ed0..4ff65b2f5 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -573,6 +573,14 @@ def make_dpo_loss_fn(beta=0.1, lora_mods=None, reference_free=False): by the trainer at setup. """ _mods = list(lora_mods) if lora_mods is not None else [] + if not _mods and not reference_free: + raise ValueError( + "Unsloth: DPO with a reference model is not yet supported for full " + "fine-tuning on MLX — the reference is obtained by disabling LoRA " + "adapters, but this model has none. Use a LoRA/PEFT model, or pass " + "reference_free=True to train without a reference (TRL-style " + "reference-free DPO)." + ) def _row_logp_and_mask(model, batch, lengths): inputs = batch[:, :-1]