Guard Windows ROCm torchao override skip#6837
Conversation
When doing full finetuning (FFT) of a bfloat16 model, the fp16/bf16 mismatch validation fires before the corrective logic runs, causing a misleading error even though the code would properly handle it downstream. Skip the validation when full_finetuning is active. Fixes unslothai#6731
…idation Instead of entirely skipping validation (which could let mismatches through when mixed_precision_dtype is float32), auto-correct explicit fp16/bf16 settings that conflict with the model's dtype for FFT. This way the existing validation still catches real mismatches for non-FFT cases, and the corrective logic below handles the normalized settings. Fixes the issue raised in Codex review of PR unslothai#6813.
Detect installed ROCm torch directly before applying the torchao override so Windows ROCm environments never install the crashing torchao package even if the earlier ROCm-installed flag is missing.
There was a problem hiding this comment.
Code Review
This pull request introduces a check to detect Windows ROCm PyTorch builds to safely skip torchao installation, along with corresponding unit tests. It also updates the TRL RL trainer patching logic to automatically adjust precision flags (use_fp16/use_bf16) during full finetuning to match the model's dtype. The review feedback suggests making the ROCm torch probe more robust against stray stdout noise by checking the last non-empty line, and ensuring that the underlying args.fp16 and args.bf16 attributes are updated alongside the local variables to maintain consistency.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Tolerate stray stdout noise when probing Windows ROCm torch installs by checking the last non-empty output line, matching the existing torch version probe behavior. Also keep args.fp16 and args.bf16 synchronized with the full-finetuning precision auto-corrections in the RL trainer patch so downstream eval settings see a consistent TrainingArguments state.
for more information, see https://pre-commit.ci
|
Main installer change looks correct. I have one question, why is the unsloth/models/rl.py precision flag change included? The PR description is focused on Windows ROCm/torchao, was it intentional to include? Could we also make the new guard test a bit more behavioral rather than checking the exact source string? For example, patch |
|
@Imagineer99 Yes! the Good point on the test as well. I’ll switch the new guard test to a behavioral one by patching |
Patch imported MLXTrainer and MLXTrainingConfig objects to preserve the expected dataclass field ordering and to provide a _train_dataset_for_batches fallback when older trainers or test doubles only expose train_dataset. Also add focused worker tests covering both compatibility paths.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
… into issue-6833-windows-rocm-torchao
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci

What changed
This PR hardens the Windows ROCm
torchaoskip path instudio/install_python_stack.py.Previously, the installer skipped
torchaoonly when_rocm_windows_torch_installedhad been set earlier in the install flow. This change adds a direct runtime probe of the installed torch build and skips thetorchaooverride whenever the target venv already contains a Windows ROCm torch wheel.Why it changed
Issue #6833 reports that on Windows ROCm,
torchao==0.17.0crashes on import because the ROCm Windows torch build does not exposetorch.ops._c10d_functional.all_gather_into_tensor.The repo already had the main mitigation in place:
torchaoon Windows ROCm sotransformersandpeftimports do not crash.install_python_stack.pyskips installingtorchaowhen the ROCm-install path marks_rocm_windows_torch_installed.The remaining gap was that the skip depended on that earlier flag being present. If the flag was missed or stale but the environment already contained a ROCm torch wheel, the installer could still install
torchaoand ship a package that crashes on import.Root cause
The issue is not general ROCm failure and not GPU detection failure. The root cause is that Windows ROCm torch can be present in the venv while the
torchaooverride decision still relies on installer state rather than the actual installed torch runtime.This PR closes that gap by checking the installed torch runtime directly before applying the
torchaooverride.User impact
_rocm_windows_torch_installedto avoid installing incompatibletorchao.torchaoconsistently.#6833.Validation
Focused tests:
pytest -q tests/studio/install/test_rocm_support.py -k 'RocmTorchInstalledEnvVar or WindowsRocmTorchaoGuard'Broader installer coverage:
pytest -q tests/studio/install/test_rocm_support.py tests/studio/install/test_pr5940_followups.pyResults:
8 passed378 passed, 4 skipped