Skip to content

Feat(mlx): GRPO support for MLX#832

Open
BardiaKoopah wants to merge 9 commits into
unslothai:mainfrom
BardiaKoopah:feat/mlx-grpo
Open

Feat(mlx): GRPO support for MLX#832
BardiaKoopah wants to merge 9 commits into
unslothai:mainfrom
BardiaKoopah:feat/mlx-grpo

Conversation

@BardiaKoopah

Copy link
Copy Markdown
Contributor

Add GRPO (reinforcement learning) for text models on MLX

Adds GRPO (Group Relative Policy Optimization) to the MLX trainer — the first RL method on the Apple Silicon path. Mirrors TRL's GRPOTrainer / GRPOConfig API.

What's included

  • MLXGRPOTrainer + MLXGRPOConfig — TRL-style classes, consistent with the existing MLXDPOTrainer / MLXORPOTrainer.
  • make_grpo_loss_fn (in utils.py) — token-level loss matching TRL's GRPOTrainer._compute_loss:
    • k3 KL estimator: exp(ref − pol) − (ref − pol) − 1
    • PPO-clip: coef_1 = exp(logp − stop_grad(logp)), coef_2 = clip(coef_1, 1−ε, 1+ε), per_token_loss = −min(coef_1·adv, coef_2·adv) + β·KL
    • masked-mean reduction over completion tokens
    • group-relative advantages: (r − mean) / (std + 1e-4) (same 1e-4 epsilon as TRL)
  • Reward-function interface matching TRLreward_funcs accepts a single callable or a list (summed); signature fn(completions, prompts=..., **kwargs) -> list[float].
  • KL reference via LoRA-disable — same mechanism as the DPO PR (zero the LoRA scales under stop_gradient, restore in finally); verified to survive mx.compile.

Rollout engine

Uses mlx-lm generation for rollouts (the group of completions per prompt). This is parity with CUDA Unsloth's non-vLLM generation path (the else branch of if use_vllm:).

vLLM-style fast rollout (via vllm-metal) is a planned follow-up. It currently has no installable distribution in the Studio Python 3.13 environment (pip install vllm-metal → "No matching distribution found"), so it is deferred as a pluggable fast-path rather than a hard dependency. The rollout is isolated in _grpo_rollout_generator, so a faster engine can slot in behind the same seam later.

Verification (Qwen2.5-0.5B, 4-bit)

  • Rollout produces diverse completions (temperature sampling).
  • Advantages correct: nonzero when rewards vary within a group, zero when uniform.
  • Loss finite, KL ≥ 0 (k3 property), reference scales restored (forward and through grad).
  • Model measurably learns — mean reward 0.709 → 0.759 over 20 steps with a graded reward (upward trend confirms gradient direction is correct).
  • Edge/stress paths tested: grad_accum > 1, multiple summed reward functions, reference_free=True end-to-end, num_generations=2, and uniform-reward stress (std=0 → graceful zero-grad, no NaN).
  • Regressions: SFT / DPO / ORPO still train; existing MLX suite 90 passed.

Scope / follow-ups

  • Text-only — VLM raises a clear guard ("GRPO is not yet supported for VLM models on MLX.").
  • vllm-metal fast rollout — when installable / benchmarked against mlx-lm.
  • num_iterations > 1 (PPO multi-epoch) — the clip structure is in place but only exercised at num_iterations = 1 (where coef_1 = 1).
  • Reward models (e.g. AutoModelForSequenceClassification) — callable reward functions only for now.
  • Smoke-tested for correctness, not quality-benchmarked against CUDA GRPO on a full task.

Note on base branch

Built on feat/mlx-orpo (reuses the LoRA-disable reference pattern from DPO). Best reviewed / merged after the ORPO/DPO PR; the diff here shows only the GRPO commit. Once ORPO/DPO lands, this rebases cleanly onto main as a single-commit PR.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for ORPO, DPO, and GRPO training to the MLX trainer, introducing new configuration classes, specialized trainers, preference batch creation helpers, and corresponding loss functions. The review feedback highlights several opportunities for improvement, including utilizing the existing iter_mlx_lora_modules helper to find LoRA modules (which removes the dependency on LoRALinear), optimizing prompt extraction for large datasets, and padding GRPO rollout batches to a multiple of 32 to avoid excessive recompilations on Apple Silicon. Additionally, the reviewer recommends specifying explicit dtypes for mx.zeros to prevent unintended float32 promotion and points out a limitation in the GRPO loss function's PPO clipping mechanism when num_iterations > 1 is used.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten
from mlx_lm.tuner.lora import LoRALinear

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since we can use iter_mlx_lora_modules to find LoRA modules, we no longer need to import LoRALinear from mlx_lm.tuner.lora.

