Skip to content

feat(mlx): add DDP training support#841

Open
Lyxot wants to merge 10 commits into
unslothai:mainfrom
Lyxot:feat/mlx-ddp
Open

feat(mlx): add DDP training support#841
Lyxot wants to merge 10 commits into
unslothai:mainfrom
Lyxot:feat/mlx-ddp

Conversation

@Lyxot

@Lyxot Lyxot commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds initial MLX backend distributed data parallel (DDP) training support using MLX distributed primitives.

The trainer initializes distributed state through mx.distributed.init() without forcing a backend. This follows MLX's default backend="any" behavior: MLX tries the available distributed backends and returns a singleton group if distributed initialization is unavailable. The PR validates the implementation through single-node mlx.launch -n coverage and manual multi-node smoke tests; examples pin --backend ring only where an explicit, reproducible backend is useful.

This PR keeps the implementation inside the MLX backend. It does not add CLI, Studio, or CI orchestration for multi-node launches.

References

What Changed

  • Added MLX distributed trainer state and result fields:
    • distributed_world_size
    • distributed_rank
    • distributed_is_main_process
  • Added DDP-aware batch sharding for:
    • ordered text batches
    • labeled / response-only text batches
    • streaming text batches
    • VLM batches
    • streaming VLM batches
  • Synchronized MLX text training under DDP:
    • rank-local forward/backward
    • gradient all-reduce
    • global token and loss accounting
    • main-rank callback/save behavior
  • Added text checkpoint resume support under DDP:
    • resume path validation across ranks
    • synchronized resume contract
    • parity coverage against uninterrupted training
  • Extended the DDP path to VLM and streaming training.
  • Enabled mx.compile for the DDP local loss/gradient-accumulation step while keeping distributed collectives eager.
  • Added compile failure handling:
    • strict mode raises
    • best-effort mode falls back to eager consistently across ranks
    • non-compile runtime errors are not mislabeled as compile fallback
  • Sharded text evaluation under DDP and reduced eval metrics across ranks.
  • Coordinated DDP failure handling for batch-fetch failures and zero-token batches.
  • Added DDP diagnostics to training results:
    • host/rank map
    • per-rank runtime
    • per-rank token counts
    • per-rank throughput history
    • per-rank peak memory
    • eval metrics
    • compile fallback status

What Is Not Included

  • No CLI support.
  • No Studio support.
  • No multi-node CI job.
  • No CUDA DDP changes.
  • No tensor parallelism, pipeline parallelism, ZeRO, FSDP, or optimizer-state sharding.
  • No automatic host discovery, SSH setup, or network-interface routing abstraction.
  • No benchmark scripts or benchmark artifacts.

Reviewer Notes

  • The DDP implementation is intentionally limited to the MLX backend.
  • Unsloth does not select or restrict the MLX communication backend. Backend selection is delegated to MLX distributed initialization.
  • Tests are added with the behavior they cover and run through mlx.launch -n on a single node.
  • In DDP mode, compile covers only the local loss/gradient computation. Distributed collectives remain eager.
  • Main-rank-only behavior is preserved for callbacks and model/checkpoint saves.
  • The multi-node launch path was manually smoke-tested, but this PR does not add automated multi-node CI.

Performance / Result Sanity Check

These are small local smoke-test numbers, not a formal benchmark. They cover two models on yahma/alpaca-cleaned for 8 optimizer steps with LoRA rank 8, bfloat16, CCE enabled, and compile enabled. The multi-node rows used heterogeneous Apple Silicon hardware connected through Thunderbolt 4.

Single-process regression against main

main has no MLX DDP support, so the main comparison is single-process only. The PR matches main exactly on avg loss and final loss in these smoke runs.

Model / seq Hardware Branch Runtime Post-step1 median tok/s Avg loss Final loss
Qwen3-0.6B / 512 MacBook Pro 16, M3 Max main 2.988s 1251.8 1.816272 1.561391
Qwen3-0.6B / 512 MacBook Pro 16, M3 Max PR 3.071s 1224.0 1.816272 1.561391
Qwen3-0.6B / 512 MacBook Pro 14, M3 Pro main 5.080s 692.3 1.816272 1.561391
Qwen3-0.6B / 512 MacBook Pro 14, M3 Pro PR 5.077s 691.6 1.816272 1.561391
Qwen3-1.7B / 2048 MacBook Pro 16, M3 Max main 5.675s 617.2 1.569812 1.292439
Qwen3-1.7B / 2048 MacBook Pro 16, M3 Max PR 5.548s 647.1 1.569812 1.292439
Qwen3-1.7B / 2048 MacBook Pro 14, M3 Pro main 10.557s 331.1 1.569812 1.292439
Qwen3-1.7B / 2048 MacBook Pro 14, M3 Pro PR 10.547s 330.3 1.569812 1.292439

DDP parity check

This keeps the effective global batch constant: single process uses batch=2, grad_accum=1; DDP uses batch=1/rank, grad_accum=1 across 2 ranks. Loss parity is the main correctness signal.

