Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
146740d
Fix: skip fp16/bf16 validation for full finetuning in RL trainers
Jul 2, 2026
340cd50
Fix: auto-correct fp16/bf16 mismatches for full finetuning before val…
Jul 2, 2026
63b8b4b
Guard Windows ROCm torchao override skip
Jul 3, 2026
4d0475d
Update unsloth/models/rl.py
InfoSage05 Jul 3, 2026
f1091e8
Update studio/install_python_stack.py
InfoSage05 Jul 3, 2026
69e7401
Harden ROCm probe and sync RL precision flags
Jul 3, 2026
0a8b0cf
Merge upstream/main into issue-6833-windows-rocm-torchao
Jul 3, 2026
e74e491
Merge remote PR updates into issue-6833-windows-rocm-torchao
Jul 3, 2026
1b12c69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2026
bf77875
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
319e544
Add MLX trainer compatibility shims
Jul 3, 2026
262f699
Merge remote-tracking branch 'origin/issue-6833-windows-rocm-torchao'…
Jul 3, 2026
3705eb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2026
a1d9c9c
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
d3f4270
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
eba4669
Scope PR to Windows ROCm torchao guard
Imagineer99 Jul 3, 2026
c177ce1
Restore PR scope to Windows ROCm guard
Imagineer99 Jul 3, 2026
f5f34ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2026
1492dd0
test: cover Windows ROCm torchao skip behavior
Imagineer99 Jul 3, 2026
070e9b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2026
9fed497
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
a4d4967
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion studio/install_python_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,39 @@ def _probe_installed_torch_version() -> str | None:
return lines[-1] if lines else None


def _installed_torch_is_windows_rocm() -> bool:
"""Return True when the target venv currently has a Windows ROCm torch build.

This is a belt-and-suspenders guard for the torchao override step: if the
earlier ROCm install path failed to set _rocm_windows_torch_installed but the
venv already contains a ROCm torch wheel, still skip torchao because it
crashes on import on Windows ROCm.
"""
if not IS_WINDOWS:
return False
try:
probe = subprocess.run(
[
sys.executable,
"-c",
(
"import sys, torch; "
"hip = getattr(getattr(torch, 'version', None), 'hip', None) or ''; "
"ver = getattr(torch, '__version__', '').lower(); "
"sys.stdout.write('yes' if (hip or 'rocm' in ver or 'rocmsdk' in ver) else '')"
),
],
stdout = subprocess.PIPE,
stderr = subprocess.DEVNULL,
text = True,
timeout = 90,
**_windows_hidden_subprocess_kwargs(),
)
except (OSError, subprocess.TimeoutExpired):
return False
return probe.returncode == 0 and (probe.stdout or "").strip() == "yes"
Comment thread
InfoSage05 marked this conversation as resolved.
Outdated


# AMD Windows ROCm wheels (repo.amd.com/rocm/whl/{arch_family}/).
# Override with UNSLOTH_ROCM_WINDOWS_MIRROR for air-gapped/mirror installs.
_ROCM_WINDOWS_INDEX_BASE = (
Expand Down Expand Up @@ -2221,7 +2254,7 @@ def _win_amd_smi_has_gpu(stdout: str) -> bool:
# (no working build; see below).
if NO_TORCH:
_progress("dependency overrides (skipped, no torch)")
elif _rocm_windows_torch_installed:
elif _rocm_windows_torch_installed or _installed_torch_is_windows_rocm():
# No working Windows ROCm torchao build: it imports an absent c10d backend
# and crashes transformers.quantizers. Studio stubs it at runtime, so
# installing it only ships a package that crashes on import -- skip it.
Expand Down
35 changes: 35 additions & 0 deletions tests/studio/install/test_rocm_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,6 +2294,41 @@ def _bad_probe(*a, **kw):
mock_bnb.assert_not_called()


class TestWindowsRocmTorchaoGuard:
"""Verify the torchao skip can detect an installed Windows ROCm torch build."""

def test_installed_torch_is_windows_rocm_accepts_rocm_probe(self):
rv = MagicMock()
rv.returncode = 0
rv.stdout = "yes"
with (
patch.object(stack_mod, "IS_WINDOWS", True),
patch.object(stack_mod.subprocess, "run", return_value = rv),
):
assert stack_mod._installed_torch_is_windows_rocm() is True

def test_installed_torch_is_windows_rocm_rejects_non_rocm_probe(self):
rv = MagicMock()
rv.returncode = 0
rv.stdout = ""
with (
patch.object(stack_mod, "IS_WINDOWS", True),
patch.object(stack_mod.subprocess, "run", return_value = rv),
):
assert stack_mod._installed_torch_is_windows_rocm() is False

def test_installed_torch_is_windows_rocm_is_non_windows_noop(self):
with patch.object(stack_mod, "IS_WINDOWS", False):
assert stack_mod._installed_torch_is_windows_rocm() is False

def test_install_python_stack_uses_direct_windows_rocm_torchao_guard(self):
source = _STACK_PATH.read_text(encoding = "utf-8")
assert (
"elif _rocm_windows_torch_installed or _installed_torch_is_windows_rocm():"
in source
)


# TEST: worker.py -- Windows ROCm patches (source-level checks)


Expand Down
3 changes: 3 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,9 @@ def _patch_trl_rl_trainers_impl(trainer_file = "grpo_trainer"):
"dtype = _get_dtype(dtype)\n"
"float16 = dtype == torch.float16\n"
"bfloat16 = dtype == torch.bfloat16\n"
"if full_finetuning:\n"
" if bfloat16 and use_fp16: use_fp16 = False\n"
" if float16 and use_bf16: use_bf16 = False\n"
Comment thread
InfoSage05 marked this conversation as resolved.
"if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"
"if not force_float32 and (bfloat16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"
"if force_float32:\n"
Expand Down