Comment on lines +1272 to +1274
_lora_mods = [mod for _, mod in tree_flatten(
model, is_leaf=lambda x: isinstance(x, LoRALinear))
if isinstance(mod, LoRALinear)]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using tree_flatten with a custom is_leaf filter to find LoRALinear modules is less robust and introduces an unnecessary dependency on LoRALinear from mlx_lm. We can use the existing iter_mlx_lora_modules(model) helper instead, which is cleaner and more robust.

                _lora_mods = [mod for _, mod in iter_mlx_lora_modules(model)]

Comment on lines +1282 to +1284
_lora_mods = [mod for _, mod in tree_flatten(
model, is_leaf=lambda x: isinstance(x, LoRALinear))
if isinstance(mod, LoRALinear)]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using tree_flatten with a custom is_leaf filter to find LoRALinear modules is less robust and introduces an unnecessary dependency on LoRALinear from mlx_lm. We can use the existing iter_mlx_lora_modules(model) helper instead, which is cleaner and more robust.

                _lora_mods = [mod for _, mod in iter_mlx_lora_modules(model)]

Comment on lines +2426 to +2433
def _grpo_prompts(self):
"""Extract prompt strings from the dataset (expects a 'prompt' column)."""
prompts = []
for ex in self.train_dataset:
if "prompt" not in ex:
raise ValueError("GRPO requires a 'prompt' column in the dataset.")
prompts.append(ex["prompt"])
return prompts

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Iterating over the entire dataset in a Python loop to extract the 'prompt' column can be slow and memory-intensive for large Hugging Face datasets. We can check if the dataset supports column indexing (e.g., self.train_dataset['prompt']) to retrieve the column instantly.

Suggested change
def _grpo_prompts(self):
"""Extract prompt strings from the dataset (expects a 'prompt' column)."""
prompts = []
for ex in self.train_dataset:
if "prompt" not in ex:
raise ValueError("GRPO requires a 'prompt' column in the dataset.")
prompts.append(ex["prompt"])
return prompts
def _grpo_prompts(self):
"""Extract prompt strings from the dataset (expects a 'prompt' column)."""
if hasattr(self.train_dataset, "column_names") and "prompt" in self.train_dataset.column_names:
return self.train_dataset["prompt"]
prompts = []
for ex in self.train_dataset:
if "prompt" not in ex:
raise ValueError("GRPO requires a 'prompt' column in the dataset.")
prompts.append(ex["prompt"])
return prompts