Model / seq Run Hardware Runtime Post-step1 median tok/s Avg loss Final loss Final-loss delta vs PR single
Qwen3-0.6B / 512 PR single process MacBook Pro 16, M3 Max 3.071s 1224.0 1.816272 1.561391 -
Qwen3-0.6B / 512 PR single-node DDP MacBook Pro 16, M3 Max 2.872s 1312.0 1.815974 1.561426 +0.000034
Qwen3-0.6B / 512 PR multi-node DDP M3 Max + M3 Pro 2.887s 1267.7 1.815974 1.561426 +0.000034
Qwen3-1.7B / 2048 PR single process MacBook Pro 16, M3 Max 5.548s 647.1 1.569812 1.292439 -
Qwen3-1.7B / 2048 PR single-node DDP MacBook Pro 16, M3 Max 4.992s 741.3 1.573234 1.294697 +0.002258
Qwen3-1.7B / 2048 PR multi-node DDP M3 Max + M3 Pro 5.481s 625.2 1.573234 1.294697 +0.002258

Guide-shape throughput check

This mirrors the Unsloth DDP guide shape: per_device_train_batch_size=1, gradient_accumulation_steps=4. Single process has effective global batch 4; 2-rank DDP has effective global batch 8, so this is a throughput sanity check, not a parity test.

Model / seq Run Hardware Effective global batch Runtime Tokens Avg tok/s Speed vs single Avg loss Final loss
Qwen3-0.6B / 512 PR single process MacBook Pro 16, M3 Max 4 6.204s 6234 1004.8 1.00x 1.738459 1.378429
Qwen3-0.6B / 512 PR single-node DDP MacBook Pro 16, M3 Max 8 8.444s 11515 1363.7 1.36x 1.811959 1.619031
Qwen3-0.6B / 512 PR multi-node DDP M3 Max + M3 Pro 8 8.943s 11515 1287.5 1.28x 1.811959 1.619031
Qwen3-1.7B / 2048 PR single process MacBook Pro 16, M3 Max 4 10.198s 6234 611.3 1.00x 1.499014 1.179985
Qwen3-1.7B / 2048 PR single-node DDP MacBook Pro 16, M3 Max 8 15.779s 11515 729.7 1.19x 1.582486 1.336699
Qwen3-1.7B / 2048 PR multi-node DDP M3 Max + M3 Pro 8 17.479s 11515 658.8 1.08x 1.582486 1.336699

Notes:

  • DDP compile scope is ddp_local_grad; distributed collectives remain eager.
  • The guide-shape DDP rows process more tokens per optimizer step, so final losses are expected to differ from the single-process row.
  • Multi-node rows use heterogeneous M3 Max + M3 Pro hardware, so they should be read as launch and throughput smoke tests, not as ideal scaling claims.

How To Test

Focused single-process checks

python -m pytest tests/test_mlx_trainer_internals.py -q

Focused single-node DDP check

This launches two local MLX ranks with mlx.launch -n.

python -m pytest tests/test_mlx_ddp_metal.py -q

Manual single-node training launch

MLX's minimal local launch form is:

mlx.launch -n 2 path/to/train.py --your-training-args

-n 2 means two local ranks. The training script should initialize distributed state through the trainer; the trainer calls mx.distributed.init() without forcing a backend, so MLX uses its default backend-selection behavior.

If a test needs an explicit backend or exact interpreter/environment, use the same mlx.launch shape with extra launcher options:

cd /path/to/unsloth-zoo
mlx.launch -n 2 \
  --backend ring \
  --env PYTHONPATH="$PWD" \
  --python "$(command -v python)" \
  --cwd "$PWD" \
  path/to/train.py --your-training-args

Manual multi-node training launch

MLX's minimal multi-host launch form is:

mlx.launch --hosts ip1,ip2 path/to/train.py --your-training-args

Each node must have:

  • the same branch checked out
  • the same Python environment
  • passwordless SSH from the launch machine
  • communication addresses reachable by all MLX ranks

For more control, use an MLX hostfile. The hostfile can separate the SSH target from the communication IPs, which is useful when SSH and collectives should use different network interfaces:

