Feat(mlx): GRPO support for MLX#832
Conversation
…ference collator with ORPO
…type; flag kept for back-compat
There was a problem hiding this comment.
Code Review
This pull request adds support for ORPO, DPO, and GRPO training to the MLX trainer, introducing new configuration classes, specialized trainers, preference batch creation helpers, and corresponding loss functions. The review feedback highlights several opportunities for improvement, including utilizing the existing iter_mlx_lora_modules helper to find LoRA modules (which removes the dependency on LoRALinear), optimizing prompt extraction for large datasets, and padding GRPO rollout batches to a multiple of 32 to avoid excessive recompilations on Apple Silicon. Additionally, the reviewer recommends specifying explicit dtypes for mx.zeros to prevent unintended float32 promotion and points out a limitation in the GRPO loss function's PPO clipping mechanism when num_iterations > 1 is used.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| import mlx.nn as nn | ||
| import mlx.optimizers as optim | ||
| from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten | ||
| from mlx_lm.tuner.lora import LoRALinear |
| _lora_mods = [mod for _, mod in tree_flatten( | ||
| model, is_leaf=lambda x: isinstance(x, LoRALinear)) | ||
| if isinstance(mod, LoRALinear)] |
There was a problem hiding this comment.
Using tree_flatten with a custom is_leaf filter to find LoRALinear modules is less robust and introduces an unnecessary dependency on LoRALinear from mlx_lm. We can use the existing iter_mlx_lora_modules(model) helper instead, which is cleaner and more robust.
_lora_mods = [mod for _, mod in iter_mlx_lora_modules(model)]| _lora_mods = [mod for _, mod in tree_flatten( | ||
| model, is_leaf=lambda x: isinstance(x, LoRALinear)) | ||
| if isinstance(mod, LoRALinear)] |
There was a problem hiding this comment.
Using tree_flatten with a custom is_leaf filter to find LoRALinear modules is less robust and introduces an unnecessary dependency on LoRALinear from mlx_lm. We can use the existing iter_mlx_lora_modules(model) helper instead, which is cleaner and more robust.
_lora_mods = [mod for _, mod in iter_mlx_lora_modules(model)]| def _grpo_prompts(self): | ||
| """Extract prompt strings from the dataset (expects a 'prompt' column).""" | ||
| prompts = [] | ||
| for ex in self.train_dataset: | ||
| if "prompt" not in ex: | ||
| raise ValueError("GRPO requires a 'prompt' column in the dataset.") | ||
| prompts.append(ex["prompt"]) | ||
| return prompts |
There was a problem hiding this comment.
Iterating over the entire dataset in a Python loop to extract the 'prompt' column can be slow and memory-intensive for large Hugging Face datasets. We can check if the dataset supports column indexing (e.g., self.train_dataset['prompt']) to retrieve the column instantly.
| def _grpo_prompts(self): | |
| """Extract prompt strings from the dataset (expects a 'prompt' column).""" | |
| prompts = [] | |
| for ex in self.train_dataset: | |
| if "prompt" not in ex: | |
| raise ValueError("GRPO requires a 'prompt' column in the dataset.") | |
| prompts.append(ex["prompt"]) | |
| return prompts | |
| def _grpo_prompts(self): | |
| """Extract prompt strings from the dataset (expects a 'prompt' column).""" | |
| if hasattr(self.train_dataset, "column_names") and "prompt" in self.train_dataset.column_names: | |
| return self.train_dataset["prompt"] | |
| prompts = [] | |
| for ex in self.train_dataset: | |
| if "prompt" not in ex: | |
| raise ValueError("GRPO requires a 'prompt' column in the dataset.") | |
| prompts.append(ex["prompt"]) | |
| return prompts |
| L = max(len(r) for r in rows) | ||
| batch = mx.array([r + [pad_id] * (L - len(r)) for r in rows], dtype=mx.int32) |
There was a problem hiding this comment.
Since the training step is compiled with mx.compile, having dynamic sequence lengths that are not aligned to a multiple of 32 can cause excessive recompilations or suboptimal Metal execution. Padding L to a multiple of 32 is highly recommended to improve performance on Apple Silicon.
| L = max(len(r) for r in rows) | |
| batch = mx.array([r + [pad_id] * (L - len(r)) for r in rows], dtype=mx.int32) | |
| L = max(len(r) for r in rows) | |
| pad_to_multiple = 32 | |
| if pad_to_multiple: | |
| L = ((L + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple | |
| batch = mx.array([r + [pad_id] * (L - len(r)) for r in rows], dtype=mx.int32) |
| ref_c = mx.zeros(pol_c.shape) | ||
| ref_r = mx.zeros(pol_r.shape) |
There was a problem hiding this comment.
Calling mx.zeros without specifying a dtype defaults to float32, which can cause low-precision tensors (like float16 or bfloat16) to be promoted to float32 during operations like subtraction or addition. Specifying dtype=pol_c.dtype prevents unintended promotion.
| ref_c = mx.zeros(pol_c.shape) | |
| ref_r = mx.zeros(pol_r.shape) | |
| ref_c = mx.zeros(pol_c.shape, dtype=pol_c.dtype) | |
| ref_r = mx.zeros(pol_r.shape, dtype=pol_r.dtype) |
| md.scale = s | ||
| per_token_kl = mx.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1 | ||
| else: | ||
| per_token_kl = mx.zeros(pol_logp.shape) |
There was a problem hiding this comment.
Calling mx.zeros without specifying a dtype defaults to float32, which can cause low-precision tensors (like float16 or bfloat16) to be promoted to float32 during operations like subtraction or addition. Specifying dtype=pol_logp.dtype prevents unintended promotion.
| per_token_kl = mx.zeros(pol_logp.shape) | |
| per_token_kl = mx.zeros(pol_logp.shape, dtype=pol_logp.dtype) |
| else: | ||
| per_token_kl = mx.zeros(pol_logp.shape) | ||
|
|
||
| old_logp = mx.stop_gradient(pol_logp) |
There was a problem hiding this comment.
The current implementation of old_logp = mx.stop_gradient(pol_logp) inside the loss function will always evaluate to 1 during the forward pass of every step. While this is fine for num_iterations = 1, if num_iterations > 1 is ever used, it completely disables the PPO clipping mechanism across multiple optimization epochs. To support multi-epoch PPO/GRPO, the reference/old log-probabilities of the rollout must be computed during generation and passed to the loss function.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e7f9c12a43
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| # rewards: sum across reward functions (TRL-style) | ||
| total = [0.0] * N | ||
| for rf in self.reward_funcs: | ||
| vals = rf(completions=comps, prompts=[prompt] * N) |
There was a problem hiding this comment.
Pass reward dataset columns to reward functions
For GRPO datasets where rewards need labels or metadata (for example an answer column), _grpo_prompts() has already discarded the source example and this call only passes completions and prompts. A TRL-style reward such as def reward_func(completions, answer, **kwargs) will raise or be forced to recover labels out-of-band, so common supervised GRPO tasks cannot run correctly; keep the source example and pass each non-prompt column repeated for the N generations.
Useful? React with 👍 / 👎.
|
|
||
| tok_per_row = mx.maximum(mask.sum(-1), 1.0) | ||
| loss = ((per_token_loss * mask).sum(-1) / tok_per_row).mean() | ||
| return loss, mask.sum() |
There was a problem hiding this comment.
Avoid token-weighting accumulated GRPO gradients
With the default gradient_accumulation_steps=4, _train_inner scales each micro-batch gradient by the returned toks before dividing by the accumulated token count. This GRPO loss is already averaged per generated row, so returning mask.sum() makes longer completions/prompts carry more optimizer weight than shorter ones instead of averaging the GRPO objective per sample/group; return a sample/group count or bypass the token-scaling path for this loss.
Useful? React with 👍 / 👎.
| def loss_fn(model, batch, lengths, advantages): | ||
| pol_logp, mask = _per_token_logp_and_mask(model, batch, lengths) | ||
|
|
||
| if beta != 0.0 and not reference_free and _mods: |
There was a problem hiding this comment.
Preserve a GRPO reference for non-LoRA training
When GRPO is run on a full-finetuned model, or any model without LoRALinear adapters, _mods is empty even though reference_free defaults to false and grpo_beta defaults to 0.04. This condition skips the reference forward and falls through to a zero KL term, silently disabling the regularization users requested; either materialize a frozen reference or reject non-LoRA GRPO unless reference_free=True.
Useful? React with 👍 / 👎.
| pol, ntoks = _row_logp_and_mask(model, batch, lengths) | ||
| pol_c, pol_r = pol[:B], pol[B:] | ||
|
|
||
| if reference_free or not _mods: |
There was a problem hiding this comment.
Preserve a DPO reference for non-LoRA training
For full_finetuning=True or models whose adapters are not LoRALinear, _mods is empty while reference_free is false by default. This branch sets both reference log-probs to zero, so DPO silently becomes reference-free DPO rather than comparing against the frozen initial policy; create a reference model or fail unless reference_free=True.
Useful? React with 👍 / 👎.
| eval_batches = None | ||
| text_completion_only_loss = _text_completion_only_loss_arg(args) | ||
| if args.eval_steps > 0 and self.eval_dataset is not None: | ||
| if (getattr(args, "loss_type", "sft") in ("orpo", "dpo") |
There was a problem hiding this comment.
Skip GRPO eval until an eval loss exists
When loss_type='grpo' and eval_steps/eval_dataset are set, this guard skips only ORPO/DPO, so the next branch builds ordinary SFT eval batches. _evaluate_batch_totals then passes labels=None as the fourth argument to make_grpo_loss_fn, which treats it as advantages and crashes on advantages.reshape; include GRPO in the unsupported-eval guard or add a real GRPO eval path.
Useful? React with 👍 / 👎.
| raise ValueError( | ||
| f"{args.loss_type.upper()} is not yet supported for VLM models on MLX." | ||
| ) | ||
| batches = create_preference_batches( |
There was a problem hiding this comment.
Honor streaming for preference trainers
When ORPO/DPO is used with args.streaming=True, this new branch bypasses the existing streaming path and unconditionally calls create_preference_batches, which iterates and materializes the whole dataset before training. For iterable or very large preference datasets this can hang before the first step or exhaust memory; either provide a streaming preference iterator or reject streaming=True for these losses.
Useful? React with 👍 / 👎.
| resp = batch_generate( | ||
| self.model, self.tokenizer, prompts=[pids] * N, | ||
| max_tokens=args.max_completion_length, sampler=sampler, verbose=False, | ||
| ) |
There was a problem hiding this comment.
Generate GRPO rollouts in eval mode
The training loop has already called model.train() before this generator runs, so batch_generate samples rollouts with dropout enabled whenever the base model or LoRA adapters use nonzero dropout. Those completions are then optimized against a different stochastic forward pass in the loss, making GRPO off-policy/noisy for dropout-enabled runs; temporarily switch to eval for rollout generation and restore train mode afterward.
Useful? React with 👍 / 👎.
| while True: | ||
| prompt = prompts[idx % len(prompts)] | ||
| idx += 1 | ||
| pids = hf.encode(prompt) |
There was a problem hiding this comment.
Render GRPO prompts before encoding
When users pass prompts as chat messages or set args.chat_template/formatting_func, GRPO never goes through the base text data preparation path and encodes ex["prompt"] directly here. A conversational prompt will fail in encode, and a plain instruct prompt will be generated without the chat template or assistant generation marker, so rewards are computed for a prompt format the model was not tuned to follow; render/format prompts before encoding.
Useful? React with 👍 / 👎.
| pe = len(pids) | ||
| rows, lengths = [], [] | ||
| for c in comps: | ||
| full = hf.encode(prompt + c)[: args.max_seq_length] |
There was a problem hiding this comment.
Optimize the sampled GRPO token IDs
GRPO currently takes decoded strings from batch_generate and re-encodes prompt + c for the loss. Decode/encode is not guaranteed to round-trip to the sampled token IDs (byte fallback, cleanup, BPE boundary merges, and stop handling can change the sequence), so the gradient can be computed for tokens that were not actually generated; request return_token_ids=True and build full from pids + generated_ids.
Useful? React with 👍 / 👎.
|
I ran a stricter parity check against CUDA after checking the GRPO notebooks, because the notebook path uses chat-style prompts and reward functions that depend on dataset columns like What I matched for the fair numeric run:
Prompt hash matched on both backends: Result summary
CUDA had one KL/loss spike at step 7. Excluding only that spike, avg loss is MLX Loss curve: Grad norm curve: Reward mean curve: Reward std curve: CUDA KL curve: What looks goodThe fair run is finite on both backends. After fixing the harness so the prompt text is identical, the reward curves are almost identical. That means the main loss gap is probably not caused by prompt formatting or reward scoring in this rendered-prompt path. Issues to address
The GRPO notebooks pass
In the fair run, CUDA reward funcs received
TRL trains from the actual generated
CUDA loss roughly tracks Suggested next stepPlease add a fixed-rollout parity test before relying on stochastic train curves:
That will isolate the GRPO loss/grad math from generation differences. After that passes, the stochastic train-curve comparison will be much easier to interpret. |
The MLX ORPO loss computed its SFT/NLL term from the response-only, length-normalized log-prob (the same quantity used for the odds-ratio term). TRL's chosen_nll_loss is instead a pooled token-mean cross-entropy over the full prompt+response span (all non-pad chosen tokens, matching nn.CrossEntropyLoss default reduction). The narrower span made the SFT signal trivially easy, so ORPO saturated to ~0 loss on toy data while CUDA did not. Compute the NLL term over prompt+response for the chosen rows; the odds-ratio term (response-only, length-normalized) is unchanged.
…ges, k3 KL via LoRA-disable reference
The GRPO rollout generator called tokenizer.encode(prompt) directly, which
raises a ValueError when the prompt is a list of chat message dicts (the
format the GRPO notebooks use), so GRPO crashed before the first rollout.
Add _grpo_render_prompt: conversational prompts (a list of {role, content}
dicts) are rendered via the tokenizer's chat template, mirroring TRL's
maybe_apply_chat_template (add_generation_prompt when the last role is
'user'; continue_final_message when it is 'assistant'). Plain-string
prompts pass through unchanged. The rendered string is used for both the
prompt-id encode and the prompt+completion encode.
Verified the rendered output is byte-identical to TRL's
maybe_apply_chat_template on the same messages, so the MLX rollout prompt
matches the CUDA/TRL path by construction.
e7f9c12 to
f35e410
Compare
Rather than reproduce a single prompt hash, I verified the stronger property: the rendered output is
Still follow-up (not in this commit):
Now that prompt rendering and reward scoring match on the rendered-prompt path (your reward curves line |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f35e410869
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| prompt = ex[prompt_key] | ||
| p_ids = hf.encode(prompt) | ||
| c_ids = hf.encode(prompt + ex[chosen_key])[:max_seq_length] | ||
| r_ids = hf.encode(prompt + ex[rejected_key])[:max_seq_length] |
There was a problem hiding this comment.
Render preference chat examples before encoding
When an ORPO/DPO dataset uses conversational prompt/chosen/rejected columns, this helper sends the raw list of messages to hf.encode and concatenates lists before encoding, so those datasets fail instead of being chat-templated; raw string prompts also ignore args.chat_template unlike the SFT path. Render the preference example with the tokenizer/template before computing p_ids, c_ids, and r_ids.
Useful? React with 👍 / 👎.
| resp = batch_generate( | ||
| self.model, self.tokenizer, prompts=[pids] * N, | ||
| max_tokens=args.max_completion_length, sampler=sampler, verbose=False, |
There was a problem hiding this comment.
Batch GRPO rollouts over all requested prompts
For any GRPO run with per_device_train_batch_size > 1 (including the inherited default of 2), this repeats one selected prompt N times and advances only one dataset example per micro-batch. The trainer logs and schedules as if the configured batch size were used, but the loss/advantages are computed for a single prompt group, silently reducing the effective prompt batch by a factor of per_device_train_batch_size; collect that many prompts and repeat each by num_generations.
Useful? React with 👍 / 👎.
| logits = beta * ((pol_c - ref_c) - (pol_r - ref_r)) | ||
| loss = -mx.mean(nn.log_sigmoid(logits)) | ||
| return loss, ntoks |
There was a problem hiding this comment.
Accumulate preference gradients by pairs
With gradient_accumulation_steps > 1 (default 4), step_fn multiplies each micro-batch gradient by the returned ntoks and later divides by the accumulated token count. DPO/ORPO losses are already averaged over preference pairs, so returning response-token counts makes micro-batches with longer chosen/rejected responses dominate the optimizer step; return the number of pairs or bypass token-weighted accumulation for these losses.
Useful? React with 👍 / 👎.
| if num_batches is not None and len(out) >= num_batches: | ||
| break |
There was a problem hiding this comment.
Avoid capping preference batches after sorting
When max_steps is finite and the preference dataset has more than max_steps * gradient_accumulation_steps batches, this break is applied after rows has been sorted by length, so training keeps only the shortest pairs and never sees the rest of the dataset. Shuffle/sample before length bucketing, or build all batches and let the training loop cycle through them, so finite-step DPO/ORPO runs are not biased to a length-sorted prefix.
Useful? React with 👍 / 👎.
| pe = min(len(p_ids), len(c_ids), len(r_ids)) | ||
| rows.append((pe, c_ids, r_ids)) |
There was a problem hiding this comment.
Preserve completion tokens during preference truncation
When a prompt reaches max_seq_length, or either response is truncated down to prompt-only, pe becomes equal to that row's sequence length and the response mask is empty. DPO then uses a zero log-prob for that side and ORPO optimizes prompt-only SFT/zero odds, silently corrupting long-prompt preference examples; drop these rows or truncate the prompt to leave at least one chosen and rejected completion token.
Useful? React with 👍 / 👎.
| for ex in self.train_dataset: | ||
| if "prompt" not in ex: | ||
| raise ValueError("GRPO requires a 'prompt' column in the dataset.") | ||
| prompts.append(ex["prompt"]) |
There was a problem hiding this comment.
Honor streaming GRPO prompt datasets
When GRPO is used with an iterable/streaming prompt dataset, this loop drains the entire dataset into a list before the first rollout, so an infinite stream hangs and a very large stream can exhaust memory despite MLXTrainingConfig.streaming existing. Iterate prompts lazily (or reject streaming GRPO explicitly) instead of materializing all prompts up front.
Useful? React with 👍 / 👎.
The GRPO rollout only passed completions/prompts to reward functions, while TRL passes every dataset column (except prompt/completion/completion_ids) through as kwargs. Notebook reward funcs that read columns like 'answer' therefore broke on the MLX path. Build the full dataset rows alongside prompts and pass the row's other columns as reward kwargs, each repeated per generation to align row-for-row with completions, mirroring GRPOTrainer._calculate_rewards. The prompt and the example are indexed by the same cycled index so a reward func reading 'answer' always gets the scored prompt's row (verified across index wrap). completion_ids and trainer_state are documented as not-yet-available rather than faked: mlx_lm batch_generate does not surface generated token IDs, and there is no transformers TrainerState on the MLX path.
make_grpo_loss_fn computes per_token_kl but only returns (loss, ntoks), and the GRPO step is mx.compile'd — so KL could not be surfaced by changing the loss-fn return (breaks the shared compiled step) or via a closure holder (a Python write inside a compiled fn runs only at trace time, freezing the value). Instead add an eager probe outside the compiled step. _grpo_mean_kl computes masked-mean k3 KL (policy logp with adapters, reference logp via LoRA-disable) at the CURRENT weights, in eval mode so dropout draws no RNG and neither the trajectory nor the compiled step is perturbed. The loop runs it only for steps that will be logged, on the same batch/lengths the loss uses, and BEFORE the optimizer update, so the KL reflects the step's pre-update weights and is comparable to CUDA's logged KL. KL is stored in _kl_history and printed on the step line; returns None (nothing logged) when KL is undefined (no LoRA, beta==0, reference_free).
…ollapse DPO's reference and GRPO's KL term are computed by disabling LoRA adapters. When a model has no LoRA modules (full fine-tuning), the disable-adapter path has nothing to toggle, so both silently fell back to reference-free: DPO dropped the reference term and GRPO dropped the KL term with no warning, giving the user a different objective than requested. Raise a clear ValueError at loss-fn construction when there are no LoRA mods and a reference/KL is actually requested: DPO when not reference_free, GRPO when not reference_free and beta != 0. The messages point users to a LoRA/PEFT model or to reference_free=True / grpo_beta=0 if they genuinely want reference-free training. Legitimate reference-free configs are allowed through unchanged. This is the minimum fail-loud fix; a real full-FT reference (a frozen second model) is left as a separate maintainer design decision.
Still follow-up:
Verification for all three: the specific behavior confirmed on Qwen2.5-0.5B (4-bit), SFT/DPO/ORPO/GRPO |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 65b816af19
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| p_ids = hf.encode(prompt) | ||
| c_ids = hf.encode(prompt + ex[chosen_key])[:max_seq_length] | ||
| r_ids = hf.encode(prompt + ex[rejected_key])[:max_seq_length] | ||
| pe = min(len(p_ids), len(c_ids), len(r_ids)) |
There was a problem hiding this comment.
Tokenize preference completions as continuations
For standard preference datasets whose chosen/rejected strings do not already include EOS, these full-string encodes never add a supervised stop token; and for tokenizers that merge the last prompt token with the first answer token, len(p_ids) can also point past the first completion token. In those cases DPO/ORPO compute log-probs over the wrong completion span and omit termination training, unlike the SFT path which explicitly appends EOS when absent; build the completion continuation from the combined encoding, adjust the response start when the prompt prefix differs, and append EOS before truncation.
Useful? React with 👍 / 👎.
| hf = getattr(tokenizer, "_tokenizer", tokenizer) | ||
| pad_id = hf.eos_token_id if hf.eos_token_id is not None else 0 |
There was a problem hiding this comment.
Keep HF tokenizers wrapped for preference batching
When ORPO/DPO callers pass a regular Hugging Face fast tokenizer, this unwraps it to the low-level Rust tokenizer stored in _tokenizer; that object does not expose eos_token_id and its encode() returns Encoding objects rather than token-id lists, so create_preference_batches() fails before training. This repo already special-cases that distinction in _resolve_response_mask_tokenizer, so keep the public HF tokenizer unless the object is an mlx-lm wrapper.
Useful? React with 👍 / 👎.
| } | ||
| total = [0.0] * N | ||
| for rf in self.reward_funcs: | ||
| vals = rf(completions=comps, prompts=[prompt] * N, **reward_kwargs) |
There was a problem hiding this comment.
Wrap conversational GRPO completions for rewards
When the dataset prompt is conversational (the renderer above accepts a list of message dicts), TRL-style chat reward functions expect completions to have the same conversational shape, for example completion[0]["content"]. This passes raw generated strings instead, so common format rewards for chat data either index characters or raise before advantages are computed; wrap each generated text as an assistant message when the source prompt is conversational.
Useful? React with 👍 / 👎.
| full = hf.encode(rendered + c)[: args.max_seq_length] | ||
| rows.append(full) | ||
| lengths.append([pe, len(full)]) |
There was a problem hiding this comment.
Preserve GRPO completion tokens after truncation
If a rendered GRPO prompt is already at or above max_seq_length, this truncates prompt + completion back to prompt-only while pe still points past the truncated row. The completion mask is then empty for those samples, yielding zero reward-gradient rows or a zero-token batch error, so long-prompt GRPO runs cannot train correctly; truncate the prompt to leave completion budget or drop such examples before building lengths.
Useful? React with 👍 / 👎.
| for i, v in enumerate(vals): | ||
| total[i] += float(v) |
There was a problem hiding this comment.
TRL-style custom reward functions may return None for samples where that reward is not applicable, but this loop unconditionally casts every value to float. Multi-task GRPO runs that rely on that behavior will abort as soon as one reward returns None; ignore those entries (and handle the all-None case) instead of summing them.
Useful? React with 👍 / 👎.
Add GRPO (reinforcement learning) for text models on MLX
Adds GRPO (Group Relative Policy Optimization) to the MLX trainer — the first RL method on the Apple Silicon path. Mirrors TRL's
GRPOTrainer/GRPOConfigAPI.What's included
MLXGRPOTrainer+MLXGRPOConfig— TRL-style classes, consistent with the existingMLXDPOTrainer/MLXORPOTrainer.make_grpo_loss_fn(inutils.py) — token-level loss matching TRL'sGRPOTrainer._compute_loss:exp(ref − pol) − (ref − pol) − 1coef_1 = exp(logp − stop_grad(logp)),coef_2 = clip(coef_1, 1−ε, 1+ε),per_token_loss = −min(coef_1·adv, coef_2·adv) + β·KL(r − mean) / (std + 1e-4)(same1e-4epsilon as TRL)reward_funcsaccepts a single callable or a list (summed); signaturefn(completions, prompts=..., **kwargs) -> list[float].stop_gradient, restore infinally); verified to survivemx.compile.Rollout engine
Uses mlx-lm generation for rollouts (the group of completions per prompt). This is parity with CUDA Unsloth's non-vLLM generation path (the
elsebranch ofif use_vllm:).vLLM-style fast rollout (via
vllm-metal) is a planned follow-up. It currently has no installable distribution in the Studio Python 3.13 environment (pip install vllm-metal→ "No matching distribution found"), so it is deferred as a pluggable fast-path rather than a hard dependency. The rollout is isolated in_grpo_rollout_generator, so a faster engine can slot in behind the same seam later.Verification (Qwen2.5-0.5B, 4-bit)
0.709 → 0.759over 20 steps with a graded reward (upward trend confirms gradient direction is correct).grad_accum > 1, multiple summed reward functions,reference_free=Trueend-to-end,num_generations=2, and uniform-reward stress (std=0 → graceful zero-grad, no NaN).Scope / follow-ups
"GRPO is not yet supported for VLM models on MLX.").vllm-metalfast rollout — when installable / benchmarked against mlx-lm.num_iterations > 1(PPO multi-epoch) — the clip structure is in place but only exercised atnum_iterations = 1(wherecoef_1 = 1).AutoModelForSequenceClassification) — callable reward functions only for now.Note on base branch
Built on
feat/mlx-orpo(reuses the LoRA-disable reference pattern from DPO). Best reviewed / merged after the ORPO/DPO PR; the diff here shows only the GRPO commit. Once ORPO/DPO lands, this rebases cleanly ontomainas a single-commit PR.