Skip to content

Feat(mlx): Add ORPO (loss_type='orpo') for text models to MLXTrainer#830

Open
BardiaKoopah wants to merge 7 commits into
unslothai:mainfrom
BardiaKoopah:feat/mlx-orpo
Open

Feat(mlx): Add ORPO (loss_type='orpo') for text models to MLXTrainer#830
BardiaKoopah wants to merge 7 commits into
unslothai:mainfrom
BardiaKoopah:feat/mlx-orpo

Conversation

@BardiaKoopah

@BardiaKoopah BardiaKoopah commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Add ORPO and DPO preference tuning for text models to MLXTrainer

Adds two preference-tuning methods — ORPO and DPO — to the MLX trainer, bringing preference optimization to the Apple-Silicon path. On CUDA, Unsloth gets these by patching TRL's trainers (rl_replacements.py); since there is no TRL on MLX, these are native implementations built on the existing MLX trainer, sharing a common preference data collator (mirroring TRL's single DataCollatorForPreference).

Enable via loss_type="orpo" or loss_type="dpo". Default ("sft") is unchanged.

ORPO

Combines SFT with a preference signal in one loss, no reference model:
L = L_SFT + beta * L_OR, where L_OR = -log(sigmoid(log_odds_chosen - log_odds_rejected)) using the full p/(1-p) odds (paper-faithful, not a simplified log-prob ratio), with the 1 - exp(logp) term numerically stabilized.

Config: orpo_beta: float = 0.1 (TRL default).

DPO

Compares the policy against a frozen reference:
L = -log(sigmoid(beta * [(logp_chosen^pol - logp_chosen^ref) - (logp_rejected^pol - logp_rejected^ref)])).

Reference via live LoRA-disable (matching TRL's is_peft_model default — no second model copy). The base model is the reference: we temporarily zero the LoRA scales (mlx_lm's LoRALinear.__call__ is y + scale*z, so scale=0 disables the adapter), run the reference forward under mx.stop_gradient, and restore the scales in a finally block. Verified the scale restore is bit-exact and survives value_and_grad tracing.