[
  {"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
  {"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
]

Example with one rank per host and an explicit ring backend:

cd /path/to/unsloth-zoo
mlx.launch --hosts /path/to/hostfile.json \
  --backend ring \
  --env PYTHONPATH="/path/to/unsloth-zoo" \
  --python "/path/to/python" \
  --cwd "/path/to/unsloth-zoo" \
  path/to/train.py --your-training-args

For Thunderbolt-connected Macs, MLX recommends mlx.distributed_config to discover/configure the topology and create the hostfile for mlx.launch. This PR's multi-node smoke rows used Thunderbolt 4 hardware, but the PR does not add host configuration automation.

Validation

git diff --check origin/main...HEAD

Passed.

python -m pytest tests/test_mlx_ddp_metal.py tests/test_mlx_trainer_internals.py -q

Passed: 55 passed.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces distributed data parallel (DDP) training support for the MLX trainer, including rank-sharded batching, synchronization, and comprehensive integration tests. The review feedback highlights critical improvements for padding correctness by ensuring the tokenizer is properly passed to distributed text batching functions to avoid defaulting to a pad token ID of 0. Additionally, it recommends using a local random number generator instead of modifying the global NumPy random seed to prevent unintended side-effects.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread unsloth_zoo/mlx/utils.py
Comment thread unsloth_zoo/mlx/utils.py
Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment thread unsloth_zoo/mlx/utils.py
Comment thread unsloth_zoo/mlx/utils.py Outdated
@Lyxot Lyxot marked this pull request as ready for review June 29, 2026 19:48
Copilot AI review requested due to automatic review settings June 29, 2026 19:48

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds initial distributed data parallel (DDP) training support to the MLX backend, including rank-aware batch sharding for text and VLM workloads, synchronized training/eval metric reductions across ranks, coordinated failure handling, and richer per-rank diagnostics in the training results.

Changes:

  • Added MLXTrainer distributed metadata/diagnostics (rank/world-size/is-main, per-rank runtime/tokens/throughput/memory) and DDP-safe control flow (stop propagation, failure consensus, main-rank-only side effects).
  • Implemented DDP-aware batch sharding and padding alignment for text + VLM (materialized and streaming) and added checkpoint-resume validation across ranks.
  • Added single-node mlx.launch -n test coverage validating sharding behavior, compile fallback behavior, resume parity, and main-rank-only callbacks/saves.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
unsloth_zoo/mlx/utils.py Adds DDP-aware batch sharding/padding utilities and integrates them into text/VLM batching and streaming iterators.
unsloth_zoo/mlx/trainer.py Introduces MLX distributed initialization, DDP-synchronized training/eval reductions, compile scoping/fallback coordination, and per-rank diagnostics.
tests/test_mlx_trainer_internals.py Adds unit tests for distributed defaults and distributed-group caching behavior.
tests/test_mlx_ddp_metal.py Adds integration tests using mlx.launch to validate end-to-end DDP behavior for text and VLM training/eval/resume/compile-fallback paths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread unsloth_zoo/mlx/utils.py
Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment thread unsloth_zoo/mlx/utils.py
@Lyxot Lyxot marked this pull request as draft June 29, 2026 20:00

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 16846bf899

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment thread unsloth_zoo/mlx/trainer.py
Comment thread unsloth_zoo/mlx/trainer.py
Comment thread unsloth_zoo/mlx/trainer.py Outdated
Comment thread unsloth_zoo/mlx/trainer.py Outdated
@Lyxot Lyxot marked this pull request as ready for review June 30, 2026 16:02

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 75dc5a3364

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +3278 to +3282
bi = _rank_slice_distributed_batch(
global_indices,
batch_size,
comm_group=comm_group,
pad_source=indices,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Synchronize materialized VLM prep failures

In materialized DDP VLM runs, this sharding makes each rank build only its local slice before the trainer reaches any failure-consensus collective. If one rank hits a local _item/processor error while preparing its shard, such as an unreadable image on one host, that rank exits during _prepare_data while peers that prepared successfully continue to the first training collective and can hang. Wrap materialized preparation in the same rank-wide failure synchronization used for batch fetch/eval.

Useful? React with 👍 / 👎.

dataset_order=getattr(args, "dataset_order", "default"),
preserve_dataset_order=bool(getattr(args, "preserve_dataset_order", False)),
num_epochs=labeled_num_epochs,
comm_group=comm_group,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Check response masks before rank sharding

Passing the DDP group here makes _create_labeled_batches return only this rank's shard, but _check_all_masked(batches) immediately below still treats that local shard as the whole dataset. In a valid completion-only dataset where one rank happens to receive only all--100 rows while another rank has trainable responses, this rank raises ZeroDivisionError before training while peers can proceed to the first collective and hang; aggregate the all-masked decision across ranks or run the check before sharding.

Useful? React with 👍 / 👎.

Comment on lines +2357 to +2358
grad, toks_f = grad_accum_state
grad_norm = _apply_update(grad, toks_f)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Synchronize update failures after DDP reduction

For DDP optimizer steps, _apply_update performs the gradient/token all-reduces and then runs clipping, weight decay, and optimizer.update without another failure consensus. If a rank-local error occurs after those collectives, such as an OOM during clipping or optimizer-state update on one worker, that rank exits while peers continue to the next _distributed_should_stop() collective and can hang; wrap this update path in the same rank-wide failure synchronization used around local step execution.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +3795 to +3798
for key, fill_value in (
("input_ids", pad_id),
("attention_mask", 0),
("labels", -100),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Clear vision fields for empty VLM pads

The new empty-pad path first builds synthetic VLM rows from a real sample, but this helper only overwrites text ids, attention masks, and labels. For non-divisible DDP eval shards on models/processors that pass pixel_values, image_grid_thw, or derived position metadata, the padded row still carries a real image while its text contains no image tokens, which can make Qwen/Phi-style VLM forwards fail their image-token/feature alignment or wastefully process stale images; drop or neutralize the row-aligned vision metadata too.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Repo admins can enable using credits for code reviews in their settings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants