-
Notifications
You must be signed in to change notification settings - Fork 289
Feat(mlx): Add ORPO (loss_type='orpo') for text models to MLXTrainer #830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
599f051
2e3c22e
89108ff
3f14771
72bb6c3
203919e
7eda394
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -502,6 +502,217 @@ def loss_fn(model, batch, lengths, labels=None): | |
| return loss_fn | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # ORPO (Odds Ratio Preference Optimization) — text models. | ||
| # Mirrors TRL/CUDA's concatenated-forward: chosen and rejected are stacked into | ||
| # one batch (chosen block then rejected block) and run through a single forward, | ||
| # then split. Loss follows the ORPO paper: L = L_SFT + beta * L_OR, where | ||
| # L_OR = -log(sigmoid(log_odds_chosen - log_odds_rejected)) and the odds use the | ||
| # full p/(1-p) form (not a simplified log-prob ratio). Built on this module's | ||
| # own length-mask convention (see make_baseline_loss_fn), not ported from TRL | ||
| # (no TRL on MLX) or copied from third-party MLX projects. | ||
| # ============================================================================ | ||
| def make_orpo_loss_fn(beta=0.1): | ||
| """Create an ORPO loss function over a concatenated [chosen; rejected] batch. | ||
|
|
||
| Signature matches make_baseline_loss_fn: (model, batch, lengths, labels=None) | ||
| -> (loss, ntoks), so it drops into the trainer's text step path unchanged. | ||
|
|
||
| Expects ``batch`` shape (2B, L) where rows [0:B] are chosen and rows [B:2B] | ||
| are rejected, paired by index (produced by ``create_preference_batches``). | ||
| ``lengths`` is (2B, 2) with per-row [response_start, seq_end). The odds-ratio | ||
| term scores response tokens only; the SFT/NLL term (TRL chosen_nll_loss) is | ||
| computed over the full prompt+response span (all non-pad chosen tokens). | ||
| """ | ||
| def loss_fn(model, batch, lengths, labels=None): | ||
| inputs = batch[:, :-1] | ||
| targets = batch[:, 1:] | ||
| logits = model(inputs) | ||
| steps = mx.arange(1, targets.shape[1] + 1) | ||
| ce_tok = nn.losses.cross_entropy(logits, targets) | ||
| mask = mx.logical_and( | ||
| steps >= lengths[:, 0:1], steps < lengths[:, 1:] | ||
| ).astype(mx.float32) | ||
| logp_tok = -ce_tok * mask | ||
| ntok_row = mask.sum(axis=1) | ||
| logp = logp_tok.sum(axis=1) / mx.maximum(ntok_row, mx.array(1.0)) | ||
| B = batch.shape[0] // 2 | ||
| logp_c, logp_r = logp[:B], logp[B:] | ||
| # SFT term: TRL chosen_nll_loss. Pooled token-mean cross-entropy over the | ||
| # full prompt+response span (all non-pad chosen tokens), matching TRL's | ||
| # nn.CrossEntropyLoss default reduction. This is NOT the response-only, | ||
| # length-normalized logp used for the odds-ratio term below. | ||
| nll_mask = (steps < lengths[:, 1:]).astype(mx.float32)[:B] | ||
| sft = (ce_tok[:B] * nll_mask).sum() / mx.maximum(nll_mask.sum(), mx.array(1.0)) | ||
| # Odds-ratio term. log(p/(1-p)) per side; 1-exp(logp) stabilized. | ||
| val_c = mx.maximum(-mx.expm1(logp_c), mx.array(1e-12, logp_c.dtype)) | ||
| val_r = mx.maximum(-mx.expm1(logp_r), mx.array(1e-12, logp_r.dtype)) | ||
| log_odds = (logp_c - mx.log(val_c)) - (logp_r - mx.log(val_r)) | ||
| or_loss = -mx.mean(nn.log_sigmoid(log_odds)) | ||
| loss = sft + beta * or_loss | ||
| return loss, mask.sum() | ||
| return loss_fn | ||
|
|
||
|
|
||
| def make_dpo_loss_fn(beta=0.1, lora_mods=None, reference_free=False): | ||
| """Create a DPO (Direct Preference Optimization) loss function. | ||
|
|
||
| Operates on a concatenated [chosen; rejected] batch (same layout as | ||
| make_orpo_loss_fn / create_preference_batches): rows [0:B] chosen, | ||
| [B:2B] rejected. Signature matches the other loss fns: | ||
| (model, batch, lengths, labels=None) -> (loss, ntoks). | ||
|
|
||
| DPO compares the policy's per-response log-probs against a frozen | ||
| reference. For LoRA models the reference is the base model: obtained by | ||
| temporarily zeroing the LoRA scales (adapters off), running the reference | ||
| forward under stop_gradient, then restoring the scales in a finally block. | ||
| Mirrors TRL's disable-adapter approach (no second model copy). With | ||
| reference_free=True the reference term is dropped, matching TRL. | ||
|
|
||
| ``lora_mods`` is the list of LoRALinear modules to toggle; collected once | ||
| by the trainer at setup. | ||
| """ | ||
| _mods = list(lora_mods) if lora_mods is not None else [] | ||
| if not _mods and not reference_free: | ||
| raise ValueError( | ||
| "Unsloth: DPO with a reference model is not yet supported for full " | ||
| "fine-tuning on MLX — the reference is obtained by disabling LoRA " | ||
| "adapters, but this model has none. Use a LoRA/PEFT model, or pass " | ||
| "reference_free=True to train without a reference (TRL-style " | ||
| "reference-free DPO)." | ||
| ) | ||
|
|
||
| def _row_logp_and_mask(model, batch, lengths): | ||
| inputs = batch[:, :-1] | ||
| targets = batch[:, 1:] | ||
| logits = model(inputs) | ||
| steps = mx.arange(1, targets.shape[1] + 1) | ||
| mask = mx.logical_and( | ||
| steps >= lengths[:, 0:1], steps < lengths[:, 1:] | ||
| ).astype(mx.float32) | ||
| logp_tok = -nn.losses.cross_entropy(logits, targets) * mask | ||
| return logp_tok.sum(axis=1), mask.sum() | ||
|
|
||
| def loss_fn(model, batch, lengths, labels=None): | ||
| B = batch.shape[0] // 2 | ||
| pol, ntoks = _row_logp_and_mask(model, batch, lengths) | ||
| pol_c, pol_r = pol[:B], pol[B:] | ||
|
|
||
| if reference_free or not _mods: | ||
| ref_c = mx.zeros(pol_c.shape) | ||
| ref_r = mx.zeros(pol_r.shape) | ||
| else: | ||
| saved = [md.scale for md in _mods] | ||
| try: | ||
| for md in _mods: | ||
| md.scale = 0.0 | ||
| ref, _ = _row_logp_and_mask(model, batch, lengths) | ||
| ref = mx.stop_gradient(ref) | ||
| finally: | ||
| for md, s in zip(_mods, saved): | ||
| md.scale = s | ||
| ref_c, ref_r = ref[:B], ref[B:] | ||
|
|
||
| logits = beta * ((pol_c - ref_c) - (pol_r - ref_r)) | ||
| loss = -mx.mean(nn.log_sigmoid(logits)) | ||
| return loss, ntoks | ||
| return loss_fn | ||
|
|
||
| def create_preference_batches(dataset, tokenizer, batch_size, max_seq_length, | ||
| prompt_key="prompt", chosen_key="chosen", | ||
| rejected_key="rejected", pad_to_multiple=32, | ||
| num_batches=None, dataset_order="default", | ||
| preserve_dataset_order=False, seed=None, | ||
| append_eos=True): | ||
| """Build concatenated [chosen; rejected] preference batches for ORPO/DPO. | ||
|
|
||
| Each example contributes ``prompt + chosen`` and ``prompt + rejected``. | ||
| Pairs are grouped into batches of ``batch_size`` PAIRS, and every row in a | ||
| batch is padded to that batch's max length, rounded up to | ||
| ``pad_to_multiple`` (Apple-Silicon padding). | ||
|
|
||
| ``dataset_order`` controls how pairs are ordered before batching (mirrors | ||
| the SFT/VLM builders so preference runs can match CUDA/TRL parity): | ||
| "default" length-sort by ``max(len(chosen), len(rejected))`` — least | ||
| padding / best throughput; the historical behavior. | ||
| "sequential" keep dataset order — matches CUDA ``SequentialSampler``. | ||
| "torch_randperm" seeded permutation — matches CUDA ``RandomSampler``. | ||
| ``preserve_dataset_order=True`` forces "sequential" (Studio wiring). | ||
| ``seed`` seeds the "torch_randperm" order. | ||
|
|
||
| ``append_eos`` (default True, matching TRL) appends the tokenizer's EOS id | ||
| to each chosen/rejected completion, guarded to avoid a double EOS (mirrors | ||
| TRL's ``add_eos_token_if_needed``). EOS is appended before truncation to | ||
| ``max_seq_length`` (same contract as the SFT path). Because EOS lands after | ||
| the prompt boundary it falls inside the ``[response_start, seq_end)`` loss | ||
| span, so it is trained on — as in TRL. | ||
|
|
||
| Returns a list of ``(batch, lengths, None)`` tuples: | ||
| batch: (2B, L) int32 — rows [0:B] chosen, [B:2B] rejected, paired by index | ||
| lengths: (2B, 2) — per row [response_start, seq_end) | ||
| """ | ||
| order_mode = "sequential" if preserve_dataset_order else dataset_order | ||
| if order_mode not in ("default", "sequential", "torch_randperm"): | ||
| raise ValueError( | ||
| f"Unsloth MLX: unsupported preference dataset_order={order_mode!r}. " | ||
| "Expected 'default', 'sequential', or 'torch_randperm'." | ||
| ) | ||
| hf = getattr(tokenizer, "_tokenizer", tokenizer) | ||
| eos_id = hf.eos_token_id | ||
| pad_id = eos_id if eos_id is not None else 0 | ||
|
|
||
| def _with_eos(ids): | ||
| # TRL appends EOS to the completion, avoiding a double EOS | ||
| # (add_eos_token_if_needed). Append before max_seq_length truncation, | ||
| # matching the SFT path's append-then-truncate contract. | ||
| if append_eos and eos_id is not None and (not ids or ids[-1] != eos_id): | ||
| ids = ids + [eos_id] | ||
| return ids[:max_seq_length] | ||
|
|
||
| rows = [] | ||
| for ex in dataset: | ||
| if prompt_key not in ex or chosen_key not in ex or rejected_key not in ex: | ||
| raise ValueError( | ||
| f"ORPO requires '{prompt_key}', '{chosen_key}', '{rejected_key}' " | ||
| f"columns; got {sorted(ex.keys())}." | ||
| ) | ||
| prompt = ex[prompt_key] | ||
| p_ids = hf.encode(prompt) | ||
| c_ids = _with_eos(hf.encode(prompt + ex[chosen_key])) | ||
| r_ids = _with_eos(hf.encode(prompt + ex[rejected_key])) | ||
| pe = min(len(p_ids), len(c_ids), len(r_ids)) | ||
| rows.append((pe, c_ids, r_ids)) | ||
|
Comment on lines
+673
to
+684
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of create_orpo_batches directly encodes prompt and prompt + ex[chosen_key] as raw strings. This has two major issues:
To fix this, we should normalize the inputs into message lists, apply the tokenizer's chat template (if available) to format them with control tokens, and then encode the resulting formatted strings. for ex in dataset:
if prompt_key not in ex or chosen_key not in ex or rejected_key not in ex:
raise ValueError(
f"ORPO requires '{prompt_key}', '{chosen_key}', '{rejected_key}' "
f"columns; got {sorted(ex.keys())}."
)
prompt = ex[prompt_key]
chosen = ex[chosen_key]
rejected = ex[rejected_key]
if isinstance(prompt, str):
prompt_messages = [{"role": "user", "content": prompt}]
else:
prompt_messages = prompt
if isinstance(chosen, str):
chosen_messages = [{"role": "assistant", "content": chosen}]
else:
chosen_messages = chosen
if isinstance(rejected, str):
rejected_messages = [{"role": "assistant", "content": rejected}]
else:
rejected_messages = rejected
apply_tmpl = getattr(tokenizer, "apply_chat_template", None) or getattr(hf, "apply_chat_template", None)
if apply_tmpl is not None:
prompt_str = apply_tmpl(prompt_messages, tokenize=False, add_generation_prompt=True)
chosen_str = apply_tmpl(prompt_messages + chosen_messages, tokenize=False)
rejected_str = apply_tmpl(prompt_messages + rejected_messages, tokenize=False)
else:
prompt_str = prompt if isinstance(prompt, str) else ""
chosen_str = prompt_str + (chosen if isinstance(chosen, str) else "")
rejected_str = prompt_str + (rejected if isinstance(rejected, str) else "")
p_ids = hf.encode(prompt_str)
c_ids = hf.encode(chosen_str)[:max_seq_length]
r_ids = hf.encode(rejected_str)[:max_seq_length]
pe = min(len(p_ids), len(c_ids), len(r_ids))
rows.append((pe, c_ids, r_ids)) |
||
|
|
||
| # rows are collected in dataset order above; reorder per the requested mode. | ||
| if order_mode == "default": | ||
| rows.sort(key=lambda t: max(len(t[1]), len(t[2]))) | ||
| elif order_mode == "torch_randperm": | ||
| order = _torch_randperm_order(len(rows), _normalize_seed(seed)) | ||
| rows = [rows[i] for i in order] | ||
| # "sequential": leave rows in dataset order (CUDA SequentialSampler parity). | ||
|
|
||
| out = [] | ||
| for i in range(0, len(rows), batch_size): | ||
| chunk = rows[i:i + batch_size] | ||
| Lmax = max(max(len(c), len(r)) for _, c, r in chunk) | ||
| if pad_to_multiple: | ||
| Lmax = ((Lmax + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple | ||
| chosen_rows, rejected_rows, lengths = [], [], [] | ||
| for pe, c, r in chunk: | ||
| chosen_rows.append(c + [pad_id] * (Lmax - len(c))) | ||
| lengths.append([pe, len(c)]) | ||
| for pe, c, r in chunk: | ||
| rejected_rows.append(r + [pad_id] * (Lmax - len(r))) | ||
| lengths.append([pe, len(r)]) | ||
| batch = mx.array(chosen_rows + rejected_rows, dtype=mx.int32) | ||
| lengths_arr = mx.array(lengths) | ||
| out.append((batch, lengths_arr, None)) | ||
| if num_batches is not None and len(out) >= num_batches: | ||
| break | ||
| mx.eval([b for b, _, _ in out] + [l for _, l, _ in out]) | ||
| return out | ||
|
|
||
|
|
||
| def make_baseline_loss_fn(): | ||
| """Create a standard cross-entropy loss function (full logits via LM head). | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In low-precision training (such as float16 or bfloat16), computing odds-ratio terms directly can lead to numerical instability. Specifically, the constant 1e-12 underflows to 0.0 in float16 (where the minimum positive subnormal is 5.96e-8). If logp_c is 0.0 (e.g., if a token has a probability of 1.0), -mx.expm1(logp_c) becomes 0.0, and mx.maximum(0.0, 1e-12) will evaluate to 0.0 due to underflow. This results in mx.log(0.0) = -inf, causing NaN gradients and training crashes.
To prevent this, perform the odds-ratio calculations in float32 and cast the final or_loss back to the original dtype.