Config: dpo_beta: float = 0.1, reference_free: bool = False (drops the reference term, matching TRL). Precompute of reference log-probs (TRL's precompute_ref_log_probs) is a possible follow-up optimization; this PR computes the reference live each step.

Shared design

  • Concatenated forward, mirroring TRL's concatenated_forward: chosen and rejected stacked into one batch (chosen block then rejected block), single forward, split in the loss. Verified MLX value_and_grad differentiates cleanly through the concat-and-split.
  • Shared collator create_preference_batches: length-sorts pairs, dynamic-pads to a multiple of 32 (Apple-Silicon padding), produces (batch, lengths, None) consumed by the existing text step path. Both ORPO and DPO use it.
  • Minimal trainer footprint: both loss fns match make_baseline_loss_fn's (model, batch, lengths, labels) signature; the batch path short-circuits in _prepare_data.

Scope

Text models only. VLMs raise a clear not-supported error; eval during preference tuning prints a skip notice and trains normally. Both are follow-ups.

Testing (Qwen2.5-0.5B, small preference sets)

  • ORPO end-to-end: trains, loss decreases; loss math verified against the paper in all regimes (chosen≫rejected, rejected≫chosen, equal) and at the numerical edge.
  • DPO end-to-end: trains, loss decreases. Step 1 loss = 0.6931 = -log(0.5) exactly (untrained adapter → policy == reference → zero logits), the theoretically-correct starting value, then moves as the policy diverges. DPO loss math verified directionally in a controlled test (GOOD < NEUTRAL < BAD, NEUTRAL = -log(0.5)). Reference scale-toggle verified bit-exact and gradient-safe across multiple steps (no leak).
  • reference_free DPO: trains, reference term dropped (distinct loss behavior, no -log(0.5) midpoint).
  • Regression: SFT (default) unchanged (prints CCE loss); ORPO unaffected by the DPO wiring; both verified after the shared-collator rename.
  • gradient_accumulation_steps>1: both methods run clean.
  • Guards: VLM raises the clear error; eval prints the skip notice and trains.
  • Existing MLX suite: 90 passed across baseline-loss-parity, batching, padding, module-exports, trainer-internals. The only failures are three pre-existing CCE-kernel tests (test_cce_*) that fail identically on main without this change — a torch-shim issue unrelated to this work.
    Smoke tests confirm the objectives run correctly and the losses behave as expected; they do not measure downstream preference-alignment quality (needs a full eval, out of scope). Tested on small datasets.

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Odds Ratio Preference Optimization (ORPO) training in the MLX trainer, including configuration options, batch creation, and loss function implementation. The review feedback highlights two critical improvements: first, create_orpo_batches should apply chat templates and support message lists to prevent broken formatting and type errors; second, the odds-ratio calculation in make_orpo_loss_fn should be performed in float32 to avoid numerical underflow and NaN gradients during low-precision training.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +570 to +581
for ex in dataset:
if prompt_key not in ex or chosen_key not in ex or rejected_key not in ex:
raise ValueError(
f"ORPO requires '{prompt_key}', '{chosen_key}', '{rejected_key}' "
f"columns; got {sorted(ex.keys())}."
)
prompt = ex[prompt_key]
p_ids = hf.encode(prompt)
c_ids = hf.encode(prompt + ex[chosen_key])[:max_seq_length]
r_ids = hf.encode(prompt + ex[rejected_key])[:max_seq_length]
pe = min(len(p_ids), len(c_ids), len(r_ids))
rows.append((pe, c_ids, r_ids))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of create_orpo_batches directly encodes prompt and prompt + ex[chosen_key] as raw strings. This has two major issues:

  1. Broken Chat Templates / Control Tokens: It bypasses the model's chat template. For chat/instruction models, prompts and responses must be formatted with control tokens (e.g., <|im_start|>, [INST], etc.). Without this, the model trains on raw text and will fail to learn the correct chat format.
  2. Type Errors with Message Lists: Many preference datasets (e.g., UltraFeedback) store prompt and chosen/rejected as lists of message dicts (conversations). Directly doing prompt + ex[chosen_key] will raise a TypeError when trying to concatenate a list and a string, or when passing a list of dicts to hf.encode.

To fix this, we should normalize the inputs into message lists, apply the tokenizer's chat template (if available) to format them with control tokens, and then encode the resulting formatted strings.

    for ex in dataset:
        if prompt_key not in ex or chosen_key not in ex or rejected_key not in ex:
            raise ValueError(
                f"ORPO requires '{prompt_key}', '{chosen_key}', '{rejected_key}' "
                f"columns; got {sorted(ex.keys())}."
            )
        prompt = ex[prompt_key]
        chosen = ex[chosen_key]
        rejected = ex[rejected_key]

        if isinstance(prompt, str):
            prompt_messages = [{"role": "user", "content": prompt}]
        else:
            prompt_messages = prompt

        if isinstance(chosen, str):
            chosen_messages = [{"role": "assistant", "content": chosen}]
        else:
            chosen_messages = chosen

        if isinstance(rejected, str):
            rejected_messages = [{"role": "assistant", "content": rejected}]
        else:
            rejected_messages = rejected

        apply_tmpl = getattr(tokenizer, "apply_chat_template", None) or getattr(hf, "apply_chat_template", None)
        if apply_tmpl is not None:
            prompt_str = apply_tmpl(prompt_messages, tokenize=False, add_generation_prompt=True)
            chosen_str = apply_tmpl(prompt_messages + chosen_messages, tokenize=False)
            rejected_str = apply_tmpl(prompt_messages + rejected_messages, tokenize=False)
        else:
            prompt_str = prompt if isinstance(prompt, str) else ""
            chosen_str = prompt_str + (chosen if isinstance(chosen, str) else "")
            rejected_str = prompt_str + (rejected if isinstance(rejected, str) else "")

        p_ids = hf.encode(prompt_str)
        c_ids = hf.encode(chosen_str)[:max_seq_length]
        r_ids = hf.encode(rejected_str)[:max_seq_length]
        pe = min(len(p_ids), len(c_ids), len(r_ids))
        rows.append((pe, c_ids, r_ids))

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +541 to +546
# Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized.
val_c = mx.maximum(-mx.expm1(logp_c), mx.array(1e-12, logp_c.dtype))
val_r = mx.maximum(-mx.expm1(logp_r), mx.array(1e-12, logp_r.dtype))
log_odds = (logp_c - mx.log(val_c)) - (logp_r - mx.log(val_r))
or_loss = -mx.mean(nn.log_sigmoid(log_odds))
loss = sft + beta * or_loss

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In low-precision training (such as float16 or bfloat16), computing odds-ratio terms directly can lead to numerical instability. Specifically, the constant 1e-12 underflows to 0.0 in float16 (where the minimum positive subnormal is 5.96e-8). If logp_c is 0.0 (e.g., if a token has a probability of 1.0), -mx.expm1(logp_c) becomes 0.0, and mx.maximum(0.0, 1e-12) will evaluate to 0.0 due to underflow. This results in mx.log(0.0) = -inf, causing NaN gradients and training crashes.

To prevent this, perform the odds-ratio calculations in float32 and cast the final or_loss back to the original dtype.

Suggested change
# Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized.
val_c = mx.maximum(-mx.expm1(logp_c), mx.array(1e-12, logp_c.dtype))
val_r = mx.maximum(-mx.expm1(logp_r), mx.array(1e-12, logp_r.dtype))
log_odds = (logp_c - mx.log(val_c)) - (logp_r - mx.log(val_r))
or_loss = -mx.mean(nn.log_sigmoid(log_odds))
loss = sft + beta * or_loss
# Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized in float32.
logp_c_f32 = logp_c.astype(mx.float32)
logp_r_f32 = logp_r.astype(mx.float32)
val_c = mx.maximum(-mx.expm1(logp_c_f32), mx.array(1e-12, dtype=mx.float32))
val_r = mx.maximum(-mx.expm1(logp_r_f32), mx.array(1e-12, dtype=mx.float32))
log_odds = (logp_c_f32 - mx.log(val_c)) - (logp_r_f32 - mx.log(val_r))
or_loss = -mx.mean(nn.log_sigmoid(log_odds))
loss = sft + beta * or_loss.astype(sft.dtype)

Lyxot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

I ran a small cross-backend train-curve check against the CUDA/TRL path. This is not a full quality eval, and after checking the reference notebooks I would treat it as a smoke/parity probe rather than a strict apples-to-apples benchmark.

Setup:

  • PR head: 89108ff
  • Model: Qwen/Qwen3-0.6B
  • CUDA reference: import unsloth first, then TRL trainers; runtime resolves to UnslothORPOTrainer / UnslothDPOTrainer, not vanilla TRL
  • MLX: this PR via FastMLXModel, 4-bit MLX affine quantization
  • CUDA: Unsloth + TRL, load_in_4bit=True, RTX 4070 SUPER
  • Shared config: 8 train steps, batch size 2, max_seq_length=128, lr 1e-4, beta 0.1, LoRA r=8, alpha=16, seed 3407

Results summary (grad = global grad norm):

dataset method MLX loss curve MLX grad curve CUDA loss curve CUDA grad curve
arithmetic toy ORPO 1.572, 1.029, 0.349, 0.169, 0.058, 0.016, 0.000, 0.000 17.787, 13.891, 7.204, 4.311, 1.951, 0.925, 0.490, 0.382 3.582, 3.397, 2.768, 2.343, 2.298, 2.297, 2.025, 1.863 12.483, 10.201, 8.139, 11.295, 6.754, 7.254, 5.469, 5.122
arithmetic toy DPO 0.693, 0.627, 0.594, 0.491, 0.484, 0.438, 0.367, 0.360 2.502, 2.419, 1.815, 3.371, 2.282, 2.172, 2.874, 2.385 0.693, 0.640, 0.601, 0.600, 0.482, 0.473, 0.341, 0.392 4.045, 2.728, 2.574, 2.350, 2.382, 3.008, 2.311, 3.458
QA preference set ORPO 3.993, 2.098, 1.526, 1.378, 3.572, 1.723, 2.233, 0.964 47.459, 18.906, 17.048, 17.903, 36.140, 22.757, 20.603, 18.054 3.923, 3.190, 2.972, 2.634, 2.584, 2.359, 2.320, 2.055 25.142, 24.234, 11.785, 8.813, 10.179, 8.897, 6.843, 5.843
QA preference set DPO 0.693, 0.704, 0.703, 0.698, 0.683, 0.636, 0.676, 0.658 2.082, 2.982, 2.880, 2.101, 3.500, 3.325, 1.467, 2.607 0.693, 0.687, 0.667, 0.597, 0.666, 0.633, 0.678, 0.633 4.442, 2.617, 3.128, 3.031, 3.352, 3.253, 3.735, 3.742

DPO looks reasonably close: both backends start at 0.693 ~= -log(0.5), and the loss curves stay in the same range. Grad norms are not identical, but they are broadly comparable for this kind of cross-backend 4-bit run. That is a good sign for the reference/LoRA-disable direction.

ORPO still has a real parity gap. On the arithmetic toy set, MLX saturates to near-zero loss while CUDA does not. On the QA set, MLX no longer saturates, but the loss/grad shape is still noticeably different.

Confirmed differences I found:

  1. ORPO SFT/NLL semantics differ from TRL. Unsloth/TRL ORPO computes chosen_nll_loss from the chosen input_ids with the attention mask, i.e. full prompt + response NLL. This MLX implementation masks to response tokens only and uses response-normalized logp_c as the SFT term. That is the strongest confirmed reason ORPO scale/shape can differ.

  2. The synthetic harness is not a strict notebook reproduction. The reference ORPO notebook explicitly formats an Alpaca-style prompt/chosen/rejected dataset. The DPO notebook explicitly applies a chat template into text_prompt/text_chosen/text_rejected. My harness used small synthetic prompt/chosen/rejected rows, so it is useful for catching gross loss/grad issues but not enough to prove notebook-level parity.

  3. CUDA and MLX sequence construction are not identical yet. On the synthetic QA row, CUDA/TRL tokenized raw prompt/completion and appended EOS to completions; MLX currently directly encodes prompt + chosen/rejected in create_preference_batches. That is closer than a chat-template mismatch for this specific harness, but still not bit-equivalent batching.

Suggested next step before judging parity: rerun with the exact notebook-style dataset formatting, dump the first CUDA and MLX batch token IDs/masks/labels, and compare one frozen no-update loss component by component. For ORPO, align the TRL chosen-NLL/SFT term first. DPO already looks close enough that I would focus mostly on strict preprocessing/batch parity and ORPO semantics.

The MLX ORPO loss computed its SFT/NLL term from the response-only,
length-normalized log-prob (the same quantity used for the odds-ratio
term). TRL's chosen_nll_loss is instead a pooled token-mean cross-entropy
over the full prompt+response span (all non-pad chosen tokens, matching
nn.CrossEntropyLoss default reduction). The narrower span made the SFT
signal trivially easy, so ORPO saturated to ~0 loss on toy data while
CUDA did not.

Compute the NLL term over prompt+response for the chosen rows; the
odds-ratio term (response-only, length-normalized) is unchanged.
@BardiaKoopah

BardiaKoopah commented Jul 1, 2026

Copy link
Copy Markdown
Contributor Author

Fix (3f14771): compute the NLL term over prompt+response for the chosen rows (nll_mask = steps <
seq_end, pooled token-mean CE). The odds-ratio term is left exactly as-is — it was already
parity-correct (response-only, length-normalized logps; the log_odds algebra matches TRL's
odds_ratio_loss). The final loss = sft + beta * or_loss combination is unchanged.
Verified on unsloth/Qwen2.5-0.5B (4-bit) over 8 steps on a toy arithmetic preference set:

  • before: [1.20, 1.16, 0.19, 0.002, 0.002, 0.0, 0.0, 0.004] — saturates to ~0
  • after: [3.06, 3.04, 1.75, 0.66, 0.29, 0.55, 0.43, 0.35] — CUDA-like start, no collapse

DPO (step-1 = 0.6931) and SFT still train; the MLX baseline test suite is green (90 passed / 5 skipped,
3 known CCE deselects).

@Lyxot

Lyxot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Re-tested current HEAD (3f147718) against CUDA/Unsloth TRL on Qwen3-0.6B LoRA, 8 steps, same preference rows / seed / LR.

Summary RMSE vs CUDA:

loss data loss RMSE grad RMSE
ORPO arithmetic 1.170 3.808
ORPO QA 0.890 3.136
DPO arithmetic 0.048 1.364
DPO QA 0.103 2.111

ORPO arithmetic loss curve:

xychart-beta
  title "ORPO arithmetic loss"
  x-axis [1, 2, 3, 4, 5, 6, 7, 8]
  y-axis "loss" 0 --> 4.5
  line "CUDA" [4.289, 3.310, 3.050, 2.641, 2.223, 1.999, 1.670, 1.703]
  line "MLX HEAD" [2.301, 1.966, 2.207, 1.406, 1.013, 1.134, 0.866, 1.388]
Loading

DPO loss is much closer, but grad still differs. Example DPO arithmetic grad:

CUDA: [2.37, 2.70, 4.32, 4.47, 4.09, 3.55, 3.59, 1.73]
MLX: [1.94, 2.82, 2.36, 2.74, 2.38, 2.97, 1.47, 1.51]

Main remaining gaps I see:

  • preference batching does not append EOS to chosen/rejected like TRL;
  • preference batching length-sorts and does not honor dataset_order / preserve_dataset_order, so seeded CUDA/MLX runs can see different row order;
  • conversational preference rows are not passed through chat templates (prompt/chosen/rejected as messages);
  • DPO silently becomes reference-free when no LoRA modules are found, even if reference_free=False.

Tiny example of the EOS/order wiring I’d expect:

create_preference_batches(
    ...,
    append_eos=args.append_eos,
    dataset_order="sequential" if args.preserve_dataset_order else args.dataset_order,
    seed=args.seed,
)

The ORPO NLL fix definitely helped, but I think the preference data path still needs these parity fixes before the curves are reliably CUDA-comparable.

create_preference_batches unconditionally length-sorted every pair by
max(len(chosen), len(rejected)), which minimizes padding but diverges
from CUDA/TRL — HF Trainer feeds preference data via SequentialSampler
(dataset order) or a seeded RandomSampler (torch.randperm), never length
sorted. Identical-seed MLX and CUDA runs therefore saw examples in
different order from step 1, producing different batches and curves.

Add dataset_order ("default" | "sequential" | "torch_randperm") plus a
preserve_dataset_order shortcut and seed, mirroring the existing SFT/VLM
builders, and wire them at the trainer call site. "default" keeps the
historical length-sort so existing runs are byte-identical; "sequential"
(or preserve_dataset_order=True) reproduces CUDA SequentialSampler order
and "torch_randperm" reproduces RandomSampler order for parity testing.
create_preference_batches encoded prompt+chosen / prompt+rejected with no
EOS, while TRL appends the tokenizer's EOS to every completion (DPO
tokenize_row unconditionally; ORPO via add_eos_token_if_needed). Missing
EOS changes the trained token sequence, logprobs, and loss: the model
never has to predict the completion's end token, so the ORPO loss sits
systematically below the CUDA/TRL reference.