Comment on lines +2471 to +2472
L = max(len(r) for r in rows)
batch = mx.array([r + [pad_id] * (L - len(r)) for r in rows], dtype=mx.int32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since the training step is compiled with mx.compile, having dynamic sequence lengths that are not aligned to a multiple of 32 can cause excessive recompilations or suboptimal Metal execution. Padding L to a multiple of 32 is highly recommended to improve performance on Apple Silicon.

Suggested change
L = max(len(r) for r in rows)
batch = mx.array([r + [pad_id] * (L - len(r)) for r in rows], dtype=mx.int32)
L = max(len(r) for r in rows)
pad_to_multiple = 32
if pad_to_multiple:
L = ((L + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple
batch = mx.array([r + [pad_id] * (L - len(r)) for r in rows], dtype=mx.int32)

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +588 to +589
ref_c = mx.zeros(pol_c.shape)
ref_r = mx.zeros(pol_r.shape)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling mx.zeros without specifying a dtype defaults to float32, which can cause low-precision tensors (like float16 or bfloat16) to be promoted to float32 during operations like subtraction or addition. Specifying dtype=pol_c.dtype prevents unintended promotion.

Suggested change
ref_c = mx.zeros(pol_c.shape)
ref_r = mx.zeros(pol_r.shape)
ref_c = mx.zeros(pol_c.shape, dtype=pol_c.dtype)
ref_r = mx.zeros(pol_r.shape, dtype=pol_r.dtype)

Comment thread unsloth_zoo/mlx/utils.py
md.scale = s
per_token_kl = mx.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1
else:
per_token_kl = mx.zeros(pol_logp.shape)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling mx.zeros without specifying a dtype defaults to float32, which can cause low-precision tensors (like float16 or bfloat16) to be promoted to float32 during operations like subtraction or addition. Specifying dtype=pol_logp.dtype prevents unintended promotion.

Suggested change
per_token_kl = mx.zeros(pol_logp.shape)
per_token_kl = mx.zeros(pol_logp.shape, dtype=pol_logp.dtype)

Comment thread unsloth_zoo/mlx/utils.py
else:
per_token_kl = mx.zeros(pol_logp.shape)

old_logp = mx.stop_gradient(pol_logp)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of old_logp = mx.stop_gradient(pol_logp) inside the loss function will always evaluate to 1 during the forward pass of every step. While this is fine for num_iterations = 1, if num_iterations > 1 is ever used, it completely disables the PPO clipping mechanism across multiple optimization epochs. To support multi-epoch PPO/GRPO, the reference/old log-probabilities of the rollout must be computed during generation and passed to the loss function.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e7f9c12a43

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/trainer.py Outdated
# rewards: sum across reward functions (TRL-style)
total = [0.0] * N
for rf in self.reward_funcs:
vals = rf(completions=comps, prompts=[prompt] * N)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Pass reward dataset columns to reward functions

For GRPO datasets where rewards need labels or metadata (for example an answer column), _grpo_prompts() has already discarded the source example and this call only passes completions and prompts. A TRL-style reward such as def reward_func(completions, answer, **kwargs) will raise or be forced to recover labels out-of-band, so common supervised GRPO tasks cannot run correctly; keep the source example and pass each non-prompt column repeated for the N generations.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py

tok_per_row = mx.maximum(mask.sum(-1), 1.0)
loss = ((per_token_loss * mask).sum(-1) / tok_per_row).mean()
return loss, mask.sum()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid token-weighting accumulated GRPO gradients

With the default gradient_accumulation_steps=4, _train_inner scales each micro-batch gradient by the returned toks before dividing by the accumulated token count. This GRPO loss is already averaged per generated row, so returning mask.sum() makes longer completions/prompts carry more optimizer weight than shorter ones instead of averaging the GRPO objective per sample/group; return a sample/group count or bypass the token-scaling path for this loss.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py
def loss_fn(model, batch, lengths, advantages):
pol_logp, mask = _per_token_logp_and_mask(model, batch, lengths)

if beta != 0.0 and not reference_free and _mods:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve a GRPO reference for non-LoRA training

When GRPO is run on a full-finetuned model, or any model without LoRALinear adapters, _mods is empty even though reference_free defaults to false and grpo_beta defaults to 0.04. This condition skips the reference forward and falls through to a zero KL term, silently disabling the regularization users requested; either materialize a frozen reference or reject non-LoRA GRPO unless reference_free=True.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py
pol, ntoks = _row_logp_and_mask(model, batch, lengths)
pol_c, pol_r = pol[:B], pol[B:]

if reference_free or not _mods:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve a DPO reference for non-LoRA training

For full_finetuning=True or models whose adapters are not LoRALinear, _mods is empty while reference_free is false by default. This branch sets both reference log-probs to zero, so DPO silently becomes reference-free DPO rather than comparing against the frozen initial policy; create a reference model or fail unless reference_free=True.

Useful? React with 👍 / 👎.

eval_batches = None
text_completion_only_loss = _text_completion_only_loss_arg(args)
if args.eval_steps > 0 and self.eval_dataset is not None:
if (getattr(args, "loss_type", "sft") in ("orpo", "dpo")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Skip GRPO eval until an eval loss exists

When loss_type='grpo' and eval_steps/eval_dataset are set, this guard skips only ORPO/DPO, so the next branch builds ordinary SFT eval batches. _evaluate_batch_totals then passes labels=None as the fourth argument to make_grpo_loss_fn, which treats it as advantages and crashes on advantages.reshape; include GRPO in the unsupported-eval guard or add a real GRPO eval path.

Useful? React with 👍 / 👎.

raise ValueError(
f"{args.loss_type.upper()} is not yet supported for VLM models on MLX."
)
batches = create_preference_batches(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor streaming for preference trainers

When ORPO/DPO is used with args.streaming=True, this new branch bypasses the existing streaming path and unconditionally calls create_preference_batches, which iterates and materializes the whole dataset before training. For iterable or very large preference datasets this can hang before the first step or exhaust memory; either provide a streaming preference iterator or reject streaming=True for these losses.

Useful? React with 👍 / 👎.

Comment on lines +2451 to +2454
resp = batch_generate(
self.model, self.tokenizer, prompts=[pids] * N,
max_tokens=args.max_completion_length, sampler=sampler, verbose=False,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Generate GRPO rollouts in eval mode

The training loop has already called model.train() before this generator runs, so batch_generate samples rollouts with dropout enabled whenever the base model or LoRA adapters use nonzero dropout. Those completions are then optimized against a different stochastic forward pass in the loss, making GRPO off-policy/noisy for dropout-enabled runs; temporarily switch to eval for rollout generation and restore train mode afterward.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/trainer.py Outdated
while True:
prompt = prompts[idx % len(prompts)]
idx += 1
pids = hf.encode(prompt)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Render GRPO prompts before encoding

When users pass prompts as chat messages or set args.chat_template/formatting_func, GRPO never goes through the base text data preparation path and encodes ex["prompt"] directly here. A conversational prompt will fail in encode, and a plain instruct prompt will be generated without the chat template or assistant generation marker, so rewards are computed for a prompt format the model was not tuned to follow; render/format prompts before encoding.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/trainer.py Outdated
pe = len(pids)
rows, lengths = [], []
for c in comps:
full = hf.encode(prompt + c)[: args.max_seq_length]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Optimize the sampled GRPO token IDs

GRPO currently takes decoded strings from batch_generate and re-encodes prompt + c for the loss. Decode/encode is not guaranteed to round-trip to the sampled token IDs (byte fallback, cleanup, BPE boundary merges, and stop handling can change the sequence), so the gradient can be computed for tokens that were not actually generated; request return_token_ids=True and build full from pids + generated_ids.

Useful? React with 👍 / 👎.

@Lyxot

Lyxot commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

I ran a stricter parity check against CUDA after checking the GRPO notebooks, because the notebook path uses chat-style prompts and reward functions that depend on dataset columns like answer.

What I matched for the fair numeric run:

  • Model: Qwen/Qwen3-0.6B
  • LoRA: r=8, alpha=16
  • 8 train steps, 32 rows, 4 generations
  • max completion length 32
  • lr 1e-4, beta 0.04, max grad norm 0.1
  • no vLLM, no dataset shuffle
  • byte-identical rendered Qwen chat prompt on MLX and CUDA

Prompt hash matched on both backends:

e5c0ecfc914e0f9fbb4d61b3cef3d3a70e00f6cef5a8ffce6549d769b0fc6652

Result summary

Backend Avg loss Avg grad norm Avg reward finite?
MLX 0.00141 2.0556 0.63397 yes
CUDA 0.00632 2.0596 0.63309 yes

CUDA had one KL/loss spike at step 7. Excluding only that spike, avg loss is MLX 0.00118 vs CUDA 0.00217.

Loss curve:

MLX : [-0.0000006, 0.000245, 0.000940, 0.001277, 0.001331, 0.001969, 0.003023, 0.002522]
CUDA: [-0.0000026, 0.000664, 0.001611, 0.001715, 0.003509, 0.002928, 0.035392, 0.004745]

Grad norm curve:

MLX : [1.896, 2.229, 1.715, 2.122, 2.358, 1.775, 2.528, 1.823]
CUDA: [1.748, 1.623, 1.288, 1.939, 1.489, 1.468, 5.166, 1.754]

Reward mean curve:

MLX : [0.20850, 0.21075, 0.20650, 0.90875, 0.40950, 0.21175, 0.21050, 2.70550]
CUDA: [0.20700, 0.21150, 0.20825, 0.90950, 0.40775, 0.21100, 0.20525, 2.70450]

Reward std curve:

MLX : [0.00602, 0.00334, 0.00427, 0.00390, 0.00320, 0.00228, 0.00180, 0.00456]
CUDA: [0.00424, 0.00173, 0.00330, 0.00645, 0.00602, 0.00535, 0.00499, 0.00191]

CUDA KL curve:

[-0.000000001, 0.01671, 0.04027, 0.04288, 0.08770, 0.07324, 0.88485, 0.11937]

What looks good

The fair run is finite on both backends. After fixing the harness so the prompt text is identical, the reward curves are almost identical. That means the main loss gap is probably not caused by prompt formatting or reward scoring in this rendered-prompt path.

Issues to address

  1. Notebook-style conversational prompts still fail on MLX before rollout.

The GRPO notebooks pass prompt as chat messages, e.g. [{"role": "system", ...}, {"role": "user", ...}]. The current MLX path appears to call hf.encode(prompt) directly, which raises a tokenizer ValueError for list-of-message prompts. This should use the chat template path, equivalent to apply_chat_template(..., add_generation_prompt=True).

  1. Reward function kwargs do not match TRL.

In the fair run, CUDA reward funcs received answer, completion_ids, question, and trainer_state. MLX reward funcs only received completions and prompts. The notebook reward functions depend on answer, so this will break normal notebook-style usage.

  1. MLX appears to train from decoded text re-encoding instead of generated token IDs.

TRL trains from the actual generated completion_ids. The MLX path appears to rebuild rows with something like hf.encode(prompt + completion_text). Re-encoding decoded text can change token boundaries, masks, logprobs, and KL/loss.

  1. Loss/KL still differs after prompt and reward are matched.

CUDA loss roughly tracks beta * KL, and the biggest CUDA loss point corresponds to a KL spike. MLX currently does not log KL, so it is hard to tell whether the gap is reference-logprob, masking, tokenization, or quantization related.

Suggested next step

Please add a fixed-rollout parity test before relying on stochastic train curves:

  • same prompt IDs
  • same generated completion IDs
  • same rewards and advantages
  • no sampling randomness
  • compare loss and grad norm MLX vs CUDA

That will isolate the GRPO loss/grad math from generation differences. After that passes, the stochastic train-curve comparison will be much easier to interpret.

The MLX ORPO loss computed its SFT/NLL term from the response-only,
length-normalized log-prob (the same quantity used for the odds-ratio
term). TRL's chosen_nll_loss is instead a pooled token-mean cross-entropy
over the full prompt+response span (all non-pad chosen tokens, matching
nn.CrossEntropyLoss default reduction). The narrower span made the SFT
signal trivially easy, so ORPO saturated to ~0 loss on toy data while
CUDA did not.

Compute the NLL term over prompt+response for the chosen rows; the
odds-ratio term (response-only, length-normalized) is unchanged.
The GRPO rollout generator called tokenizer.encode(prompt) directly, which
raises a ValueError when the prompt is a list of chat message dicts (the
format the GRPO notebooks use), so GRPO crashed before the first rollout.

Add _grpo_render_prompt: conversational prompts (a list of {role, content}
dicts) are rendered via the tokenizer's chat template, mirroring TRL's
maybe_apply_chat_template (add_generation_prompt when the last role is
'user'; continue_final_message when it is 'assistant'). Plain-string
prompts pass through unchanged. The rendered string is used for both the
prompt-id encode and the prompt+completion encode.

Verified the rendered output is byte-identical to TRL's
maybe_apply_chat_template on the same messages, so the MLX rollout prompt
matches the CUDA/TRL path by construction.
@BardiaKoopah

Copy link
Copy Markdown
Contributor Author

I ran a stricter parity check against CUDA after checking the GRPO notebooks, because the notebook path uses chat-style prompts and reward functions that depend on dataset columns like answer.

What I matched for the fair numeric run:

  • Model: Qwen/Qwen3-0.6B
  • LoRA: r=8, alpha=16
  • 8 train steps, 32 rows, 4 generations
  • max completion length 32
  • lr 1e-4, beta 0.04, max grad norm 0.1
  • no vLLM, no dataset shuffle
  • byte-identical rendered Qwen chat prompt on MLX and CUDA

Prompt hash matched on both backends:

e5c0ecfc914e0f9fbb4d61b3cef3d3a70e00f6cef5a8ffce6549d769b0fc6652

Result summary

Backend Avg loss Avg grad norm Avg reward finite?
MLX 0.00141 2.0556 0.63397 yes
CUDA 0.00632 2.0596 0.63309 yes
CUDA had one KL/loss spike at step 7. Excluding only that spike, avg loss is MLX 0.00118 vs CUDA 0.00217.

Loss curve:

MLX : [-0.0000006, 0.000245, 0.000940, 0.001277, 0.001331, 0.001969, 0.003023, 0.002522]
CUDA: [-0.0000026, 0.000664, 0.001611, 0.001715, 0.003509, 0.002928, 0.035392, 0.004745]

Grad norm curve:

MLX : [1.896, 2.229, 1.715, 2.122, 2.358, 1.775, 2.528, 1.823]
CUDA: [1.748, 1.623, 1.288, 1.939, 1.489, 1.468, 5.166, 1.754]

Reward mean curve:

MLX : [0.20850, 0.21075, 0.20650, 0.90875, 0.40950, 0.21175, 0.21050, 2.70550]
CUDA: [0.20700, 0.21150, 0.20825, 0.90950, 0.40775, 0.21100, 0.20525, 2.70450]

Reward std curve:

MLX : [0.00602, 0.00334, 0.00427, 0.00390, 0.00320, 0.00228, 0.00180, 0.00456]
CUDA: [0.00424, 0.00173, 0.00330, 0.00645, 0.00602, 0.00535, 0.00499, 0.00191]

CUDA KL curve:

[-0.000000001, 0.01671, 0.04027, 0.04288, 0.08770, 0.07324, 0.88485, 0.11937]

What looks good

The fair run is finite on both backends. After fixing the harness so the prompt text is identical, the reward curves are almost identical. That means the main loss gap is probably not caused by prompt formatting or reward scoring in this rendered-prompt path.

Issues to address

  1. Notebook-style conversational prompts still fail on MLX before rollout.

The GRPO notebooks pass prompt as chat messages, e.g. [{"role": "system", ...}, {"role": "user", ...}]. The current MLX path appears to call hf.encode(prompt) directly, which raises a tokenizer ValueError for list-of-message prompts. This should use the chat template path, equivalent to apply_chat_template(..., add_generation_prompt=True).

  1. Reward function kwargs do not match TRL.

In the fair run, CUDA reward funcs received answer, completion_ids, question, and trainer_state. MLX reward funcs only received completions and prompts. The notebook reward functions depend on answer, so this will break normal notebook-style usage.

  1. MLX appears to train from decoded text re-encoding instead of generated token IDs.

TRL trains from the actual generated completion_ids. The MLX path appears to rebuild rows with something like hf.encode(prompt + completion_text). Re-encoding decoded text can change token boundaries, masks, logprobs, and KL/loss.

  1. Loss/KL still differs after prompt and reward are matched.

CUDA loss roughly tracks beta * KL, and the biggest CUDA loss point corresponds to a KL spike. MLX currently does not log KL, so it is hard to tell whether the gap is reference-logprob, masking, tokenization, or quantization related.

Suggested next step

Please add a fixed-rollout parity test before relying on stochastic train curves:

  • same prompt IDs
  • same generated completion IDs
  • same rewards and advantages
  • no sampling randomness
  • compare loss and grad norm MLX vs CUDA

That will isolate the GRPO loss/grad math from generation differences. After that passes, the stochastic train-curve comparison will be much easier to interpret.

  1. Chat-prompt crash — fixed in f35e410. The rollout generator called tokenizer.encode(prompt) directly,
    which raises a ValueError for list-of-message prompts, so GRPO crashed on the normal notebook format
    before the first rollout. It now renders conversational prompts through the chat template
    (_grpo_render_prompt), mirroring TRL's maybe_apply_chat_template: add_generation_prompt=True when the
    last message role is user, continue_final_message=True when it's assistant. Plain-string prompts pass
    through unchanged (back-compat), and the rendered string is used for both the prompt-id encode and the
    prompt+completion encode.

Rather than reproduce a single prompt hash, I verified the stronger property: the rendered output is
byte-identical to TRL's maybe_apply_chat_template on the same messages (checked on Qwen/Qwen3-0.6B,
user-only and system+user). Since the MLX render equals the TRL render for any shared input, the rollout
prompt matches the CUDA/TRL path by construction — not just for one example. (Note: this uses the
tokenizer's default chat-template behavior on both sides, so Qwen3's enable_thinking default is applied
identically to MLX and CUDA.)

  1. Fixed-rollout parity harness — added (dev/throwaway for now). A diagnostic that isolates the loss/grad
    math from generation randomness: fixed prompt IDs, hardcoded completion IDs (no sampling), fixed rewards
    → fixed advantages, driving the real make_grpo_loss_fn. It confirms the loss is finite and deterministic
    (identical across runs), the grad is finite/nonzero with LoRA scales restored, and per-token KL ≥ 0
    (k3). It also prints per-row policy/reference log-probs so a CUDA/TRL run with identical fixed inputs can
    be diffed component-by-component. Happy to promote it into the committed test suite if you'd prefer it
    there.

Still follow-up (not in this commit):

  • Reward-func kwargs parity — MLX currently passes only completions/prompts; TRL also passes dataset
    columns like answer (plus completion_ids, etc.). Needed for notebook-style supervised GRPO rewards.
  • Token-ID rollout — MLX re-encodes prompt + completion_text instead of using the sampled completion_ids,
    which can shift token boundaries/masks/logprobs. The parity harness above is the tool to confirm whether
    this re-encoding is the real source of the remaining MLX-vs-CUDA loss gap before investing in
    return_token_ids=True.
  • KL logging — surfacing per-step KL so loss ≈ beta·KL can be checked directly against CUDA.

Now that prompt rendering and reward scoring match on the rendered-prompt path (your reward curves line
up), these three are the remaining suspects for the loss delta, and the fixed-rollout harness should let
us pin it down component-by-component.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f35e410869

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +702 to +705
prompt = ex[prompt_key]
p_ids = hf.encode(prompt)
c_ids = hf.encode(prompt + ex[chosen_key])[:max_seq_length]
r_ids = hf.encode(prompt + ex[rejected_key])[:max_seq_length]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Render preference chat examples before encoding

When an ORPO/DPO dataset uses conversational prompt/chosen/rejected columns, this helper sends the raw list of messages to hf.encode and concatenates lists before encoding, so those datasets fail instead of being chat-templated; raw string prompts also ignore args.chat_template unlike the SFT path. Render the preference example with the tokenizer/template before computing p_ids, c_ids, and r_ids.

Useful? React with 👍 / 👎.

Comment on lines +2482 to +2484
resp = batch_generate(
self.model, self.tokenizer, prompts=[pids] * N,
max_tokens=args.max_completion_length, sampler=sampler, verbose=False,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Batch GRPO rollouts over all requested prompts

For any GRPO run with per_device_train_batch_size > 1 (including the inherited default of 2), this repeats one selected prompt N times and advances only one dataset example per micro-batch. The trainer logs and schedules as if the configured batch size were used, but the loss/advantages are computed for a single prompt group, silently reducing the effective prompt batch by a factor of per_device_train_batch_size; collect that many prompts and repeat each by num_generations.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +608 to +610
logits = beta * ((pol_c - ref_c) - (pol_r - ref_r))
loss = -mx.mean(nn.log_sigmoid(logits))
return loss, ntoks

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Accumulate preference gradients by pairs

With gradient_accumulation_steps > 1 (default 4), step_fn multiplies each micro-batch gradient by the returned ntoks and later divides by the accumulated token count. DPO/ORPO losses are already averaged over preference pairs, so returning response-token counts makes micro-batches with longer chosen/rejected responses dominate the optimizer step; return the number of pairs or bypass token-weighted accumulation for these losses.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +727 to +728
if num_batches is not None and len(out) >= num_batches:
break

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid capping preference batches after sorting

When max_steps is finite and the preference dataset has more than max_steps * gradient_accumulation_steps batches, this break is applied after rows has been sorted by length, so training keeps only the shortest pairs and never sees the rest of the dataset. Shuffle/sample before length bucketing, or build all batches and let the training loop cycle through them, so finite-step DPO/ORPO runs are not biased to a length-sorted prefix.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +706 to +707
pe = min(len(p_ids), len(c_ids), len(r_ids))
rows.append((pe, c_ids, r_ids))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve completion tokens during preference truncation

When a prompt reaches max_seq_length, or either response is truncated down to prompt-only, pe becomes equal to that row's sequence length and the response mask is empty. DPO then uses a zero log-prob for that side and ORPO optimizes prompt-only SFT/zero odds, silently corrupting long-prompt preference examples; drop these rows or truncate the prompt to leave at least one chosen and rejected completion token.

Useful? React with 👍 / 👎.

Comment on lines +2429 to +2432
for ex in self.train_dataset:
if "prompt" not in ex:
raise ValueError("GRPO requires a 'prompt' column in the dataset.")
prompts.append(ex["prompt"])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor streaming GRPO prompt datasets

When GRPO is used with an iterable/streaming prompt dataset, this loop drains the entire dataset into a list before the first rollout, so an infinite stream hangs and a very large stream can exhaust memory despite MLXTrainingConfig.streaming existing. Iterate prompts lazily (or reject streaming GRPO explicitly) instead of materializing all prompts up front.

Useful? React with 👍 / 👎.

The GRPO rollout only passed completions/prompts to reward functions, while
TRL passes every dataset column (except prompt/completion/completion_ids)
through as kwargs. Notebook reward funcs that read columns like 'answer'
therefore broke on the MLX path.

Build the full dataset rows alongside prompts and pass the row's other
columns as reward kwargs, each repeated per generation to align row-for-row
with completions, mirroring GRPOTrainer._calculate_rewards. The prompt and
the example are indexed by the same cycled index so a reward func reading
'answer' always gets the scored prompt's row (verified across index wrap).

completion_ids and trainer_state are documented as not-yet-available rather
than faked: mlx_lm batch_generate does not surface generated token IDs, and
there is no transformers TrainerState on the MLX path.
make_grpo_loss_fn computes per_token_kl but only returns (loss, ntoks), and
the GRPO step is mx.compile'd — so KL could not be surfaced by changing the
loss-fn return (breaks the shared compiled step) or via a closure holder
(a Python write inside a compiled fn runs only at trace time, freezing the
value). Instead add an eager probe outside the compiled step.

_grpo_mean_kl computes masked-mean k3 KL (policy logp with adapters,
reference logp via LoRA-disable) at the CURRENT weights, in eval mode so
dropout draws no RNG and neither the trajectory nor the compiled step is
perturbed. The loop runs it only for steps that will be logged, on the same
batch/lengths the loss uses, and BEFORE the optimizer update, so the KL
reflects the step's pre-update weights and is comparable to CUDA's logged
KL. KL is stored in _kl_history and printed on the step line; returns None
(nothing logged) when KL is undefined (no LoRA, beta==0, reference_free).
…ollapse

DPO's reference and GRPO's KL term are computed by disabling LoRA adapters.
When a model has no LoRA modules (full fine-tuning), the disable-adapter
path has nothing to toggle, so both silently fell back to reference-free:
DPO dropped the reference term and GRPO dropped the KL term with no warning,
giving the user a different objective than requested.

Raise a clear ValueError at loss-fn construction when there are no LoRA mods
and a reference/KL is actually requested: DPO when not reference_free, GRPO
when not reference_free and beta != 0. The messages point users to a
LoRA/PEFT model or to reference_free=True / grpo_beta=0 if they genuinely
want reference-free training. Legitimate reference-free configs are allowed
through unchanged.

This is the minimum fail-loud fix; a real full-FT reference (a frozen second
model) is left as a separate maintainer design decision.
@BardiaKoopah

Copy link
Copy Markdown
Contributor Author

@Lyxot

  1. Reward-function kwargs parity (abff6ed). MLX now passes every dataset column (except
    prompt/completion/completion_ids) through to reward functions as kwargs, each repeated per generation to
    align row-for-row with completions — mirroring GRPOTrainer._calculate_rewards. So notebook rewards
    reading answer (etc.) work now. The prompt and its dataset row are indexed by the same cycled index, so a
    reward func always gets the scored prompt's row (verified across index wrap-around). completion_ids and
    trainer_state are not faked: mlx_lm.batch_generate doesn't surface generated token IDs, and there's no
    transformers.TrainerState on the MLX path — documented as not-yet-available (the token-ID part is tied to
    Fix longest common substring implementation #4 below).

  2. GRPO KL logging (b3ffc75). KL is now logged per step. Since the GRPO step is mx.compiled, KL couldn't
    be surfaced by changing the loss-fn return (breaks the shared compiled step) or via a closure (a write
    inside a compiled fn only runs at trace time). Instead there's an eager probe (_grpo_mean_kl) computed
    before the optimizer update, on the same batch the loss used, so it reflects the step's pre-update
    weights and is comparable to CUDA's logged KL. It runs in eval mode (no dropout RNG draw, so it doesn't
    perturb the trajectory or the compiled step) and only on logged steps. Sanity check on a moving policy:
    step-1 KL is exactly 0 (adapters still zero pre-update) and grows to a nonzero, varying curve after —
    confirming it measures real divergence, not a constant.

  3. Full-FT reference now fails loud (65b816a). Previously, on a model with no LoRA adapters (full
    fine-tuning), DPO's reference and GRPO's KL silently collapsed to reference-free because the
    disable-adapter path had nothing to toggle. Both now raise a clear ValueError at construction when a
    reference/KL is actually requested (DPO: not reference_free; GRPO: not reference_free and beta != 0),
    pointing users to a LoRA model or to reference_free=True / grpo_beta=0. Legitimate reference-free configs
    pass through unchanged. This is the minimal fail-loud fix; a real full-FT reference (frozen second
    model) is left as a separate design decision.

Still follow-up:

  • Fix longest common substring implementation #4 token-ID rollout — training still re-encodes prompt + completion_text rather than using the sampled
    completion_ids. This needs a CUDA-side parity run to confirm it's the source of the remaining loss delta
    before changing it; the fixed-rollout harness is set up to isolate exactly that, component-by-component.
    (Providing completion_ids to reward funcs is blocked on the same thing.)
  • Could loose trl version dependency? #7 DPO/ORPO batch/preprocessing parity — separate PR; needs your exact dataset formatting to compare.

Verification for all three: the specific behavior confirmed on Qwen2.5-0.5B (4-bit), SFT/DPO/ORPO/GRPO
regressions unaffected (LoRA DPO step-1 = 0.6931), and the MLX baseline suite green (90 passed).

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 65b816af19

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +719 to +722
p_ids = hf.encode(prompt)
c_ids = hf.encode(prompt + ex[chosen_key])[:max_seq_length]
r_ids = hf.encode(prompt + ex[rejected_key])[:max_seq_length]
pe = min(len(p_ids), len(c_ids), len(r_ids))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Tokenize preference completions as continuations

For standard preference datasets whose chosen/rejected strings do not already include EOS, these full-string encodes never add a supervised stop token; and for tokenizers that merge the last prompt token with the first answer token, len(p_ids) can also point past the first completion token. In those cases DPO/ORPO compute log-probs over the wrong completion span and omit termination training, unlike the SFT path which explicitly appends EOS when absent; build the completion continuation from the combined encoding, adjust the response start when the prompt prefix differs, and append EOS before truncation.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +708 to +709
hf = getattr(tokenizer, "_tokenizer", tokenizer)
pad_id = hf.eos_token_id if hf.eos_token_id is not None else 0

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep HF tokenizers wrapped for preference batching

When ORPO/DPO callers pass a regular Hugging Face fast tokenizer, this unwraps it to the low-level Rust tokenizer stored in _tokenizer; that object does not expose eos_token_id and its encode() returns Encoding objects rather than token-id lists, so create_preference_batches() fails before training. This repo already special-cases that distinction in _resolve_response_mask_tokenizer, so keep the public HF tokenizer unless the object is an mlx-lm wrapper.

Useful? React with 👍 / 👎.

}
total = [0.0] * N
for rf in self.reward_funcs:
vals = rf(completions=comps, prompts=[prompt] * N, **reward_kwargs)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Wrap conversational GRPO completions for rewards

When the dataset prompt is conversational (the renderer above accepts a list of message dicts), TRL-style chat reward functions expect completions to have the same conversational shape, for example completion[0]["content"]. This passes raw generated strings instead, so common format rewards for chat data either index characters or raise before advantages are computed; wrap each generated text as an assistant message when the source prompt is conversational.

Useful? React with 👍 / 👎.

Comment on lines +2595 to +2597
full = hf.encode(rendered + c)[: args.max_seq_length]
rows.append(full)
lengths.append([pe, len(full)])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve GRPO completion tokens after truncation

If a rendered GRPO prompt is already at or above max_seq_length, this truncates prompt + completion back to prompt-only while pe still points past the truncated row. The completion mask is then empty for those samples, yielding zero reward-gradient rows or a zero-token batch error, so long-prompt GRPO runs cannot train correctly; truncate the prompt to leave completion budget or drop such examples before building lengths.

Useful? React with 👍 / 👎.

Comment on lines +2587 to +2588
for i, v in enumerate(vals):
total[i] += float(v)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Skip None GRPO reward values

TRL-style custom reward functions may return None for samples where that reward is not applicable, but this loop unconditionally casts every value to float. Multi-task GRPO runs that rely on that behavior will abort as soon as one reward returns None; ignore those entries (and handle the all-None case) instead of summing them.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants