Feat(mlx): Add ORPO (loss_type='orpo') for text models to MLXTrainer#830
Feat(mlx): Add ORPO (loss_type='orpo') for text models to MLXTrainer#830BardiaKoopah wants to merge 7 commits into
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Odds Ratio Preference Optimization (ORPO) training in the MLX trainer, including configuration options, batch creation, and loss function implementation. The review feedback highlights two critical improvements: first, create_orpo_batches should apply chat templates and support message lists to prevent broken formatting and type errors; second, the odds-ratio calculation in make_orpo_loss_fn should be performed in float32 to avoid numerical underflow and NaN gradients during low-precision training.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for ex in dataset: | ||
| if prompt_key not in ex or chosen_key not in ex or rejected_key not in ex: | ||
| raise ValueError( | ||
| f"ORPO requires '{prompt_key}', '{chosen_key}', '{rejected_key}' " | ||
| f"columns; got {sorted(ex.keys())}." | ||
| ) | ||
| prompt = ex[prompt_key] | ||
| p_ids = hf.encode(prompt) | ||
| c_ids = hf.encode(prompt + ex[chosen_key])[:max_seq_length] | ||
| r_ids = hf.encode(prompt + ex[rejected_key])[:max_seq_length] | ||
| pe = min(len(p_ids), len(c_ids), len(r_ids)) | ||
| rows.append((pe, c_ids, r_ids)) |
There was a problem hiding this comment.
The current implementation of create_orpo_batches directly encodes prompt and prompt + ex[chosen_key] as raw strings. This has two major issues:
- Broken Chat Templates / Control Tokens: It bypasses the model's chat template. For chat/instruction models, prompts and responses must be formatted with control tokens (e.g., <|im_start|>, [INST], etc.). Without this, the model trains on raw text and will fail to learn the correct chat format.
- Type Errors with Message Lists: Many preference datasets (e.g., UltraFeedback) store prompt and chosen/rejected as lists of message dicts (conversations). Directly doing prompt + ex[chosen_key] will raise a TypeError when trying to concatenate a list and a string, or when passing a list of dicts to hf.encode.
To fix this, we should normalize the inputs into message lists, apply the tokenizer's chat template (if available) to format them with control tokens, and then encode the resulting formatted strings.
for ex in dataset:
if prompt_key not in ex or chosen_key not in ex or rejected_key not in ex:
raise ValueError(
f"ORPO requires '{prompt_key}', '{chosen_key}', '{rejected_key}' "
f"columns; got {sorted(ex.keys())}."
)
prompt = ex[prompt_key]
chosen = ex[chosen_key]
rejected = ex[rejected_key]
if isinstance(prompt, str):
prompt_messages = [{"role": "user", "content": prompt}]
else:
prompt_messages = prompt
if isinstance(chosen, str):
chosen_messages = [{"role": "assistant", "content": chosen}]
else:
chosen_messages = chosen
if isinstance(rejected, str):
rejected_messages = [{"role": "assistant", "content": rejected}]
else:
rejected_messages = rejected
apply_tmpl = getattr(tokenizer, "apply_chat_template", None) or getattr(hf, "apply_chat_template", None)
if apply_tmpl is not None:
prompt_str = apply_tmpl(prompt_messages, tokenize=False, add_generation_prompt=True)
chosen_str = apply_tmpl(prompt_messages + chosen_messages, tokenize=False)
rejected_str = apply_tmpl(prompt_messages + rejected_messages, tokenize=False)
else:
prompt_str = prompt if isinstance(prompt, str) else ""
chosen_str = prompt_str + (chosen if isinstance(chosen, str) else "")
rejected_str = prompt_str + (rejected if isinstance(rejected, str) else "")
p_ids = hf.encode(prompt_str)
c_ids = hf.encode(chosen_str)[:max_seq_length]
r_ids = hf.encode(rejected_str)[:max_seq_length]
pe = min(len(p_ids), len(c_ids), len(r_ids))
rows.append((pe, c_ids, r_ids))| # Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized. | ||
| val_c = mx.maximum(-mx.expm1(logp_c), mx.array(1e-12, logp_c.dtype)) | ||
| val_r = mx.maximum(-mx.expm1(logp_r), mx.array(1e-12, logp_r.dtype)) | ||
| log_odds = (logp_c - mx.log(val_c)) - (logp_r - mx.log(val_r)) | ||
| or_loss = -mx.mean(nn.log_sigmoid(log_odds)) | ||
| loss = sft + beta * or_loss |
There was a problem hiding this comment.
In low-precision training (such as float16 or bfloat16), computing odds-ratio terms directly can lead to numerical instability. Specifically, the constant 1e-12 underflows to 0.0 in float16 (where the minimum positive subnormal is 5.96e-8). If logp_c is 0.0 (e.g., if a token has a probability of 1.0), -mx.expm1(logp_c) becomes 0.0, and mx.maximum(0.0, 1e-12) will evaluate to 0.0 due to underflow. This results in mx.log(0.0) = -inf, causing NaN gradients and training crashes.
To prevent this, perform the odds-ratio calculations in float32 and cast the final or_loss back to the original dtype.
| # Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized. | |
| val_c = mx.maximum(-mx.expm1(logp_c), mx.array(1e-12, logp_c.dtype)) | |
| val_r = mx.maximum(-mx.expm1(logp_r), mx.array(1e-12, logp_r.dtype)) | |
| log_odds = (logp_c - mx.log(val_c)) - (logp_r - mx.log(val_r)) | |
| or_loss = -mx.mean(nn.log_sigmoid(log_odds)) | |
| loss = sft + beta * or_loss | |
| # Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized in float32. | |
| logp_c_f32 = logp_c.astype(mx.float32) | |
| logp_r_f32 = logp_r.astype(mx.float32) | |
| val_c = mx.maximum(-mx.expm1(logp_c_f32), mx.array(1e-12, dtype=mx.float32)) | |
| val_r = mx.maximum(-mx.expm1(logp_r_f32), mx.array(1e-12, dtype=mx.float32)) | |
| log_odds = (logp_c_f32 - mx.log(val_c)) - (logp_r_f32 - mx.log(val_r)) | |
| or_loss = -mx.mean(nn.log_sigmoid(log_odds)) | |
| loss = sft + beta * or_loss.astype(sft.dtype) |
…ference collator with ORPO
…type; flag kept for back-compat
|
I ran a small cross-backend train-curve check against the CUDA/TRL path. This is not a full quality eval, and after checking the reference notebooks I would treat it as a smoke/parity probe rather than a strict apples-to-apples benchmark. Setup:
Results summary (
DPO looks reasonably close: both backends start at ORPO still has a real parity gap. On the arithmetic toy set, MLX saturates to near-zero loss while CUDA does not. On the QA set, MLX no longer saturates, but the loss/grad shape is still noticeably different. Confirmed differences I found:
Suggested next step before judging parity: rerun with the exact notebook-style dataset formatting, dump the first CUDA and MLX batch token IDs/masks/labels, and compare one frozen no-update loss component by component. For ORPO, align the TRL chosen-NLL/SFT term first. DPO already looks close enough that I would focus mostly on strict preprocessing/batch parity and ORPO semantics. |
The MLX ORPO loss computed its SFT/NLL term from the response-only, length-normalized log-prob (the same quantity used for the odds-ratio term). TRL's chosen_nll_loss is instead a pooled token-mean cross-entropy over the full prompt+response span (all non-pad chosen tokens, matching nn.CrossEntropyLoss default reduction). The narrower span made the SFT signal trivially easy, so ORPO saturated to ~0 loss on toy data while CUDA did not. Compute the NLL term over prompt+response for the chosen rows; the odds-ratio term (response-only, length-normalized) is unchanged.
|
Fix (3f14771): compute the NLL term over prompt+response for the chosen rows (nll_mask = steps <
DPO (step-1 = 0.6931) and SFT still train; the MLX baseline test suite is green (90 passed / 5 skipped, |
|
Re-tested current HEAD ( Summary RMSE vs CUDA:
ORPO arithmetic loss curve: xychart-beta
title "ORPO arithmetic loss"
x-axis [1, 2, 3, 4, 5, 6, 7, 8]
y-axis "loss" 0 --> 4.5
line "CUDA" [4.289, 3.310, 3.050, 2.641, 2.223, 1.999, 1.670, 1.703]
line "MLX HEAD" [2.301, 1.966, 2.207, 1.406, 1.013, 1.134, 0.866, 1.388]
DPO loss is much closer, but grad still differs. Example DPO arithmetic grad:
Main remaining gaps I see:
Tiny example of the EOS/order wiring I’d expect: create_preference_batches(
...,
append_eos=args.append_eos,
dataset_order="sequential" if args.preserve_dataset_order else args.dataset_order,
seed=args.seed,
)The ORPO NLL fix definitely helped, but I think the preference data path still needs these parity fixes before the curves are reliably CUDA-comparable. |
create_preference_batches unconditionally length-sorted every pair by
max(len(chosen), len(rejected)), which minimizes padding but diverges
from CUDA/TRL — HF Trainer feeds preference data via SequentialSampler
(dataset order) or a seeded RandomSampler (torch.randperm), never length
sorted. Identical-seed MLX and CUDA runs therefore saw examples in
different order from step 1, producing different batches and curves.
Add dataset_order ("default" | "sequential" | "torch_randperm") plus a
preserve_dataset_order shortcut and seed, mirroring the existing SFT/VLM
builders, and wire them at the trainer call site. "default" keeps the
historical length-sort so existing runs are byte-identical; "sequential"
(or preserve_dataset_order=True) reproduces CUDA SequentialSampler order
and "torch_randperm" reproduces RandomSampler order for parity testing.
create_preference_batches encoded prompt+chosen / prompt+rejected with no EOS, while TRL appends the tokenizer's EOS to every completion (DPO tokenize_row unconditionally; ORPO via add_eos_token_if_needed). Missing EOS changes the trained token sequence, logprobs, and loss: the model never has to predict the completion's end token, so the ORPO loss sits systematically below the CUDA/TRL reference. Add an append_eos flag (default True, matching TRL) that appends the EOS id to each chosen/rejected completion, guarded against a double EOS (mirrors add_eos_token_if_needed) and applied before max_seq_length truncation (same append-then-truncate contract as the SFT path). The EOS lands after the prompt boundary, inside the [response_start, seq_end) loss span, so it is trained on as in TRL. Verified: the appended EOS has mask==1.0 at its target position, and the untrained-model ORPO loss over a fixed set rises (+0.256) once completions carry the end token.
…rtion) DPO's reference is computed by disabling LoRA adapters. When a model has no LoRA modules (full fine-tuning), the disable-adapter path has nothing to toggle, so DPO silently fell back to reference-free — dropping the reference term and training a different objective than requested, with no warning. Raise a clear ValueError at loss-fn construction when there are no LoRA mods and a reference is actually requested (not reference_free). The message points users to a LoRA/PEFT model or to reference_free=True if they genuinely want reference-free DPO. Legitimate reference-free configs pass through unchanged. This is the DPO portion of the full-FT fail-loud guard; the matching GRPO KL guard lands with the GRPO branch (make_grpo_loss_fn does not exist here). A real full-FT reference (a frozen second model) remains a separate maintainer design decision.
|
Add ORPO and DPO preference tuning for text models to MLXTrainer
Adds two preference-tuning methods — ORPO and DPO — to the MLX trainer, bringing preference optimization to the Apple-Silicon path. On CUDA, Unsloth gets these by patching TRL's trainers (
rl_replacements.py); since there is no TRL on MLX, these are native implementations built on the existing MLX trainer, sharing a common preference data collator (mirroring TRL's singleDataCollatorForPreference).Enable via
loss_type="orpo"orloss_type="dpo". Default ("sft") is unchanged.ORPO
Combines SFT with a preference signal in one loss, no reference model:
L = L_SFT + beta * L_OR, whereL_OR = -log(sigmoid(log_odds_chosen - log_odds_rejected))using the fullp/(1-p)odds (paper-faithful, not a simplified log-prob ratio), with the1 - exp(logp)term numerically stabilized.Config:
orpo_beta: float = 0.1(TRL default).DPO
Compares the policy against a frozen reference:
L = -log(sigmoid(beta * [(logp_chosen^pol - logp_chosen^ref) - (logp_rejected^pol - logp_rejected^ref)])).Reference via live LoRA-disable (matching TRL's
is_peft_modeldefault — no second model copy). The base model is the reference: we temporarily zero the LoRA scales (mlx_lm'sLoRALinear.__call__isy + scale*z, so scale=0 disables the adapter), run the reference forward undermx.stop_gradient, and restore the scales in afinallyblock. Verified the scale restore is bit-exact and survivesvalue_and_gradtracing.Config:
dpo_beta: float = 0.1,reference_free: bool = False(drops the reference term, matching TRL). Precompute of reference log-probs (TRL'sprecompute_ref_log_probs) is a possible follow-up optimization; this PR computes the reference live each step.Shared design
concatenated_forward: chosen and rejected stacked into one batch (chosen block then rejected block), single forward, split in the loss. Verified MLXvalue_and_graddifferentiates cleanly through the concat-and-split.create_preference_batches: length-sorts pairs, dynamic-pads to a multiple of 32 (Apple-Silicon padding), produces(batch, lengths, None)consumed by the existing text step path. Both ORPO and DPO use it.make_baseline_loss_fn's(model, batch, lengths, labels)signature; the batch path short-circuits in_prepare_data.Scope
Text models only. VLMs raise a clear not-supported error; eval during preference tuning prints a skip notice and trains normally. Both are follow-ups.
Testing (Qwen2.5-0.5B, small preference sets)
0.6931 = -log(0.5)exactly (untrained adapter → policy == reference → zero logits), the theoretically-correct starting value, then moves as the policy diverges. DPO loss math verified directionally in a controlled test (GOOD < NEUTRAL < BAD, NEUTRAL = -log(0.5)). Reference scale-toggle verified bit-exact and gradient-safe across multiple steps (no leak).test_cce_*) that fail identically onmainwithout this change — a torch-shim issue unrelated to this work.Smoke tests confirm the objectives run correctly and the losses behave as expected; they do not measure downstream preference-alignment quality (needs a full eval, out of scope). Tested on small datasets.