Add an append_eos flag (default True, matching TRL) that appends the EOS
id to each chosen/rejected completion, guarded against a double EOS
(mirrors add_eos_token_if_needed) and applied before max_seq_length
truncation (same append-then-truncate contract as the SFT path). The EOS
lands after the prompt boundary, inside the [response_start, seq_end)
loss span, so it is trained on as in TRL. Verified: the appended EOS has
mask==1.0 at its target position, and the untrained-model ORPO loss over
a fixed set rises (+0.256) once completions carry the end token.
…rtion)

DPO's reference is computed by disabling LoRA adapters. When a model has no
LoRA modules (full fine-tuning), the disable-adapter path has nothing to
toggle, so DPO silently fell back to reference-free — dropping the reference
term and training a different objective than requested, with no warning.

Raise a clear ValueError at loss-fn construction when there are no LoRA mods
and a reference is actually requested (not reference_free). The message points
users to a LoRA/PEFT model or to reference_free=True if they genuinely want
reference-free DPO. Legitimate reference-free configs pass through unchanged.

This is the DPO portion of the full-FT fail-loud guard; the matching GRPO KL
guard lands with the GRPO branch (make_grpo_loss_fn does not exist here). A
real full-FT reference (a frozen second model) remains a separate maintainer
design decision.
@BardiaKoopah

