Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
146740d
Fix: skip fp16/bf16 validation for full finetuning in RL trainers
Jul 2, 2026
340cd50
Fix: auto-correct fp16/bf16 mismatches for full finetuning before val…
Jul 2, 2026
63b8b4b
Guard Windows ROCm torchao override skip
Jul 3, 2026
4d0475d
Update unsloth/models/rl.py
InfoSage05 Jul 3, 2026
f1091e8
Update studio/install_python_stack.py
InfoSage05 Jul 3, 2026
69e7401
Harden ROCm probe and sync RL precision flags
Jul 3, 2026
0a8b0cf
Merge upstream/main into issue-6833-windows-rocm-torchao
Jul 3, 2026
e74e491
Merge remote PR updates into issue-6833-windows-rocm-torchao
Jul 3, 2026
1b12c69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2026
bf77875
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
319e544
Add MLX trainer compatibility shims
Jul 3, 2026
262f699
Merge remote-tracking branch 'origin/issue-6833-windows-rocm-torchao'…
Jul 3, 2026
3705eb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2026
a1d9c9c
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
d3f4270
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
eba4669
Scope PR to Windows ROCm torchao guard
Imagineer99 Jul 3, 2026
c177ce1
Restore PR scope to Windows ROCm guard
Imagineer99 Jul 3, 2026
f5f34ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2026
1492dd0
test: cover Windows ROCm torchao skip behavior
Imagineer99 Jul 3, 2026
070e9b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2026
9fed497
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
a4d4967
Merge branch 'main' into issue-6833-windows-rocm-torchao
Imagineer99 Jul 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions studio/backend/tests/test_torchao_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import sys
from pathlib import Path
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -71,12 +72,58 @@ def test_default_spec_matches_table(monkeypatch):
assert mod._select_torchao_spec("2.9.0") == mod._TORCHAO_DEFAULT_SPEC


def test_skips_torchao_on_windows_rocm():
@pytest.mark.parametrize(
("rocm_windows_torch_installed", "installed_torch_is_windows_rocm"),
[
(True, False),
(False, True),
],
)
def test_skips_torchao_on_windows_rocm(
monkeypatch, tmp_path, rocm_windows_torch_installed, installed_torch_is_windows_rocm
):
"""The overrides step must skip torchao on Windows ROCm: no working build exists
there (it imports an absent c10d backend and crashes transformers.quantizers),
so the installer skips it and relies on the runtime stub instead."""
source = _INSTALL_SCRIPT.read_text(encoding = "utf-8")
# Branches on the Windows-ROCm marker set by _ensure_rocm_torch ...
assert "elif _rocm_windows_torch_installed:" in source
# ... and reports the skip in the progress label.
assert "dependency overrides (skipped, Windows ROCm)" in source
mod = _load_module(monkeypatch)
installed_specs: list[str] = []
progress_labels: list[str] = []

def _record_pip_install(*args, **kwargs):
installed_specs.extend(str(arg) for arg in args)
return 0

unstructured_plugin = tmp_path / "unstructured"
github_plugin = tmp_path / "github"
unstructured_plugin.mkdir()
github_plugin.mkdir()

subprocess_result = MagicMock()
subprocess_result.returncode = 0
subprocess_result.stdout = ""

monkeypatch.setenv("SKIP_STUDIO_BASE", "1")
monkeypatch.setattr(mod, "IS_WINDOWS", True)
monkeypatch.setattr(mod, "IS_MACOS", False)
monkeypatch.setattr(mod, "IS_MAC_ARM", False)
monkeypatch.setattr(mod, "NO_TORCH", False)
monkeypatch.setattr(mod, "_rocm_windows_torch_installed", rocm_windows_torch_installed)
monkeypatch.setattr(
mod, "_installed_torch_is_windows_rocm", lambda: installed_torch_is_windows_rocm
)
monkeypatch.setattr(mod, "_bootstrap_uv", lambda: False)
monkeypatch.setattr(mod, "_repair_bad_anyio", lambda: None)
monkeypatch.setattr(mod, "_ensure_rocm_torch", lambda: None)
monkeypatch.setattr(mod, "_ensure_cuda_torch", lambda: None)
monkeypatch.setattr(mod, "_has_usable_nvidia_gpu", lambda: True)
monkeypatch.setattr(mod, "run", lambda *args, **kwargs: None)
monkeypatch.setattr(mod, "pip_install", _record_pip_install)
monkeypatch.setattr(mod, "_progress", lambda label: progress_labels.append(label))
monkeypatch.setattr(mod, "LOCAL_DD_UNSTRUCTURED_PLUGIN", unstructured_plugin)
monkeypatch.setattr(mod, "LOCAL_DD_GITHUB_PLUGIN", github_plugin)
monkeypatch.setattr(mod.subprocess, "run", lambda *args, **kwargs: subprocess_result)

assert mod.install_python_stack() == 0

assert not any(spec.startswith("torchao") for spec in installed_specs)
assert "dependency overrides (skipped, Windows ROCm)" in progress_labels
36 changes: 35 additions & 1 deletion studio/install_python_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,40 @@ def _probe_installed_torch_version() -> str | None:
return lines[-1] if lines else None


