feat(mlx): add DDP training support#841
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces distributed data parallel (DDP) training support for the MLX trainer, including rank-sharded batching, synchronization, and comprehensive integration tests. The review feedback highlights critical improvements for padding correctness by ensuring the tokenizer is properly passed to distributed text batching functions to avoid defaulting to a pad token ID of 0. Additionally, it recommends using a local random number generator instead of modifying the global NumPy random seed to prevent unintended side-effects.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
There was a problem hiding this comment.
Pull request overview
This PR adds initial distributed data parallel (DDP) training support to the MLX backend, including rank-aware batch sharding for text and VLM workloads, synchronized training/eval metric reductions across ranks, coordinated failure handling, and richer per-rank diagnostics in the training results.
Changes:
- Added MLXTrainer distributed metadata/diagnostics (rank/world-size/is-main, per-rank runtime/tokens/throughput/memory) and DDP-safe control flow (stop propagation, failure consensus, main-rank-only side effects).
- Implemented DDP-aware batch sharding and padding alignment for text + VLM (materialized and streaming) and added checkpoint-resume validation across ranks.
- Added single-node
mlx.launch -ntest coverage validating sharding behavior, compile fallback behavior, resume parity, and main-rank-only callbacks/saves.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| unsloth_zoo/mlx/utils.py | Adds DDP-aware batch sharding/padding utilities and integrates them into text/VLM batching and streaming iterators. |
| unsloth_zoo/mlx/trainer.py | Introduces MLX distributed initialization, DDP-synchronized training/eval reductions, compile scoping/fallback coordination, and per-rank diagnostics. |
| tests/test_mlx_trainer_internals.py | Adds unit tests for distributed defaults and distributed-group caching behavior. |
| tests/test_mlx_ddp_metal.py | Adds integration tests using mlx.launch to validate end-to-end DDP behavior for text and VLM training/eval/resume/compile-fallback paths. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 16846bf899
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 75dc5a3364
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| bi = _rank_slice_distributed_batch( | ||
| global_indices, | ||
| batch_size, | ||
| comm_group=comm_group, | ||
| pad_source=indices, |
There was a problem hiding this comment.
Synchronize materialized VLM prep failures
In materialized DDP VLM runs, this sharding makes each rank build only its local slice before the trainer reaches any failure-consensus collective. If one rank hits a local _item/processor error while preparing its shard, such as an unreadable image on one host, that rank exits during _prepare_data while peers that prepared successfully continue to the first training collective and can hang. Wrap materialized preparation in the same rank-wide failure synchronization used for batch fetch/eval.
Useful? React with 👍 / 👎.
| dataset_order=getattr(args, "dataset_order", "default"), | ||
| preserve_dataset_order=bool(getattr(args, "preserve_dataset_order", False)), | ||
| num_epochs=labeled_num_epochs, | ||
| comm_group=comm_group, |
There was a problem hiding this comment.
Check response masks before rank sharding
Passing the DDP group here makes _create_labeled_batches return only this rank's shard, but _check_all_masked(batches) immediately below still treats that local shard as the whole dataset. In a valid completion-only dataset where one rank happens to receive only all--100 rows while another rank has trainable responses, this rank raises ZeroDivisionError before training while peers can proceed to the first collective and hang; aggregate the all-masked decision across ranks or run the check before sharding.
Useful? React with 👍 / 👎.
| grad, toks_f = grad_accum_state | ||
| grad_norm = _apply_update(grad, toks_f) |
There was a problem hiding this comment.
Synchronize update failures after DDP reduction
For DDP optimizer steps, _apply_update performs the gradient/token all-reduces and then runs clipping, weight decay, and optimizer.update without another failure consensus. If a rank-local error occurs after those collectives, such as an OOM during clipping or optimizer-state update on one worker, that rank exits while peers continue to the next _distributed_should_stop() collective and can hang; wrap this update path in the same rank-wide failure synchronization used around local step execution.
Useful? React with 👍 / 👎.
| for key, fill_value in ( | ||
| ("input_ids", pad_id), | ||
| ("attention_mask", 0), | ||
| ("labels", -100), |
There was a problem hiding this comment.
Clear vision fields for empty VLM pads
The new empty-pad path first builds synthetic VLM rows from a real sample, but this helper only overwrites text ids, attention masks, and labels. For non-divisible DDP eval shards on models/processors that pass pixel_values, image_grid_thw, or derived position metadata, the padded row still carries a real image while its text contains no image tokens, which can make Qwen/Phi-style VLM forwards fail their image-token/feature alignment or wastefully process stale images; drop or neutralize the row-aligned vision metadata too.
Useful? React with 👍 / 👎.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Summary
Adds initial MLX backend distributed data parallel (DDP) training support using MLX distributed primitives.
The trainer initializes distributed state through
mx.distributed.init()without forcing a backend. This follows MLX's defaultbackend="any"behavior: MLX tries the available distributed backends and returns a singleton group if distributed initialization is unavailable. The PR validates the implementation through single-nodemlx.launch -ncoverage and manual multi-node smoke tests; examples pin--backend ringonly where an explicit, reproducible backend is useful.This PR keeps the implementation inside the MLX backend. It does not add CLI, Studio, or CI orchestration for multi-node launches.
References
What Changed
distributed_world_sizedistributed_rankdistributed_is_main_processmx.compilefor the DDP local loss/gradient-accumulation step while keeping distributed collectives eager.What Is Not Included
Reviewer Notes
mlx.launch -non a single node.Performance / Result Sanity Check
These are small local smoke-test numbers, not a formal benchmark. They cover two models on
yahma/alpaca-cleanedfor 8 optimizer steps with LoRA rank 8,bfloat16, CCE enabled, and compile enabled. The multi-node rows used heterogeneous Apple Silicon hardware connected through Thunderbolt 4.Single-process regression against
mainmainhas no MLX DDP support, so themaincomparison is single-process only. The PR matchesmainexactly on avg loss and final loss in these smoke runs.DDP parity check
This keeps the effective global batch constant: single process uses
batch=2, grad_accum=1; DDP usesbatch=1/rank, grad_accum=1across 2 ranks. Loss parity is the main correctness signal.Guide-shape throughput check
This mirrors the Unsloth DDP guide shape:
per_device_train_batch_size=1,gradient_accumulation_steps=4. Single process has effective global batch 4; 2-rank DDP has effective global batch 8, so this is a throughput sanity check, not a parity test.Notes:
ddp_local_grad; distributed collectives remain eager.How To Test
Focused single-process checks
Focused single-node DDP check
This launches two local MLX ranks with
mlx.launch -n.Manual single-node training launch
MLX's minimal local launch form is:
-n 2means two local ranks. The training script should initialize distributed state through the trainer; the trainer callsmx.distributed.init()without forcing a backend, so MLX uses its default backend-selection behavior.If a test needs an explicit backend or exact interpreter/environment, use the same
mlx.launchshape with extra launcher options:Manual multi-node training launch
MLX's minimal multi-host launch form is:
Each node must have:
For more control, use an MLX hostfile. The hostfile can separate the SSH target from the communication IPs, which is useful when SSH and collectives should use different network interfaces:
[ {"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]}, {"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]} ]Example with one rank per host and an explicit ring backend:
For Thunderbolt-connected Macs, MLX recommends
mlx.distributed_configto discover/configure the topology and create the hostfile formlx.launch. This PR's multi-node smoke rows used Thunderbolt 4 hardware, but the PR does not add host configuration automation.Validation
Passed.
Passed:
55 passed.