Diffusion training: mxfp8 base precision (Blackwell) + SDXL U-Net regional compile#6852
Diffusion training: mxfp8 base precision (Blackwell) + SDXL U-Net regional compile#6852danielhanchen wants to merge 13 commits into
Conversation
- base_precision="mxfp8": torchao MX block-scaled float8 compute on the frozen base linears (Blackwell sm100+, cuBLAS kernels). Applied after add_adapter like fp8, never fatal, weights stay bf16 in memory. Measured 1.16x over compiled bf16 on Z-Image at 1024px batch 4 (16k tokens/step); a wash at small token counts, so it stays an explicit opt-in and auto never picks it. - SDXL: regionally compile the U-Net's BasicTransformerBlocks through the same never-fatal wrapper the DiT trainer uses. 1.35x steady state at 1024px batch 4 with same-seed loss parity (~1e-5 per step) and unchanged peak VRAM; ~30 s one-time warmup. Steady-state samples/sec now excludes step 1, matching the DiT trainer. - /info: mxfp8 advertised only on sm100+; supports_compile now true for sdxl. - NVFP4 training: not available in torchao 0.16 (no autograd path, no training recipe), so NVFP4 stays an inference-only quant for now. 193 diffusion backend tests green; frontend build clean.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request adds support for the mxfp8 (Blackwell-native block-scaled float8) base precision mode for DiT training, including GPU capability checks, module filtering, and fallback mechanisms. It also extends regional torch.compile support to SDXL's U-Net transformer blocks and updates the frontend to accommodate these options. The review feedback correctly identifies a UI bug where the compile control is nested inside an isDiT check, preventing it from being rendered for SDXL even though supportsCompile is enabled.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| // Whether to show the torch.compile control. The backend advertises this per family | ||
| // (the SDXL U-Net path compiles regionally too now); default on for DiT families when | ||
| // an older backend does not report it. | ||
| const supportsCompile = reportedFamily?.supports_compile ?? isDiT; |
There was a problem hiding this comment.
While supportsCompile is correctly updated to allow compilation for SDXL (by falling back to isDiT but preferring reportedFamily?.supports_compile), the actual rendering of the Compile transformer control in trainingSettings (around line 814) is nested inside the isDiT conditional block. As a result, the compile control will never be displayed for SDXL in the UI.
To fix this, the Compile transformer control should be moved outside of the isDiT conditional block so that it renders for any family where supportsCompile is true.
There was a problem hiding this comment.
Fixed in 2a5e957 on this PR: the compile control moved out of the DiT-only branch and renders for any family whose /info reports supports_compile. Verified live with Playwright: mxfp8 listed for DiT families on sm100, the compile select present for SDXL, and the DiT precision selector still hidden for SDXL.
The compile select was nested inside the DiT-only branch of the precision area, so supports_compile=true for sdxl never rendered it. Lift it out of the ternary: the precision selects stay family-specific, the compile control follows supports_compile. Verified with a live Playwright pass (mxfp8 listed for DiT families on sm100, compile select present for SDXL, DiT precision selector still hidden for SDXL).
…usion-train-perf2
for more information, see https://pre-commit.ci
# Conflicts: # studio/backend/core/training/diffusion_dit_trainer.py # studio/backend/core/training/diffusion_train_common.py # studio/frontend/src/features/images/train/diffusion-train-panel.tsx
The SDXL trainer now regionally compiles its transformer blocks, so /info advertises supports_compile for every family; the precision selector stays DiT-only.
for more information, see https://pre-commit.ci
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 94b076226b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
|
||
| quantize_( | ||
| transformer, | ||
| MXLinearConfig.from_recipe_name("mxfp8_cublas"), |
There was a problem hiding this comment.
Use the MXFP8 API available in torchao 0.17
On Studio installs with torch 2.11/2.12, the installer selects torchao 0.17, whose torchao.prototype.mx_formats no longer exports MXLinearConfig. Because this call is inside the broad fallback, selecting base_precision="mxfp8" on a Blackwell GPU only emits a warning and trains dense bf16, so the advertised mxfp8 mode never actually engages for those supported installs.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in 5ea8eb9: mxfp8_training_config tries the torchao 0.16 MXLinearConfig recipe first and falls back to the 0.17 MXFP8TrainingOpConfig.from_recipe API; both feed quantize, and mxfp8 still degrades to bf16 with a warning when neither import resolves.
torchao 0.17 removed MXLinearConfig from prototype.mx_formats in favour of MXFP8TrainingOpConfig.from_recipe shared with MoE training. _mxfp8_training_config tries the 0.16 API first and falls back to the 0.17 one; both feed quantize_. mxfp8 still degrades to bf16 with a warning when neither import resolves
# Conflicts: # studio/backend/core/training/diffusion_train_common.py # studio/backend/tests/test_diffusion_dit_trainer.py # studio/backend/tests/test_diffusion_training.py
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
for more information, see https://pre-commit.ci
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for mxfp8 (block-scaled float8 compute) base precision for DiT training on Blackwell GPUs (sm100+), leveraging torchao. Additionally, it extends regional torch.compile support to the SDXL U-Net's transformer blocks, refactors step-rate reporting to exclude step 1 warmup costs, and updates the frontend UI to expose these new capabilities. There are no review comments to address, so I have no feedback to provide.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Stacked on #6843 (base:
diffusion-krea2). Two training-speed additions, both measured on B200 with 40-step LoRA runs at 1024px, batch 4.mxfp8 base precision (DiT families)
New
base_precision="mxfp8": the frozen base linears switch to torchao MX block-scaled float8 training compute (MXLinearConfig.from_recipe_name("mxfp8_cublas")), applied afteradd_adapterexactly like the fp8 mode so the LoRA modules stay high precision. Weights stay bf16 in memory (it is a speed mode, not a memory mode), the swap is never fatal (any failure falls back to bf16 with a warning event), and the filter skipslora_/proj_outplus any dim not divisible by the 32-wide MX blocks.Measured (Z-Image dense, compiled, 1024px batch 4 = 16k tokens/step):
Microbench (fwd+bwd through DiT-shaped LoRA linears, compiled): 0.99x at 4k tokens, 1.14x at 16k, 1.16x at 64k. mxfp8 only pays off at high resolution or batch, so it stays an explicit opt-in:
autonever picks it, and/infoonly advertises it on sm100+ (Blackwell), where its cuBLAS kernels exist. Same-seed loss curves track bf16 step for step with the expected small float8 forward noise; no divergence, no NaN.Also documented: NVFP4 training is not possible on torchao 0.16 (no autograd path, no training recipe), so NVFP4 remains inference-only.
SDXL U-Net regional compile
The SDXL trainer now compiles the U-Net's repeated
BasicTransformerBlocks through the same never-fatal wrapper the DiT trainer uses (cfg.compile_transformer, default auto = on for this dense bf16 path, eager fallback with a warning event on any failure).Same-seed losses match to ~1e-5 at every step; the one-time warmup is ~30 s. The SDXL progress rate now excludes step 1 (compile/cudnn warmup) like the DiT trainer, and
model_load_completedcarriescompiledfor the UI./inforeportssupports_compile: truefor sdxl so the Train tab shows the compile control there too.Tests
15 new unit tests (mxfp8 validation incl. prequant/fp16 rejection, the MX module filter, compile policy, sm89/sm100/sm120 capability gating, sdxl
/info, request-model Literal): 193 diffusion backend tests green on diffusers 0.39; frontend build clean.