diff --git a/studio/backend/tests/test_torchao_select.py b/studio/backend/tests/test_torchao_select.py index a99eb4c45c..2d3dc5fbff 100644 --- a/studio/backend/tests/test_torchao_select.py +++ b/studio/backend/tests/test_torchao_select.py @@ -12,6 +12,7 @@ import sys from pathlib import Path +from unittest.mock import MagicMock import pytest @@ -71,12 +72,58 @@ def test_default_spec_matches_table(monkeypatch): assert mod._select_torchao_spec("2.9.0") == mod._TORCHAO_DEFAULT_SPEC -def test_skips_torchao_on_windows_rocm(): +@pytest.mark.parametrize( + ("rocm_windows_torch_installed", "installed_torch_is_windows_rocm"), + [ + (True, False), + (False, True), + ], +) +def test_skips_torchao_on_windows_rocm( + monkeypatch, tmp_path, rocm_windows_torch_installed, installed_torch_is_windows_rocm +): """The overrides step must skip torchao on Windows ROCm: no working build exists there (it imports an absent c10d backend and crashes transformers.quantizers), so the installer skips it and relies on the runtime stub instead.""" - source = _INSTALL_SCRIPT.read_text(encoding = "utf-8") - # Branches on the Windows-ROCm marker set by _ensure_rocm_torch ... - assert "elif _rocm_windows_torch_installed:" in source - # ... and reports the skip in the progress label. - assert "dependency overrides (skipped, Windows ROCm)" in source + mod = _load_module(monkeypatch) + installed_specs: list[str] = [] + progress_labels: list[str] = [] + + def _record_pip_install(*args, **kwargs): + installed_specs.extend(str(arg) for arg in args) + return 0 + + unstructured_plugin = tmp_path / "unstructured" + github_plugin = tmp_path / "github" + unstructured_plugin.mkdir() + github_plugin.mkdir() + + subprocess_result = MagicMock() + subprocess_result.returncode = 0 + subprocess_result.stdout = "" + + monkeypatch.setenv("SKIP_STUDIO_BASE", "1") + monkeypatch.setattr(mod, "IS_WINDOWS", True) + monkeypatch.setattr(mod, "IS_MACOS", False) + monkeypatch.setattr(mod, "IS_MAC_ARM", False) + monkeypatch.setattr(mod, "NO_TORCH", False) + monkeypatch.setattr(mod, "_rocm_windows_torch_installed", rocm_windows_torch_installed) + monkeypatch.setattr( + mod, "_installed_torch_is_windows_rocm", lambda: installed_torch_is_windows_rocm + ) + monkeypatch.setattr(mod, "_bootstrap_uv", lambda: False) + monkeypatch.setattr(mod, "_repair_bad_anyio", lambda: None) + monkeypatch.setattr(mod, "_ensure_rocm_torch", lambda: None) + monkeypatch.setattr(mod, "_ensure_cuda_torch", lambda: None) + monkeypatch.setattr(mod, "_has_usable_nvidia_gpu", lambda: True) + monkeypatch.setattr(mod, "run", lambda *args, **kwargs: None) + monkeypatch.setattr(mod, "pip_install", _record_pip_install) + monkeypatch.setattr(mod, "_progress", lambda label: progress_labels.append(label)) + monkeypatch.setattr(mod, "LOCAL_DD_UNSTRUCTURED_PLUGIN", unstructured_plugin) + monkeypatch.setattr(mod, "LOCAL_DD_GITHUB_PLUGIN", github_plugin) + monkeypatch.setattr(mod.subprocess, "run", lambda *args, **kwargs: subprocess_result) + + assert mod.install_python_stack() == 0 + + assert not any(spec.startswith("torchao") for spec in installed_specs) + assert "dependency overrides (skipped, Windows ROCm)" in progress_labels diff --git a/studio/install_python_stack.py b/studio/install_python_stack.py index 37805e2a57..439d3ffe7b 100644 --- a/studio/install_python_stack.py +++ b/studio/install_python_stack.py @@ -181,6 +181,40 @@ def _probe_installed_torch_version() -> str | None: return lines[-1] if lines else None +def _installed_torch_is_windows_rocm() -> bool: + """Return True when the target venv currently has a Windows ROCm torch build. + + This is a belt-and-suspenders guard for the torchao override step: if the + earlier ROCm install path failed to set _rocm_windows_torch_installed but the + venv already contains a ROCm torch wheel, still skip torchao because it + crashes on import on Windows ROCm. + """ + if not IS_WINDOWS: + return False + try: + probe = subprocess.run( + [ + sys.executable, + "-c", + ( + "import sys, torch; " + "hip = getattr(getattr(torch, 'version', None), 'hip', None) or ''; " + "ver = getattr(torch, '__version__', '').lower(); " + "sys.stdout.write('yes' if (hip or 'rocm' in ver or 'rocmsdk' in ver) else '')" + ), + ], + stdout = subprocess.PIPE, + stderr = subprocess.DEVNULL, + text = True, + timeout = 90, + **_windows_hidden_subprocess_kwargs(), + ) + except (OSError, subprocess.TimeoutExpired): + return False + lines = [line.strip() for line in (probe.stdout or "").splitlines() if line.strip()] + return probe.returncode == 0 and bool(lines and lines[-1] == "yes") + + # constraints.txt caps new anyio resolutions at <4.14 (#6483), but an install # from before the cap existed can already be stuck at 4.14+, which later # constrained installs won't touch since it already satisfies mcp/fastmcp. @@ -2256,7 +2290,7 @@ def _win_amd_smi_has_gpu(stdout: str) -> bool: # (no working build; see below). if NO_TORCH: _progress("dependency overrides (skipped, no torch)") - elif _rocm_windows_torch_installed: + elif _rocm_windows_torch_installed or _installed_torch_is_windows_rocm(): # No working Windows ROCm torchao build: it imports an absent c10d backend # and crashes transformers.quantizers. Studio stubs it at runtime, so # installing it only ships a package that crashes on import -- skip it. diff --git a/tests/studio/install/test_rocm_support.py b/tests/studio/install/test_rocm_support.py index 092e1803f8..c8f2053946 100644 --- a/tests/studio/install/test_rocm_support.py +++ b/tests/studio/install/test_rocm_support.py @@ -2296,6 +2296,70 @@ def _bad_probe(*a, **kw): mock_bnb.assert_not_called() +class TestWindowsRocmTorchaoGuard: + """Verify the torchao skip can detect an installed Windows ROCm torch build.""" + + def test_installed_torch_is_windows_rocm_accepts_rocm_probe(self): + rv = MagicMock() + rv.returncode = 0 + rv.stdout = "yes" + with ( + patch.object(stack_mod, "IS_WINDOWS", True), + patch.object(stack_mod.subprocess, "run", return_value = rv), + ): + assert stack_mod._installed_torch_is_windows_rocm() is True + + def test_installed_torch_is_windows_rocm_rejects_non_rocm_probe(self): + rv = MagicMock() + rv.returncode = 0 + rv.stdout = "" + with ( + patch.object(stack_mod, "IS_WINDOWS", True), + patch.object(stack_mod.subprocess, "run", return_value = rv), + ): + assert stack_mod._installed_torch_is_windows_rocm() is False + + def test_installed_torch_is_windows_rocm_is_non_windows_noop(self): + with patch.object(stack_mod, "IS_WINDOWS", False): + assert stack_mod._installed_torch_is_windows_rocm() is False + + @patch.object(stack_mod, "_repair_bad_anyio") + @patch.object(stack_mod, "_ensure_rocm_torch") + @patch.object(stack_mod, "_ensure_cuda_torch") + @patch.object(stack_mod, "_has_usable_nvidia_gpu", return_value = True) + @patch.object(stack_mod, "run") + @patch.object(stack_mod, "pip_install") + def test_install_python_stack_skips_torchao_when_windows_rocm_torch_is_installed( + self, mock_pip, mock_run, mock_has_nvidia, mock_cuda, mock_rocm, mock_anyio, tmp_path + ): + unstructured_plugin = tmp_path / "unstructured" + github_plugin = tmp_path / "github" + unstructured_plugin.mkdir() + github_plugin.mkdir() + + subprocess_result = MagicMock() + subprocess_result.returncode = 0 + subprocess_result.stdout = "" + + with ( + patch.dict(os.environ, {"SKIP_STUDIO_BASE": "1"}), + patch.object(stack_mod, "IS_WINDOWS", True), + patch.object(stack_mod, "IS_MACOS", False), + patch.object(stack_mod, "IS_MAC_ARM", False), + patch.object(stack_mod, "NO_TORCH", False), + patch.object(stack_mod, "_rocm_windows_torch_installed", False), + patch.object(stack_mod, "_bootstrap_uv", return_value = False), + patch.object(stack_mod, "_installed_torch_is_windows_rocm", return_value = True), + patch.object(stack_mod, "LOCAL_DD_UNSTRUCTURED_PLUGIN", unstructured_plugin), + patch.object(stack_mod, "LOCAL_DD_GITHUB_PLUGIN", github_plugin), + patch.object(stack_mod.subprocess, "run", return_value = subprocess_result), + ): + assert stack_mod.install_python_stack() == 0 + + installed_specs = [str(arg) for call in mock_pip.call_args_list for arg in call.args] + assert not any("torchao" in arg for arg in installed_specs) + + # TEST: worker.py -- Windows ROCm patches (source-level checks)