def _installed_torch_is_windows_rocm() -> bool:
"""Return True when the target venv currently has a Windows ROCm torch build.

This is a belt-and-suspenders guard for the torchao override step: if the
earlier ROCm install path failed to set _rocm_windows_torch_installed but the
venv already contains a ROCm torch wheel, still skip torchao because it
crashes on import on Windows ROCm.
"""
if not IS_WINDOWS:
return False
try:
probe = subprocess.run(
[
sys.executable,
"-c",
(
"import sys, torch; "
"hip = getattr(getattr(torch, 'version', None), 'hip', None) or ''; "
"ver = getattr(torch, '__version__', '').lower(); "
"sys.stdout.write('yes' if (hip or 'rocm' in ver or 'rocmsdk' in ver) else '')"
),
],
stdout = subprocess.PIPE,
stderr = subprocess.DEVNULL,
text = True,
timeout = 90,
**_windows_hidden_subprocess_kwargs(),
)
except (OSError, subprocess.TimeoutExpired):
return False
lines = [line.strip() for line in (probe.stdout or "").splitlines() if line.strip()]
return probe.returncode == 0 and bool(lines and lines[-1] == "yes")


# constraints.txt caps new anyio resolutions at <4.14 (#6483), but an install
# from before the cap existed can already be stuck at 4.14+, which later
# constrained installs won't touch since it already satisfies mcp/fastmcp.
Expand Down Expand Up @@ -2256,7 +2290,7 @@ def _win_amd_smi_has_gpu(stdout: str) -> bool:
# (no working build; see below).
if NO_TORCH:
_progress("dependency overrides (skipped, no torch)")
elif _rocm_windows_torch_installed:
elif _rocm_windows_torch_installed or _installed_torch_is_windows_rocm():
# No working Windows ROCm torchao build: it imports an absent c10d backend
# and crashes transformers.quantizers. Studio stubs it at runtime, so
# installing it only ships a package that crashes on import -- skip it.
Expand Down
64 changes: 64 additions & 0 deletions tests/studio/install/test_rocm_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,70 @@ def _bad_probe(*a, **kw):
mock_bnb.assert_not_called()


class TestWindowsRocmTorchaoGuard:
"""Verify the torchao skip can detect an installed Windows ROCm torch build."""

def test_installed_torch_is_windows_rocm_accepts_rocm_probe(self):
rv = MagicMock()
rv.returncode = 0
rv.stdout = "yes"
with (
patch.object(stack_mod, "IS_WINDOWS", True),
patch.object(stack_mod.subprocess, "run", return_value = rv),
):
assert stack_mod._installed_torch_is_windows_rocm() is True

def test_installed_torch_is_windows_rocm_rejects_non_rocm_probe(self):
rv = MagicMock()
rv.returncode = 0
rv.stdout = ""
with (
patch.object(stack_mod, "IS_WINDOWS", True),
patch.object(stack_mod.subprocess, "run", return_value = rv),
):
assert stack_mod._installed_torch_is_windows_rocm() is False

def test_installed_torch_is_windows_rocm_is_non_windows_noop(self):
with patch.object(stack_mod, "IS_WINDOWS", False):
assert stack_mod._installed_torch_is_windows_rocm() is False

@patch.object(stack_mod, "_repair_bad_anyio")
@patch.object(stack_mod, "_ensure_rocm_torch")
@patch.object(stack_mod, "_ensure_cuda_torch")
@patch.object(stack_mod, "_has_usable_nvidia_gpu", return_value = True)
@patch.object(stack_mod, "run")
@patch.object(stack_mod, "pip_install")
def test_install_python_stack_skips_torchao_when_windows_rocm_torch_is_installed(
self, mock_pip, mock_run, mock_has_nvidia, mock_cuda, mock_rocm, mock_anyio, tmp_path
):
unstructured_plugin = tmp_path / "unstructured"
github_plugin = tmp_path / "github"
unstructured_plugin.mkdir()
github_plugin.mkdir()

subprocess_result = MagicMock()
subprocess_result.returncode = 0
subprocess_result.stdout = ""

with (
patch.dict(os.environ, {"SKIP_STUDIO_BASE": "1"}),
patch.object(stack_mod, "IS_WINDOWS", True),
patch.object(stack_mod, "IS_MACOS", False),
patch.object(stack_mod, "IS_MAC_ARM", False),
patch.object(stack_mod, "NO_TORCH", False),
patch.object(stack_mod, "_rocm_windows_torch_installed", False),
patch.object(stack_mod, "_bootstrap_uv", return_value = False),
patch.object(stack_mod, "_installed_torch_is_windows_rocm", return_value = True),
patch.object(stack_mod, "LOCAL_DD_UNSTRUCTURED_PLUGIN", unstructured_plugin),
patch.object(stack_mod, "LOCAL_DD_GITHUB_PLUGIN", github_plugin),
patch.object(stack_mod.subprocess, "run", return_value = subprocess_result),
):
assert stack_mod.install_python_stack() == 0

installed_specs = [str(arg) for call in mock_pip.call_args_list for arg in call.args]
assert not any("torchao" in arg for arg in installed_specs)


# TEST: worker.py -- Windows ROCm patches (source-level checks)


Expand Down
Loading