Copy link
Copy Markdown
Contributor Author

@Lyxot

  1. Row order is now configurable (parity opt-in). You were right that create_preference_batches unconditionally length-sorted while CUDA/TRL never does. HF Trainer._get_train_sampler feeds preference data via SequentialSampler (dataset order) or a seeded RandomSampler (torch.randperm), so identical-seed runs diverged from step 1. Added dataset_order ("default" | "sequential" | "torch_randperm") plus a preserve_dataset_order shortcut and seed, wired as you suggested (dataset_order="sequential" if args.preserve_dataset_order else args.dataset_order, seed=args.seed), mirroring the existing SFT/VLM builders.
    Behavior-change note: the default stays "default" = length-sort, so existing runs are byte-identical and only opt-in users get parity ordering. I kept parity opt-in rather than default because flipping it changes everyone's batch composition/throughput. Flagging it as arguably a maintainer call if you'd prefer parity-by-default.
    One honest finding: in an apples-to-apples 8-step harness, switching to sequential order changed the curve but did not by itself move it toward the CUDA reference, so order alone doesn't appear to be the dominant RMSE driver. Which leads to:
  2. EOS now appended to completions (the bigger lever). Confirmed the raw tokenizer emits no trailing EOS, matching your point. Added append_eos (default True, matching TRL), which appends the EOS id to each chosen/rejected completion, guarded against double-EOS (mirrors TRL's add_eos_token_if_needed; DPO's tokenize_row does it unconditionally, so the guard is the safe superset). Verified the appended EOS lands inside the [response_start, seq_end) loss span (mask==1.0 at its target position), so it's actually trained. Convergence-free check: appending EOS raises the untrained-model ORPO loss (+0.256 over a fixed set), i.e. it lifts the systematically-too-low loss level, consistent with the model now having to predict a hard end-token. DPO step-1 still 0.6931.
  3. Fail-loud on full-FT DPO reference (your pt 4), DPO portion. Full-FT DPO with a reference now raises a clear ValueError instead of silently collapsing to reference-free; reference_free=True passes through unchanged. (The matching GRPO KL guard lives on the GRPO branch, since that code doesn't exist here.)
    Baseline: canonical 5-file MLX suite gives 90 passed, 5 skipped, 3 deselected.
    Since I can't run the CUDA side here, I've verified the mechanisms (order configurable, EOS in the loss span and lifting the level, guard fires) but not the final MLX-vs-CUDA curve convergence. If you get a chance to re-run your parity harness against these commits, that'll confirm whether the RMSE actually closes. My expectation is EOS does most of the work and order mainly matters for exact seeded reproduction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants