From 9982d778e2d5af03eb8272f0dde8dba8db37e82d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 24 Jun 2026 11:11:52 +0000 Subject: [PATCH 01/82] Add shared Xet to HTTP stall fallback (hf_xet_fallback, hf_cache_state) Hugging Face Xet downloads can hang on a blob with no progress and no exception, and a blocked native Xet thread cannot be killed in-process. Unsloth Studio already recovers from this; this lands the same logic in unsloth_zoo so both Unsloth main and Studio can share one implementation. - hf_cache_state.py: sparse-aware HF cache primitives (st_blocks based byte accounting, active .incomplete detection) over the active HF_HUB_CACHE. - hf_xet_fallback.py: a 180s no-progress watchdog plus a spawn-child download that keeps Xet primary and falls back to plain HTTP exactly once on a stall. hf_hub_download_with_xet_fallback handles a single file; the new snapshot_download_with_xet_fallback warms a whole repo in a killable child (the entrypoint Unsloth from_pretrained uses before its in-process load), with an in-process local_files_only fast path for a warm cache. Honors an UNSLOTH_DISABLE_XET knob and injectable prepare-for-http / scrub hooks so Studio can pass its marker-aware cache management. Both files are AGPL-3.0 (per-file), matching the Studio source. Tests: CPU-only, no network, no real subprocess (the download seam is faked). --- tests/test_hf_xet_fallback.py | 425 ++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 193 ++++++++++ unsloth_zoo/hf_xet_fallback.py | 620 +++++++++++++++++++++++++++++++++ 3 files changed, 1238 insertions(+) create mode 100644 tests/test_hf_xet_fallback.py create mode 100644 unsloth_zoo/hf_cache_state.py create mode 100644 unsloth_zoo/hf_xet_fallback.py diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py new file mode 100644 index 000000000..ceace39e7 --- /dev/null +++ b/tests/test_hf_xet_fallback.py @@ -0,0 +1,425 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. + +"""Tests for unsloth_zoo.hf_xet_fallback: the no-progress watchdog, the Xet->HTTP +transport policy, the per-file and whole-snapshot entrypoints, the UNSLOTH_DISABLE_XET +knob, and the HF_HUB_DISABLE_XET precondition the fallback rests on. + +CPU-only, no network, no real subprocess (the per-attempt download seam is +monkeypatched). The two modules under test are loaded directly via importlib so the +tests do not import the full ``unsloth_zoo`` package (which pulls in torch + GPU init). +""" + +from __future__ import annotations + +import importlib.util +import subprocess +import sys +import threading +import time +import types as _types +from pathlib import Path + +import pytest + +import huggingface_hub +from huggingface_hub import constants as hf_constants + +_ZOO_DIR = Path(__file__).resolve().parents[1] / "unsloth_zoo" + + +def _load(name: str, filename: str): + spec = importlib.util.spec_from_file_location(name, _ZOO_DIR / filename) + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + spec.loader.exec_module(module) + return module + + +# A package placeholder so ``from unsloth_zoo.hf_cache_state import ...`` inside +# hf_xet_fallback resolves to the file we load below, not the installed package. +if "unsloth_zoo" not in sys.modules: + _pkg = _types.ModuleType("unsloth_zoo") + _pkg.__path__ = [str(_ZOO_DIR)] + sys.modules["unsloth_zoo"] = _pkg + +_load("unsloth_zoo.hf_cache_state", "hf_cache_state.py") +xf = _load("unsloth_zoo.hf_xet_fallback", "hf_xet_fallback.py") + + +# --------------------------------------------------------------------------- # +# Watchdog: fires only on a constant-size .incomplete, sparse-aware byte total. +# --------------------------------------------------------------------------- # +REPO = "ztest/xet-watchdog" + + +@pytest.fixture +def hf_cache(tmp_path, monkeypatch): + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path)) + return tmp_path + + +def _blobs_dir(root: Path, repo_id: str = REPO) -> Path: + d = root / f"models--{repo_id.replace('/', '--')}" / "blobs" + d.mkdir(parents = True, exist_ok = True) + return d + + +def _wait(predicate, timeout: float = 2.0, step: float = 0.02) -> bool: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if predicate(): + return True + time.sleep(step) + return predicate() + + +def test_constant_incomplete_fires_stall(hf_cache): + blobs = _blobs_dir(hf_cache) + (blobs / "deadbeef.incomplete").write_bytes(b"\0" * 1024) # never grows + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3 + ) + try: + assert _wait( + lambda: len(calls) >= 1, timeout = 3.0 + ), "watchdog never fired on a constant-size .incomplete" + finally: + stop.set() + assert "stalled" in calls[0].lower() + + +def test_growing_incomplete_never_stalls(hf_cache): + blobs = _blobs_dir(hf_cache) + part = blobs / "growing.incomplete" + part.write_bytes(b"\0" * 1024) + + grow_stop = threading.Event() + + def _grow(): + size = 1024 + while not grow_stop.wait(0.05): + size += 4096 + part.write_bytes(b"\0" * size) + + grower = threading.Thread(target = _grow, daemon = True) + grower.start() + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3 + ) + try: + time.sleep(1.0) # well past stall_timeout, but bytes keep growing + assert calls == [], "watchdog fired despite continuous progress" + finally: + stop.set() + grow_stop.set() + + +def test_no_incomplete_never_stalls(hf_cache): + blobs = _blobs_dir(hf_cache) + (blobs / "finalized_blob").write_bytes(b"\0" * 4096) # no .incomplete + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3 + ) + try: + time.sleep(0.8) + assert calls == [], "watchdog fired with no active .incomplete" + finally: + stop.set() + + +def test_stall_fires_at_most_once(hf_cache): + blobs = _blobs_dir(hf_cache) + (blobs / "frozen.incomplete").write_bytes(b"\0" * 2048) + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.2 + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0) + time.sleep(0.6) # keep ticking; must not fire again + assert len(calls) == 1, f"on_stall fired {len(calls)} times, expected exactly 1" + finally: + stop.set() + + +def test_get_state_empty_cache(hf_cache): + assert xf.get_hf_download_state([REPO]) == (0, False) + + +def test_get_state_absent_cache_root(tmp_path, monkeypatch): + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path / "no-such-cache")) + assert xf.get_hf_download_state([REPO]) == (0, False) + + +def test_get_state_skips_local_paths(hf_cache): + # Filesystem paths are not HF repo IDs and must be ignored without error. + assert xf.get_hf_download_state(["/abs/path", "./rel", "~user", "c:\\x"]) == (0, False) + + +def test_get_state_sparse_aware(hf_cache): + blobs = _blobs_dir(hf_cache) + sparse = blobs / "sparse.incomplete" + with open(sparse, "wb") as f: + f.truncate(64 * 1024 * 1024) # large apparent size, few allocated blocks + st = sparse.stat() + if getattr(st, "st_blocks", 0) == 0: + pytest.skip("filesystem does not report st_blocks; sparse accounting unavailable") + total, has_incomplete = xf.get_hf_download_state([REPO]) + assert has_incomplete is True + assert total < st.st_size, "sparse partial counted at apparent size, not allocated blocks" + + +# --------------------------------------------------------------------------- # +# Transport policy: cached short-circuit, cancel, error propagation, the single +# Xet->HTTP fallback, the injected prepare seam, and the UNSLOTH_DISABLE_XET knob. +# _run_download_attempt is faked, so no real spawn. +# --------------------------------------------------------------------------- # +DL_REPO, FILE = "ztest/xet-dl", "model-Q4_K_XL.gguf" + + +@pytest.fixture(autouse = True) +def _no_real_cache_hit(monkeypatch): + """Default: the file cached probe misses and the snapshot fast path misses, so + tests exercise the download seam unless they override these.""" + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: None) + + def _snap_miss(*a, **k): + raise FileNotFoundError("not fully cached") + + monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap_miss) + # Neutralize the generic cache purge by default; tests that care record it. + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a, **k: None) + # No env knob unless a test sets it. + monkeypatch.delenv("UNSLOTH_DISABLE_XET", raising = False) + monkeypatch.delenv("UNSLOTH_STABLE_DOWNLOADS", raising = False) + monkeypatch.delenv("HF_HUB_DISABLE_XET", raising = False) + + +class _FakeAttempt: + """Records calls to the download seam and returns scripted results. + + Matches unsloth_zoo.hf_xet_fallback._run_download_attempt's signature. + """ + + def __init__(self, results): + self._results = list(results) + self.calls = [] + + def __call__( + self, + repo_id, + *, + kind, + params, + token, + repo_type, + disable_xet, + cancel_event, + stall_timeout, + interval, + grace_period, + on_status, + ): + self.calls.append( + _types.SimpleNamespace( + repo_id = repo_id, + kind = kind, + target = params.get("filename", repo_id), + disable_xet = disable_xet, + repo_type = repo_type, + ) + ) + return self._results[len(self.calls) - 1] + + +def _install(monkeypatch, results): + fake = _FakeAttempt(results) + monkeypatch.setattr(xf, "_run_download_attempt", fake) + return fake + + +def test_cached_file_short_circuits(monkeypatch, tmp_path): + cached = tmp_path / "cached.gguf" + cached.write_bytes(b"\0" * 8) + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: str(cached)) + fake = _install(monkeypatch, []) # must not be called + + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == str(cached) + assert fake.calls == [], "spawned a download for an already-cached file" + + +def test_cancel_before_start_raises_no_attempt(monkeypatch): + fake = _install(monkeypatch, []) + ev = threading.Event() + ev.set() + with pytest.raises(RuntimeError, match = "Cancelled"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cancel_event = ev) + assert fake.calls == [] + + +def test_nonstall_error_propagates_without_fallback(monkeypatch): + fake = _install(monkeypatch, [("error", "RepositoryNotFoundError: 404 not found")]) + with pytest.raises(RuntimeError, match = "RepositoryNotFoundError"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1, "deterministic error must not trigger an HTTP fallback" + assert fake.calls[0].disable_xet is False + + +def test_immediate_success_uses_xet_only(monkeypatch): + prepared = [] + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a: prepared.append(a)) + fake = _install(monkeypatch, [("ok", "/cache/model.gguf")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/cache/model.gguf" + assert len(fake.calls) == 1 and fake.calls[0].disable_xet is False + assert prepared == [], "no cache prep should run when Xet succeeds first try" + + +def test_stall_then_http_fallback_succeeds(monkeypatch): + prepared = [] + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda repo_type, repo_id: prepared.append((repo_type, repo_id))) + fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/model.gguf")]) + + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/cache/model.gguf" + assert len(fake.calls) == 2 + assert fake.calls[0].disable_xet is False # Xet first + assert fake.calls[1].disable_xet is True # HTTP fallback + assert prepared == [("model", DL_REPO)], "must prep cache for HTTP before the retry" + + +def test_injected_prepare_for_http_used(monkeypatch): + """Studio injects its marker-aware prepare; the generic default must not run.""" + monkeypatch.setattr( + xf, "_default_prepare_for_http", lambda *a: pytest.fail("generic prepare ran") + ) + injected = [] + _install(monkeypatch, [("stall", None), ("ok", "/cache/model.gguf")]) + out = xf.hf_hub_download_with_xet_fallback( + DL_REPO, FILE, None, prepare_for_http_fn = lambda rt, rid: injected.append((rt, rid)) + ) + assert out == "/cache/model.gguf" + assert injected == [("model", DL_REPO)] + + +def test_second_stall_raises_download_stall_error(monkeypatch): + fake = _install(monkeypatch, [("stall", None), ("stall", None)]) + with pytest.raises(xf.DownloadStallError): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 2 + + +def test_cancelled_midattempt_raises_no_fallback(monkeypatch): + fake = _install(monkeypatch, [("cancelled", None)]) + with pytest.raises(RuntimeError, match = "Cancelled"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1 + + +def test_per_file_independent_fallback(monkeypatch): + """A stalled shard falls back; a sibling shard that succeeds does not.""" + fake = _install(monkeypatch, [("ok", "/a"), ("stall", None), ("ok", "/b")]) + assert xf.hf_hub_download_with_xet_fallback(DL_REPO, "shardA.gguf", None) == "/a" + assert xf.hf_hub_download_with_xet_fallback(DL_REPO, "shardB.gguf", None) == "/b" + assert [c.disable_xet for c in fake.calls] == [False, False, True] + + +def test_unsloth_disable_xet_forces_http_first(monkeypatch): + """UNSLOTH_DISABLE_XET=1 skips the Xet attempt: first (and only) attempt is HTTP.""" + monkeypatch.setenv("UNSLOTH_DISABLE_XET", "1") + fake = _install(monkeypatch, [("ok", "/http/model.gguf")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/http/model.gguf" + assert len(fake.calls) == 1 and fake.calls[0].disable_xet is True + + +def test_unsloth_disable_xet_stall_raises_no_retry(monkeypatch): + """With the knob set, a stall on the (already HTTP) attempt does not retry.""" + monkeypatch.setenv("UNSLOTH_DISABLE_XET", "1") + fake = _install(monkeypatch, [("stall", None)]) + with pytest.raises(xf.DownloadStallError): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1 + + +# --------------------------------------------------------------------------- # +# Snapshot variant: in-process fast path on a warm cache, else watched download. +# --------------------------------------------------------------------------- # +def test_snapshot_fast_path_no_child(monkeypatch): + """A fully cached repo resolves in-process via local_files_only -- no attempt.""" + seen = {} + + def _snap(*a, **k): + seen["local_files_only"] = k.get("local_files_only") + return "/cache/snap-dir" + + monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap) + fake = _install(monkeypatch, []) # must not be called + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/snap-dir" + assert seen["local_files_only"] is True + assert fake.calls == [], "spawned a download for an already-cached snapshot" + + +def test_snapshot_stall_then_http(monkeypatch): + prepared = [] + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda rt, rid: prepared.append((rt, rid))) + fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/snap-dir")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/snap-dir" + assert [c.kind for c in fake.calls] == ["snapshot", "snapshot"] + assert [c.disable_xet for c in fake.calls] == [False, True] + assert prepared == [("model", DL_REPO)] + + +# --------------------------------------------------------------------------- # +# Precondition: HF_HUB_DISABLE_XET is read at import time, so assert its effect +# in a FRESH interpreter (huggingface/huggingface_hub#3266 once ignored it). +# --------------------------------------------------------------------------- # +def _safe_path() -> str: + import os + + return os.environ.get("PATH", "") + + +def test_disable_xet_constant_set_in_fresh_interpreter(): + code = ( + "from huggingface_hub import constants as c; " + "import sys; sys.exit(0 if c.HF_HUB_DISABLE_XET is True else 17)" + ) + proc = subprocess.run( + [sys.executable, "-c", code], + env = {"HF_HUB_DISABLE_XET": "1", "PATH": _safe_path()}, + capture_output = True, + text = True, + ) + assert proc.returncode == 0, ( + f"HF_HUB_DISABLE_XET=1 did not set constants.HF_HUB_DISABLE_XET=True " + f"(rc={proc.returncode}): {proc.stderr}" + ) + + +def test_default_leaves_xet_enabled(): + code = ( + "from huggingface_hub import constants as c; " + "import sys; sys.exit(0 if c.HF_HUB_DISABLE_XET is False else 17)" + ) + proc = subprocess.run( + [sys.executable, "-c", code], + env = {"PATH": _safe_path()}, # no HF_HUB_DISABLE_XET + capture_output = True, + text = True, + ) + assert proc.returncode == 0, ( + f"without the env var, constants.HF_HUB_DISABLE_XET was not False " + f"(rc={proc.returncode}): {proc.stderr}" + ) diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py new file mode 100644 index 000000000..734434d05 --- /dev/null +++ b/unsloth_zoo/hf_cache_state.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# This file is licensed under the GNU Affero General Public License v3.0 only +# (AGPL-3.0-only), unlike the rest of unsloth_zoo which is LGPL-3.0-or-later. It +# is the single shared home for the sparse-aware Hugging Face cache primitives +# used by the Xet -> HTTP stall fallback (unsloth_zoo.hf_xet_fallback) and by +# Unsloth Studio's download manager, which re-exports the names below. +# See . + +"""Sparse-aware introspection of the active Hugging Face hub cache. + +These helpers answer two questions for a repo's blobs under ``HF_HUB_CACHE``: +how many bytes are actually on disk (sparse-aware, so a partially written Xet / +``hf_transfer`` ``.incomplete`` is not mistaken for full-size progress) and +whether an ``.incomplete`` partial is present. The no-progress download watchdog +is built on exactly these two signals. + +Only the single active cache root (``huggingface_hub.constants.HF_HUB_CACHE``) is +scanned here; multi-root / legacy-cache enumeration and transport-marker logic +are download-manager concerns that live in the consumer, not in this module. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Iterator, Optional + + +__all__ = [ + "INCOMPLETE_SUFFIX", + "hf_cache_root", + "target_dir_name", + "repo_cache_dir_name", + "blob_bytes_present", + "latest_snapshot_dir", + "iter_active_repo_cache_dirs", + "repo_cache_dir_has_incomplete_blobs", + "has_active_incomplete_blobs", +] + + +INCOMPLETE_SUFFIX = ".incomplete" + + +def _safe_is_dir(path: Path) -> bool: + """``Path.is_dir()`` returning False instead of raising when the path or a + parent is unreadable (e.g. a restricted ``~/.cache/huggingface/hub``), so + cache enumeration skips that root rather than erroring.""" + try: + return path.is_dir() + except OSError: + return False + + +def hf_cache_root(*, create: bool = False) -> Optional[Path]: + """The active hub cache root (``HF_HUB_CACHE``), or None if unavailable. + + Read lazily so any cache redirect applied at import time (see + ``unsloth_zoo.hf_cache.redirect_hf_cache_if_readonly``) is honored. + """ + try: + from huggingface_hub import constants as hf_constants + except ImportError: + return None + root = Path(hf_constants.HF_HUB_CACHE) + if create: + try: + root.mkdir(parents = True, exist_ok = True) + except OSError: + return None + return root + return root if _safe_is_dir(root) else None + + +def target_dir_name(repo_type: str, repo_id: str) -> str: + return repo_cache_dir_name(repo_type, repo_id).lower() + + +def repo_cache_dir_name(repo_type: str, repo_id: str) -> str: + return f"{repo_type}s--{repo_id.replace('/', '--')}" + + +def _blob_dir_is_partial(blobs_dir: Path) -> bool: + try: + for blob in blobs_dir.iterdir(): + if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): + return True + except OSError: + return False + return False + + +def blob_bytes_present(path: Path) -> int: + """Sparse-aware on-disk size: XET / ``hf_transfer`` ``.incomplete`` partials + report a full ``st_size`` while only some blocks are allocated, so prefer + ``st_blocks``, falling back to ``st_size`` where it is unreported (Windows, + some network filesystems).""" + st = path.stat() + blocks = getattr(st, "st_blocks", 0) + if blocks > 0: + return min(blocks * 512, st.st_size) + if sys.platform == "win32": + allocated = _windows_allocated_size(path) + if allocated is not None: + return min(allocated, st.st_size) + return st.st_size + + +def _windows_allocated_size(path: Path) -> Optional[int]: + """Best-effort allocated-byte count for sparse files on Windows.""" + if sys.platform != "win32": + return None + try: + import ctypes + from ctypes import wintypes + + kernel32 = ctypes.WinDLL("kernel32", use_last_error = True) + get_compressed_file_size = kernel32.GetCompressedFileSizeW + get_compressed_file_size.argtypes = [ + wintypes.LPCWSTR, + ctypes.POINTER(wintypes.DWORD), + ] + get_compressed_file_size.restype = wintypes.DWORD + + high = wintypes.DWORD(0) + ctypes.set_last_error(0) + low = get_compressed_file_size(str(path), ctypes.byref(high)) + if low == 0xFFFFFFFF and ctypes.get_last_error() != 0: + return None + return (int(high.value) << 32) + int(low) + except Exception: + return None + + +def latest_snapshot_dir(repo_dir: Path) -> Optional[Path]: + """Newest immediate child of ``repo_dir/snapshots`` by mtime, or None. + + mtime is the signal huggingface_hub's from_pretrained resolves to, so this + points at whatever snapshot most recently landed on disk. + """ + snapshots_dir = repo_dir / "snapshots" + try: + if not snapshots_dir.is_dir(): + return None + snapshots = [entry for entry in snapshots_dir.iterdir() if entry.is_dir()] + if not snapshots: + return None + return max(snapshots, key = lambda entry: entry.stat().st_mtime) + except OSError: + return None + + +def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: + latest = latest_snapshot_dir(repo_dir) + if latest is None: + return False + try: + for entry in latest.rglob("*"): + if entry.is_symlink() and not entry.exists(): + return True + except OSError: + return False + return False + + +def iter_active_repo_cache_dirs(repo_type: str, repo_id: str) -> Iterator[Path]: + """Yield the repo's cache dir(s) under the single active ``HF_HUB_CACHE`` root.""" + root = hf_cache_root() + if root is None: + return + target = target_dir_name(repo_type, repo_id) + try: + for entry in root.iterdir(): + if entry.name.lower() == target: + yield entry + except OSError: + return + + +def repo_cache_dir_has_incomplete_blobs(repo_dir: Path) -> bool: + blobs_dir = repo_dir / "blobs" + return (blobs_dir.is_dir() and _blob_dir_is_partial(blobs_dir)) or ( + _repo_dir_has_broken_snapshot_symlinks(repo_dir) + ) + + +def has_active_incomplete_blobs(repo_type: str, repo_id: str) -> bool: + for entry in iter_active_repo_cache_dirs(repo_type, repo_id): + if repo_cache_dir_has_incomplete_blobs(entry): + return True + return False diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py new file mode 100644 index 000000000..48108be2c --- /dev/null +++ b/unsloth_zoo/hf_xet_fallback.py @@ -0,0 +1,620 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# This file is licensed under the GNU Affero General Public License v3.0 only +# (AGPL-3.0-only), unlike the rest of unsloth_zoo which is LGPL-3.0-or-later. It +# is the single shared home for the Xet -> HTTP stall fallback used by both +# Unsloth (FastModel.from_pretrained) and Unsloth Studio, which imports it. +# See . + +"""Xet-primary HF downloads with an automatic HTTP fallback on a no-progress stall. + +Xet (``hf_xet``) is the fast default but can hang with no progress and no +exception, and a blocked native thread cannot be killed. Keep Xet primary; fall +back to plain HTTP only when the parent observes a stall. ``HF_HUB_DISABLE_XET`` +is read at import time, so the fallback runs in a fresh ``spawn`` child (not a +thread) that sets the env before importing ``huggingface_hub``. Cached files +short-circuit with no child; deterministic errors (401/403/404/disk-full) and +cancellation propagate without a fallback. + +``hf_hub_download_with_xet_fallback`` downloads a single file; the new +``snapshot_download_with_xet_fallback`` does a whole repo (the entrypoint +Unsloth's ``from_pretrained`` uses to warm the cache in a killable child before +the in-process load). Studio-specific cache/secret/process helpers are used +best-effort (imported only if present) or injected, so the same code runs both +inside Studio and standalone. +""" + +from __future__ import annotations + +import importlib.util +import multiprocessing as mp +import logging +import os +import queue +import re +import signal +import sys +import threading +import time +from typing import Any, Callable, Optional + +from unsloth_zoo.hf_cache_state import ( + INCOMPLETE_SUFFIX, + blob_bytes_present, + has_active_incomplete_blobs, + hf_cache_root, + iter_active_repo_cache_dirs, +) + +logger = logging.getLogger(__name__) + +_CTX = mp.get_context("spawn") + +# Defaults match the existing Studio inference watchdog and hub shutdown deadline. +DEFAULT_HEARTBEAT_INTERVAL = 30.0 +DEFAULT_STALL_TIMEOUT = 180.0 +DEFAULT_GRACE_PERIOD = 10.0 +_POLL_INTERVAL = 0.5 + +# Hugging Face boolean env convention: 1 / ON / YES / TRUE, case-insensitive. +_TRUTHY = {"1", "true", "yes", "on"} + + +def _is_true(value: Optional[str]) -> bool: + return value is not None and str(value).strip().lower() in _TRUTHY + + +class DownloadStallError(RuntimeError): + """Raised when no download progress is observed for too long. + + Canonical home; Studio's orchestrator re-imports it so all paths share one type. + """ + + +def is_hf_xet_available() -> bool: + """True iff the ``hf_xet`` extra is importable (Hub uses it automatically).""" + try: + return importlib.util.find_spec("hf_xet") is not None + except Exception: + return False + + +def xet_force_disabled() -> bool: + """Whether the user has asked us to skip Xet up front (force HTTP). + + Honors the Unsloth knobs ``UNSLOTH_DISABLE_XET`` / ``UNSLOTH_STABLE_DOWNLOADS`` + and Hugging Face's own ``HF_HUB_DISABLE_XET``. + """ + return ( + _is_true(os.environ.get("UNSLOTH_DISABLE_XET")) + or _is_true(os.environ.get("UNSLOTH_STABLE_DOWNLOADS")) + or _is_true(os.environ.get("HF_HUB_DISABLE_XET")) + ) + + +def child_should_disable_xet(config: dict) -> bool: + """Single source of truth for the per-worker Xet env flip.""" + return bool(config.get("disable_xet")) + + +def _default_scrub_secrets(text: str, hf_token: Optional[str] = None) -> str: + """Best-effort redaction of a token / bearer credential from an error string.""" + if not text: + return text + out = text + if hf_token: + out = out.replace(hf_token, "***") + out = re.sub(r"hf_[A-Za-z0-9]{8,}", "***", out) + out = re.sub(r"([Bb]earer\s+)[A-Za-z0-9._\-]+", r"\1***", out) + return out + + +def _default_prepare_for_http(repo_type: str, repo_id: str) -> None: + """Generic 'make the partial safe for an HTTP resume': delete the repo's active + ``*.incomplete`` blobs (an HTTP resume over a sparse Xet/hf_transfer partial + silently corrupts the blob). Studio injects its marker-aware version instead.""" + try: + for entry in iter_active_repo_cache_dirs(repo_type, repo_id): + blobs_dir = entry / "blobs" + if not blobs_dir.is_dir(): + continue + for blob in blobs_dir.iterdir(): + if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): + try: + blob.unlink() + except FileNotFoundError: + continue + except Exception as e: + logger.debug("default prepare_for_http failed for %s: %s", repo_id, e) + + +def get_hf_download_state( + repo_ids: Optional[list[str]] = None, *, repo_type: str = "model" +) -> Optional[tuple[int, bool]]: + """Return ``(total_on_disk_bytes, has_incomplete)`` for the active HF cache. + + Sparse-aware (st_blocks based) so a sparse Xet/``hf_transfer`` ``.incomplete`` + is not mistaken for full-size progress. ``None`` means the state could not be + measured, so callers skip stall logic for that tick. + """ + try: + if hf_cache_root() is None: + return (0, False) + + total = 0 + has_incomplete = False + for repo_id in repo_ids or []: + # Skip local paths: HF IDs never start with / . ~ or contain "\". + if not repo_id or repo_id.startswith(("/", ".", "~")) or "\\" in repo_id: + continue + for entry in iter_active_repo_cache_dirs(repo_type, repo_id): + blobs_dir = entry / "blobs" + if not blobs_dir.is_dir(): + continue + for blob in blobs_dir.iterdir(): + try: + if blob.is_file(): + total += blob_bytes_present(blob) + except OSError: + pass + if has_active_incomplete_blobs(repo_type, repo_id): + has_incomplete = True + return (total, has_incomplete) + except Exception as e: + logger.debug("Failed to determine HF download state: %s", e) + return None + + +def start_watchdog( + *, + repo_ids: list[str], + on_stall: Callable[[str], None], + repo_type: str = "model", + interval: float = DEFAULT_HEARTBEAT_INTERVAL, + stall_timeout: float = DEFAULT_STALL_TIMEOUT, + xet_disabled: bool = False, + on_heartbeat: Optional[Callable[[str], None]] = None, +) -> threading.Event: + """Start a daemon thread that fires ``on_stall(message)`` exactly once iff a + ``*.incomplete`` is present AND the on-disk size is unchanged for + *stall_timeout* seconds. The timer resets while no ``*.incomplete`` exists, so + post-download init is never misread as a stall. Returns a stop event the + caller sets when the download phase ends. + """ + stop = threading.Event() + transport = "https" if xet_disabled else "xet" + fired = False + + def _beat() -> None: + nonlocal fired + state = get_hf_download_state(repo_ids, repo_type = repo_type) + last_size = state[0] if state is not None else 0 + last_change = time.monotonic() + + while not stop.wait(interval): + state = get_hf_download_state(repo_ids, repo_type = repo_type) + now = time.monotonic() + + if state is None: + if on_heartbeat is not None: + on_heartbeat(f"Downloading ({transport} transport)...") + continue + + current_size, has_incomplete = state + if current_size != last_size: + last_size = current_size + last_change = now + + # Reset unless .incomplete confirms an active download, so model init + # and lock waits are not counted as a stall. + if not has_incomplete: + last_change = now + elif now - last_change >= stall_timeout: + if not fired: + fired = True + on_stall( + f"Download appears stalled ({transport} transport) " + f"-- no progress for {int(now - last_change)}s" + ) + return + + if on_heartbeat is not None: + on_heartbeat(f"Downloading ({transport} transport)...") + + threading.Thread(target = _beat, daemon = True, name = "hf-xet-watchdog").start() + return stop + + +def _scrub_in_child(text: str, token: Optional[str]) -> str: + """Redact secrets from a child error string, preferring Studio's richer + patterns when running inside Studio, else the generic redaction.""" + try: + from hub.utils.download_registry import scrub_secrets # type: ignore + + return scrub_secrets(text, hf_token = token) + except Exception: + return _default_scrub_secrets(text, hf_token = token) + + +def _child_download(*, kind: str, params: dict, token: Optional[str], repo_type: str) -> str: + """Run the actual HF download for one attempt inside the spawn child.""" + if kind == "snapshot": + from huggingface_hub import snapshot_download + + return snapshot_download( + repo_id = params["repo_id"], + repo_type = repo_type, + token = token, + revision = params.get("revision"), + cache_dir = params.get("cache_dir"), + allow_patterns = params.get("allow_patterns"), + ignore_patterns = params.get("ignore_patterns"), + ) + + from huggingface_hub import hf_hub_download + + return hf_hub_download( + repo_id = params["repo_id"], + filename = params["filename"], + repo_type = repo_type, + token = token, + revision = params.get("revision"), + ) + + +def _download_child_entry( + *, + kind: str, + params: dict, + token: Optional[str], + repo_type: str, + disable_xet: bool, + result_queue: Any, +) -> None: + """Spawn-child entrypoint: download and report the result. + + Top-level and picklable. Sets the Xet env BEFORE importing huggingface_hub, + forms its own process group so the parent can kill the whole transfer, and + never logs the token or signed URLs. + """ + # Die with the parent on Linux when running under Studio (best-effort; the + # module is absent standalone, in which case there is nothing to bind to). + try: + from utils.process_lifetime import bind_current_process_to_parent_lifetime # type: ignore + + bind_current_process_to_parent_lifetime() + except Exception: + pass + + if hasattr(os, "setsid"): + try: + os.setsid() + except OSError: + pass + + if disable_xet: + os.environ["HF_HUB_DISABLE_XET"] = "1" + # Keep the HTTP writer sequential and resumable (hf_transfer leaves sparse + # partials a sequential resume cannot safely continue). + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" + os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") + + repo_id = params["repo_id"] + + # Test-only fault injection (never set in production): stall the Xet attempt + # so the watchdog + HTTP fallback can be exercised against a real repo. + if not disable_xet and os.environ.get("UNSLOTH_HF_XET_FORCE_STALL") == "1": + try: + from huggingface_hub.constants import HF_HUB_CACHE + + blobs = os.path.join(HF_HUB_CACHE, "models--" + repo_id.replace("/", "--"), "blobs") + os.makedirs(blobs, exist_ok = True) + with open(os.path.join(blobs, "xet-force-stall.incomplete"), "wb") as fh: + fh.write(b"\0" * 4096) + except OSError: + pass + while True: + time.sleep(3600) + + try: + path = _child_download(kind = kind, params = params, token = token, repo_type = repo_type) + result_queue.put({"ok": True, "path": path}) + except BaseException as e: # noqa: BLE001 - report every failure to the parent + result_queue.put({"ok": False, "error": _scrub_in_child(f"{type(e).__name__}: {e}", token)}) + + +def _terminate_process_group(proc: "mp.process.BaseProcess", grace_period: float) -> None: + """Kill *proc* and its whole process group (Xet may spawn helper procs). + + The child calls ``os.setsid()`` so its pgid equals its pid; signal via + ``os.killpg(pid, ...)`` -- NOT ``getpgid``, which before the child becomes a + group leader resolves to OUR group. SIGTERM, then SIGKILL after *grace_period*. + """ + pid = proc.pid + + def _signal_group(sig: int) -> None: + if pid is not None and hasattr(os, "killpg"): + try: + os.killpg(pid, sig) + return + except (ProcessLookupError, PermissionError, OSError): + pass + # Windows or pre-setsid: best effort on the single process. + try: + proc.terminate() if sig != getattr(signal, "SIGKILL", -9) else proc.kill() + except Exception: + pass + + _signal_group(getattr(signal, "SIGTERM", signal.SIGINT)) + proc.join(timeout = grace_period) + if proc.is_alive(): + _signal_group(getattr(signal, "SIGKILL", signal.SIGTERM)) + proc.join(timeout = 5.0) + + +def _run_download_attempt( + repo_id: str, + *, + kind: str, + params: dict, + token: Optional[str], + repo_type: str, + disable_xet: bool, + cancel_event: Optional[threading.Event], + stall_timeout: float, + interval: float, + grace_period: float, + on_status: Optional[Callable[[str], None]], +) -> tuple[str, Optional[str]]: + """Run one download in a spawn child supervised by the no-progress watchdog. + + Returns ``("ok", path)``, ``("stall", None)``, ``("cancelled", None)``, or + ``("error", message)``. This is the seam tests monkeypatch to avoid spawning. + """ + result_queue: Any = _CTX.Queue() + proc = _CTX.Process( + target = _download_child_entry, + kwargs = dict( + kind = kind, + params = params, + token = token, + repo_type = repo_type, + disable_xet = disable_xet, + result_queue = result_queue, + ), + daemon = True, + ) + proc.start() + + # Bind the child to the parent lifetime when running under Studio (best-effort). + try: + from utils.process_lifetime import adopt_pid # type: ignore + + adopt_pid(proc.pid) + except Exception: + pass + + stalled = threading.Event() + stop_watchdog = start_watchdog( + repo_ids = [repo_id], + on_stall = lambda msg: stalled.set(), + repo_type = repo_type, + interval = interval, + stall_timeout = stall_timeout, + xet_disabled = disable_xet, + on_heartbeat = on_status, + ) + + result: Optional[dict] = None + try: + while proc.is_alive(): + if cancel_event is not None and cancel_event.is_set(): + _terminate_process_group(proc, grace_period) + return ("cancelled", None) + if stalled.is_set(): + _terminate_process_group(proc, grace_period) + return ("stall", None) + try: + result = result_queue.get(timeout = _POLL_INTERVAL) + break + except queue.Empty: + continue + else: + # Process exited; drain any result it enqueued. + try: + result = result_queue.get_nowait() + except queue.Empty: + result = None + finally: + stop_watchdog.set() + proc.join(timeout = grace_period) + + if result is None: + return ( + "error", + f"download process for '{repo_id}' exited " + f"(code={proc.exitcode}) without a result", + ) + if result.get("ok"): + return ("ok", result["path"]) + return ("error", result.get("error") or "unknown download error") + + +def _download_with_xet_fallback( + *, + repo_id: str, + label: str, + kind: str, + params: dict, + token: Optional[str], + repo_type: str, + cancel_event: Optional[threading.Event], + stall_timeout: float, + interval: float, + grace_period: float, + on_status: Optional[Callable[[str], None]], + prepare_for_http_fn: Optional[Callable[[str, str], None]], +) -> str: + """Shared 2-attempt loop: Xet primary, HTTP on a stall. Returns the local path.""" + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + + prepare_for_http = prepare_for_http_fn or _default_prepare_for_http + # The Unsloth/HF knobs can force HTTP from the very first attempt. + disable_xet = xet_force_disabled() + + for attempt in range(2): + if disable_xet: + # Purge a non-HTTP partial before resuming over HTTP: an HTTP resume + # over a sparse Xet/hf_transfer partial silently corrupts the blob. + try: + prepare_for_http(repo_type, repo_id) + except Exception as e: + logger.debug("prepare_for_http failed for %s: %s", repo_id, e) + + kind_result, payload = _run_download_attempt( + repo_id, + kind = kind, + params = params, + token = token, + repo_type = repo_type, + disable_xet = disable_xet, + cancel_event = cancel_event, + stall_timeout = stall_timeout, + interval = interval, + grace_period = grace_period, + on_status = on_status, + ) + + if kind_result == "ok": + return payload # type: ignore[return-value] + if kind_result == "cancelled": + raise RuntimeError("Cancelled") + if kind_result == "error": + # Deterministic failure: the other transport would fail identically. + raise RuntimeError(payload) + # kind_result == "stall" + if not disable_xet: + logger.warning( + "Download stalled for '%s' -- retrying with HF_HUB_DISABLE_XET=1", label + ) + if on_status is not None: + on_status(f"{label}: Xet stalled, retrying over HTTP") + disable_xet = True + continue + raise DownloadStallError( + f"Download stalled for '{label}' even with HF_HUB_DISABLE_XET=1 " + f"-- check your network connection" + ) + + # Unreachable: the loop either returns or raises on each attempt. + raise DownloadStallError(f"Download failed for '{label}'") + + +def hf_hub_download_with_xet_fallback( + repo_id: str, + filename: str, + token: Optional[str], + *, + cancel_event: Optional[threading.Event] = None, + repo_type: str = "model", + revision: Optional[str] = None, + stall_timeout: float = DEFAULT_STALL_TIMEOUT, + interval: float = DEFAULT_HEARTBEAT_INTERVAL, + grace_period: float = DEFAULT_GRACE_PERIOD, + on_status: Optional[Callable[[str], None]] = None, + prepare_for_http_fn: Optional[Callable[[str, str], None]] = None, +) -> str: + """Download a single file with Xet primary and HTTP as a stall-only fallback. + + Returns the local cache path. Raises ``RuntimeError("Cancelled")`` if + *cancel_event* is set, re-raises a deterministic child error unchanged (no + fallback), and raises ``DownloadStallError`` only if BOTH transports stall. + """ + # Finalized blob already cached: return it with no child and no network. + try: + from huggingface_hub import try_to_load_from_cache + + cached = try_to_load_from_cache(repo_id, filename, repo_type = repo_type, revision = revision) + if isinstance(cached, str) and os.path.exists(cached): + return cached + except Exception as e: + logger.debug("Cached probe failed for %s/%s: %s", repo_id, filename, e) + + return _download_with_xet_fallback( + repo_id = repo_id, + label = f"{repo_id}/{filename}", + kind = "file", + params = {"repo_id": repo_id, "filename": filename, "revision": revision}, + token = token, + repo_type = repo_type, + cancel_event = cancel_event, + stall_timeout = stall_timeout, + interval = interval, + grace_period = grace_period, + on_status = on_status, + prepare_for_http_fn = prepare_for_http_fn, + ) + + +def snapshot_download_with_xet_fallback( + repo_id: str, + *, + revision: Optional[str] = None, + token: Optional[str] = None, + repo_type: str = "model", + cache_dir: Optional[str] = None, + allow_patterns: Optional[Any] = None, + ignore_patterns: Optional[Any] = None, + cancel_event: Optional[threading.Event] = None, + stall_timeout: float = DEFAULT_STALL_TIMEOUT, + interval: float = DEFAULT_HEARTBEAT_INTERVAL, + grace_period: float = DEFAULT_GRACE_PERIOD, + on_status: Optional[Callable[[str], None]] = None, + prepare_for_http_fn: Optional[Callable[[str, str], None]] = None, +) -> str: + """Download a whole repo snapshot with Xet primary and HTTP as a stall-only + fallback, returning the local snapshot dir. + + Used by Unsloth's ``from_pretrained`` to warm the cache in a killable child + BEFORE the in-process model load (which then hits a warm cache and cannot + hang on a native Xet thread). A fully cached repo short-circuits in-process + via ``local_files_only`` with no child and no network. + """ + # Fast path: everything already on disk -> resolve in-process (no Xet, no hang). + try: + from huggingface_hub import snapshot_download + + return snapshot_download( + repo_id = repo_id, + repo_type = repo_type, + revision = revision, + cache_dir = cache_dir, + allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, + local_files_only = True, + ) + except Exception as e: + logger.debug("Snapshot not fully cached for %s (%s); downloading.", repo_id, e) + + return _download_with_xet_fallback( + repo_id = repo_id, + label = repo_id, + kind = "snapshot", + params = { + "repo_id": repo_id, + "revision": revision, + "cache_dir": cache_dir, + "allow_patterns": allow_patterns, + "ignore_patterns": ignore_patterns, + }, + token = token, + repo_type = repo_type, + cancel_event = cancel_event, + stall_timeout = stall_timeout, + interval = interval, + grace_period = grace_period, + on_status = on_status, + prepare_for_http_fn = prepare_for_http_fn, + ) From 0bba57176dad3aaedf90849ef69b069a7d4fd640 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 24 Jun 2026 11:56:18 +0000 Subject: [PATCH 02/82] Address review: thread cache_dir through watchdog/cleanup; drop unused sys import - The no-progress watchdog and the HTTP-retry cache purge now honor a caller supplied snapshot cache_dir, not just HF_HUB_CACHE, so a stall under a custom cache is detected and the partial is purged before the HTTP fallback. Threaded through hf_cache_root / iter_active_repo_cache_dirs / has_active_incomplete_blobs (hf_cache_state) and get_hf_download_state / start_watchdog / _default_prepare_for_http / _run_download_attempt (hf_xet_fallback). - Remove the unused 'import sys' (ruff F401). - Add a regression test that a stall under a custom cache_dir is watched and cleaned. --- tests/test_hf_xet_fallback.py | 43 +++++++++++++++++++++++++++++-- unsloth_zoo/hf_cache_state.py | 35 +++++++++++++++---------- unsloth_zoo/hf_xet_fallback.py | 47 ++++++++++++++++++++++------------ 3 files changed, 93 insertions(+), 32 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index ceace39e7..c74adb72c 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -46,6 +46,9 @@ def _load(name: str, filename: str): _load("unsloth_zoo.hf_cache_state", "hf_cache_state.py") xf = _load("unsloth_zoo.hf_xet_fallback", "hf_xet_fallback.py") +# Real prep impl, captured before the autouse fixture stubs the module attribute. +_REAL_DEFAULT_PREPARE = xf._default_prepare_for_http + # --------------------------------------------------------------------------- # # Watchdog: fires only on a constant-size .incomplete, sparse-aware byte total. @@ -177,6 +180,42 @@ def test_get_state_sparse_aware(hf_cache): assert total < st.st_size, "sparse partial counted at apparent size, not allocated blocks" +def test_custom_cache_dir_is_watched_and_cleaned(tmp_path, monkeypatch): + """A stall under a caller-supplied snapshot ``cache_dir`` (not HF_HUB_CACHE) + must still be seen by the state probe, the watchdog, and the HTTP-prep purge.""" + default_cache = tmp_path / "default" + custom_cache = tmp_path / "custom" + default_cache.mkdir() + custom_cache.mkdir() + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(default_cache)) + + blobs = custom_cache / f"models--{REPO.replace('/', '--')}" / "blobs" + blobs.mkdir(parents = True) + partial = blobs / "stalled.incomplete" + partial.write_bytes(b"partial-bytes") + + # Default cache sees nothing; the custom cache sees the active partial. + assert xf.get_hf_download_state([REPO]) == (0, False) + total, has_incomplete = xf.get_hf_download_state([REPO], cache_dir = str(custom_cache)) + assert has_incomplete is True and total > 0 + + # The watchdog fires for the custom cache, not the (empty) default one. + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, cache_dir = str(custom_cache), + interval = 0.05, stall_timeout = 0.3, + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), "watchdog ignored the custom cache_dir" + finally: + stop.set() + + # The HTTP-prep purge removes the unsafe partial from the custom cache + # (call the real impl; the autouse fixture stubs the module attribute). + _REAL_DEFAULT_PREPARE("model", REPO, cache_dir = str(custom_cache)) + assert not partial.exists() + + # --------------------------------------------------------------------------- # # Transport policy: cached short-circuit, cancel, error propagation, the single # Xet->HTTP fallback, the injected prepare seam, and the UNSLOTH_DISABLE_XET knob. @@ -286,7 +325,7 @@ def test_immediate_success_uses_xet_only(monkeypatch): def test_stall_then_http_fallback_succeeds(monkeypatch): prepared = [] - monkeypatch.setattr(xf, "_default_prepare_for_http", lambda repo_type, repo_id: prepared.append((repo_type, repo_id))) + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda repo_type, repo_id, cache_dir = None: prepared.append((repo_type, repo_id))) fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/model.gguf")]) out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) @@ -372,7 +411,7 @@ def _snap(*a, **k): def test_snapshot_stall_then_http(monkeypatch): prepared = [] - monkeypatch.setattr(xf, "_default_prepare_for_http", lambda rt, rid: prepared.append((rt, rid))) + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda rt, rid, cache_dir = None: prepared.append((rt, rid))) fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/snap-dir")]) out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) assert out == "/cache/snap-dir" diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 734434d05..5fbd9a23f 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -54,17 +54,22 @@ def _safe_is_dir(path: Path) -> bool: return False -def hf_cache_root(*, create: bool = False) -> Optional[Path]: - """The active hub cache root (``HF_HUB_CACHE``), or None if unavailable. +def hf_cache_root(*, create: bool = False, cache_dir: "Optional[str | Path]" = None) -> Optional[Path]: + """The hub cache root to scan, or None if unavailable. - Read lazily so any cache redirect applied at import time (see + When *cache_dir* is given (a caller-supplied ``snapshot_download`` cache), it + is used verbatim; otherwise the active ``HF_HUB_CACHE`` is read lazily so any + redirect applied at import time (see ``unsloth_zoo.hf_cache.redirect_hf_cache_if_readonly``) is honored. """ - try: - from huggingface_hub import constants as hf_constants - except ImportError: - return None - root = Path(hf_constants.HF_HUB_CACHE) + if cache_dir is not None: + root = Path(cache_dir) + else: + try: + from huggingface_hub import constants as hf_constants + except ImportError: + return None + root = Path(hf_constants.HF_HUB_CACHE) if create: try: root.mkdir(parents = True, exist_ok = True) @@ -165,9 +170,11 @@ def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: return False -def iter_active_repo_cache_dirs(repo_type: str, repo_id: str) -> Iterator[Path]: - """Yield the repo's cache dir(s) under the single active ``HF_HUB_CACHE`` root.""" - root = hf_cache_root() +def iter_active_repo_cache_dirs( + repo_type: str, repo_id: str, *, cache_dir: "Optional[str | Path]" = None +) -> Iterator[Path]: + """Yield the repo's cache dir(s) under *cache_dir* (or the active ``HF_HUB_CACHE``).""" + root = hf_cache_root(cache_dir = cache_dir) if root is None: return target = target_dir_name(repo_type, repo_id) @@ -186,8 +193,10 @@ def repo_cache_dir_has_incomplete_blobs(repo_dir: Path) -> bool: ) -def has_active_incomplete_blobs(repo_type: str, repo_id: str) -> bool: - for entry in iter_active_repo_cache_dirs(repo_type, repo_id): +def has_active_incomplete_blobs( + repo_type: str, repo_id: str, *, cache_dir: "Optional[str | Path]" = None +) -> bool: + for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): if repo_cache_dir_has_incomplete_blobs(entry): return True return False diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 48108be2c..6b00db82b 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -34,7 +34,6 @@ import queue import re import signal -import sys import threading import time from typing import Any, Callable, Optional @@ -110,12 +109,14 @@ def _default_scrub_secrets(text: str, hf_token: Optional[str] = None) -> str: return out -def _default_prepare_for_http(repo_type: str, repo_id: str) -> None: +def _default_prepare_for_http( + repo_type: str, repo_id: str, *, cache_dir: Optional[str] = None +) -> None: """Generic 'make the partial safe for an HTTP resume': delete the repo's active ``*.incomplete`` blobs (an HTTP resume over a sparse Xet/hf_transfer partial silently corrupts the blob). Studio injects its marker-aware version instead.""" try: - for entry in iter_active_repo_cache_dirs(repo_type, repo_id): + for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): blobs_dir = entry / "blobs" if not blobs_dir.is_dir(): continue @@ -130,16 +131,20 @@ def _default_prepare_for_http(repo_type: str, repo_id: str) -> None: def get_hf_download_state( - repo_ids: Optional[list[str]] = None, *, repo_type: str = "model" + repo_ids: Optional[list[str]] = None, + *, + repo_type: str = "model", + cache_dir: Optional[str] = None, ) -> Optional[tuple[int, bool]]: - """Return ``(total_on_disk_bytes, has_incomplete)`` for the active HF cache. + """Return ``(total_on_disk_bytes, has_incomplete)`` for the HF cache being written. - Sparse-aware (st_blocks based) so a sparse Xet/``hf_transfer`` ``.incomplete`` - is not mistaken for full-size progress. ``None`` means the state could not be - measured, so callers skip stall logic for that tick. + Scans *cache_dir* when the download targets a caller-supplied cache, else the + active ``HF_HUB_CACHE``. Sparse-aware (st_blocks based) so a sparse Xet/ + ``hf_transfer`` ``.incomplete`` is not mistaken for full-size progress. ``None`` + means the state could not be measured, so callers skip stall logic for that tick. """ try: - if hf_cache_root() is None: + if hf_cache_root(cache_dir = cache_dir) is None: return (0, False) total = 0 @@ -148,7 +153,7 @@ def get_hf_download_state( # Skip local paths: HF IDs never start with / . ~ or contain "\". if not repo_id or repo_id.startswith(("/", ".", "~")) or "\\" in repo_id: continue - for entry in iter_active_repo_cache_dirs(repo_type, repo_id): + for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): blobs_dir = entry / "blobs" if not blobs_dir.is_dir(): continue @@ -158,7 +163,7 @@ def get_hf_download_state( total += blob_bytes_present(blob) except OSError: pass - if has_active_incomplete_blobs(repo_type, repo_id): + if has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir): has_incomplete = True return (total, has_incomplete) except Exception as e: @@ -171,6 +176,7 @@ def start_watchdog( repo_ids: list[str], on_stall: Callable[[str], None], repo_type: str = "model", + cache_dir: Optional[str] = None, interval: float = DEFAULT_HEARTBEAT_INTERVAL, stall_timeout: float = DEFAULT_STALL_TIMEOUT, xet_disabled: bool = False, @@ -179,8 +185,9 @@ def start_watchdog( """Start a daemon thread that fires ``on_stall(message)`` exactly once iff a ``*.incomplete`` is present AND the on-disk size is unchanged for *stall_timeout* seconds. The timer resets while no ``*.incomplete`` exists, so - post-download init is never misread as a stall. Returns a stop event the - caller sets when the download phase ends. + post-download init is never misread as a stall. Scans *cache_dir* when the + download targets a caller-supplied cache, else the active ``HF_HUB_CACHE``. + Returns a stop event the caller sets when the download phase ends. """ stop = threading.Event() transport = "https" if xet_disabled else "xet" @@ -188,12 +195,12 @@ def start_watchdog( def _beat() -> None: nonlocal fired - state = get_hf_download_state(repo_ids, repo_type = repo_type) + state = get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) last_size = state[0] if state is not None else 0 last_change = time.monotonic() while not stop.wait(interval): - state = get_hf_download_state(repo_ids, repo_type = repo_type) + state = get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) now = time.monotonic() if state is None: @@ -400,6 +407,7 @@ def _run_download_attempt( repo_ids = [repo_id], on_stall = lambda msg: stalled.set(), repo_type = repo_type, + cache_dir = params.get("cache_dir"), interval = interval, stall_timeout = stall_timeout, xet_disabled = disable_xet, @@ -460,7 +468,7 @@ def _download_with_xet_fallback( if cancel_event is not None and cancel_event.is_set(): raise RuntimeError("Cancelled") - prepare_for_http = prepare_for_http_fn or _default_prepare_for_http + cache_dir = params.get("cache_dir") # The Unsloth/HF knobs can force HTTP from the very first attempt. disable_xet = xet_force_disabled() @@ -468,8 +476,13 @@ def _download_with_xet_fallback( if disable_xet: # Purge a non-HTTP partial before resuming over HTTP: an HTTP resume # over a sparse Xet/hf_transfer partial silently corrupts the blob. + # The generic purge is cache_dir-aware; an injected (Studio) hook owns + # its own cache accounting and keeps the (repo_type, repo_id) signature. try: - prepare_for_http(repo_type, repo_id) + if prepare_for_http_fn is None: + _default_prepare_for_http(repo_type, repo_id, cache_dir = cache_dir) + else: + prepare_for_http_fn(repo_type, repo_id) except Exception as e: logger.debug("prepare_for_http failed for %s: %s", repo_id, e) From 8af1b5debd7ffe00329cf10dcbd9f76d886c6d60 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 24 Jun 2026 13:01:58 +0000 Subject: [PATCH 03/82] Address review 2: set transport env in parent before spawn; cache_dir on file path - Critical: the HTTP retry could still use Xet. The spawn child re-imports the (heavy) unsloth_zoo package -- importing huggingface_hub, which reads HF_HUB_DISABLE_XET into a module constant at import time -- BEFORE the child body ran, so a child-side os.environ assignment landed too late. Now set the transport env in the parent (under a lock) around proc.start() so the child inherits it from creation; the child still sets it defensively. Tests assert the child inherits HF_HUB_DISABLE_XET=1 on the HTTP retry and is left untouched on the Xet attempt. - hf_hub_download_with_xet_fallback now accepts cache_dir (symmetric with the snapshot variant); threaded into try_to_load_from_cache and the child download. --- tests/test_hf_xet_fallback.py | 90 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 33 +++++++++++-- 2 files changed, 120 insertions(+), 3 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index c74adb72c..4061ec4d8 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -13,6 +13,7 @@ from __future__ import annotations import importlib.util +import os import subprocess import sys import threading @@ -272,6 +273,7 @@ def __call__( repo_id = repo_id, kind = kind, target = params.get("filename", repo_id), + cache_dir = params.get("cache_dir"), disable_xet = disable_xet, repo_type = repo_type, ) @@ -390,6 +392,94 @@ def test_unsloth_disable_xet_stall_raises_no_retry(monkeypatch): assert len(fake.calls) == 1 +def test_file_path_accepts_cache_dir(monkeypatch): + """The single-file wrapper accepts cache_dir (no TypeError) and threads it through.""" + fake = _install(monkeypatch, [("ok", "/cache/model.gguf")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cache_dir = "/custom/cache") + assert out == "/cache/model.gguf" + assert fake.calls[0].cache_dir == "/custom/cache" + + +# --------------------------------------------------------------------------- # +# Spawn env-timing: the parent sets HF_HUB_DISABLE_XET before the child starts, +# so the child inherits it before re-importing huggingface_hub (whose constants +# cache the value at import). Uses a fake spawn context -- no real subprocess. +# --------------------------------------------------------------------------- # +class _FakeProc: + def __init__(self, recorder): + self._rec = recorder + self.pid = 4242 + self.exitcode = 0 + + def start(self): + self._rec["disable_xet"] = os.environ.get("HF_HUB_DISABLE_XET") + self._rec["hf_transfer"] = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") + + def is_alive(self): + return False + + def join(self, timeout = None): + pass + + +class _FakeQueue: + def __init__(self, result): + self._result = result + + def get(self, timeout = None): + return self._result + + def get_nowait(self): + return self._result + + def put(self, item): + pass + + +class _FakeCtx: + def __init__(self, recorder, result): + self._rec = recorder + self._result = result + + def Process(self, *, target = None, kwargs = None, daemon = None): + return _FakeProc(self._rec) + + def Queue(self): + return _FakeQueue(self._result) + + +def test_http_retry_sets_disable_xet_before_spawn(monkeypatch): + monkeypatch.delenv("HF_HUB_DISABLE_XET", raising = False) + monkeypatch.delenv("HF_HUB_ENABLE_HF_TRANSFER", raising = False) + rec: dict = {} + monkeypatch.setattr(xf, "_CTX", _FakeCtx(rec, {"ok": True, "path": "/cache/x"})) + + kind_result, payload = xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = True, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + assert (kind_result, payload) == ("ok", "/cache/x") + # Child inherited HTTP transport env at spawn time. + assert rec["disable_xet"] == "1" + assert rec["hf_transfer"] == "0" + # Parent env is restored afterwards (was unset). + assert "HF_HUB_DISABLE_XET" not in os.environ + + +def test_xet_attempt_does_not_force_disable_before_spawn(monkeypatch): + monkeypatch.delenv("HF_HUB_DISABLE_XET", raising = False) + rec: dict = {} + monkeypatch.setattr(xf, "_CTX", _FakeCtx(rec, {"ok": True, "path": "/cache/x"})) + xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + # On the Xet-first attempt we must NOT force-disable Xet for the child. + assert rec["disable_xet"] is None + + # --------------------------------------------------------------------------- # # Snapshot variant: in-process fast path on a warm cache, else watched download. # --------------------------------------------------------------------------- # diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 6b00db82b..b5afccc1b 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -56,6 +56,10 @@ DEFAULT_GRACE_PERIOD = 10.0 _POLL_INTERVAL = 0.5 +# Serializes the brief parent-env mutation around a child spawn (below) so +# concurrent downloads cannot observe each other's transport env. +_SPAWN_ENV_LOCK = threading.Lock() + # Hugging Face boolean env convention: 1 / ON / YES / TRUE, case-insensitive. _TRUTHY = {"1", "true", "yes", "on"} @@ -267,6 +271,7 @@ def _child_download(*, kind: str, params: dict, token: Optional[str], repo_type: repo_type = repo_type, token = token, revision = params.get("revision"), + cache_dir = params.get("cache_dir"), ) @@ -392,7 +397,26 @@ def _run_download_attempt( ), daemon = True, ) - proc.start() + # Set the transport env in THIS process around the spawn so the child inherits + # it from creation. HF reads HF_HUB_DISABLE_XET into constants at import time, + # and a spawn child re-imports the (heavy) unsloth_zoo package -- importing + # huggingface_hub -- before the child body runs, so a child-side os.environ + # assignment would land too late. The child still sets it too, defensively. + child_env = {"HF_HUB_DISABLE_PROGRESS_BARS": "1"} + if disable_xet: + child_env["HF_HUB_DISABLE_XET"] = "1" + child_env["HF_HUB_ENABLE_HF_TRANSFER"] = "0" + with _SPAWN_ENV_LOCK: + saved_env = {k: os.environ.get(k) for k in child_env} + try: + os.environ.update(child_env) + proc.start() + finally: + for k, v in saved_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v # Bind the child to the parent lifetime when running under Studio (best-effort). try: @@ -533,6 +557,7 @@ def hf_hub_download_with_xet_fallback( cancel_event: Optional[threading.Event] = None, repo_type: str = "model", revision: Optional[str] = None, + cache_dir: Optional[str] = None, stall_timeout: float = DEFAULT_STALL_TIMEOUT, interval: float = DEFAULT_HEARTBEAT_INTERVAL, grace_period: float = DEFAULT_GRACE_PERIOD, @@ -549,7 +574,9 @@ def hf_hub_download_with_xet_fallback( try: from huggingface_hub import try_to_load_from_cache - cached = try_to_load_from_cache(repo_id, filename, repo_type = repo_type, revision = revision) + cached = try_to_load_from_cache( + repo_id, filename, repo_type = repo_type, revision = revision, cache_dir = cache_dir + ) if isinstance(cached, str) and os.path.exists(cached): return cached except Exception as e: @@ -559,7 +586,7 @@ def hf_hub_download_with_xet_fallback( repo_id = repo_id, label = f"{repo_id}/{filename}", kind = "file", - params = {"repo_id": repo_id, "filename": filename, "revision": revision}, + params = {"repo_id": repo_id, "filename": filename, "revision": revision, "cache_dir": cache_dir}, token = token, repo_type = repo_type, cancel_event = cancel_event, From 6266d0811d161ec0fe204b97cb5e1641c3dd138e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 24 Jun 2026 23:48:35 +0000 Subject: [PATCH 04/82] Harden cleanup, watchdog and result drain (Gemini review) - prepare_for_http: continue past a locked/permission-denied partial (catch OSError, not just FileNotFoundError) so one bad blob does not abort cleanup of the rest. - get_hf_download_state: also skip drive-letter (":") repo ids so Windows absolute paths like C:/models are not mistaken for Hub ids. - watchdog: reset the no-progress timer on an unmeasurable tick so a long unmeasurable gap cannot trip a false stall the instant state is readable. - result drain: get(timeout=1.0) instead of get_nowait() so a child that exits microseconds before its queue feeder flushes is not misreported as 'exited without a result'. --- tests/test_hf_xet_fallback.py | 4 +++- unsloth_zoo/hf_xet_fallback.py | 25 ++++++++++++++++++++----- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 4061ec4d8..ef5ab9bb1 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -165,7 +165,9 @@ def test_get_state_absent_cache_root(tmp_path, monkeypatch): def test_get_state_skips_local_paths(hf_cache): # Filesystem paths are not HF repo IDs and must be ignored without error. - assert xf.get_hf_download_state(["/abs/path", "./rel", "~user", "c:\\x"]) == (0, False) + assert xf.get_hf_download_state( + ["/abs/path", "./rel", "~user", "c:\\x", "c:/x"] + ) == (0, False) def test_get_state_sparse_aware(hf_cache): diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index b5afccc1b..6b7adb6a2 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -128,7 +128,9 @@ def _default_prepare_for_http( if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): try: blob.unlink() - except FileNotFoundError: + except OSError: + # A locked / permission-denied blob (common on Windows) + # must not abort cleanup of the rest of the partials. continue except Exception as e: logger.debug("default prepare_for_http failed for %s: %s", repo_id, e) @@ -154,8 +156,14 @@ def get_hf_download_state( total = 0 has_incomplete = False for repo_id in repo_ids or []: - # Skip local paths: HF IDs never start with / . ~ or contain "\". - if not repo_id or repo_id.startswith(("/", ".", "~")) or "\\" in repo_id: + # Skip local paths: HF IDs never start with / . ~, contain "\", or a + # drive-letter ":" (e.g. C:/models or C:\models on Windows). + if ( + not repo_id + or repo_id.startswith(("/", ".", "~")) + or "\\" in repo_id + or ":" in repo_id + ): continue for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): blobs_dir = entry / "blobs" @@ -208,6 +216,10 @@ def _beat() -> None: now = time.monotonic() if state is None: + # Unmeasurable this tick (transient FS error): treat as progress + # so a long unmeasurable gap cannot trip a false stall the instant + # the state becomes readable again. + last_change = now if on_heartbeat is not None: on_heartbeat(f"Downloading ({transport} transport)...") continue @@ -453,9 +465,12 @@ def _run_download_attempt( except queue.Empty: continue else: - # Process exited; drain any result it enqueued. + # Process exited; drain any result it enqueued. Use a short timeout, + # not get_nowait(): the child can exit microseconds before its queue + # feeder flushes the pipe, and a bare get_nowait() would then spuriously + # report "exited without a result" on an otherwise successful download. try: - result = result_queue.get_nowait() + result = result_queue.get(timeout = 1.0) except queue.Empty: result = None finally: From 2619c825f8ba1ed788795bebf32b531c6ba0b7ea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 25 Jun 2026 08:18:33 +0000 Subject: [PATCH 05/82] Fix spawn from notebooks, case-collision deletes, stale symlinks, force_download (review 3) - spawn from notebook / stdin / python -c / unguarded script: multiprocessing 'spawn' re-opens __main__.__file__, which is missing or '' there, so proc.start() raised before the child ran and the download never happened. Point __main__ at this importable module just for the spawn (serialized under the existing lock) and restore it. Verified from a real context: the child now runs and returns the expected HF error instead of failing to spawn. - _default_prepare_for_http only deletes from an exact-case cache dir, or a single unambiguous case-insensitive match, so preparing HTTP for Org/Repo no longer purges a case-colliding org/repo on a case-sensitive filesystem. - HTTP prep now also clears broken snapshot symlinks, which the incomplete-state detector counts as active; otherwise the HTTP retry inherited stale state and re-tripped the watchdog. - snapshot/file fallbacks accept and forward force_download, skipping the warm cache short-circuit so force_download=True re-fetches in the killable child. - Added regression tests for all four (30 passed, 1 skipped). --- tests/test_hf_xet_fallback.py | 66 +++++++++++++ unsloth_zoo/hf_xet_fallback.py | 167 +++++++++++++++++++++++++-------- 2 files changed, 192 insertions(+), 41 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index ef5ab9bb1..c9938fd71 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -219,6 +219,46 @@ def test_custom_cache_dir_is_watched_and_cleaned(tmp_path, monkeypatch): assert not partial.exists() +def test_prepare_for_http_clears_broken_snapshot_symlink(tmp_path): + """A broken snapshot symlink is counted as active-incomplete state by the + detector, so HTTP prep must clear it too or the retry re-trips the watchdog.""" + repo = "ztest/broken-symlink" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + snap = repo_dir / "snapshots" / "abc123" + snap.mkdir(parents = True) + link = snap / "model.safetensors" + link.symlink_to(repo_dir / "blobs" / "missing-blob") # dangling + assert link.is_symlink() and not link.exists() + + # Detector treats the dangling link as active incomplete state. + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, True) + + _REAL_DEFAULT_PREPARE("model", repo, cache_dir = str(tmp_path)) + + assert not link.is_symlink(), "broken snapshot symlink not cleared by HTTP prep" + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, False) + + +def test_prepare_for_http_preserves_case_colliding_repo(tmp_path): + """On a case-sensitive filesystem, preparing HTTP for ``Org/Repo`` must purge + only its exact-case cache dir, never a case-colliding ``org/repo``.""" + upper = tmp_path / "models--Org--Repo" / "blobs" + lower = tmp_path / "models--org--repo" / "blobs" + upper.mkdir(parents = True) + lower.mkdir(parents = True) + if upper.parent.resolve() == lower.parent.resolve(): + pytest.skip("case-insensitive filesystem; cannot collide cache dirs") + upper_partial = upper / "a.incomplete" + lower_partial = lower / "b.incomplete" + upper_partial.write_bytes(b"x") + lower_partial.write_bytes(b"y") + + _REAL_DEFAULT_PREPARE("model", "Org/Repo", cache_dir = str(tmp_path)) + + assert not upper_partial.exists(), "exact-case partial should be purged" + assert lower_partial.exists(), "case-colliding repo's partial must be preserved" + + # --------------------------------------------------------------------------- # # Transport policy: cached short-circuit, cancel, error propagation, the single # Xet->HTTP fallback, the injected prepare seam, and the UNSLOTH_DISABLE_XET knob. @@ -276,6 +316,7 @@ def __call__( kind = kind, target = params.get("filename", repo_id), cache_dir = params.get("cache_dir"), + force_download = params.get("force_download"), disable_xet = disable_xet, repo_type = repo_type, ) @@ -512,6 +553,31 @@ def test_snapshot_stall_then_http(monkeypatch): assert prepared == [("model", DL_REPO)] +def test_force_download_skips_fast_path_and_threads(monkeypatch): + """force_download=True must bypass the warm-cache short-circuit and re-fetch in + the killable child, forwarding force_download into the download params.""" + def _snap(*a, **k): + pytest.fail("force_download must not take the local_files_only fast path") + + monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap) + fake = _install(monkeypatch, [("ok", "/cache/snap-dir")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, force_download = True) + assert out == "/cache/snap-dir" + assert len(fake.calls) == 1 and fake.calls[0].force_download is True + + +def test_force_download_file_skips_cache_probe(monkeypatch, tmp_path): + """The single-file path must also skip the cached-blob short-circuit and thread + force_download through when force_download=True.""" + cached = tmp_path / "cached.gguf" + cached.write_bytes(b"\0" * 8) + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: str(cached)) + fake = _install(monkeypatch, [("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, force_download = True) + assert out == "/cache/x" + assert len(fake.calls) == 1 and fake.calls[0].force_download is True + + # --------------------------------------------------------------------------- # # Precondition: HF_HUB_DISABLE_XET is read at import time, so assert its effect # in a FRESH interpreter (huggingface/huggingface_hub#3266 once ignored it). diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 6b7adb6a2..748100f5c 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -34,6 +34,7 @@ import queue import re import signal +import sys import threading import time from typing import Any, Callable, Optional @@ -44,6 +45,8 @@ has_active_incomplete_blobs, hf_cache_root, iter_active_repo_cache_dirs, + latest_snapshot_dir, + repo_cache_dir_name, ) logger = logging.getLogger(__name__) @@ -56,10 +59,14 @@ DEFAULT_GRACE_PERIOD = 10.0 _POLL_INTERVAL = 0.5 -# Serializes the brief parent-env mutation around a child spawn (below) so -# concurrent downloads cannot observe each other's transport env. +# Serializes the brief parent-env (and __main__.__file__) mutation around a child +# spawn (below) so concurrent downloads cannot observe each other's transport env. _SPAWN_ENV_LOCK = threading.Lock() +# Sentinel: "__main__.__file__ was not touched for this spawn" (distinct from a +# real saved value of None, which means the attribute was absent). +_UNSET = object() + # Hugging Face boolean env convention: 1 / ON / YES / TRUE, case-insensitive. _TRUTHY = {"1", "true", "yes", "on"} @@ -113,25 +120,65 @@ def _default_scrub_secrets(text: str, hf_token: Optional[str] = None) -> str: return out +def _destructive_repo_cache_dirs( + repo_type: str, repo_id: str, *, cache_dir: Optional[str] = None +) -> list: + """Repo cache dir(s) safe to delete from: an exact-case match, or a single + unambiguous case-insensitive match. + + ``iter_active_repo_cache_dirs`` matches case-insensitively, which is correct + for read-only state probing but unsafe for deletion: on a case-sensitive + filesystem with both ``models--Org--Repo`` and ``models--org--repo`` present, + preparing HTTP for ``Org/Repo`` would also delete the other repo's partial. + """ + entries = list(iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir)) + exact_name = repo_cache_dir_name(repo_type, repo_id) + exact = [entry for entry in entries if entry.name == exact_name] + if exact: + return exact + if len(entries) <= 1: + return entries + logger.debug( + "Ambiguous case-colliding cache dirs for %s; skipping destructive HTTP prep", repo_id + ) + return [] + + def _default_prepare_for_http( repo_type: str, repo_id: str, *, cache_dir: Optional[str] = None ) -> None: """Generic 'make the partial safe for an HTTP resume': delete the repo's active ``*.incomplete`` blobs (an HTTP resume over a sparse Xet/hf_transfer partial - silently corrupts the blob). Studio injects its marker-aware version instead.""" + silently corrupts the blob) and any broken snapshot symlinks the incomplete + detector counts as active (else the HTTP retry inherits stale 'incomplete' + state and trips the watchdog again). Studio injects its marker-aware version + instead.""" try: - for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): + for entry in _destructive_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): blobs_dir = entry / "blobs" - if not blobs_dir.is_dir(): - continue - for blob in blobs_dir.iterdir(): - if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): - try: - blob.unlink() - except OSError: - # A locked / permission-denied blob (common on Windows) - # must not abort cleanup of the rest of the partials. - continue + if blobs_dir.is_dir(): + for blob in blobs_dir.iterdir(): + if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): + try: + blob.unlink() + except OSError: + # A locked / permission-denied blob (common on Windows) + # must not abort cleanup of the rest of the partials. + continue + # repo_cache_dir_has_incomplete_blobs() also flags a broken snapshot + # symlink as active incomplete state; clear those too so the detector + # reads clean after prep. + latest = latest_snapshot_dir(entry) + if latest is not None: + try: + for link in latest.rglob("*"): + if link.is_symlink() and not link.exists(): + try: + link.unlink() + except OSError: + continue + except OSError: + pass except Exception as e: logger.debug("default prepare_for_http failed for %s: %s", repo_id, e) @@ -273,6 +320,7 @@ def _child_download(*, kind: str, params: dict, token: Optional[str], repo_type: cache_dir = params.get("cache_dir"), allow_patterns = params.get("allow_patterns"), ignore_patterns = params.get("ignore_patterns"), + force_download = params.get("force_download", False), ) from huggingface_hub import hf_hub_download @@ -284,6 +332,7 @@ def _child_download(*, kind: str, params: dict, token: Optional[str], repo_type: token = token, revision = params.get("revision"), cache_dir = params.get("cache_dir"), + force_download = params.get("force_download", False), ) @@ -420,6 +469,19 @@ def _run_download_attempt( child_env["HF_HUB_ENABLE_HF_TRANSFER"] = "0" with _SPAWN_ENV_LOCK: saved_env = {k: os.environ.get(k) for k in child_env} + # multiprocessing 'spawn' re-opens __main__.__file__ in the child. From a + # notebook / `python -` / `python -c` / unguarded top-level script that + # file is missing or a pseudo-path like '', so proc.start() raises + # before the child runs and the download never happens. Point __main__ at + # this importable module (no top-level download side effects) just for the + # spawn, then restore it. + main_module = sys.modules.get("__main__") + saved_main_file = _UNSET + if main_module is not None: + main_file = getattr(main_module, "__file__", None) + if not main_file or str(main_file).startswith("<"): + saved_main_file = main_file + main_module.__file__ = __file__ try: os.environ.update(child_env) proc.start() @@ -429,6 +491,14 @@ def _run_download_attempt( os.environ.pop(k, None) else: os.environ[k] = v + if saved_main_file is not _UNSET: + if saved_main_file is None: + try: + delattr(main_module, "__file__") + except AttributeError: + pass + else: + main_module.__file__ = saved_main_file # Bind the child to the parent lifetime when running under Studio (best-effort). try: @@ -573,6 +643,7 @@ def hf_hub_download_with_xet_fallback( repo_type: str = "model", revision: Optional[str] = None, cache_dir: Optional[str] = None, + force_download: bool = False, stall_timeout: float = DEFAULT_STALL_TIMEOUT, interval: float = DEFAULT_HEARTBEAT_INTERVAL, grace_period: float = DEFAULT_GRACE_PERIOD, @@ -584,24 +655,33 @@ def hf_hub_download_with_xet_fallback( Returns the local cache path. Raises ``RuntimeError("Cancelled")`` if *cancel_event* is set, re-raises a deterministic child error unchanged (no fallback), and raises ``DownloadStallError`` only if BOTH transports stall. + ``force_download=True`` re-fetches even if cached (skips the cache short-circuit). """ - # Finalized blob already cached: return it with no child and no network. - try: - from huggingface_hub import try_to_load_from_cache + # Finalized blob already cached: return it with no child and no network + # (skipped when force_download re-fetches unconditionally). + if not force_download: + try: + from huggingface_hub import try_to_load_from_cache - cached = try_to_load_from_cache( - repo_id, filename, repo_type = repo_type, revision = revision, cache_dir = cache_dir - ) - if isinstance(cached, str) and os.path.exists(cached): - return cached - except Exception as e: - logger.debug("Cached probe failed for %s/%s: %s", repo_id, filename, e) + cached = try_to_load_from_cache( + repo_id, filename, repo_type = repo_type, revision = revision, cache_dir = cache_dir + ) + if isinstance(cached, str) and os.path.exists(cached): + return cached + except Exception as e: + logger.debug("Cached probe failed for %s/%s: %s", repo_id, filename, e) return _download_with_xet_fallback( repo_id = repo_id, label = f"{repo_id}/{filename}", kind = "file", - params = {"repo_id": repo_id, "filename": filename, "revision": revision, "cache_dir": cache_dir}, + params = { + "repo_id": repo_id, + "filename": filename, + "revision": revision, + "cache_dir": cache_dir, + "force_download": force_download, + }, token = token, repo_type = repo_type, cancel_event = cancel_event, @@ -622,6 +702,7 @@ def snapshot_download_with_xet_fallback( cache_dir: Optional[str] = None, allow_patterns: Optional[Any] = None, ignore_patterns: Optional[Any] = None, + force_download: bool = False, cancel_event: Optional[threading.Event] = None, stall_timeout: float = DEFAULT_STALL_TIMEOUT, interval: float = DEFAULT_HEARTBEAT_INTERVAL, @@ -635,23 +716,26 @@ def snapshot_download_with_xet_fallback( Used by Unsloth's ``from_pretrained`` to warm the cache in a killable child BEFORE the in-process model load (which then hits a warm cache and cannot hang on a native Xet thread). A fully cached repo short-circuits in-process - via ``local_files_only`` with no child and no network. + via ``local_files_only`` with no child and no network. ``force_download=True`` + re-fetches in the killable child even if cached (skips that short-circuit). """ - # Fast path: everything already on disk -> resolve in-process (no Xet, no hang). - try: - from huggingface_hub import snapshot_download - - return snapshot_download( - repo_id = repo_id, - repo_type = repo_type, - revision = revision, - cache_dir = cache_dir, - allow_patterns = allow_patterns, - ignore_patterns = ignore_patterns, - local_files_only = True, - ) - except Exception as e: - logger.debug("Snapshot not fully cached for %s (%s); downloading.", repo_id, e) + # Fast path: everything already on disk -> resolve in-process (no Xet, no + # hang). Skipped when force_download re-fetches unconditionally. + if not force_download: + try: + from huggingface_hub import snapshot_download + + return snapshot_download( + repo_id = repo_id, + repo_type = repo_type, + revision = revision, + cache_dir = cache_dir, + allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, + local_files_only = True, + ) + except Exception as e: + logger.debug("Snapshot not fully cached for %s (%s); downloading.", repo_id, e) return _download_with_xet_fallback( repo_id = repo_id, @@ -663,6 +747,7 @@ def snapshot_download_with_xet_fallback( "cache_dir": cache_dir, "allow_patterns": allow_patterns, "ignore_patterns": ignore_patterns, + "force_download": force_download, }, token = token, repo_type = repo_type, From dc0fddb108ffdebdc200ad02ac1eb0d9fe1bbd14 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 25 Jun 2026 08:48:02 +0000 Subject: [PATCH 06/82] Normalize repo_type=None and make the cache read path case-collision safe (review 4) - repo_type=None (Hugging Face's default model repo) resolved to a bogus Nones-- cache dir, so get_hf_download_state/the watchdog missed the real models-- partial and could let a stalled Xet child run forever. Normalize repo_type to 'model' in repo_cache_dir_name and at the public entrypoints. - The destructive HTTP-prep path was case-collision safe but the read/watchdog path still yielded every case-insensitive match, so a stale partial in a colliding org/repo could make the watchdog kill an active Org/Repo download. Move the exact-or-unambiguous guard into iter_active_repo_cache_dirs so both paths share one rule, and drop the now-redundant _destructive_repo_cache_dirs. - Added regression tests for both (32 passed, 1 skipped). --- tests/test_hf_xet_fallback.py | 28 +++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 49 ++++++++++++++++++++++++++-------- unsloth_zoo/hf_xet_fallback.py | 43 +++++++++-------------------- 3 files changed, 78 insertions(+), 42 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index c9938fd71..c0f813717 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -259,6 +259,34 @@ def test_prepare_for_http_preserves_case_colliding_repo(tmp_path): assert lower_partial.exists(), "case-colliding repo's partial must be preserved" +def test_repo_type_none_resolves_model_cache(hf_cache): + """A caller forwarding repo_type=None (HF's default model) must still see the + real models-- partial, not look up a bogus Nones-- dir.""" + blobs = _blobs_dir(hf_cache) + (blobs / "x.incomplete").write_bytes(b"abc") + + model_state = xf.get_hf_download_state([REPO], repo_type = "model") + none_state = xf.get_hf_download_state([REPO], repo_type = None) + assert model_state == none_state + assert none_state[1] is True and none_state[0] > 0 + + +def test_state_ignores_case_colliding_repo_partial(tmp_path, monkeypatch): + """The read/watchdog path attributes a partial only to an exact-case repo dir, + so a stale partial in a case-colliding repo cannot trip the watchdog.""" + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path)) + exact = tmp_path / "models--Org--Repo" / "blobs" + other = tmp_path / "models--org--repo" / "blobs" + exact.mkdir(parents = True) + other.mkdir(parents = True) + if exact.parent.resolve() == other.parent.resolve(): + pytest.skip("case-insensitive filesystem; cannot collide cache dirs") + (other / "stale.incomplete").write_bytes(b"x") # only the lowercase repo + + # Org/Repo has no partial of its own; the lowercase repo's must not count. + assert xf.get_hf_download_state(["Org/Repo"]) == (0, False) + + # --------------------------------------------------------------------------- # # Transport policy: cached short-circuit, cancel, error propagation, the single # Xet->HTTP fallback, the injected prepare seam, and the UNSLOTH_DISABLE_XET knob. diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 5fbd9a23f..868355c45 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -79,11 +79,15 @@ def hf_cache_root(*, create: bool = False, cache_dir: "Optional[str | Path]" = N return root if _safe_is_dir(root) else None -def target_dir_name(repo_type: str, repo_id: str) -> str: +def target_dir_name(repo_type: Optional[str], repo_id: str) -> str: return repo_cache_dir_name(repo_type, repo_id).lower() -def repo_cache_dir_name(repo_type: str, repo_id: str) -> str: +def repo_cache_dir_name(repo_type: Optional[str], repo_id: str) -> str: + # Hugging Face treats repo_type=None as the default "model"; mirror that here + # so a caller forwarding repo_type=None still resolves models-- (not + # Nones--, which would make the cache-state probe miss real partials). + repo_type = repo_type or "model" return f"{repo_type}s--{repo_id.replace('/', '--')}" @@ -170,20 +174,43 @@ def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: return False +def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: str) -> list: + """Cache dirs that can be safely attributed to this exact repo id. + + The cache dir name is case-folded by the Hub, so a case-insensitive match is + needed for compatibility, but a bare case-insensitive match is unsafe: on a + case-sensitive filesystem ``models--Org--Repo`` and ``models--org--repo`` are + distinct repos. Prefer an exact-case match; otherwise accept a single + unambiguous folded match; on a 2+ way collision attribute to neither, so a + stale partial in one repo cannot be charged to the other (which would let the + watchdog kill an unrelated active download or HTTP-prep purge the wrong repo). + """ + target = repo_cache_dir_name(repo_type, repo_id) + folded_target = target.lower() + try: + entries = [entry for entry in root.iterdir() if entry.name.lower() == folded_target] + except OSError: + return [] + exact = [entry for entry in entries if entry.name == target] + if exact: + return exact + if len(entries) <= 1: + return entries + return [] + + def iter_active_repo_cache_dirs( - repo_type: str, repo_id: str, *, cache_dir: "Optional[str | Path]" = None + repo_type: Optional[str], repo_id: str, *, cache_dir: "Optional[str | Path]" = None ) -> Iterator[Path]: - """Yield the repo's cache dir(s) under *cache_dir* (or the active ``HF_HUB_CACHE``).""" + """Yield the repo's cache dir(s) under *cache_dir* (or the active ``HF_HUB_CACHE``). + + Case-collision safe (see ``_case_safe_repo_cache_dirs``), so both the read / + watchdog path and the destructive HTTP-prep path share one attribution rule. + """ root = hf_cache_root(cache_dir = cache_dir) if root is None: return - target = target_dir_name(repo_type, repo_id) - try: - for entry in root.iterdir(): - if entry.name.lower() == target: - yield entry - except OSError: - return + yield from _case_safe_repo_cache_dirs(root, repo_type, repo_id) def repo_cache_dir_has_incomplete_blobs(repo_dir: Path) -> bool: diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 748100f5c..ed89055f7 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -46,7 +46,6 @@ hf_cache_root, iter_active_repo_cache_dirs, latest_snapshot_dir, - repo_cache_dir_name, ) logger = logging.getLogger(__name__) @@ -120,30 +119,6 @@ def _default_scrub_secrets(text: str, hf_token: Optional[str] = None) -> str: return out -def _destructive_repo_cache_dirs( - repo_type: str, repo_id: str, *, cache_dir: Optional[str] = None -) -> list: - """Repo cache dir(s) safe to delete from: an exact-case match, or a single - unambiguous case-insensitive match. - - ``iter_active_repo_cache_dirs`` matches case-insensitively, which is correct - for read-only state probing but unsafe for deletion: on a case-sensitive - filesystem with both ``models--Org--Repo`` and ``models--org--repo`` present, - preparing HTTP for ``Org/Repo`` would also delete the other repo's partial. - """ - entries = list(iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir)) - exact_name = repo_cache_dir_name(repo_type, repo_id) - exact = [entry for entry in entries if entry.name == exact_name] - if exact: - return exact - if len(entries) <= 1: - return entries - logger.debug( - "Ambiguous case-colliding cache dirs for %s; skipping destructive HTTP prep", repo_id - ) - return [] - - def _default_prepare_for_http( repo_type: str, repo_id: str, *, cache_dir: Optional[str] = None ) -> None: @@ -152,9 +127,13 @@ def _default_prepare_for_http( silently corrupts the blob) and any broken snapshot symlinks the incomplete detector counts as active (else the HTTP retry inherits stale 'incomplete' state and trips the watchdog again). Studio injects its marker-aware version - instead.""" + instead. + + ``iter_active_repo_cache_dirs`` is case-collision safe, so this destructive + purge only touches an exact-case (or single unambiguous) repo cache dir. + """ try: - for entry in _destructive_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): + for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): blobs_dir = entry / "blobs" if blobs_dir.is_dir(): for blob in blobs_dir.iterdir(): @@ -186,7 +165,7 @@ def _default_prepare_for_http( def get_hf_download_state( repo_ids: Optional[list[str]] = None, *, - repo_type: str = "model", + repo_type: Optional[str] = "model", cache_dir: Optional[str] = None, ) -> Optional[tuple[int, bool]]: """Return ``(total_on_disk_bytes, has_incomplete)`` for the HF cache being written. @@ -234,7 +213,7 @@ def start_watchdog( *, repo_ids: list[str], on_stall: Callable[[str], None], - repo_type: str = "model", + repo_type: Optional[str] = "model", cache_dir: Optional[str] = None, interval: float = DEFAULT_HEARTBEAT_INTERVAL, stall_timeout: float = DEFAULT_STALL_TIMEOUT, @@ -640,7 +619,7 @@ def hf_hub_download_with_xet_fallback( token: Optional[str], *, cancel_event: Optional[threading.Event] = None, - repo_type: str = "model", + repo_type: Optional[str] = "model", revision: Optional[str] = None, cache_dir: Optional[str] = None, force_download: bool = False, @@ -657,6 +636,7 @@ def hf_hub_download_with_xet_fallback( fallback), and raises ``DownloadStallError`` only if BOTH transports stall. ``force_download=True`` re-fetches even if cached (skips the cache short-circuit). """ + repo_type = repo_type or "model" # HF treats None as the default model repo. # Finalized blob already cached: return it with no child and no network # (skipped when force_download re-fetches unconditionally). if not force_download: @@ -698,7 +678,7 @@ def snapshot_download_with_xet_fallback( *, revision: Optional[str] = None, token: Optional[str] = None, - repo_type: str = "model", + repo_type: Optional[str] = "model", cache_dir: Optional[str] = None, allow_patterns: Optional[Any] = None, ignore_patterns: Optional[Any] = None, @@ -719,6 +699,7 @@ def snapshot_download_with_xet_fallback( via ``local_files_only`` with no child and no network. ``force_download=True`` re-fetches in the killable child even if cached (skips that short-circuit). """ + repo_type = repo_type or "model" # HF treats None as the default model repo. # Fast path: everything already on disk -> resolve in-process (no Xet, no # hang). Skipped when force_download re-fetches unconditionally. if not force_download: From 27c279588030eb52e5159be70cf473c834d593ae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 25 Jun 2026 10:02:46 +0000 Subject: [PATCH 07/82] Harden process-group kill, status callbacks, and cache_dir expansion (review 4 P2) - _terminate_process_group always sends the post-grace SIGKILL to the whole process group, even when the Python leader already exited on SIGTERM, so a surviving Xet helper cannot keep the stalled writer alive during HTTP cleanup. - Watchdog status/heartbeat callbacks are wrapped: a raising on_status (e.g. a disconnected Studio client) no longer kills the daemon thread and stops stall detection for a genuinely hung child. - hf_cache_root expands ~ in a custom cache_dir (as huggingface_hub does on write), so a stall under e.g. ~/hf-cache is still seen and cleaned. - Added regression tests (34 passed, 1 skipped). --- tests/test_hf_xet_fallback.py | 35 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 4 +++- unsloth_zoo/hf_xet_fallback.py | 25 +++++++++++++++++++----- 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index c0f813717..706ae729b 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -287,6 +287,41 @@ def test_state_ignores_case_colliding_repo_partial(tmp_path, monkeypatch): assert xf.get_hf_download_state(["Org/Repo"]) == (0, False) +def test_cache_dir_is_expanded(tmp_path, monkeypatch): + """A custom cache_dir with ~ must be expanded (as HF does on write), else the + state probe scans the literal '~/...' path and misses the partial.""" + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) # Windows home var + blobs = tmp_path / "hfcache" / f"models--{REPO.replace('/', '--')}" / "blobs" + blobs.mkdir(parents = True) + (blobs / "p.incomplete").write_bytes(b"abc") + + total, has_incomplete = xf.get_hf_download_state([REPO], cache_dir = "~/hfcache") + assert has_incomplete is True and total > 0 + + +def test_status_callback_failure_does_not_kill_watchdog(hf_cache): + """A raising on_heartbeat (e.g. a disconnected client) must not stop the + daemon watchdog from detecting a real stall and firing on_stall.""" + blobs = _blobs_dir(hf_cache) + (blobs / "x.incomplete").write_bytes(b"\0" * 1024) # constant size -> stalls + + def boom(_message): + raise RuntimeError("client disconnected") + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, on_heartbeat = boom, + interval = 0.05, stall_timeout = 0.3, + ) + try: + assert _wait( + lambda: len(calls) >= 1, timeout = 3.0 + ), "a raising on_heartbeat killed stall detection" + finally: + stop.set() + + # --------------------------------------------------------------------------- # # Transport policy: cached short-circuit, cancel, error propagation, the single # Xet->HTTP fallback, the injected prepare seam, and the UNSLOTH_DISABLE_XET knob. diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 868355c45..12f26e28d 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -63,7 +63,9 @@ def hf_cache_root(*, create: bool = False, cache_dir: "Optional[str | Path]" = N ``unsloth_zoo.hf_cache.redirect_hf_cache_if_readonly``) is honored. """ if cache_dir is not None: - root = Path(cache_dir) + # Match huggingface_hub, which expands ~ before writing; scanning the + # literal path would otherwise miss a partial under e.g. ~/hf-cache. + root = Path(cache_dir).expanduser() else: try: from huggingface_hub import constants as hf_constants diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index ed89055f7..67fcf57cc 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -74,6 +74,19 @@ def _is_true(value: Optional[str]) -> bool: return value is not None and str(value).strip().lower() in _TRUTHY +def _safe_status(callback: Optional[Callable[[str], None]], message: str) -> None: + """Invoke a caller status/heartbeat callback without letting it kill the + daemon watchdog thread. A disconnected Studio client can make on_status raise; + if that propagated, stall detection for a genuinely hung child would stop and + the HTTP retry would never fire.""" + if callback is None: + return + try: + callback(message) + except Exception as e: + logger.debug("watchdog status callback raised (ignored): %s", e) + + class DownloadStallError(RuntimeError): """Raised when no download progress is observed for too long. @@ -246,8 +259,7 @@ def _beat() -> None: # so a long unmeasurable gap cannot trip a false stall the instant # the state becomes readable again. last_change = now - if on_heartbeat is not None: - on_heartbeat(f"Downloading ({transport} transport)...") + _safe_status(on_heartbeat, f"Downloading ({transport} transport)...") continue current_size, has_incomplete = state @@ -268,8 +280,7 @@ def _beat() -> None: ) return - if on_heartbeat is not None: - on_heartbeat(f"Downloading ({transport} transport)...") + _safe_status(on_heartbeat, f"Downloading ({transport} transport)...") threading.Thread(target = _beat, daemon = True, name = "hf-xet-watchdog").start() return stop @@ -400,8 +411,12 @@ def _signal_group(sig: int) -> None: _signal_group(getattr(signal, "SIGTERM", signal.SIGINT)) proc.join(timeout = grace_period) + # Always send the post-grace SIGKILL to the whole group, even if the Python + # leader already exited on SIGTERM: a Xet helper left in the group can keep + # the stalled writer alive while the parent starts HTTP cleanup. killpg on an + # already-dead group is a no-op (ProcessLookupError is caught in _signal_group). + _signal_group(getattr(signal, "SIGKILL", signal.SIGTERM)) if proc.is_alive(): - _signal_group(getattr(signal, "SIGKILL", signal.SIGTERM)) proc.join(timeout = 5.0) From e93e669ee64205e8020ae626785f9ad8de30e07c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 25 Jun 2026 12:14:42 +0000 Subject: [PATCH 08/82] Let the download child skip unsloth_zoo's heavy GPU init (review 4 P1) The Xet fallback runs the download in a spawn child -- a fresh interpreter that re-imports unsloth_zoo.hf_xet_fallback and so re-runs unsloth_zoo/__init__.py, which on any non-MLX host imports torch + transformers and does device init the child never uses (~5-7s per download child). The cache/download submodules only need stdlib + huggingface_hub. Reuse the existing MLX lightweight-import path: __init__ now also honors an opt-in UNSLOTH_ZOO_DISABLE_GPU_INIT=1 (off by default, so normal CUDA/CPU runs are byte-for-byte unchanged), and the parent sets it in the child's environment only around spawning the download child (restored immediately, like the transport env). The unconditional HF cache redirect still runs, so the child writes the same cache. MLX stub injection is now gated on actual MLX, not the generic skip flag. Verified: normal import still runs full init (DEVICE_TYPE=cuda, torch+transformers imported); with the flag the helper imports with torch/transformers absent; a real download through the spawn child completes with zero heavy-init banners (child stays light) and the parent env is restored. 35 passed, 1 skipped. --- tests/test_hf_xet_fallback.py | 18 ++++++++++++++++++ unsloth_zoo/__init__.py | 21 ++++++++++++++++----- unsloth_zoo/hf_xet_fallback.py | 8 +++++++- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 706ae729b..a0a38bce3 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -520,6 +520,7 @@ def __init__(self, recorder): def start(self): self._rec["disable_xet"] = os.environ.get("HF_HUB_DISABLE_XET") self._rec["hf_transfer"] = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") + self._rec["skip_gpu_init"] = os.environ.get("UNSLOTH_ZOO_DISABLE_GPU_INIT") def is_alive(self): return False @@ -586,6 +587,23 @@ def test_xet_attempt_does_not_force_disable_before_spawn(monkeypatch): assert rec["disable_xet"] is None +def test_child_skips_gpu_init_env_set_before_spawn_and_restored(monkeypatch): + """The download child inherits UNSLOTH_ZOO_DISABLE_GPU_INIT=1 at spawn (so its + fresh unsloth_zoo import skips heavy torch/transformers init), and the parent's + env is restored afterwards.""" + monkeypatch.delenv("UNSLOTH_ZOO_DISABLE_GPU_INIT", raising = False) + rec: dict = {} + monkeypatch.setattr(xf, "_CTX", _FakeCtx(rec, {"ok": True, "path": "/cache/x"})) + + xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + assert rec["skip_gpu_init"] == "1" # set in the parent before proc.start() + assert "UNSLOTH_ZOO_DISABLE_GPU_INIT" not in os.environ # restored after + + # --------------------------------------------------------------------------- # # Snapshot variant: in-process fast path on a warm cache, else watched download. # --------------------------------------------------------------------------- # diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 821173deb..f86800e75 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -132,20 +132,31 @@ def has_429_exact_full_read(log_dir: str | Path) -> str: ALLOW_PREQUANTIZED_MODELS = True del _is_mlx_only, is_mlx_available, find_spec # Everything below this point is GPU-only. Use a flag to gate it. + _IS_MLX = True _SKIP_GPU_INIT = True else: - _SKIP_GPU_INIT = False + _IS_MLX = False + # Opt-in lightweight import. A short-lived helper subprocess that only needs + # the cache/download utilities (e.g. the unsloth_zoo.hf_xet_fallback download + # child) can set UNSLOTH_ZOO_DISABLE_GPU_INIT=1 to skip the heavy torch / + # transformers / device init it never uses. Off by default, so normal + # CUDA/CPU runs are byte-for-byte unchanged; the parent only sets it around + # spawning that child, never for a training/inference process. The + # unconditional HF cache redirect above still runs, so the child writes to the + # same cache as the parent. + _SKIP_GPU_INIT = os.environ.get("UNSLOTH_ZOO_DISABLE_GPU_INIT", "0") == "1" del _is_mlx_only, is_mlx_available -# Inject triton & bitsandbytes stubs on Apple Silicon with MLX so unsloth's -# CUDA-only imports don't error at startup. _SKIP_GPU_INIT is True only on -# Darwin/arm64 with mlx installed (the exact case stubs are needed). -if _SKIP_GPU_INIT: +# Inject triton & bitsandbytes stubs only on Apple Silicon with MLX so unsloth's +# CUDA-only imports don't error at startup (the generic light-import path never +# imports triton/bitsandbytes, so it must not mask the real modules). +if _IS_MLX: from .stubs.triton_stub import inject_into_sys_modules as _inject_triton _inject_triton() from .stubs.bitsandbytes_stub import inject_into_sys_modules as _inject_bnb _inject_bnb() del _inject_triton, _inject_bnb +del _IS_MLX # Lazy bridge for downstream code that still imports the old flat MLX module # names. Installed on every host so external scripts don't hit a hard diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 67fcf57cc..9405554a5 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -457,7 +457,13 @@ def _run_download_attempt( # and a spawn child re-imports the (heavy) unsloth_zoo package -- importing # huggingface_hub -- before the child body runs, so a child-side os.environ # assignment would land too late. The child still sets it too, defensively. - child_env = {"HF_HUB_DISABLE_PROGRESS_BARS": "1"} + child_env = { + "HF_HUB_DISABLE_PROGRESS_BARS": "1", + # The download child is a fresh spawn interpreter that only needs + # huggingface_hub; tell unsloth_zoo's __init__ to skip its heavy torch / + # transformers / device init in that process (the parent keeps full init). + "UNSLOTH_ZOO_DISABLE_GPU_INIT": "1", + } if disable_xet: child_env["HF_HUB_DISABLE_XET"] = "1" child_env["HF_HUB_ENABLE_HF_TRANSFER"] = "0" From e2300e55801feb1fa9e73c458621650d810c4295 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 25 Jun 2026 14:48:09 +0000 Subject: [PATCH 09/82] Address Codex review: spawn from unguarded scripts, partial-cache safety, robust callbacks/tokens P1: - Spawn child now always repoints __main__.__file__ at this side-effect-free module (not only for stdin/notebook callers). An unguarded top-level script (python script.py with a bare from_pretrained) previously had its file re-imported as __mp_main__ in the child, re-running the download and failing with the multiprocessing bootstrapping error; now the child imports this module instead and downloads. Verified end-to-end with a real unguarded script. - snapshot fast path no longer short-circuits a cached-but-incomplete snapshot (an interrupted download leaving .incomplete blobs / broken symlinks): it checks has_active_incomplete_blobs and completes the download in the killable child rather than letting from_pretrained load missing files. P2: - Xet to HTTP retry status callback wrapped in _safe_status so a raising on_status (disconnected client) cannot abort the recoverable retry. - hf_hub_download_with_xet_fallback defaults token to None (parity with hf_hub_download). - If an unsafe partial cannot be cleared before HTTP, force a clean re-download instead of an unsafe resume over a sparse partial. - _default_scrub_secrets tolerates token=True (no TypeError in the child scrubber). - Docstring: state the package imports with unsloth installed; the spawn child uses the UNSLOTH_ZOO_DISABLE_GPU_INIT lightweight path. 41 passed, 1 skipped. --- tests/test_hf_xet_fallback.py | 67 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 67 +++++++++++++++++++++++++--------- 2 files changed, 116 insertions(+), 18 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index a0a38bce3..82c50f4ee 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -521,6 +521,7 @@ def start(self): self._rec["disable_xet"] = os.environ.get("HF_HUB_DISABLE_XET") self._rec["hf_transfer"] = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") self._rec["skip_gpu_init"] = os.environ.get("UNSLOTH_ZOO_DISABLE_GPU_INIT") + self._rec["main_file"] = getattr(sys.modules.get("__main__"), "__file__", None) def is_alive(self): return False @@ -604,6 +605,72 @@ def test_child_skips_gpu_init_env_set_before_spawn_and_restored(monkeypatch): assert "UNSLOTH_ZOO_DISABLE_GPU_INIT" not in os.environ # restored after +def test_spawn_repoints_main_file_and_restores(monkeypatch): + """For an unguarded top-level caller script, the spawn child must import this + side-effect-free module as __mp_main__ rather than re-execute the caller, so the + parent repoints __main__.__file__ here at spawn and restores it afterwards.""" + main_mod = sys.modules["__main__"] + monkeypatch.setattr(main_mod, "__file__", "/fake/user_script.py", raising = False) + rec: dict = {} + monkeypatch.setattr(xf, "_CTX", _FakeCtx(rec, {"ok": True, "path": "/cache/x"})) + + xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + assert rec["main_file"] == xf.__file__ # child imports the helper, not the script + assert main_mod.__file__ == "/fake/user_script.py" # restored in the parent + + +def test_scrub_secrets_handles_boolean_token(): + """token=True ("use the cached token") must not crash the child error scrubber.""" + out = xf._default_scrub_secrets("auth failed for hf_abcdefghij", hf_token = True) + assert "hf_abcdefghij" not in out and "***" in out + + +def test_file_download_defaults_token_to_none(monkeypatch): + """The single-file helper accepts no token (parity with hf_hub_download).""" + fake = _install(monkeypatch, [("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE) # no token arg + assert out == "/cache/x" and len(fake.calls) == 1 + + +def test_incomplete_cached_snapshot_not_short_circuited(hf_cache, monkeypatch): + """A cached-but-incomplete snapshot (interrupted download) must not take the + fast path; it must complete in the killable child instead.""" + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: "/cache/snap") + (_blobs_dir(hf_cache, DL_REPO) / "x.incomplete").write_bytes(b"abc") # active partial + fake = _install(monkeypatch, [("ok", "/cache/snap-fresh")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/snap-fresh" and len(fake.calls) == 1 + + +def test_retry_status_failure_does_not_abort_fallback(monkeypatch): + """A raising on_status during the Xet->HTTP retry must not abort the fallback.""" + fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/x")]) + + def boom(_message): + raise RuntimeError("client gone") + + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, on_status = boom) + assert out == "/cache/x" + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_unclearable_partial_forces_clean_redownload(hf_cache, monkeypatch): + """When prep cannot clear an unsafe partial, the HTTP attempt forces a clean + re-download instead of an unsafe resume over the sparse partial.""" + # The autouse fixture makes _default_prepare_for_http a no-op (simulates a + # cleanup that left the partial in place). + (_blobs_dir(hf_cache, DL_REPO) / "x.incomplete").write_bytes(b"abc") + fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/x")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/x" + assert fake.calls[0].force_download is False # Xet attempt: not forced + assert fake.calls[1].force_download is True # HTTP attempt: forced clean + + # --------------------------------------------------------------------------- # # Snapshot variant: in-process fast path on a warm cache, else watched download. # --------------------------------------------------------------------------- # diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 9405554a5..08eb9725f 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -22,7 +22,13 @@ Unsloth's ``from_pretrained`` uses to warm the cache in a killable child before the in-process load). Studio-specific cache/secret/process helpers are used best-effort (imported only if present) or injected, so the same code runs both -inside Studio and standalone. +inside Unsloth Studio and in Unsloth itself. + +Like the rest of ``unsloth_zoo``, this module is imported with ``unsloth`` +installed; the package ``__init__`` runs its device init on first import. The +download spawn child does not need that and sets ``UNSLOTH_ZOO_DISABLE_GPU_INIT=1`` +before it imports the package, which selects ``unsloth_zoo``'s lightweight import +path (no torch/transformers), keeping each child fast. """ from __future__ import annotations @@ -125,7 +131,9 @@ def _default_scrub_secrets(text: str, hf_token: Optional[str] = None) -> str: if not text: return text out = text - if hf_token: + # HF callers commonly pass token=True ("use the cached token"); only a real + # string token can be substring-redacted (str.replace(True, ...) raises). + if isinstance(hf_token, str) and hf_token: out = out.replace(hf_token, "***") out = re.sub(r"hf_[A-Za-z0-9]{8,}", "***", out) out = re.sub(r"([Bb]earer\s+)[A-Za-z0-9._\-]+", r"\1***", out) @@ -469,19 +477,21 @@ def _run_download_attempt( child_env["HF_HUB_ENABLE_HF_TRANSFER"] = "0" with _SPAWN_ENV_LOCK: saved_env = {k: os.environ.get(k) for k in child_env} - # multiprocessing 'spawn' re-opens __main__.__file__ in the child. From a - # notebook / `python -` / `python -c` / unguarded top-level script that - # file is missing or a pseudo-path like '', so proc.start() raises - # before the child runs and the download never happens. Point __main__ at - # this importable module (no top-level download side effects) just for the - # spawn, then restore it. + # multiprocessing 'spawn' reconstructs __main__ in the child from + # __main__.__file__. If that is a pseudo-path ('', a notebook) the + # child fails to start; if it is a real but UNGUARDED caller script the + # child re-imports it as __mp_main__ and re-runs the top-level + # from_pretrained/download, hitting the "start a new process before + # bootstrapping" error -> the parent then sees the child exit without a + # result. In every case we only need the child to unpickle and run + # _download_child_entry, so point __main__ at THIS importable, side-effect + # -free module for the spawn (and restore it after). The child imports us + # as __mp_main__ instead of re-executing the caller's script. main_module = sys.modules.get("__main__") saved_main_file = _UNSET if main_module is not None: - main_file = getattr(main_module, "__file__", None) - if not main_file or str(main_file).startswith("<"): - saved_main_file = main_file - main_module.__file__ = __file__ + saved_main_file = getattr(main_module, "__file__", _UNSET) + main_module.__file__ = __file__ try: os.environ.update(child_env) proc.start() @@ -491,8 +501,9 @@ def _run_download_attempt( os.environ.pop(k, None) else: os.environ[k] = v - if saved_main_file is not _UNSET: - if saved_main_file is None: + if main_module is not None: + if saved_main_file is _UNSET: + # __file__ was absent before; remove the one we added. try: delattr(main_module, "__file__") except AttributeError: @@ -594,6 +605,16 @@ def _download_with_xet_fallback( prepare_for_http_fn(repo_type, repo_id) except Exception as e: logger.debug("prepare_for_http failed for %s: %s", repo_id, e) + # If an unsafe partial could not be cleared (e.g. a locked file or a + # permission error), an HTTP resume over a sparse Xet/hf_transfer + # partial would silently corrupt the blob. Force a clean re-download + # for this HTTP attempt instead of resuming over it. + if has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir): + logger.warning( + "Unsafe partial for '%s' could not be cleared; forcing a clean " + "HTTP re-download instead of an unsafe resume.", label + ) + params = {**params, "force_download": True} kind_result, payload = _run_download_attempt( repo_id, @@ -621,8 +642,10 @@ def _download_with_xet_fallback( logger.warning( "Download stalled for '%s' -- retrying with HF_HUB_DISABLE_XET=1", label ) - if on_status is not None: - on_status(f"{label}: Xet stalled, retrying over HTTP") + # _safe_status: a raising status hook (e.g. a disconnected client) must + # not abort the retry before disable_xet is set, turning a recoverable + # stall into a failed download. + _safe_status(on_status, f"{label}: Xet stalled, retrying over HTTP") disable_xet = True continue raise DownloadStallError( @@ -637,7 +660,7 @@ def _download_with_xet_fallback( def hf_hub_download_with_xet_fallback( repo_id: str, filename: str, - token: Optional[str], + token: Optional[str] = None, *, cancel_event: Optional[threading.Event] = None, repo_type: Optional[str] = "model", @@ -727,7 +750,7 @@ def snapshot_download_with_xet_fallback( try: from huggingface_hub import snapshot_download - return snapshot_download( + cached_dir = snapshot_download( repo_id = repo_id, repo_type = repo_type, revision = revision, @@ -736,6 +759,14 @@ def snapshot_download_with_xet_fallback( ignore_patterns = ignore_patterns, local_files_only = True, ) + # local_files_only returns a snapshot dir whenever refs/ and + # snapshots/ exist, even if a prior download was interrupted and + # left .incomplete blobs or broken symlinks. Only short-circuit when + # the cache is actually clean; otherwise complete it in the killable + # child so the in-process load does not proceed with missing files. + if not has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir): + return cached_dir + logger.debug("Cached snapshot for %s has incomplete state; downloading.", repo_id) except Exception as e: logger.debug("Snapshot not fully cached for %s (%s); downloading.", repo_id, e) From d53ab16498a8625936a5a49707cdc35af83d8910 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 25 Jun 2026 15:23:09 +0000 Subject: [PATCH 10/82] Validate the requested snapshot revision, not just the latest The snapshot fast path and the incomplete-state detector only inspected the newest snapshot by mtime. A caller requesting an older revision whose snapshot had a dangling symlink (an interrupted download) while a newer revision was clean would read the repo as healthy and load with missing files. - hf_cache_state.py: add snapshot_dir_has_broken_symlinks(snapshot_dir) and make the broken-symlink detector iterate every snapshot, not just the latest. - hf_xet_fallback.py: validate the exact returned revision dir in the local_files_only fast path; clear broken symlinks across all snapshots during HTTP prep; write the forced-stall test partial under the caller's cache_dir and the repo_type-correct dir name so the watchdog actually sees it. - tests: add regression coverage for the per-snapshot primitive, broken older snapshot detection, fast-path rejection of a broken requested revision, and all-snapshot symlink cleanup. --- tests/test_hf_xet_fallback.py | 70 +++++++++++++++++++++++++++++++++- unsloth_zoo/hf_cache_state.py | 34 ++++++++++++++--- unsloth_zoo/hf_xet_fallback.py | 36 +++++++++++++---- 3 files changed, 126 insertions(+), 14 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 82c50f4ee..3fbcdfa7f 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -44,7 +44,7 @@ def _load(name: str, filename: str): _pkg.__path__ = [str(_ZOO_DIR)] sys.modules["unsloth_zoo"] = _pkg -_load("unsloth_zoo.hf_cache_state", "hf_cache_state.py") +hcs = _load("unsloth_zoo.hf_cache_state", "hf_cache_state.py") xf = _load("unsloth_zoo.hf_xet_fallback", "hf_xet_fallback.py") # Real prep impl, captured before the autouse fixture stubs the module attribute. @@ -239,6 +239,74 @@ def test_prepare_for_http_clears_broken_snapshot_symlink(tmp_path): assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, False) +def test_snapshot_dir_has_broken_symlinks_unit(tmp_path): + """The new per-snapshot primitive flags a dangling link and is clean otherwise.""" + snap = tmp_path / "snapshots" / "sha" + snap.mkdir(parents = True) + good = snap / "config.json" + good.write_text("{}") + assert hcs.snapshot_dir_has_broken_symlinks(snap) is False + (snap / "model.safetensors").symlink_to(tmp_path / "blobs" / "missing") + assert hcs.snapshot_dir_has_broken_symlinks(snap) is True + + +def test_broken_older_snapshot_detected_when_newer_is_clean(tmp_path): + """Detector must inspect every snapshot, not just the newest by mtime: an older + revision with a dangling symlink must read as incomplete even when a more + recently landed snapshot is fully present.""" + repo = "ztest/two-snaps" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + old = repo_dir / "snapshots" / "oldsha" + new = repo_dir / "snapshots" / "newsha" + old.mkdir(parents = True) + new.mkdir(parents = True) + # Broken (older) revision; clean (newer) revision. + (old / "model.safetensors").symlink_to(repo_dir / "blobs" / "missing") + (new / "config.json").write_text("{}") + # Make the clean snapshot the newest by mtime so a latest-only check would + # report the repo healthy. + os.utime(new, (time.time() + 10, time.time() + 10)) + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, True) + + +def test_snapshot_fast_path_rejects_broken_requested_revision(tmp_path, monkeypatch): + """snapshot_download(local_files_only=True) can hand back an older requested + revision whose snapshot is broken while the repo-wide scan is clean. The fast + path must validate the EXACT returned dir and complete in the killable child + rather than short-circuiting to a snapshot with missing files.""" + snap = tmp_path / "snapshots" / "oldsha" + snap.mkdir(parents = True) + (snap / "model.safetensors").symlink_to(tmp_path / "blobs" / "missing") # dangling + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + # Repo-wide incomplete-blob scan sees nothing (empty cache root), so only the + # per-revision symlink check can catch the broken returned dir. + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path / "empty-cache")) + fake = _install(monkeypatch, [("ok", "/cache/snap-fresh")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/snap-fresh", "fast path returned a broken requested revision" + assert len(fake.calls) == 1 + + +def test_prepare_for_http_clears_broken_symlink_in_older_snapshot(tmp_path): + """HTTP prep must clear dangling links across all snapshots, not just the + newest, so the incomplete detector reads clean afterwards.""" + repo = "ztest/old-broken" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + old = repo_dir / "snapshots" / "oldsha" + new = repo_dir / "snapshots" / "newsha" + old.mkdir(parents = True) + new.mkdir(parents = True) + link = old / "model.safetensors" + link.symlink_to(repo_dir / "blobs" / "missing") # dangling, older snapshot + (new / "config.json").write_text("{}") + os.utime(new, (time.time() + 10, time.time() + 10)) # newer snapshot is clean + + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, True) + _REAL_DEFAULT_PREPARE("model", repo, cache_dir = str(tmp_path)) + assert not link.is_symlink(), "broken symlink in older snapshot not cleared" + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, False) + + def test_prepare_for_http_preserves_case_colliding_repo(tmp_path): """On a case-sensitive filesystem, preparing HTTP for ``Org/Repo`` must purge only its exact-case cache dir, never a case-colliding ``org/repo``.""" diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 12f26e28d..18c5701ef 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -35,6 +35,7 @@ "repo_cache_dir_name", "blob_bytes_present", "latest_snapshot_dir", + "snapshot_dir_has_broken_symlinks", "iter_active_repo_cache_dirs", "repo_cache_dir_has_incomplete_blobs", "has_active_incomplete_blobs", @@ -163,12 +164,13 @@ def latest_snapshot_dir(repo_dir: Path) -> Optional[Path]: return None -def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: - latest = latest_snapshot_dir(repo_dir) - if latest is None: - return False +def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: + """True if *snapshot_dir* contains a dangling symlink -- a file the snapshot + references whose blob is missing or still an ``.incomplete`` partial, i.e. an + interrupted download. Used to validate one specific (caller-requested) + revision, not just the newest one on disk.""" try: - for entry in latest.rglob("*"): + for entry in snapshot_dir.rglob("*"): if entry.is_symlink() and not entry.exists(): return True except OSError: @@ -176,6 +178,28 @@ def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: return False +def _iter_snapshot_dirs(repo_dir: Path) -> Iterator[Path]: + snapshots_dir = repo_dir / "snapshots" + try: + if not snapshots_dir.is_dir(): + return + children = [entry for entry in snapshots_dir.iterdir() if entry.is_dir()] + except OSError: + return + yield from children + + +def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: + # Check every snapshot, not just the newest by mtime: a caller may request an + # older revision whose snapshot is broken while a more recent one is clean, so + # a latest-only check would report the repo healthy and let the interrupted + # revision load with missing files. + return any( + snapshot_dir_has_broken_symlinks(snapshot) + for snapshot in _iter_snapshot_dirs(repo_dir) + ) + + def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: str) -> list: """Cache dirs that can be safely attributed to this exact repo id. diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 08eb9725f..587fd6647 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -43,6 +43,7 @@ import sys import threading import time +from pathlib import Path from typing import Any, Callable, Optional from unsloth_zoo.hf_cache_state import ( @@ -51,7 +52,7 @@ has_active_incomplete_blobs, hf_cache_root, iter_active_repo_cache_dirs, - latest_snapshot_dir, + snapshot_dir_has_broken_symlinks, ) logger = logging.getLogger(__name__) @@ -167,18 +168,25 @@ def _default_prepare_for_http( continue # repo_cache_dir_has_incomplete_blobs() also flags a broken snapshot # symlink as active incomplete state; clear those too so the detector - # reads clean after prep. - latest = latest_snapshot_dir(entry) - if latest is not None: + # reads clean after prep. Sweep EVERY snapshot, not just the newest: + # the broken-symlink detector now inspects all of them, so a stale + # dangling link under an older revision would otherwise keep the repo + # marked incomplete after prep and re-trip the watchdog. + snapshots_dir = entry / "snapshots" + try: + snapshot_dirs = [s for s in snapshots_dir.iterdir() if s.is_dir()] + except OSError: + snapshot_dirs = [] + for snapshot in snapshot_dirs: try: - for link in latest.rglob("*"): + for link in snapshot.rglob("*"): if link.is_symlink() and not link.exists(): try: link.unlink() except OSError: continue except OSError: - pass + continue except Exception as e: logger.debug("default prepare_for_http failed for %s: %s", repo_id, e) @@ -379,7 +387,13 @@ def _download_child_entry( try: from huggingface_hub.constants import HF_HUB_CACHE - blobs = os.path.join(HF_HUB_CACHE, "models--" + repo_id.replace("/", "--"), "blobs") + # Write the fake partial under the SAME cache the watchdog scans + # (params["cache_dir"] when the caller set one, else HF_HUB_CACHE) and + # under the repo_type-correct dir name, so has_active_incomplete_blobs + # sees it and the stall/HTTP fallback actually fires in tests. + cache_root = params.get("cache_dir") or HF_HUB_CACHE + repo_dir_name = f"{repo_type or 'model'}s--" + repo_id.replace("/", "--") + blobs = os.path.join(cache_root, repo_dir_name, "blobs") os.makedirs(blobs, exist_ok = True) with open(os.path.join(blobs, "xet-force-stall.incomplete"), "wb") as fh: fh.write(b"\0" * 4096) @@ -764,7 +778,13 @@ def snapshot_download_with_xet_fallback( # left .incomplete blobs or broken symlinks. Only short-circuit when # the cache is actually clean; otherwise complete it in the killable # child so the in-process load does not proceed with missing files. - if not has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir): + # Validate the EXACT returned revision dir (snapshot_download may hand + # back an older requested revision while a newer one is clean), plus + # the repo-wide .incomplete blob check. + if ( + not snapshot_dir_has_broken_symlinks(Path(cached_dir)) + and not has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir) + ): return cached_dir logger.debug("Cached snapshot for %s has incomplete state; downloading.", repo_id) except Exception as e: From 805091877a5fec5a1f78f02214c1783da3921a36 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 03:40:04 +0000 Subject: [PATCH 11/82] Keep the stub-injection gate on _SKIP_GPU_INIT; kill the download child on any exit The light-import env-gate change replaced the stub-injection gate `if _SKIP_GPU_INIT:` with `if _IS_MLX:`, which regressed the pinned-symbol guard test_apple_silicon_stub_injection_entrypoints_pinned (it requires the positive _SKIP_GPU_INIT gate, sub-bug (a) of commit 2053539). Restore the original gate and drop _IS_MLX: injecting the triton/bitsandbytes stubs whenever GPU init is skipped is the maintainer's invariant, and the download child never imports those modules, so the stubs are inert there. The MLX path is unchanged. Also harden _run_download_attempt: the finally only joined the child, so an unexpected exception (e.g. KeyboardInterrupt) could leak a live download process. Terminate the whole process group if it is still alive after the grace join. _terminate_process_group is idempotent, so the redundant call after the cancel/stall branch is a no-op. --- unsloth_zoo/__init__.py | 13 ++++++------- unsloth_zoo/hf_xet_fallback.py | 7 +++++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index f86800e75..6c0f1d93e 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -132,10 +132,8 @@ def has_429_exact_full_read(log_dir: str | Path) -> str: ALLOW_PREQUANTIZED_MODELS = True del _is_mlx_only, is_mlx_available, find_spec # Everything below this point is GPU-only. Use a flag to gate it. - _IS_MLX = True _SKIP_GPU_INIT = True else: - _IS_MLX = False # Opt-in lightweight import. A short-lived helper subprocess that only needs # the cache/download utilities (e.g. the unsloth_zoo.hf_xet_fallback download # child) can set UNSLOTH_ZOO_DISABLE_GPU_INIT=1 to skip the heavy torch / @@ -147,16 +145,17 @@ def has_429_exact_full_read(log_dir: str | Path) -> str: _SKIP_GPU_INIT = os.environ.get("UNSLOTH_ZOO_DISABLE_GPU_INIT", "0") == "1" del _is_mlx_only, is_mlx_available -# Inject triton & bitsandbytes stubs only on Apple Silicon with MLX so unsloth's -# CUDA-only imports don't error at startup (the generic light-import path never -# imports triton/bitsandbytes, so it must not mask the real modules). -if _IS_MLX: +# Inject triton & bitsandbytes stubs whenever GPU init is skipped: Apple Silicon +# with MLX (torch/CUDA absent), or the opt-in light-import download child. unsloth's +# CUDA-only imports then resolve to a loud no-op stub instead of a hard ImportError; +# the stub is never touched by the cache/download-only child, so it is inert there. +# On a normal CUDA/CPU run _SKIP_GPU_INIT is False and the real modules are untouched. +if _SKIP_GPU_INIT: from .stubs.triton_stub import inject_into_sys_modules as _inject_triton _inject_triton() from .stubs.bitsandbytes_stub import inject_into_sys_modules as _inject_bnb _inject_bnb() del _inject_triton, _inject_bnb -del _IS_MLX # Lazy bridge for downstream code that still imports the old flat MLX module # names. Installed on every host so external scripts don't hit a hard diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 587fd6647..d85dcbaf0 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -571,6 +571,13 @@ def _run_download_attempt( finally: stop_watchdog.set() proc.join(timeout = grace_period) + # Any exit from the loop -- normal completion, cancel/stall, or an + # unexpected exception (e.g. KeyboardInterrupt) -- must not leak the child. + # If it is still alive after the grace join, kill its whole process group. + # _terminate_process_group is idempotent, so a redundant call after the + # cancel/stall branch already terminated it is a harmless no-op. + if proc.is_alive(): + _terminate_process_group(proc, grace_period) if result is None: return ( From 36d19dd4963651c66b32b74d083f542597d5940e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 04:46:04 +0000 Subject: [PATCH 12/82] Harden spawn, cache attribution, offline, and secret scrubbing (Codex P2s) - spawn: also save/clear/restore __main__.__spec__, not just __file__. When the caller is launched as a module (python -m pkg), multiprocessing's spawn prep prefers __spec__.name over __file__ and re-imports the user's module by name, re-running its top-level from_pretrained in the child and hitting the bootstrapping error. Clearing __spec__ forces the path branch onto the repointed helper module. - hf_cache_state: only attribute a single folded-but-not-exact cache dir to a differently-cased repo when the filesystem is case-insensitive (the exact-case lookup resolves to the same entry). On a case-sensitive FS models--Org--Repo and models--org--repo are distinct repos, so a folded-only match is no longer charged here. - offline: add local_files_only to hf_hub_download_with_xet_fallback and snapshot_download_with_xet_fallback. When set, resolve from cache in-process and never spawn a network child, matching Hugging Face offline semantics. - cache_dir: expand ~ before the single-file cache probe (HF expands it before writing), so a finalized file under ~/hf-cache short-circuits instead of spawning a child; the expanded value is reused for the attempt. - scrub: redact the query string of presigned S3/CAS download URLs (X-Amz-Signature, sig, token, ...) in child error text so temporary credentials are not echoed to the parent and logged. Non-signed URLs keep their query. - tests: +5 regression tests (presigned-URL redaction, local_files_only file and snapshot no-child, expanded cache_dir probe, case-sensitive folded rejection). --- tests/test_hf_xet_fallback.py | 86 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 22 ++++++--- unsloth_zoo/hf_xet_fallback.py | 73 +++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+), 6 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 3fbcdfa7f..85fe1d412 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -355,6 +355,23 @@ def test_state_ignores_case_colliding_repo_partial(tmp_path, monkeypatch): assert xf.get_hf_download_state(["Org/Repo"]) == (0, False) +def test_single_folded_match_rejected_on_case_sensitive_fs(tmp_path, monkeypatch): + """A single folded-but-not-exact cache dir must not be attributed to a + differently-cased repo on a case-sensitive filesystem -- it is a different + repo, and charging its partial here could misread the watchdog or let HTTP-prep + delete it. Only an exact-case dir (or a folded dir the FS resolves to the same + entry on a case-insensitive FS) counts.""" + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path)) + lower = tmp_path / "models--org--repo" / "blobs" + lower.mkdir(parents = True) + if (tmp_path / "models--Org--Repo").exists(): + pytest.skip("case-insensitive filesystem; the folded dir is the same entry") + (lower / "stale.incomplete").write_bytes(b"x") # only the lowercase repo exists + # Request the exact-case repo, which has no dir of its own: the lowercase repo's + # partial must not be attributed to it. + assert xf.get_hf_download_state(["Org/Repo"]) == (0, False) + + def test_cache_dir_is_expanded(tmp_path, monkeypatch): """A custom cache_dir with ~ must be expanded (as HF does on write), else the state probe scans the literal '~/...' path and misses the partial.""" @@ -697,6 +714,75 @@ def test_scrub_secrets_handles_boolean_token(): assert "hf_abcdefghij" not in out and "***" in out +def test_scrub_redacts_presigned_url(): + """A presigned S3/CAS blob URL in a child error carries temporary credentials in + its query string; the default scrubber must redact the query before it is + raised/logged in the parent, while leaving non-signed URLs intact.""" + url = ( + "https://cas-bridge.xethub.hf.co/xet-bridge-us/abc/def" + "?X-Amz-Signature=deadbeefcafe&X-Amz-Credential=AKIAEXAMPLE123" + ) + out = xf._default_scrub_secrets(f"403 Client Error for url: {url}") + assert "X-Amz-Signature" not in out + assert "deadbeefcafe" not in out and "AKIAEXAMPLE123" not in out + assert "cas-bridge.xethub.hf.co/xet-bridge-us/abc/def?***" in out + # A non-signed URL keeps its (harmless) query string. + plain = xf._default_scrub_secrets("see https://huggingface.co/org/repo?download=true now") + assert "download=true" in plain + + +def test_local_files_only_file_resolves_in_process(monkeypatch): + """local_files_only resolves the single file from cache in-process and never + spawns a network child (Hugging Face offline semantics).""" + seen = {} + + def _dl(*a, **k): + seen.update(k) + return "/cache/file.gguf" + + monkeypatch.setattr(huggingface_hub, "hf_hub_download", _dl) + fake = _install(monkeypatch, []) # the download seam must not be called + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, local_files_only = True) + assert out == "/cache/file.gguf" + assert seen.get("local_files_only") is True + assert fake.calls == [], "local_files_only must not spawn a download child" + + +def test_local_files_only_snapshot_resolves_in_process(monkeypatch): + seen = {} + + def _snap(*a, **k): + seen.update(k) + return "/cache/snap" + + monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap) + fake = _install(monkeypatch, []) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, local_files_only = True) + assert out == "/cache/snap" + assert seen.get("local_files_only") is True + assert fake.calls == [], "local_files_only must not spawn a download child" + + +def test_file_probe_uses_expanded_cache_dir(monkeypatch, tmp_path): + """The single-file cache probe must use the expanded cache_dir (HF expands ~ + before writing), or a finalized file under ~/hf-cache is missed and a child is + spawned for an already-cached file.""" + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) # Windows home var + seen = {} + + def _probe(repo_id, filename, *, repo_type, revision, cache_dir): + seen["cache_dir"] = cache_dir + return None # not cached -> falls through to the (faked) download seam + + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", _probe) + fake = _install(monkeypatch, [("ok", "/cache/x")]) + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cache_dir = "~/hfcache") + assert seen["cache_dir"] == str(tmp_path / "hfcache") + # The expanded cache_dir is also what the download attempt receives. + assert fake.calls[0].cache_dir == str(tmp_path / "hfcache") + + def test_file_download_defaults_token_to_none(monkeypatch): """The single-file helper accepts no token (parity with hf_hub_download).""" fake = _install(monkeypatch, [("ok", "/cache/x")]) diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 18c5701ef..5b59b1a7b 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -206,10 +206,11 @@ def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: st The cache dir name is case-folded by the Hub, so a case-insensitive match is needed for compatibility, but a bare case-insensitive match is unsafe: on a case-sensitive filesystem ``models--Org--Repo`` and ``models--org--repo`` are - distinct repos. Prefer an exact-case match; otherwise accept a single - unambiguous folded match; on a 2+ way collision attribute to neither, so a - stale partial in one repo cannot be charged to the other (which would let the - watchdog kill an unrelated active download or HTTP-prep purge the wrong repo). + distinct repos. Prefer an exact-case match; otherwise accept a single folded + match ONLY when the filesystem is case-insensitive (so the folded dir really is + the same entry); on a 2+ way collision attribute to neither, so a stale partial + in one repo cannot be charged to the other (which would let the watchdog kill an + unrelated active download or HTTP-prep purge the wrong repo). """ target = repo_cache_dir_name(repo_type, repo_id) folded_target = target.lower() @@ -220,8 +221,17 @@ def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: st exact = [entry for entry in entries if entry.name == target] if exact: return exact - if len(entries) <= 1: - return entries + if len(entries) == 1: + # A single folded-but-not-exact match. Attribute it to this repo only when + # the filesystem is case-insensitive: looking up the exact-case name then + # resolves to that same directory. On a case-sensitive filesystem the + # exact-case path does not exist, so the folded dir is a DIFFERENT repo and + # must not be charged here. + try: + if (root / target).exists(): + return entries + except OSError: + return [] return [] diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index d85dcbaf0..29498b768 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -138,6 +138,20 @@ def _default_scrub_secrets(text: str, hf_token: Optional[str] = None) -> str: out = out.replace(hf_token, "***") out = re.sub(r"hf_[A-Za-z0-9]{8,}", "***", out) out = re.sub(r"([Bb]earer\s+)[A-Za-z0-9._\-]+", r"\1***", out) + # HF download errors can embed the presigned S3/CAS blob URL, whose query + # string carries temporary credentials (X-Amz-Signature, sig, token, ...). + # Redact the query of any URL that looks signed so it is not echoed back to + # the parent and logged. Non-signed URLs (e.g. ...?download=true) are kept. + def _redact_signed_query(match: "re.Match") -> str: + base, query = match.group(1), match.group(2) + if re.search( + r"(X-Amz-|[Ss]ignature|(?:^|&)(?:sig|token|key|Expires|Policy|Key-Pair-Id)=)", + query, + ): + return f"{base}?***" + return match.group(0) + + out = re.sub(r"(https?://[^\s?]+)\?([^\s]*)", _redact_signed_query, out) return out @@ -503,9 +517,18 @@ def _run_download_attempt( # as __mp_main__ instead of re-executing the caller's script. main_module = sys.modules.get("__main__") saved_main_file = _UNSET + saved_main_spec = _UNSET if main_module is not None: saved_main_file = getattr(main_module, "__file__", _UNSET) main_module.__file__ = __file__ + # When the caller was launched as a module (python -m pkg), spawn's + # preparation prefers __main__.__spec__.name over __file__ and re-imports + # the user's module BY NAME -> re-runs its top-level from_pretrained in + # the child and hits the bootstrapping error. Clearing __spec__ forces + # the path branch, which uses the __file__ we just repointed at this + # side-effect-free helper module. + saved_main_spec = getattr(main_module, "__spec__", _UNSET) + main_module.__spec__ = None try: os.environ.update(child_env) proc.start() @@ -524,6 +547,13 @@ def _run_download_attempt( pass else: main_module.__file__ = saved_main_file + if saved_main_spec is _UNSET: + try: + delattr(main_module, "__spec__") + except AttributeError: + pass + else: + main_module.__spec__ = saved_main_spec # Bind the child to the parent lifetime when running under Studio (best-effort). try: @@ -688,6 +718,7 @@ def hf_hub_download_with_xet_fallback( revision: Optional[str] = None, cache_dir: Optional[str] = None, force_download: bool = False, + local_files_only: bool = False, stall_timeout: float = DEFAULT_STALL_TIMEOUT, interval: float = DEFAULT_HEARTBEAT_INTERVAL, grace_period: float = DEFAULT_GRACE_PERIOD, @@ -700,8 +731,29 @@ def hf_hub_download_with_xet_fallback( *cancel_event* is set, re-raises a deterministic child error unchanged (no fallback), and raises ``DownloadStallError`` only if BOTH transports stall. ``force_download=True`` re-fetches even if cached (skips the cache short-circuit). + ``local_files_only=True`` resolves from cache in-process and never spawns a + network child (matching Hugging Face offline semantics). """ repo_type = repo_type or "model" # HF treats None as the default model repo. + # Expand ~ as huggingface_hub does before writing, so the cache probe below and + # the child both resolve to the same on-disk location (else a warm ~/hf-cache + # is missed and we spawn a child for an already-cached file). + if isinstance(cache_dir, str): + cache_dir = os.path.expanduser(cache_dir) + # Offline: resolve purely from the local cache, never reaching the network. HF + # raises LocalEntryNotFoundError if it is not cached; let that propagate. + if local_files_only: + from huggingface_hub import hf_hub_download + + return hf_hub_download( + repo_id = repo_id, + filename = filename, + token = token, + repo_type = repo_type, + revision = revision, + cache_dir = cache_dir, + local_files_only = True, + ) # Finalized blob already cached: return it with no child and no network # (skipped when force_download re-fetches unconditionally). if not force_download: @@ -748,6 +800,7 @@ def snapshot_download_with_xet_fallback( allow_patterns: Optional[Any] = None, ignore_patterns: Optional[Any] = None, force_download: bool = False, + local_files_only: bool = False, cancel_event: Optional[threading.Event] = None, stall_timeout: float = DEFAULT_STALL_TIMEOUT, interval: float = DEFAULT_HEARTBEAT_INTERVAL, @@ -763,8 +816,28 @@ def snapshot_download_with_xet_fallback( hang on a native Xet thread). A fully cached repo short-circuits in-process via ``local_files_only`` with no child and no network. ``force_download=True`` re-fetches in the killable child even if cached (skips that short-circuit). + ``local_files_only=True`` resolves from cache in-process and never spawns a + network child (matching Hugging Face offline semantics). """ repo_type = repo_type or "model" # HF treats None as the default model repo. + # Expand ~ as huggingface_hub does before writing, so the probe and the child + # resolve to the same on-disk cache location. + if isinstance(cache_dir, str): + cache_dir = os.path.expanduser(cache_dir) + # Offline: resolve purely from the local cache, never reaching the network. HF + # raises if the snapshot is not cached; let that propagate. + if local_files_only: + from huggingface_hub import snapshot_download + + return snapshot_download( + repo_id = repo_id, + repo_type = repo_type, + revision = revision, + cache_dir = cache_dir, + allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, + local_files_only = True, + ) # Fast path: everything already on disk -> resolve in-process (no Xet, no # hang). Skipped when force_download re-fetches unconditionally. if not force_download: From cb73a98933154461698a827d0e58bef316e43d6d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 05:12:50 +0000 Subject: [PATCH 13/82] Normalize path-like cache dirs and forward subfolder (Codex P2) - cache_dir: normalize pathlib.Path (not just str) via os.fspath before the cache probe and the child attempt. HF accepts Path cache dirs; an unexpanded Path("~/hf-cache") made the probe miss a warm file and the child write its .incomplete under the literal ~ while the watchdog watched $HOME, so the stall was never detected. Applies to both the file and snapshot helpers. - subfolder: add subfolder to hf_hub_download_with_xet_fallback and forward it to the child hf_hub_download, the local_files_only resolve, and (as the combined "/" path) the cache probe. A caller passing subfolder no longer hits TypeError. - tests: +2 (pathlib cache_dir normalization, subfolder forwarding). --- tests/test_hf_xet_fallback.py | 35 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 24 +++++++++++++++-------- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 85fe1d412..bcc423da9 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -464,6 +464,7 @@ def __call__( kind = kind, target = params.get("filename", repo_id), cache_dir = params.get("cache_dir"), + subfolder = params.get("subfolder"), force_download = params.get("force_download"), disable_xet = disable_xet, repo_type = repo_type, @@ -783,6 +784,40 @@ def _probe(repo_id, filename, *, repo_type, revision, cache_dir): assert fake.calls[0].cache_dir == str(tmp_path / "hfcache") +def test_pathlib_cache_dir_is_expanded(monkeypatch, tmp_path): + """A pathlib.Path cache_dir with ~ must be normalized too (HF accepts Path), or + the child writes under the literal '~/...' while the watchdog watches $HOME/... + and the stall is never detected.""" + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) + fake = _install(monkeypatch, [("ok", "/cache/snap")]) + xf.snapshot_download_with_xet_fallback( + DL_REPO, token = None, cache_dir = Path("~/hfcache") + ) + # Normalized to an expanded string for the child attempt + probes. + assert fake.calls[0].cache_dir == str(tmp_path / "hfcache") + + +def test_subfolder_forwarded_to_file_download(monkeypatch): + """A single-file caller passing subfolder must not get a TypeError; subfolder + is forwarded into the download params (and the cache probe uses the combined + '/' path).""" + probed = {} + + def _probe(repo_id, filename, *, repo_type, revision, cache_dir): + probed["filename"] = filename + return None # not cached -> falls through to the faked attempt + + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", _probe) + fake = _install(monkeypatch, [("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback( + DL_REPO, FILE, None, subfolder = "checkpoint-10" + ) + assert out == "/cache/x" + assert probed["filename"] == f"checkpoint-10/{FILE}" # probe uses combined path + assert fake.calls[0].subfolder == "checkpoint-10" # forwarded to the child + + def test_file_download_defaults_token_to_none(monkeypatch): """The single-file helper accepts no token (parity with hf_hub_download).""" fake = _install(monkeypatch, [("ok", "/cache/x")]) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 29498b768..e2b488704 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -348,6 +348,7 @@ def _child_download(*, kind: str, params: dict, token: Optional[str], repo_type: return hf_hub_download( repo_id = params["repo_id"], filename = params["filename"], + subfolder = params.get("subfolder"), repo_type = repo_type, token = token, revision = params.get("revision"), @@ -717,6 +718,7 @@ def hf_hub_download_with_xet_fallback( repo_type: Optional[str] = "model", revision: Optional[str] = None, cache_dir: Optional[str] = None, + subfolder: Optional[str] = None, force_download: bool = False, local_files_only: bool = False, stall_timeout: float = DEFAULT_STALL_TIMEOUT, @@ -732,14 +734,16 @@ def hf_hub_download_with_xet_fallback( fallback), and raises ``DownloadStallError`` only if BOTH transports stall. ``force_download=True`` re-fetches even if cached (skips the cache short-circuit). ``local_files_only=True`` resolves from cache in-process and never spawns a - network child (matching Hugging Face offline semantics). + network child (matching Hugging Face offline semantics). ``subfolder`` is + forwarded to ``hf_hub_download`` for files stored under a repo subdirectory. """ repo_type = repo_type or "model" # HF treats None as the default model repo. # Expand ~ as huggingface_hub does before writing, so the cache probe below and # the child both resolve to the same on-disk location (else a warm ~/hf-cache - # is missed and we spawn a child for an already-cached file). - if isinstance(cache_dir, str): - cache_dir = os.path.expanduser(cache_dir) + # is missed and we spawn a child for an already-cached file). Path-like cache + # dirs are normalized too, since HF accepts pathlib.Path. + if isinstance(cache_dir, (str, os.PathLike)): + cache_dir = os.path.expanduser(os.fspath(cache_dir)) # Offline: resolve purely from the local cache, never reaching the network. HF # raises LocalEntryNotFoundError if it is not cached; let that propagate. if local_files_only: @@ -748,6 +752,7 @@ def hf_hub_download_with_xet_fallback( return hf_hub_download( repo_id = repo_id, filename = filename, + subfolder = subfolder, token = token, repo_type = repo_type, revision = revision, @@ -755,13 +760,15 @@ def hf_hub_download_with_xet_fallback( local_files_only = True, ) # Finalized blob already cached: return it with no child and no network - # (skipped when force_download re-fetches unconditionally). + # (skipped when force_download re-fetches unconditionally). The cache stores a + # subfolder file under "/", which is what the probe wants. if not force_download: try: from huggingface_hub import try_to_load_from_cache + probe_filename = f"{subfolder}/{filename}" if subfolder else filename cached = try_to_load_from_cache( - repo_id, filename, repo_type = repo_type, revision = revision, cache_dir = cache_dir + repo_id, probe_filename, repo_type = repo_type, revision = revision, cache_dir = cache_dir ) if isinstance(cached, str) and os.path.exists(cached): return cached @@ -775,6 +782,7 @@ def hf_hub_download_with_xet_fallback( params = { "repo_id": repo_id, "filename": filename, + "subfolder": subfolder, "revision": revision, "cache_dir": cache_dir, "force_download": force_download, @@ -822,8 +830,8 @@ def snapshot_download_with_xet_fallback( repo_type = repo_type or "model" # HF treats None as the default model repo. # Expand ~ as huggingface_hub does before writing, so the probe and the child # resolve to the same on-disk cache location. - if isinstance(cache_dir, str): - cache_dir = os.path.expanduser(cache_dir) + if isinstance(cache_dir, (str, os.PathLike)): + cache_dir = os.path.expanduser(os.fspath(cache_dir)) # Offline: resolve purely from the local cache, never reaching the network. HF # raises if the snapshot is not cached; let that propagate. if local_files_only: From 5d898dae5abd4b5a92e03a42691be9c0c741d733 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 05:39:30 +0000 Subject: [PATCH 14/82] Count a fully-sparse blob as 0 bytes, not full size (Codex P2) blob_bytes_present gated the sparse-aware path on `st_blocks > 0`, so a freshly truncated .incomplete (st_size == full, 0 allocated blocks) fell through to st_size and was counted as fully present -- defeating the watchdog's sparse byte accounting for an empty partial. Trust st_blocks whenever it is reported (POSIX), even when 0; only fall back to st_size when the field is absent (Windows / some network filesystems). Adds a regression test for the 0-block case. --- tests/test_hf_xet_fallback.py | 15 +++++++++++++++ unsloth_zoo/hf_cache_state.py | 8 ++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index bcc423da9..c562f0e9a 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -183,6 +183,21 @@ def test_get_state_sparse_aware(hf_cache): assert total < st.st_size, "sparse partial counted at apparent size, not allocated blocks" +def test_blob_bytes_present_zero_blocks_is_zero(tmp_path): + """A freshly truncated, fully-sparse .incomplete reports st_size > 0 with 0 + allocated blocks; it must count as 0 bytes present, not full size (a > 0 guard + would mis-read an empty partial as complete).""" + p = tmp_path / "sparse.incomplete" + with open(p, "wb") as f: + f.truncate(8 * 1024 * 1024) # apparent 8 MiB, nothing actually written + st = p.stat() + if getattr(st, "st_blocks", None) is None: + pytest.skip("st_blocks not reported on this platform") + if st.st_blocks != 0: + pytest.skip("filesystem pre-allocated blocks for the sparse file") + assert hcs.blob_bytes_present(p) == 0 + + def test_custom_cache_dir_is_watched_and_cleaned(tmp_path, monkeypatch): """A stall under a caller-supplied snapshot ``cache_dir`` (not HF_HUB_CACHE) must still be seen by the state probe, the watchdog, and the HTTP-prep purge.""" diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 5b59b1a7b..f25d7cffd 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -110,8 +110,12 @@ def blob_bytes_present(path: Path) -> int: ``st_blocks``, falling back to ``st_size`` where it is unreported (Windows, some network filesystems).""" st = path.stat() - blocks = getattr(st, "st_blocks", 0) - if blocks > 0: + blocks = getattr(st, "st_blocks", None) + if blocks is not None: + # st_blocks is reported (POSIX): trust it even when 0. A freshly truncated + # sparse .incomplete reports st_size == full but 0 allocated blocks, and + # must count as 0 bytes present, not full size (a > 0 guard would fall + # through to st_size and read an empty partial as complete). return min(blocks * 512, st.st_size) if sys.platform == "win32": allocated = _windows_allocated_size(path) From b46f8167d747764e1354e73e108897290bea61d5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 06:49:11 +0000 Subject: [PATCH 15/82] Scope the snapshot fast-path validation to the requested revision (Codex P2) snapshot_download(local_files_only=True) already validated the exact revision it returns, so the per-snapshot broken-symlink check is sufficient to prove that revision is complete. The extra repo-wide has_active_incomplete_blobs term rejected a clean cached snapshot whenever an unrelated revision in the same repo cache was mid-download (a stale .incomplete blob or a broken older snapshot), forcing a needless re-fetch in the killable child. Drop that term and rely on validating only the returned snapshot dir. Replace test_incomplete_cached_snapshot_not_short_circuited (which encoded the old over-broad behavior) with a positive test asserting a clean cached snapshot short-circuits in-process despite an unrelated stale partial. The interrupted requested-revision case stays covered by test_snapshot_fast_path_rejects_broken_requested_revision. --- tests/test_hf_xet_fallback.py | 24 +++++++++++++++++------- unsloth_zoo/hf_xet_fallback.py | 19 +++++++++---------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index c562f0e9a..43637aa7b 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -840,14 +840,24 @@ def test_file_download_defaults_token_to_none(monkeypatch): assert out == "/cache/x" and len(fake.calls) == 1 -def test_incomplete_cached_snapshot_not_short_circuited(hf_cache, monkeypatch): - """A cached-but-incomplete snapshot (interrupted download) must not take the - fast path; it must complete in the killable child instead.""" - monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: "/cache/snap") - (_blobs_dir(hf_cache, DL_REPO) / "x.incomplete").write_bytes(b"abc") # active partial - fake = _install(monkeypatch, [("ok", "/cache/snap-fresh")]) +def test_unrelated_partial_does_not_block_clean_cached_snapshot(hf_cache, monkeypatch): + """A clean requested snapshot must short-circuit in-process even when the same + repo cache holds a stale .incomplete from another (unrelated) revision: the fast + path validates only the returned snapshot dir, not the whole repo, so a sibling + mid-download does not force a needless re-fetch of a snapshot that is complete.""" + blobs = _blobs_dir(hf_cache, DL_REPO) + repo_dir = blobs.parent + snap = repo_dir / "snapshots" / "goodsha" + snap.mkdir(parents = True) + good = blobs / "good" + good.write_bytes(b"weights") + (snap / "model.safetensors").symlink_to(good) # resolves -> snapshot is clean + (blobs / "other.incomplete").write_bytes(b"abc") # unrelated stale partial + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + fake = _install(monkeypatch, []) # must NOT spawn a child out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) - assert out == "/cache/snap-fresh" and len(fake.calls) == 1 + assert out == str(snap), "clean cached snapshot rejected by an unrelated partial" + assert fake.calls == [], "spawned a download despite a clean requested snapshot" def test_retry_status_failure_does_not_abort_fallback(monkeypatch): diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index e2b488704..af55528c8 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -863,16 +863,15 @@ def snapshot_download_with_xet_fallback( ) # local_files_only returns a snapshot dir whenever refs/ and # snapshots/ exist, even if a prior download was interrupted and - # left .incomplete blobs or broken symlinks. Only short-circuit when - # the cache is actually clean; otherwise complete it in the killable - # child so the in-process load does not proceed with missing files. - # Validate the EXACT returned revision dir (snapshot_download may hand - # back an older requested revision while a newer one is clean), plus - # the repo-wide .incomplete blob check. - if ( - not snapshot_dir_has_broken_symlinks(Path(cached_dir)) - and not has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir) - ): + # left broken symlinks. Validate the EXACT returned revision dir (a + # dangling symlink there means a referenced blob is missing or still an + # .incomplete partial); if broken, complete it in the killable child so + # the in-process load never proceeds with missing files. Scope the check + # to the returned snapshot, NOT the whole repo: snapshot_download already + # validated this exact revision, so an unrelated revision mid-download (a + # stale .incomplete blob or a broken older snapshot elsewhere in the same + # repo cache) must not force a needless re-fetch of a complete snapshot. + if not snapshot_dir_has_broken_symlinks(Path(cached_dir)): return cached_dir logger.debug("Cached snapshot for %s has incomplete state; downloading.", repo_id) except Exception as e: From be5e80978849f57f8cc2795b6c2cef1e2e82ae35 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 07:56:26 +0000 Subject: [PATCH 16/82] Require weight-level completeness for a cached snapshot (Codex P1 + P2) snapshot_download(local_files_only=True) returns any existing snapshot dir, even one left by a prior patterned or interrupted download (a config-only snapshot from an AutoConfig fetch, or a partial shard pull), without checking that the weight files are present. The dangling-symlink-only fast-path check missed these because the absent files were never symlinked, so a config-only snapshot was treated as a warm cache, the killable child was skipped, and the in-process load hit Xet on the missing weights. Add snapshot_dir_is_complete() in hf_cache_state: no dangling symlinks, every shipped weight-shard index (model.safetensors.index.json / pytorch_model.bin.index.json) resolves all its shards on disk, and at least one weight file is present. Use it for the fast-path short-circuit so a config-only / partial cached snapshot falls through to the killable child. Also revalidate the killable child's snapshot result before returning it: Hugging Face falls back to an existing snapshot dir on an offline or timed-out request, so a child result with dangling symlinks is now retried over HTTP and, if still broken, raised rather than handed to the load with missing files. The child check uses the dangling-symlink signal only (not the weight-presence heuristic): the child just did a full download, so an absent weight format means the repo ships none, and rejecting it would wrongly fail a legitimately weight-less repo. Both checks stay scoped to the returned snapshot, so an unrelated revision mid-download still does not block a complete one. --- tests/test_hf_xet_fallback.py | 118 +++++++++++++++++++++++++++++++-- unsloth_zoo/hf_cache_state.py | 87 ++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 65 +++++++++++++++--- 3 files changed, 255 insertions(+), 15 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 43637aa7b..282011c53 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -13,6 +13,7 @@ from __future__ import annotations import importlib.util +import json import os import subprocess import sys @@ -888,22 +889,131 @@ def test_unclearable_partial_forces_clean_redownload(hf_cache, monkeypatch): # --------------------------------------------------------------------------- # # Snapshot variant: in-process fast path on a warm cache, else watched download. # --------------------------------------------------------------------------- # -def test_snapshot_fast_path_no_child(monkeypatch): - """A fully cached repo resolves in-process via local_files_only -- no attempt.""" +def test_snapshot_fast_path_no_child(hf_cache, monkeypatch): + """A fully cached repo (weights present) resolves in-process via local_files_only + -- no child attempt.""" + blobs = _blobs_dir(hf_cache, DL_REPO) + snap = blobs.parent / "snapshots" / "sha" + snap.mkdir(parents = True) + weight = blobs / "w" + weight.write_bytes(b"\0" * 16) + (snap / "model.safetensors").symlink_to(weight) # weights present -> complete + (snap / "config.json").write_text("{}") seen = {} def _snap(*a, **k): seen["local_files_only"] = k.get("local_files_only") - return "/cache/snap-dir" + return str(snap) monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap) fake = _install(monkeypatch, []) # must not be called out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) - assert out == "/cache/snap-dir" + assert out == str(snap) assert seen["local_files_only"] is True assert fake.calls == [], "spawned a download for an already-cached snapshot" +def test_snapshot_dir_is_complete_unit(tmp_path): + """Weight presence drives completeness: a config-only snapshot is incomplete; one + with a resolvable weight file is complete.""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "config.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap) is False # no weights + blob = tmp_path / "blob" + blob.write_bytes(b"weights") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_snapshot_dir_is_complete_broken_symlink(tmp_path): + """A dangling weight symlink reads as incomplete.""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "model.safetensors").symlink_to(tmp_path / "missing") + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_snapshot_dir_is_complete_missing_shard(tmp_path): + """A shard index whose shards are not all on disk reads as incomplete until they are.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps( + { + "weight_map": { + "a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors", + } + } + ) + ) + assert hcs.snapshot_dir_is_complete(snap) is False # shard 2 missing + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): + """HF's local_files_only returns a config-only snapshot (e.g. left by an earlier + AutoConfig fetch) without checking weights. The fast path must reject it and complete + the download in the killable child rather than load with missing weights.""" + blobs = _blobs_dir(hf_cache, DL_REPO) + snap = blobs.parent / "snapshots" / "sha" + snap.mkdir(parents = True) + cfg_blob = blobs / "cfg" + cfg_blob.write_text("{}") + (snap / "config.json").symlink_to(cfg_blob) # only config, no weights + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + fake = _install(monkeypatch, [("ok", "/cache/snap-fresh")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/snap-fresh" and len(fake.calls) == 1 + + +def test_child_broken_snapshot_retries_over_http(monkeypatch, tmp_path): + """A real but broken child snapshot result (HF offline-fallback returning a dir with + dangling symlinks) is rejected on the Xet attempt and retried over HTTP; a clean + second result is accepted.""" + broken = tmp_path / "broken" + broken.mkdir() + (broken / "model.safetensors").symlink_to(tmp_path / "missing") # dangling + clean = tmp_path / "clean" + clean.mkdir() + blob = tmp_path / "b" + blob.write_bytes(b"x") + (clean / "model.safetensors").symlink_to(blob) + fake = _install(monkeypatch, [("ok", str(broken)), ("ok", str(clean))]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == str(clean) + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_child_broken_snapshot_after_http_raises(monkeypatch, tmp_path): + """If even the HTTP attempt returns a broken snapshot, fail loudly rather than hand + missing files to the load.""" + broken = tmp_path / "broken" + broken.mkdir() + (broken / "model.safetensors").symlink_to(tmp_path / "missing") + fake = _install(monkeypatch, [("ok", str(broken)), ("ok", str(broken))]) + with pytest.raises(xf.DownloadStallError, match = "incomplete snapshot"): + xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_child_weightless_snapshot_is_accepted(monkeypatch, tmp_path): + """A child result that simply has no weight files (the repo ships none) must NOT be + rejected: the child just did a full download, so absent weights mean the repo has + none, not a partial. Only a dangling symlink marks a broken child result.""" + cfg_only = tmp_path / "cfg" + cfg_only.mkdir() + (cfg_only / "config.json").write_text("{}") # no weights, but no broken links + fake = _install(monkeypatch, [("ok", str(cfg_only))]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == str(cfg_only) and len(fake.calls) == 1 + + def test_snapshot_stall_then_http(monkeypatch): prepared = [] monkeypatch.setattr(xf, "_default_prepare_for_http", lambda rt, rid, cache_dir = None: prepared.append((rt, rid))) diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index f25d7cffd..1c788d068 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -36,6 +36,7 @@ "blob_bytes_present", "latest_snapshot_dir", "snapshot_dir_has_broken_symlinks", + "snapshot_dir_is_complete", "iter_active_repo_cache_dirs", "repo_cache_dir_has_incomplete_blobs", "has_active_incomplete_blobs", @@ -55,6 +56,15 @@ def _safe_is_dir(path: Path) -> bool: return False +def _safe_is_file(path: Path) -> bool: + """``Path.is_file()`` (follows symlinks) returning False instead of raising on an + unreadable path or a dangling link, so snapshot enumeration never errors out.""" + try: + return path.is_file() + except OSError: + return False + + def hf_cache_root(*, create: bool = False, cache_dir: "Optional[str | Path]" = None) -> Optional[Path]: """The hub cache root to scan, or None if unavailable. @@ -182,6 +192,83 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: return False +# Model weight file extensions. A snapshot with none of these is config/tokenizer +# only (e.g. a prior AutoConfig fetch), so it is not a warm cache for a weight load. +_WEIGHT_FILE_SUFFIXES = ( + ".safetensors", + ".bin", + ".pt", + ".pth", + ".gguf", + ".ckpt", + ".onnx", + ".msgpack", + ".h5", + ".pdparams", +) + + +def _weight_shard_index_complete(index_path: Path) -> bool: + """True if every shard a HF weight index (``model.safetensors.index.json`` / + ``pytorch_model.bin.index.json``) lists is present next to the index. An unreadable + or non-sharded index is treated as satisfied (nothing extra to verify), so this only + ever rejects an index whose shards are demonstrably missing on disk.""" + import json + + try: + with open(index_path, "r", encoding = "utf-8") as f: + data = json.load(f) + except (OSError, ValueError): + return True + weight_map = data.get("weight_map") if isinstance(data, dict) else None + if not isinstance(weight_map, dict): + return True + # weight_map values are filenames relative to the index file's own directory. + base = index_path.parent + for shard in set(weight_map.values()): + try: + if not (base / shard).exists(): + return False + except OSError: + return False + return True + + +def snapshot_dir_is_complete(snapshot_dir: Path) -> bool: + """Best-effort check that a cached snapshot actually holds its model weights. + + ``snapshot_download(local_files_only=True)`` returns a snapshot dir whenever + ``refs/`` and ``snapshots/`` exist, even one left by a prior interrupted + or patterned download (a config-only snapshot from an ``AutoConfig`` fetch, or a + partial shard pull). A dangling-symlink check alone misses those: the missing files + were never symlinked, so nothing dangles. Treating such a snapshot as a warm cache + skips the killable child and lets the in-process load hit Xet on the absent weights. + + A snapshot is complete only when it has no dangling symlinks, every weight-shard + index it ships resolves all its shards on disk, and it contains at least one weight + file. This does NOT assert that every non-weight file is present (no offline manifest + exists for that); the killable child completes anything else still missing. The aim + is simply to never short-circuit a snapshot whose weights are not on disk.""" + if snapshot_dir_has_broken_symlinks(snapshot_dir): + return False + try: + entries = list(snapshot_dir.rglob("*")) + except OSError: + return False + has_weight = False + for entry in entries: + name = entry.name + if name.endswith((".safetensors.index.json", ".bin.index.json")): + if not _safe_is_file(entry): + continue + if not _weight_shard_index_complete(entry): + return False + has_weight = True + elif name.endswith(_WEIGHT_FILE_SUFFIXES) and _safe_is_file(entry): + has_weight = True + return has_weight + + def _iter_snapshot_dirs(repo_dir: Path) -> Iterator[Path]: snapshots_dir = repo_dir / "snapshots" try: diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index af55528c8..f405b933e 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -53,6 +53,7 @@ hf_cache_root, iter_active_repo_cache_dirs, snapshot_dir_has_broken_symlinks, + snapshot_dir_is_complete, ) logger = logging.getLogger(__name__) @@ -621,6 +622,30 @@ def _run_download_attempt( return ("error", result.get("error") or "unknown download error") +def _snapshot_payload_incomplete(payload: Any) -> bool: + """True when a snapshot download returned a real directory with dangling symlinks (a + referenced blob missing or still an .incomplete partial). Guarded to an existing + directory so a mocked / non-path payload (unit tests) or an unexpected return is + trusted rather than rejected; in production the child always returns a real snapshot + dir, where this catches HF handing back an existing broken snapshot on an offline or + timed-out request. + + Unlike the fast-path snapshot_dir_is_complete check, this does NOT require weight + files to be present: the child just performed a full download, so an absent weight + format means the repo ships none (accept it), whereas a dangling symlink means a file + the snapshot references is genuinely missing (reject it).""" + try: + path = Path(payload) + except TypeError: + return False + try: + if not path.is_dir(): + return False + except OSError: + return False + return snapshot_dir_has_broken_symlinks(path) + + def _download_with_xet_fallback( *, repo_id: str, @@ -683,6 +708,24 @@ def _download_with_xet_fallback( ) if kind_result == "ok": + if kind == "snapshot" and _snapshot_payload_incomplete(payload): + # HF can return an existing, incomplete snapshot dir on an offline or + # timed-out request instead of fetching the missing files. Never hand a + # snapshot with absent weights to the in-process load: retry over HTTP, + # and if it still comes back incomplete, fail loudly rather than silently + # loading a broken cache. + if not disable_xet: + logger.warning( + "Download for '%s' returned an incomplete snapshot -- " + "retrying with HF_HUB_DISABLE_XET=1", label + ) + _safe_status(on_status, f"{label}: incomplete snapshot, retrying over HTTP") + disable_xet = True + continue + raise DownloadStallError( + f"Download for '{label}' returned an incomplete snapshot even with " + f"HF_HUB_DISABLE_XET=1 -- missing weight files, check your network connection" + ) return payload # type: ignore[return-value] if kind_result == "cancelled": raise RuntimeError("Cancelled") @@ -862,18 +905,18 @@ def snapshot_download_with_xet_fallback( local_files_only = True, ) # local_files_only returns a snapshot dir whenever refs/ and - # snapshots/ exist, even if a prior download was interrupted and - # left broken symlinks. Validate the EXACT returned revision dir (a - # dangling symlink there means a referenced blob is missing or still an - # .incomplete partial); if broken, complete it in the killable child so - # the in-process load never proceeds with missing files. Scope the check - # to the returned snapshot, NOT the whole repo: snapshot_download already - # validated this exact revision, so an unrelated revision mid-download (a - # stale .incomplete blob or a broken older snapshot elsewhere in the same - # repo cache) must not force a needless re-fetch of a complete snapshot. - if not snapshot_dir_has_broken_symlinks(Path(cached_dir)): + # snapshots/ exist, even one left by a prior interrupted or patterned + # download (a config-only snapshot from an AutoConfig fetch, or a partial + # shard pull). Validate that the EXACT returned revision dir is actually + # complete -- no dangling symlinks AND its weight files present on disk -- + # and complete it in the killable child otherwise, so the in-process load + # never proceeds with missing weights. Scope the check to the returned + # snapshot, NOT the whole repo: an unrelated revision mid-download (a stale + # .incomplete blob or a broken older snapshot elsewhere in the same repo + # cache) must not force a needless re-fetch of a complete snapshot. + if snapshot_dir_is_complete(Path(cached_dir)): return cached_dir - logger.debug("Cached snapshot for %s has incomplete state; downloading.", repo_id) + logger.debug("Cached snapshot for %s is incomplete; downloading.", repo_id) except Exception as e: logger.debug("Snapshot not fully cached for %s (%s); downloading.", repo_id, e) From e105537254a025ef0394d566d494a1d2bf65bd14 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 08:45:19 +0000 Subject: [PATCH 17/82] Reject incomplete child snapshots and spare active sibling partials (Codex P1) Child result completeness: snapshot_download(local_files_only=False) silently returns an existing local snapshot (rather than raising) when the Hub is unreachable, so the killable child can hand back a stale, weight-incomplete snapshot on a mid-download network outage. Reuse the fast-path weight-completeness check for the child result, not just a dangling-symlink check: this helper warms a model repo for an imminent load, so a result with no weight files means the download did not finish (a genuinely weightless repo could not be loaded as a model anyway). An incomplete child result is retried over HTTP and, if still incomplete, raised, instead of dropping from_pretrained onto missing weights and an in-process Xet load. Concurrent sibling partials: the generic HTTP-prep purge unlinked every *.incomplete blob in the repo cache, including a concurrent same-repo download's still-active temp file (on POSIX the sibling then keeps writing to an unlinked path and fails when the Hub moves it into place). Skip a partial whose mtime is within a short active-partial grace window: an actively-downloading sibling writes continuously, while our own killed partial has been static for the stall timeout. If a very short stall_timeout leaves ours inside the window, the has_active_incomplete_blobs guard still forces a clean HTTP re-download, so the retry never resumes unsafely over a sparse partial. --- tests/test_hf_xet_fallback.py | 48 ++++++++++++++++++++++++++++------ unsloth_zoo/hf_xet_fallback.py | 40 +++++++++++++++++++++------- 2 files changed, 70 insertions(+), 18 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 282011c53..a662a38fb 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -230,7 +230,10 @@ def test_custom_cache_dir_is_watched_and_cleaned(tmp_path, monkeypatch): stop.set() # The HTTP-prep purge removes the unsafe partial from the custom cache - # (call the real impl; the autouse fixture stubs the module attribute). + # (call the real impl; the autouse fixture stubs the module attribute). Age it + # past the active-partial grace so it reads as a stalled, not in-flight, blob. + old = time.time() - 600 + os.utime(partial, (old, old)) _REAL_DEFAULT_PREPARE("model", REPO, cache_dir = str(custom_cache)) assert not partial.exists() @@ -336,6 +339,11 @@ def test_prepare_for_http_preserves_case_colliding_repo(tmp_path): lower_partial = lower / "b.incomplete" upper_partial.write_bytes(b"x") lower_partial.write_bytes(b"y") + # Age both past the active-partial grace so the purge is exercised on stalled blobs + # (lower is preserved by repo attribution, not mtime). + old = time.time() - 600 + os.utime(upper_partial, (old, old)) + os.utime(lower_partial, (old, old)) _REAL_DEFAULT_PREPARE("model", "Org/Repo", cache_dir = str(tmp_path)) @@ -1002,16 +1010,40 @@ def test_child_broken_snapshot_after_http_raises(monkeypatch, tmp_path): assert [c.disable_xet for c in fake.calls] == [False, True] -def test_child_weightless_snapshot_is_accepted(monkeypatch, tmp_path): - """A child result that simply has no weight files (the repo ships none) must NOT be - rejected: the child just did a full download, so absent weights mean the repo has - none, not a partial. Only a dangling symlink marks a broken child result.""" +def test_child_weight_incomplete_snapshot_retries_over_http(monkeypatch, tmp_path): + """A child result with no weight files (HF silently returning a stale config-only + snapshot on an offline / timed-out request) is rejected on the Xet attempt and retried + over HTTP; a complete second result is accepted. The helper warms model repos, so a + weight-less result means the download did not finish, not that the repo is weightless.""" cfg_only = tmp_path / "cfg" cfg_only.mkdir() - (cfg_only / "config.json").write_text("{}") # no weights, but no broken links - fake = _install(monkeypatch, [("ok", str(cfg_only))]) + (cfg_only / "config.json").write_text("{}") # no weights + complete = tmp_path / "complete" + complete.mkdir() + blob = tmp_path / "b" + blob.write_bytes(b"x") + (complete / "model.safetensors").symlink_to(blob) + fake = _install(monkeypatch, [("ok", str(cfg_only)), ("ok", str(complete))]) out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) - assert out == str(cfg_only) and len(fake.calls) == 1 + assert out == str(complete) + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_prepare_for_http_spares_active_sibling_partial(hf_cache): + """The generic HTTP-prep purge must not unlink a concurrent download's still-active + .incomplete temp file: only stale (old-mtime) partials are removed, so a sibling + download of another file in the same repo keeps writing safely.""" + blobs = _blobs_dir(hf_cache, DL_REPO) + stale = blobs / "stalled.incomplete" + stale.write_bytes(b"\0" * 16) + active = blobs / "sibling.incomplete" + active.write_bytes(b"\0" * 16) + # Age the stalled partial well past the active-partial grace; leave the sibling current. + old = time.time() - 600 + os.utime(stale, (old, old)) + _REAL_DEFAULT_PREPARE("model", DL_REPO, cache_dir = str(hf_cache)) + assert not stale.exists(), "stale partial should be purged for the HTTP resume" + assert active.exists(), "an actively-written sibling partial must be preserved" def test_snapshot_stall_then_http(monkeypatch): diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index f405b933e..129f0636c 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -156,6 +156,13 @@ def _redact_signed_query(match: "re.Match") -> str: return out +# A *.incomplete blob touched more recently than this is treated as another +# concurrent download's still-active temp file and left in place (see +# _default_prepare_for_http). Smaller than any realistic stall_timeout, so our own +# killed-then-stalled partial (static for >= stall_timeout) is still always purged. +_ACTIVE_PARTIAL_GRACE = 30.0 + + def _default_prepare_for_http( repo_type: str, repo_id: str, *, cache_dir: Optional[str] = None ) -> None: @@ -176,6 +183,17 @@ def _default_prepare_for_http( for blob in blobs_dir.iterdir(): if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): try: + # Do not unlink a partial another concurrent download is + # still actively writing (recent mtime): on POSIX that lets + # the sibling keep writing to an unlinked path and then fail + # when the Hub moves its temp file into place. Our own killed + # partial has been static for the stall timeout, so it is + # older than this window; if a very short stall_timeout left + # it inside the window, the has_active_incomplete_blobs guard + # downstream still forces a clean HTTP re-download, so the + # HTTP attempt never resumes unsafely over it. + if time.time() - blob.stat().st_mtime < _ACTIVE_PARTIAL_GRACE: + continue blob.unlink() except OSError: # A locked / permission-denied blob (common on Windows) @@ -623,17 +641,19 @@ def _run_download_attempt( def _snapshot_payload_incomplete(payload: Any) -> bool: - """True when a snapshot download returned a real directory with dangling symlinks (a - referenced blob missing or still an .incomplete partial). Guarded to an existing + """True when a snapshot download returned a real directory that is not weight-complete + (dangling symlinks, missing shards, or no weight file at all). Guarded to an existing directory so a mocked / non-path payload (unit tests) or an unexpected return is trusted rather than rejected; in production the child always returns a real snapshot - dir, where this catches HF handing back an existing broken snapshot on an offline or - timed-out request. - - Unlike the fast-path snapshot_dir_is_complete check, this does NOT require weight - files to be present: the child just performed a full download, so an absent weight - format means the repo ships none (accept it), whereas a dangling symlink means a file - the snapshot references is genuinely missing (reject it).""" + dir, where this catches HF silently handing back an existing partial snapshot on an + offline or timed-out request (verified: snapshot_download(local_files_only=False) + returns the stale local snapshot rather than raising when the Hub is unreachable). + + This reuses the fast-path weight-completeness check: snapshot_download_with_xet_fallback + is used to warm a model repo for an imminent load, so a result with no weight files + means the download did not finish, not that the repo is genuinely weightless (which + could not be loaded as a model anyway). Sending that on would drop from_pretrained back + onto missing weights and an in-process Xet load -- exactly what this wrapper prevents.""" try: path = Path(payload) except TypeError: @@ -643,7 +663,7 @@ def _snapshot_payload_incomplete(payload: Any) -> bool: return False except OSError: return False - return snapshot_dir_has_broken_symlinks(path) + return not snapshot_dir_is_complete(path) def _download_with_xet_fallback( From 66a723a2dcf3485bd7f4985a59e8551bc9b13c62 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 09:00:01 +0000 Subject: [PATCH 18/82] Gate the snapshot weight check to full model warmups (Codex P2) The weight-completeness check rejected any snapshot without a model weight file, which broke legitimate weightless uses of this wrapper: a patterned download (allow_patterns=["config.json", "tokenizer*"]) or a non-model snapshot (repo_type="dataset") returns exactly the requested files and was then marked incomplete, retried over HTTP, and raised on the same valid result. Scope the weight requirement to the caller's intent via _snapshot_is_acceptable: a full model warmup (repo_type "model" with no allow_patterns) still requires its weight files on disk (so an offline-fallback config-only partial is rejected and retried), while a patterned or non-model request is accepted as long as no symlink dangles (every file it references is present). Used by both the fast-path short-circuit and the killable child's result validation. --- tests/test_hf_xet_fallback.py | 25 +++++++++++ unsloth_zoo/hf_xet_fallback.py | 79 +++++++++++++++++++++------------- 2 files changed, 75 insertions(+), 29 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index a662a38fb..94e1fa972 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1029,6 +1029,31 @@ def test_child_weight_incomplete_snapshot_retries_over_http(monkeypatch, tmp_pat assert [c.disable_xet for c in fake.calls] == [False, True] +def test_patterned_snapshot_without_weights_is_accepted(monkeypatch, tmp_path): + """A patterned download (allow_patterns) legitimately returns only the requested files + (e.g. config / tokenizer, no model weights). The child result must be accepted as-is, + not rejected and retried for lacking weights.""" + cfg_only = tmp_path / "cfg" + cfg_only.mkdir() + (cfg_only / "config.json").write_text("{}") # exactly what was requested, no weights + fake = _install(monkeypatch, [("ok", str(cfg_only))]) + out = xf.snapshot_download_with_xet_fallback( + DL_REPO, token = None, allow_patterns = ["config.json"] + ) + assert out == str(cfg_only) and len(fake.calls) == 1 + + +def test_dataset_snapshot_without_weights_is_accepted(monkeypatch, tmp_path): + """A non-model snapshot (repo_type='dataset') has no model weights by nature; its + child result must be accepted rather than retried/raised as 'incomplete'.""" + files = tmp_path / "ds" + files.mkdir() + (files / "data.json").write_text("[]") + fake = _install(monkeypatch, [("ok", str(files))]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, repo_type = "dataset") + assert out == str(files) and len(fake.calls) == 1 + + def test_prepare_for_http_spares_active_sibling_partial(hf_cache): """The generic HTTP-prep purge must not unlink a concurrent download's still-active .incomplete temp file: only stale (old-mtime) partials are removed, so a sibling diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 129f0636c..59f694250 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -640,20 +640,36 @@ def _run_download_attempt( return ("error", result.get("error") or "unknown download error") -def _snapshot_payload_incomplete(payload: Any) -> bool: - """True when a snapshot download returned a real directory that is not weight-complete - (dangling symlinks, missing shards, or no weight file at all). Guarded to an existing - directory so a mocked / non-path payload (unit tests) or an unexpected return is - trusted rather than rejected; in production the child always returns a real snapshot - dir, where this catches HF silently handing back an existing partial snapshot on an - offline or timed-out request (verified: snapshot_download(local_files_only=False) - returns the stale local snapshot rather than raising when the Hub is unreachable). - - This reuses the fast-path weight-completeness check: snapshot_download_with_xet_fallback - is used to warm a model repo for an imminent load, so a result with no weight files - means the download did not finish, not that the repo is genuinely weightless (which - could not be loaded as a model anyway). Sending that on would drop from_pretrained back - onto missing weights and an in-process Xet load -- exactly what this wrapper prevents.""" +def _snapshot_is_acceptable( + snapshot_dir: Path, *, repo_type: str, allow_patterns: Any +) -> bool: + """Whether a cached / downloaded snapshot dir is complete enough to use, scoped to the + caller's intent. + + A FULL model warmup (``repo_type == "model"`` and no ``allow_patterns``) must have its + weight files present: this wrapper exists to warm those weights before an in-process + load, so a result with no weights means the download did not finish (HF silently + returns a stale local snapshot on an offline / timed-out request rather than raising). + + A PATTERNED or non-model snapshot (``allow_patterns`` set, or ``repo_type`` such as + ``"dataset"``) legitimately holds only the requested subset -- e.g. just + ``config.json`` / ``tokenizer*`` files, or dataset files with no model weights at all + -- so requiring a weight file there would reject a perfectly valid result. For those it + is enough that no symlink dangles (every file the snapshot references is on disk).""" + if repo_type == "model" and not allow_patterns: + return snapshot_dir_is_complete(snapshot_dir) + return not snapshot_dir_has_broken_symlinks(snapshot_dir) + + +def _snapshot_payload_incomplete( + payload: Any, *, repo_type: str, allow_patterns: Any +) -> bool: + """True when a snapshot download returned a real directory that is not acceptable for + the request (see ``_snapshot_is_acceptable``). Guarded to an existing directory so a + mocked / non-path payload (unit tests) or an unexpected return is trusted rather than + rejected; in production the child always returns a real snapshot dir, where this + catches HF handing back an existing partial snapshot on an offline / timed-out + request.""" try: path = Path(payload) except TypeError: @@ -663,7 +679,7 @@ def _snapshot_payload_incomplete(payload: Any) -> bool: return False except OSError: return False - return not snapshot_dir_is_complete(path) + return not _snapshot_is_acceptable(path, repo_type = repo_type, allow_patterns = allow_patterns) def _download_with_xet_fallback( @@ -728,12 +744,15 @@ def _download_with_xet_fallback( ) if kind_result == "ok": - if kind == "snapshot" and _snapshot_payload_incomplete(payload): + if kind == "snapshot" and _snapshot_payload_incomplete( + payload, repo_type = repo_type, allow_patterns = params.get("allow_patterns") + ): # HF can return an existing, incomplete snapshot dir on an offline or - # timed-out request instead of fetching the missing files. Never hand a - # snapshot with absent weights to the in-process load: retry over HTTP, - # and if it still comes back incomplete, fail loudly rather than silently - # loading a broken cache. + # timed-out request instead of fetching the missing files. Never hand an + # incomplete snapshot to the in-process load: retry over HTTP, and if it + # still comes back incomplete, fail loudly rather than silently loading a + # broken cache. (A patterned / non-model request is judged by its own + # requested subset, so this never rejects a valid weightless snapshot.) if not disable_xet: logger.warning( "Download for '%s' returned an incomplete snapshot -- " @@ -744,7 +763,7 @@ def _download_with_xet_fallback( continue raise DownloadStallError( f"Download for '{label}' returned an incomplete snapshot even with " - f"HF_HUB_DISABLE_XET=1 -- missing weight files, check your network connection" + f"HF_HUB_DISABLE_XET=1 -- missing files, check your network connection" ) return payload # type: ignore[return-value] if kind_result == "cancelled": @@ -927,14 +946,16 @@ def snapshot_download_with_xet_fallback( # local_files_only returns a snapshot dir whenever refs/ and # snapshots/ exist, even one left by a prior interrupted or patterned # download (a config-only snapshot from an AutoConfig fetch, or a partial - # shard pull). Validate that the EXACT returned revision dir is actually - # complete -- no dangling symlinks AND its weight files present on disk -- - # and complete it in the killable child otherwise, so the in-process load - # never proceeds with missing weights. Scope the check to the returned - # snapshot, NOT the whole repo: an unrelated revision mid-download (a stale - # .incomplete blob or a broken older snapshot elsewhere in the same repo - # cache) must not force a needless re-fetch of a complete snapshot. - if snapshot_dir_is_complete(Path(cached_dir)): + # shard pull). Validate the EXACT returned revision dir against the request: + # a full model warmup requires its weight files on disk, a patterned / non-model + # request only its referenced files (no dangling symlinks). Complete it in the + # killable child otherwise, so the in-process load never proceeds with missing + # files. Scope the check to the returned snapshot, NOT the whole repo: an + # unrelated revision mid-download (a stale .incomplete blob or a broken older + # snapshot elsewhere in the same repo cache) must not force a needless re-fetch. + if _snapshot_is_acceptable( + Path(cached_dir), repo_type = repo_type, allow_patterns = allow_patterns + ): return cached_dir logger.debug("Cached snapshot for %s is incomplete; downloading.", repo_id) except Exception as e: From ea8f013c0eff0bb531a121ab421707cbf7d9ad52 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 10:22:49 +0000 Subject: [PATCH 19/82] Honor ignore_patterns and tie the purge grace to the stall timeout (Codex P2) Weight gating now honors ignore_patterns, not just allow_patterns. A model repo fetched with ignore_patterns that drop every weight format (e.g. to warm only config/tokenizer files) legitimately yields a weightless snapshot, but the old check still treated it as a full model warmup and rejected/retried it. Add request_can_include_weights(allow, ignore) in hf_cache_state, which probes one canonical filename per recognized weight format through huggingface_hub.filter_repo_objects; _snapshot_is_acceptable requires weights only for a model repo whose patterns can still include a weight file. Unsloth's default prefetch ignores (onnx/h5/msgpack/gguf, never safetensors) still count as including weights, so model warmups keep requiring them. Active-partial purge grace now follows the stall timeout. The HTTP-prep cleanup used a fixed 30s window, but the watchdog does not declare a download stalled until stall_timeout (180s by default, or larger if the caller raised it). A slow sibling that simply had not written for >30s could be purged. Pass the stall_timeout in use as the grace, so only a partial static for at least the stall threshold (our killed one) is removed and a slower in-flight sibling is left alone. --- tests/test_hf_xet_fallback.py | 40 ++++++++++++++++-- unsloth_zoo/hf_cache_state.py | 47 +++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 77 +++++++++++++++++++--------------- 3 files changed, 126 insertions(+), 38 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 94e1fa972..2e2184926 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -533,7 +533,7 @@ def test_nonstall_error_propagates_without_fallback(monkeypatch): def test_immediate_success_uses_xet_only(monkeypatch): prepared = [] - monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a: prepared.append(a)) + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a, **k: prepared.append(a)) fake = _install(monkeypatch, [("ok", "/cache/model.gguf")]) out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) assert out == "/cache/model.gguf" @@ -543,7 +543,7 @@ def test_immediate_success_uses_xet_only(monkeypatch): def test_stall_then_http_fallback_succeeds(monkeypatch): prepared = [] - monkeypatch.setattr(xf, "_default_prepare_for_http", lambda repo_type, repo_id, cache_dir = None: prepared.append((repo_type, repo_id))) + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda repo_type, repo_id, cache_dir = None, **k: prepared.append((repo_type, repo_id))) fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/model.gguf")]) out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) @@ -557,7 +557,7 @@ def test_stall_then_http_fallback_succeeds(monkeypatch): def test_injected_prepare_for_http_used(monkeypatch): """Studio injects its marker-aware prepare; the generic default must not run.""" monkeypatch.setattr( - xf, "_default_prepare_for_http", lambda *a: pytest.fail("generic prepare ran") + xf, "_default_prepare_for_http", lambda *a, **k: pytest.fail("generic prepare ran") ) injected = [] _install(monkeypatch, [("stall", None), ("ok", "/cache/model.gguf")]) @@ -1054,6 +1054,38 @@ def test_dataset_snapshot_without_weights_is_accepted(monkeypatch, tmp_path): assert out == str(files) and len(fake.calls) == 1 +def test_model_snapshot_with_weights_excluded_is_accepted(monkeypatch, tmp_path): + """A model repo fetched with ignore_patterns that drop every weight format (e.g. to + warm only config / tokenizer files) legitimately yields a weightless snapshot; the + result must be accepted, not rejected for lacking weights.""" + cfg_only = tmp_path / "cfg" + cfg_only.mkdir() + (cfg_only / "config.json").write_text("{}") + fake = _install(monkeypatch, [("ok", str(cfg_only))]) + out = xf.snapshot_download_with_xet_fallback( + DL_REPO, + token = None, + ignore_patterns = [ + "*.safetensors", "*.bin", "*.h5", "*.msgpack", "*.gguf", + "*.pt", "*.pth", "*.ckpt", "*.onnx", "*.pdparams", "*.index.json", + ], + ) + assert out == str(cfg_only) and len(fake.calls) == 1 + + +def test_request_can_include_weights_unit(): + """Unsloth's default prefetch ignores (onnx/h5/msgpack/gguf, never safetensors) still + count as including weights, so model warmups keep requiring them; excluding every + weight format does not.""" + assert hcs.request_can_include_weights(None, None) is True + assert hcs.request_can_include_weights(None, ["*.onnx", "*.h5", "*.msgpack", "*.gguf"]) is True + assert hcs.request_can_include_weights(["config.json"], None) is False + assert hcs.request_can_include_weights( + None, ["*.safetensors", "*.bin", "*.h5", "*.msgpack", "*.gguf", + "*.pt", "*.pth", "*.ckpt", "*.onnx", "*.pdparams", "*.index.json"] + ) is False + + def test_prepare_for_http_spares_active_sibling_partial(hf_cache): """The generic HTTP-prep purge must not unlink a concurrent download's still-active .incomplete temp file: only stale (old-mtime) partials are removed, so a sibling @@ -1073,7 +1105,7 @@ def test_prepare_for_http_spares_active_sibling_partial(hf_cache): def test_snapshot_stall_then_http(monkeypatch): prepared = [] - monkeypatch.setattr(xf, "_default_prepare_for_http", lambda rt, rid, cache_dir = None: prepared.append((rt, rid))) + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda rt, rid, cache_dir = None, **k: prepared.append((rt, rid))) fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/snap-dir")]) out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) assert out == "/cache/snap-dir" diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 1c788d068..ad8dbfd71 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -37,6 +37,7 @@ "latest_snapshot_dir", "snapshot_dir_has_broken_symlinks", "snapshot_dir_is_complete", + "request_can_include_weights", "iter_active_repo_cache_dirs", "repo_cache_dir_has_incomplete_blobs", "has_active_incomplete_blobs", @@ -269,6 +270,52 @@ def snapshot_dir_is_complete(snapshot_dir: Path) -> bool: return has_weight +def request_can_include_weights( + allow_patterns: "Optional[list]" = None, ignore_patterns: "Optional[list]" = None +) -> bool: + """Whether a download restricted by *allow_patterns* / *ignore_patterns* can still + include a model weight file. + + Used to decide whether snapshot completeness should require weights: a request that + filters every weight format out (e.g. ``ignore_patterns`` covering ``*.safetensors`` + and ``*.bin`` to fetch only config / tokenizer files from a model repo) legitimately + yields a weightless snapshot, so requiring a weight there would reject a valid result. + An unfiltered request -- or one any weight filename survives -- includes weights.""" + if not allow_patterns and not ignore_patterns: + return True + try: + from huggingface_hub.utils import filter_repo_objects + except Exception: + return True # cannot evaluate the filter -> assume weights are expected + # One canonical filename per recognized weight format (plus sharded / index variants); + # if any survives the filter, the requested set can include weights. + probes = [ + "model.safetensors", + "model-00001-of-00002.safetensors", + "model.safetensors.index.json", + "pytorch_model.bin", + "pytorch_model-00001-of-00002.bin", + "pytorch_model.bin.index.json", + "tf_model.h5", + "flax_model.msgpack", + "model.gguf", + "model.pt", + "model.pth", + "model.ckpt", + "model.onnx", + "model.pdparams", + ] + try: + kept = list( + filter_repo_objects( + probes, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) + ) + except Exception: + return True + return len(kept) > 0 + + def _iter_snapshot_dirs(repo_dir: Path) -> Iterator[Path]: snapshots_dir = repo_dir / "snapshots" try: diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 59f694250..67523b6af 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -52,6 +52,7 @@ has_active_incomplete_blobs, hf_cache_root, iter_active_repo_cache_dirs, + request_can_include_weights, snapshot_dir_has_broken_symlinks, snapshot_dir_is_complete, ) @@ -156,15 +157,12 @@ def _redact_signed_query(match: "re.Match") -> str: return out -# A *.incomplete blob touched more recently than this is treated as another -# concurrent download's still-active temp file and left in place (see -# _default_prepare_for_http). Smaller than any realistic stall_timeout, so our own -# killed-then-stalled partial (static for >= stall_timeout) is still always purged. -_ACTIVE_PARTIAL_GRACE = 30.0 - - def _default_prepare_for_http( - repo_type: str, repo_id: str, *, cache_dir: Optional[str] = None + repo_type: str, + repo_id: str, + *, + cache_dir: Optional[str] = None, + active_grace: float = DEFAULT_STALL_TIMEOUT, ) -> None: """Generic 'make the partial safe for an HTTP resume': delete the repo's active ``*.incomplete`` blobs (an HTTP resume over a sparse Xet/hf_transfer partial @@ -184,15 +182,15 @@ def _default_prepare_for_http( if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): try: # Do not unlink a partial another concurrent download is - # still actively writing (recent mtime): on POSIX that lets - # the sibling keep writing to an unlinked path and then fail - # when the Hub moves its temp file into place. Our own killed - # partial has been static for the stall timeout, so it is - # older than this window; if a very short stall_timeout left - # it inside the window, the has_active_incomplete_blobs guard - # downstream still forces a clean HTTP re-download, so the - # HTTP attempt never resumes unsafely over it. - if time.time() - blob.stat().st_mtime < _ACTIVE_PARTIAL_GRACE: + # still actively writing: on POSIX that lets the sibling keep + # writing to an unlinked path and then fail when the Hub moves + # its temp file into place. Spare any partial written within + # active_grace (the stall timeout in use): the watchdog only + # declares a download stalled after that long with no growth, + # so a slower sibling that simply has not written recently is + # not stalled and must be left alone. Our own killed partial + # has been static for the full stall timeout, so it is purged. + if time.time() - blob.stat().st_mtime < active_grace: continue blob.unlink() except OSError: @@ -641,28 +639,29 @@ def _run_download_attempt( def _snapshot_is_acceptable( - snapshot_dir: Path, *, repo_type: str, allow_patterns: Any + snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any ) -> bool: """Whether a cached / downloaded snapshot dir is complete enough to use, scoped to the caller's intent. - A FULL model warmup (``repo_type == "model"`` and no ``allow_patterns``) must have its - weight files present: this wrapper exists to warm those weights before an in-process - load, so a result with no weights means the download did not finish (HF silently - returns a stale local snapshot on an offline / timed-out request rather than raising). - - A PATTERNED or non-model snapshot (``allow_patterns`` set, or ``repo_type`` such as - ``"dataset"``) legitimately holds only the requested subset -- e.g. just - ``config.json`` / ``tokenizer*`` files, or dataset files with no model weights at all - -- so requiring a weight file there would reject a perfectly valid result. For those it - is enough that no symlink dangles (every file the snapshot references is on disk).""" - if repo_type == "model" and not allow_patterns: + Weight files are required only when the request can actually include them: a model + repo (``repo_type == "model"``) whose ``allow_patterns`` / ``ignore_patterns`` do not + filter every weight format out. This wrapper exists to warm those weights before an + in-process load, so a result with no weights then means the download did not finish (HF + silently returns a stale local snapshot on an offline / timed-out request rather than + raising). + + A PATTERNED or non-model snapshot that legitimately holds only a subset -- a dataset, or + a model repo fetched with ``allow_patterns=["config.json"]`` or ``ignore_patterns`` that + drop all weights -- would be wrongly rejected by a weight requirement, so for those it is + enough that no symlink dangles (every file the snapshot references is on disk).""" + if repo_type == "model" and request_can_include_weights(allow_patterns, ignore_patterns): return snapshot_dir_is_complete(snapshot_dir) return not snapshot_dir_has_broken_symlinks(snapshot_dir) def _snapshot_payload_incomplete( - payload: Any, *, repo_type: str, allow_patterns: Any + payload: Any, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any ) -> bool: """True when a snapshot download returned a real directory that is not acceptable for the request (see ``_snapshot_is_acceptable``). Guarded to an existing directory so a @@ -679,7 +678,9 @@ def _snapshot_payload_incomplete( return False except OSError: return False - return not _snapshot_is_acceptable(path, repo_type = repo_type, allow_patterns = allow_patterns) + return not _snapshot_is_acceptable( + path, repo_type = repo_type, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) def _download_with_xet_fallback( @@ -713,7 +714,9 @@ def _download_with_xet_fallback( # its own cache accounting and keeps the (repo_type, repo_id) signature. try: if prepare_for_http_fn is None: - _default_prepare_for_http(repo_type, repo_id, cache_dir = cache_dir) + _default_prepare_for_http( + repo_type, repo_id, cache_dir = cache_dir, active_grace = stall_timeout + ) else: prepare_for_http_fn(repo_type, repo_id) except Exception as e: @@ -745,7 +748,10 @@ def _download_with_xet_fallback( if kind_result == "ok": if kind == "snapshot" and _snapshot_payload_incomplete( - payload, repo_type = repo_type, allow_patterns = params.get("allow_patterns") + payload, + repo_type = repo_type, + allow_patterns = params.get("allow_patterns"), + ignore_patterns = params.get("ignore_patterns"), ): # HF can return an existing, incomplete snapshot dir on an offline or # timed-out request instead of fetching the missing files. Never hand an @@ -954,7 +960,10 @@ def snapshot_download_with_xet_fallback( # unrelated revision mid-download (a stale .incomplete blob or a broken older # snapshot elsewhere in the same repo cache) must not force a needless re-fetch. if _snapshot_is_acceptable( - Path(cached_dir), repo_type = repo_type, allow_patterns = allow_patterns + Path(cached_dir), + repo_type = repo_type, + allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, ): return cached_dir logger.debug("Cached snapshot for %s is incomplete; downloading.", repo_id) From d4bf5e6976e9fa84aa61588f666c95a71e331170 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 11:05:41 +0000 Subject: [PATCH 20/82] Tighten snapshot weight-completeness against index-only and shard-only caches request_can_include_weights no longer probes the shard index sidecars (model.safetensors.index.json / pytorch_model.bin.index.json). They are JSON metadata, not weights, so a metadata-only request such as allow_patterns=["*.json"] or ["*.index.json"] now reads as weightless and is accepted instead of being treated as a full weight load that demands shards on disk. snapshot_dir_is_complete now validates numbered shard sets even when no index sidecar was cached: a leftover model-00001-of-00002.safetensors from an interrupted multi-shard pull names its full set, so the missing siblings are detected without a remote manifest and the snapshot reads as incomplete until every shard is present. Adds unit coverage for both: index-only requests resolve weightless, and a single numbered shard without an index reads incomplete until the set is complete. --- tests/test_hf_xet_fallback.py | 29 ++++++++++++++++ unsloth_zoo/hf_cache_state.py | 63 ++++++++++++++++++++++++++++++----- 2 files changed, 84 insertions(+), 8 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 2e2184926..418df1605 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -964,6 +964,22 @@ def test_snapshot_dir_is_complete_missing_shard(tmp_path): assert hcs.snapshot_dir_is_complete(snap) is True +def test_snapshot_dir_is_complete_missing_shard_without_index(tmp_path): + """A leftover single numbered shard with NO index sidecar (an interrupted multi-shard + pull where the index was never cached) must read as incomplete: the shard name itself + states the full set, so the missing siblings are detectable without a manifest.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model-00001-of-00003.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is False # shards 2 and 3 missing + (snap / "model-00002-of-00003.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is False # shard 3 still missing + (snap / "model-00003-of-00003.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + + def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): """HF's local_files_only returns a config-only snapshot (e.g. left by an earlier AutoConfig fetch) without checking weights. The fast path must reject it and complete @@ -1086,6 +1102,19 @@ def test_request_can_include_weights_unit(): ) is False +def test_request_can_include_weights_index_json_only(): + """A metadata-only request that matches the shard *index* sidecars but no real weight + file must read as weightless: the index is JSON, not weights, so a JSON-only warmup + (allow_patterns=['*.json'] or ['*.index.json']) should not be forced to land shards.""" + assert hcs.request_can_include_weights(["*.json"], None) is False + assert hcs.request_can_include_weights(["*.index.json"], None) is False + assert hcs.request_can_include_weights( + ["model.safetensors.index.json", "pytorch_model.bin.index.json"], None + ) is False + # A real weight pattern still counts as including weights. + assert hcs.request_can_include_weights(["*.safetensors"], None) is True + + def test_prepare_for_http_spares_active_sibling_partial(hf_cache): """The generic HTTP-prep purge must not unlink a concurrent download's still-active .incomplete temp file: only stale (old-mtime) partials are removed, so a sibling diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index ad8dbfd71..a0bbb4576 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -23,6 +23,7 @@ from __future__ import annotations +import re import sys from pathlib import Path from typing import Iterator, Optional @@ -209,6 +210,47 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: ) +# Numbered shard naming, e.g. ``model-00001-of-00002.safetensors`` or +# ``pytorch_model-00003-of-00004.bin``: prefix, 1-based index, total, suffix. +_NUMBERED_SHARD_RE = re.compile( + r"^(?P.+)-(?P\d+)-of-(?P\d+)(?P\.[^.]+)$" +) + + +def _numbered_shard_set_present(entry: Path) -> bool: + """For a numbered weight shard (``model-00001-of-00002.safetensors``), True only when + every shard in its ``-of-NNNNN`` set is present in the same directory. + + A leftover single shard from an interrupted multi-shard download reads as a weight + file on its own, so without this an incomplete pull (one shard on disk, the rest + never fetched) would short-circuit as a warm cache. This catches that even when the + shard *index* sidecar was never cached (so ``_weight_shard_index_complete`` has + nothing to check). A non-numbered / single-file weight matches no shard pattern and + is trivially satisfied.""" + match = _NUMBERED_SHARD_RE.match(entry.name) + if match is None: + return True + total_str = match.group("total") + try: + total = int(total_str) + except ValueError: + return True + if total <= 0: + return True + prefix = match.group("prefix") + suffix = match.group("suffix") + width = len(total_str) + base = entry.parent + for i in range(1, total + 1): + shard_name = f"{prefix}-{i:0{width}d}-of-{total_str}{suffix}" + try: + if not (base / shard_name).exists(): + return False + except OSError: + return False + return True + + def _weight_shard_index_complete(index_path: Path) -> bool: """True if every shard a HF weight index (``model.safetensors.index.json`` / ``pytorch_model.bin.index.json``) lists is present next to the index. An unreadable @@ -246,10 +288,11 @@ def snapshot_dir_is_complete(snapshot_dir: Path) -> bool: skips the killable child and lets the in-process load hit Xet on the absent weights. A snapshot is complete only when it has no dangling symlinks, every weight-shard - index it ships resolves all its shards on disk, and it contains at least one weight - file. This does NOT assert that every non-weight file is present (no offline manifest - exists for that); the killable child completes anything else still missing. The aim - is simply to never short-circuit a snapshot whose weights are not on disk.""" + index it ships resolves all its shards on disk, every numbered shard set present has + all its members on disk (even with no index sidecar), and it contains at least one + weight file. This does NOT assert that every non-weight file is present (no offline + manifest exists for that); the killable child completes anything else still missing. + The aim is simply to never short-circuit a snapshot whose weights are not on disk.""" if snapshot_dir_has_broken_symlinks(snapshot_dir): return False try: @@ -266,6 +309,8 @@ def snapshot_dir_is_complete(snapshot_dir: Path) -> bool: return False has_weight = True elif name.endswith(_WEIGHT_FILE_SUFFIXES) and _safe_is_file(entry): + if not _numbered_shard_set_present(entry): + return False has_weight = True return has_weight @@ -287,15 +332,17 @@ def request_can_include_weights( from huggingface_hub.utils import filter_repo_objects except Exception: return True # cannot evaluate the filter -> assume weights are expected - # One canonical filename per recognized weight format (plus sharded / index variants); - # if any survives the filter, the requested set can include weights. + # One canonical filename per recognized weight format (plus sharded variants); + # if any survives the filter, the requested set can include weights. The shard + # *index* sidecars (``*.safetensors.index.json`` / ``*.bin.index.json``) are NOT + # probed: they are JSON metadata, not weights, so a metadata-only request such as + # ``allow_patterns=["*.json"]`` (or ``["*.index.json"]``) must read as weightless + # rather than being treated as a full weight load that requires shards on disk. probes = [ "model.safetensors", "model-00001-of-00002.safetensors", - "model.safetensors.index.json", "pytorch_model.bin", "pytorch_model-00001-of-00002.bin", - "pytorch_model.bin.index.json", "tf_model.h5", "flax_model.msgpack", "model.gguf", From cec935b8cba25bfb4de81311018226072b3d1e3e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 26 Jun 2026 22:14:24 +0000 Subject: [PATCH 21/82] Scope single-file stall to child partials; exclude trainer state; probe path-qualified weights Single-file watchdog (hf_xet_fallback): the no-progress watchdog summed every blob for the repo, so a concurrent sibling download of a different file in the same repo kept resetting last_change with its own progress and a hung child was never killed past stall_timeout. Single-file downloads now capture the partials already on disk before spawning and follow only the child's own new partials, so a sibling's in-flight progress can no longer mask this file's stall. Snapshots keep the repo-wide measurement, since every blob there is part of the one pull. snapshot_dir_is_complete (hf_cache_state): trainer / optimizer state files (training_args.bin, optimizer.pt, scheduler.pt, rng_state.pth, rng_state_N.pth, scaler.pt) carry weight suffixes but are not loadable model weights. A checkpoint or patterned cache holding only those no longer reads as a warm model cache. request_can_include_weights (hf_cache_state): the probe list only held top-level canonical names, so a path-qualified allow_patterns such as checkpoint-10/* or models/*.safetensors, or a bare non-first shard, read as weightless and let the fast path accept a stale snapshot missing the requested weights. Canonical weight probes are now re-rooted under a concrete directory prefix, a bare concrete weight filename is probed verbatim, and a weight-targeting basename under a globbed parent directory is recognized. Config / tokenizer subfolder requests stay weightless. Adds watchdog scoping tests (sibling masking, repo-wide contrast, baseline-only no-fire), a trainer-artifact completeness test, and path-qualified weight-probe coverage. --- tests/test_hf_xet_fallback.py | 131 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 121 ++++++++++++++++++++++++------ unsloth_zoo/hf_xet_fallback.py | 60 ++++++++++++++- 3 files changed, 288 insertions(+), 24 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 418df1605..cef1941d3 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -155,6 +155,95 @@ def test_stall_fires_at_most_once(hf_cache): stop.set() +def test_file_watchdog_scopes_to_child_partial(hf_cache): + """A single-file download follows only its own child's partials. A concurrent sibling + download of a different file in the same repo (its partial already in flight, so in the + baseline) keeps growing, but must not keep resetting this file's stall timer -- the + constant child partial still fires.""" + blobs = _blobs_dir(hf_cache) + sibling = blobs / "sibling.incomplete" # already in flight -> captured in baseline + sibling.write_bytes(b"\0" * 1024) + baseline = {"sibling.incomplete"} + + grow_stop = threading.Event() + + def _grow(): + size = 1024 + while not grow_stop.wait(0.05): + size += 4096 + sibling.write_bytes(b"\0" * size) # healthy sibling keeps making progress + + grower = threading.Thread(target = _grow, daemon = True) + grower.start() + + # This download's child writes its own constant (stalled) partial, not in the baseline. + (blobs / "child.incomplete").write_bytes(b"\0" * 2048) + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + watch_new_partials_only = True, baseline_incomplete_blobs = baseline, + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), ( + "file watchdog never fired: a growing sibling partial masked the stalled child" + ) + finally: + stop.set() + grow_stop.set() + + +def test_repo_wide_watchdog_is_masked_by_sibling(hf_cache): + """Contrast for the single-file scoping: the default repo-wide measurement sums every + blob, so a growing sibling resets the timer and a constant partial never trips. This is + correct for snapshots (all blobs are one pull) and is exactly what file-scoping avoids.""" + blobs = _blobs_dir(hf_cache) + sibling = blobs / "sibling.incomplete" + sibling.write_bytes(b"\0" * 1024) + (blobs / "child.incomplete").write_bytes(b"\0" * 2048) # constant + + grow_stop = threading.Event() + + def _grow(): + size = 1024 + while not grow_stop.wait(0.05): + size += 4096 + sibling.write_bytes(b"\0" * size) + + grower = threading.Thread(target = _grow, daemon = True) + grower.start() + + calls: list[str] = [] + stop = xf.start_watchdog( # default: repo-wide (watch_new_partials_only = False) + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + ) + try: + time.sleep(1.0) # well past stall_timeout, but repo-wide bytes keep growing + assert calls == [], "repo-wide watchdog should be reset by the growing sibling" + finally: + stop.set() + grow_stop.set() + + +def test_file_watchdog_ignores_baseline_only_partials(hf_cache): + """If the only active partial is a baseline sibling's (this child has not written one + yet), the file watchdog sees no owned progress and must not fire: there is nothing of + ours to stall on, so post-spawn metadata/connect time is never misread as our stall.""" + blobs = _blobs_dir(hf_cache) + (blobs / "sibling.incomplete").write_bytes(b"\0" * 4096) # constant baseline sibling + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.2, + watch_new_partials_only = True, baseline_incomplete_blobs = {"sibling.incomplete"}, + ) + try: + time.sleep(0.8) + assert calls == [], "file watchdog fired on a baseline sibling partial it must ignore" + finally: + stop.set() + + def test_get_state_empty_cache(hf_cache): assert xf.get_hf_download_state([REPO]) == (0, False) @@ -980,6 +1069,25 @@ def test_snapshot_dir_is_complete_missing_shard_without_index(tmp_path): assert hcs.snapshot_dir_is_complete(snap) is True +def test_snapshot_dir_is_complete_ignores_trainer_artifacts(tmp_path): + """Trainer / optimizer state files carry weight suffixes (.bin/.pt/.pth) but are not + loadable model weights. A checkpoint cache holding only those must read as incomplete + so the killable child still fetches the real weights.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + for junk in ( + "training_args.bin", "optimizer.pt", "scheduler.pt", + "rng_state.pth", "rng_state_0.pth", "scaler.pt", + ): + (snap / junk).symlink_to(blob) + (snap / "config.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap) is False # only trainer state, no weights + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True # real weight present + + def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): """HF's local_files_only returns a config-only snapshot (e.g. left by an earlier AutoConfig fetch) without checking weights. The fast path must reject it and complete @@ -1115,6 +1223,29 @@ def test_request_can_include_weights_index_json_only(): assert hcs.request_can_include_weights(["*.safetensors"], None) is True +def test_request_can_include_weights_path_qualified(): + """Path-qualified allow_patterns must be resolved inside their directory, and a bare + non-first shard recognized, so a subfolder / checkpoint / specific-shard weight request + is not misread as weightless (which would skip the killable child).""" + # Concrete subfolder globs: weights live under the directory. + assert hcs.request_can_include_weights(["checkpoint-10/*"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/*.safetensors"], None) is True + assert hcs.request_can_include_weights(["models/*.bin"], None) is True + # A specific (non-first) shard named verbatim. + assert hcs.request_can_include_weights(["model-00002-of-00005.safetensors"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/pytorch_model.bin"], None) is True + # Globbed parent dir, weight-targeting basename -> can include weights. + assert hcs.request_can_include_weights(["checkpoint-*/*.safetensors"], None) is True + # Subfolder requests that target only non-weight files stay weightless. + assert hcs.request_can_include_weights(["checkpoint-10/config.json"], None) is False + assert hcs.request_can_include_weights(["checkpoint-10/*.json"], None) is False + assert hcs.request_can_include_weights(["checkpoint-*/tokenizer.json"], None) is False + # The unsloth subfolder warmup shape: "/*" plus root aux files -> weights expected. + assert hcs.request_can_include_weights( + ["checkpoint-10/*", "config.json", "tokenizer.json"], None + ) is True + + def test_prepare_for_http_spares_active_sibling_partial(hf_cache): """The generic HTTP-prep purge must not unlink a concurrent download's still-active .incomplete temp file: only stale (old-mtime) partials are removed, so a sibling diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index a0bbb4576..537581ae4 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -209,6 +209,38 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: ".pdparams", ) +# Trainer / optimizer state files carry weight suffixes (.bin / .pt / .pth) but are NOT +# loadable model weights. A checkpoint dir or a patterned pull can leave only these behind, +# so they must not satisfy the "snapshot holds its weights" check (which would skip the +# killable download while from_pretrained still lacks real weights). +_NON_WEIGHT_BASENAMES = frozenset({ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.bin", + "scheduler.pt", + "scaler.pt", + "rng_state.pt", + "rng_state.pth", +}) +# Distributed trainer runs shard the RNG state as rng_state_0.pth, rng_state_1.pth, ... +_NON_WEIGHT_BASENAME_PREFIXES = ("rng_state_",) + + +def _is_loadable_weight_file(name: str) -> bool: + """True if *name* is a loadable model-weight file: a recognized weight suffix that is + not a known trainer / optimizer state artifact (training_args.bin, optimizer.pt, + scheduler.pt, rng_state.pth, ...). Those share weight suffixes but are not model + weights, so a cache holding only them is not a warm model cache.""" + if not name.endswith(_WEIGHT_FILE_SUFFIXES): + return False + lowered = name.lower() + if lowered in _NON_WEIGHT_BASENAMES: + return False + if any(lowered.startswith(prefix) for prefix in _NON_WEIGHT_BASENAME_PREFIXES): + return False + return True + # Numbered shard naming, e.g. ``model-00001-of-00002.safetensors`` or # ``pytorch_model-00003-of-00004.bin``: prefix, 1-based index, total, suffix. @@ -308,13 +340,51 @@ def snapshot_dir_is_complete(snapshot_dir: Path) -> bool: if not _weight_shard_index_complete(entry): return False has_weight = True - elif name.endswith(_WEIGHT_FILE_SUFFIXES) and _safe_is_file(entry): + elif _is_loadable_weight_file(name) and _safe_is_file(entry): if not _numbered_shard_set_present(entry): return False has_weight = True return has_weight +# One canonical filename per recognized weight format (plus a representative first shard); +# the probe set for "can this request include a weight file". The shard *index* sidecars +# (``*.safetensors.index.json`` / ``*.bin.index.json``) are intentionally absent: they are +# JSON metadata, not weights, so a metadata-only request such as ``allow_patterns=["*.json"]`` +# (or ``["*.index.json"]``) must read as weightless rather than as a full weight load. +_WEIGHT_PROBE_NAMES = ( + "model.safetensors", + "model-00001-of-00002.safetensors", + "pytorch_model.bin", + "pytorch_model-00001-of-00002.bin", + "tf_model.h5", + "flax_model.msgpack", + "model.gguf", + "model.pt", + "model.pth", + "model.ckpt", + "model.onnx", + "model.pdparams", +) + +_GLOB_CHARS = ("*", "?", "[") + + +def _has_glob(text: str) -> bool: + return any(ch in text for ch in _GLOB_CHARS) + + +def _pattern_basename_targets_weight(pattern: str) -> bool: + """True if *pattern*'s final path component looks like it selects a weight file: a + catch-all (``*`` / ``**``) or a name / glob ending in a recognized weight suffix. + Used only when the pattern's parent directory is itself globbed, so no concrete probe + path can be formed.""" + base = pattern.rsplit("/", 1)[-1].lower() + if base in ("*", "**"): + return True + return base.endswith(_WEIGHT_FILE_SUFFIXES) + + def request_can_include_weights( allow_patterns: "Optional[list]" = None, ignore_patterns: "Optional[list]" = None ) -> bool: @@ -325,33 +395,40 @@ def request_can_include_weights( filters every weight format out (e.g. ``ignore_patterns`` covering ``*.safetensors`` and ``*.bin`` to fetch only config / tokenizer files from a model repo) legitimately yields a weightless snapshot, so requiring a weight there would reject a valid result. - An unfiltered request -- or one any weight filename survives -- includes weights.""" + An unfiltered request -- or one any weight filename survives -- includes weights. + + Path-qualified requests are handled too: ``allow_patterns`` such as + ``["checkpoint-10/*"]`` or ``["models/*.safetensors"]`` probe the canonical weight + names re-rooted under that directory, and a bare non-first shard like + ``["model-00002-of-00005.safetensors"]`` is probed verbatim, so a request that does + target weights inside a subfolder / at a specific shard is not misread as weightless.""" if not allow_patterns and not ignore_patterns: return True try: from huggingface_hub.utils import filter_repo_objects except Exception: return True # cannot evaluate the filter -> assume weights are expected - # One canonical filename per recognized weight format (plus sharded variants); - # if any survives the filter, the requested set can include weights. The shard - # *index* sidecars (``*.safetensors.index.json`` / ``*.bin.index.json``) are NOT - # probed: they are JSON metadata, not weights, so a metadata-only request such as - # ``allow_patterns=["*.json"]`` (or ``["*.index.json"]``) must read as weightless - # rather than being treated as a full weight load that requires shards on disk. - probes = [ - "model.safetensors", - "model-00001-of-00002.safetensors", - "pytorch_model.bin", - "pytorch_model-00001-of-00002.bin", - "tf_model.h5", - "flax_model.msgpack", - "model.gguf", - "model.pt", - "model.pth", - "model.ckpt", - "model.onnx", - "model.pdparams", - ] + + probes = list(_WEIGHT_PROBE_NAMES) + for pat in (allow_patterns or ()): + if "/" in pat: + prefix = pat.rsplit("/", 1)[0] + if _has_glob(prefix): + # Globbed parent dir (e.g. "checkpoint-*/*.safetensors"): no concrete path + # to test, so decide from the basename. Only a weight-targeting basename + # flips this on, so config/tokenizer globs under a wildcard dir stay + # weightless and are not forced into a strict weight check. + if _pattern_basename_targets_weight(pat): + return True + continue + # Concrete parent dir: re-root the canonical weight probes under it so a + # path-qualified request is checked inside that directory, not only at the root. + probes.extend(f"{prefix}/{name}" for name in _WEIGHT_PROBE_NAMES) + # A bare concrete weight filename (e.g. a specific non-first shard, or a + # subfolder-qualified weight) is itself a probe the filter can match verbatim. + if not _has_glob(pat) and pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): + probes.append(pat) + try: kept = list( filter_repo_objects( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 67523b6af..a58d9fad5 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -222,6 +222,33 @@ def _default_prepare_for_http( logger.debug("default prepare_for_http failed for %s: %s", repo_id, e) +def _active_incomplete_blob_sizes( + repo_type: Optional[str], repo_id: str, cache_dir: Optional[str] = None +) -> dict[str, int]: + """Map ``{blob_filename: bytes_present}`` for the repo's ``*.incomplete`` partials. + + Sparse-aware (st_blocks based). The single-file watchdog uses this to follow only the + partials its own child created, so a concurrent sibling download of a different file in + the same repo (its partial already present when this download began) cannot mask this + file's stall by contributing its own progress. + """ + sizes: dict[str, int] = {} + try: + for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): + blobs_dir = entry / "blobs" + if not blobs_dir.is_dir(): + continue + for blob in blobs_dir.iterdir(): + try: + if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): + sizes[blob.name] = blob_bytes_present(blob) + except OSError: + pass + except Exception: + pass + return sizes + + def get_hf_download_state( repo_ids: Optional[list[str]] = None, *, @@ -279,6 +306,8 @@ def start_watchdog( stall_timeout: float = DEFAULT_STALL_TIMEOUT, xet_disabled: bool = False, on_heartbeat: Optional[Callable[[str], None]] = None, + watch_new_partials_only: bool = False, + baseline_incomplete_blobs: Optional[set] = None, ) -> threading.Event: """Start a daemon thread that fires ``on_stall(message)`` exactly once iff a ``*.incomplete`` is present AND the on-disk size is unchanged for @@ -286,19 +315,35 @@ def start_watchdog( post-download init is never misread as a stall. Scans *cache_dir* when the download targets a caller-supplied cache, else the active ``HF_HUB_CACHE``. Returns a stop event the caller sets when the download phase ends. + + When *watch_new_partials_only* is set (single-file downloads), progress is measured + only over ``*.incomplete`` partials that were NOT present in *baseline_incomplete_blobs* + (captured before the child started) -- i.e. the child's own partials. This keeps a + concurrent sibling download of a different file in the same repo from resetting the + stall timer with its progress, which would otherwise keep a hung child alive forever. + Snapshot downloads keep the repo-wide measurement (every blob is part of the one pull). """ stop = threading.Event() transport = "https" if xet_disabled else "xet" fired = False + baseline = set(baseline_incomplete_blobs or ()) + single_repo_id = repo_ids[0] if repo_ids else "" + + def _measure() -> Optional[tuple[int, bool]]: + if watch_new_partials_only: + sizes = _active_incomplete_blob_sizes(repo_type, single_repo_id, cache_dir) + owned = {name: n for name, n in sizes.items() if name not in baseline} + return (sum(owned.values()), len(owned) > 0) + return get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) def _beat() -> None: nonlocal fired - state = get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) + state = _measure() last_size = state[0] if state is not None else 0 last_change = time.monotonic() while not stop.wait(interval): - state = get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) + state = _measure() now = time.monotonic() if state is None: @@ -493,6 +538,15 @@ def _run_download_attempt( Returns ``("ok", path)``, ``("stall", None)``, ``("cancelled", None)``, or ``("error", message)``. This is the seam tests monkeypatch to avoid spawning. """ + # A single-file download scopes its stall detection to its own child's partials. + # Capture the partials already on disk for this repo BEFORE spawning, so the watchdog + # can ignore a concurrent sibling's in-flight partial (a different file in the same + # repo) and only follow the blob(s) this child newly writes. Snapshots stay repo-wide. + baseline_partials: Optional[set] = None + if kind == "file": + baseline_partials = set( + _active_incomplete_blob_sizes(repo_type, repo_id, params.get("cache_dir")) + ) result_queue: Any = _CTX.Queue() proc = _CTX.Process( target = _download_child_entry, @@ -591,6 +645,8 @@ def _run_download_attempt( stall_timeout = stall_timeout, xet_disabled = disable_xet, on_heartbeat = on_status, + watch_new_partials_only = (kind == "file"), + baseline_incomplete_blobs = baseline_partials, ) result: Optional[dict] = None From 047d8eefa0fc0dfc4048e5e02fb4203331959f62 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 27 Jun 2026 02:27:08 +0000 Subject: [PATCH 22/82] Detect hung resumed single-file downloads; accept string allow_patterns Single-file watchdog (hf_xet_fallback): scoping by "partials not present before the child started" missed a hung resume. Hugging Face resumes a prior interrupted download by reusing the same blob-hash .incomplete filename, so the resumed partial stayed in the baseline, owned stayed empty, has_incomplete read false, and a stalled Xet resume was never declared stalled or retried over HTTP. The watchdog now identifies the child's own partial by the .incomplete blobs the child process actually has open (via psutil, falling back to /proc//fd), which is precise across a resumed partial that reuses a baseline filename and still excludes a concurrent sibling download's partial (held open by a different pid). The earlier baseline-name exclusion remains as a fallback only where neither psutil nor /proc is available. Snapshots keep the repo-wide measurement. request_can_include_weights (hf_cache_state): allow_patterns / ignore_patterns accept the bare-string form Hugging Face also accepts. Iterating a string walked it character by character, so "checkpoint-10/*" added no probes and a subfolder weight request was misclassified as weightless, letting the fast path accept a config-only cached snapshot. The arguments are now normalized to a list first. Adds tests for a hung resumed-partial stall, pid-scoped exclusion of an unowned growing sibling, and the string allow/ignore pattern forms. --- tests/test_hf_xet_fallback.py | 99 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 18 ++++++- unsloth_zoo/hf_xet_fallback.py | 63 +++++++++++++++++++--- 3 files changed, 173 insertions(+), 7 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index cef1941d3..eb7ac5409 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -244,6 +244,90 @@ def test_file_watchdog_ignores_baseline_only_partials(hf_cache): stop.set() +def _spawn_holding_open(path: Path) -> "subprocess.Popen": + """A real child process that opens *path* and holds it open without writing, modelling a + hung download. Prints 'ok' once the file is open so the caller can synchronize.""" + code = ( + "import sys, time\n" + "f = open(sys.argv[1], 'r+b')\n" + "sys.stdout.write('ok'); sys.stdout.flush()\n" + "time.sleep(30)\n" + ) + proc = subprocess.Popen( + [sys.executable, "-c", code, str(path)], stdout = subprocess.PIPE + ) + assert proc.stdout.read(2) == b"ok" # wait until the child holds the file open + return proc + + +def test_file_watchdog_detects_resumed_baseline_partial(hf_cache): + """A resumed single-file download reuses the prior blob-hash .incomplete, so it sits in + the baseline. Name-based exclusion would never flag a hung resume; scoping to the + partials the child process holds open detects it.""" + blobs = _blobs_dir(hf_cache) + partial = blobs / "resumed.incomplete" + partial.write_bytes(b"\0" * 4096) # leftover from a prior interrupted download + baseline = {"resumed.incomplete"} # present before the (resuming) child starts + + child = _spawn_holding_open(partial) # hung resume: holds it open, never grows it + try: + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + watch_new_partials_only = True, baseline_incomplete_blobs = baseline, + child_pid = child.pid, + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), ( + "watchdog did not fire on a hung resume of a baseline partial" + ) + finally: + stop.set() + finally: + child.terminate() + child.wait(timeout = 5) + + +def test_file_watchdog_pid_scope_ignores_unowned_sibling(hf_cache): + """With pid scoping, a sibling partial this child does NOT hold open is ignored even if + it grows, so the child's own constant partial still trips the stall.""" + blobs = _blobs_dir(hf_cache) + owned_partial = blobs / "owned.incomplete" + owned_partial.write_bytes(b"\0" * 2048) # the child holds this open, constant (hung) + sibling = blobs / "sibling.incomplete" + sibling.write_bytes(b"\0" * 1024) + + grow_stop = threading.Event() + + def _grow(): + size = 1024 + while not grow_stop.wait(0.05): + size += 4096 + sibling.write_bytes(b"\0" * size) # an unrelated sibling making progress + + grower = threading.Thread(target = _grow, daemon = True) + grower.start() + + child = _spawn_holding_open(owned_partial) + try: + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + watch_new_partials_only = True, baseline_incomplete_blobs = set(), + child_pid = child.pid, + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), ( + "pid-scoped watchdog never fired: an unowned growing sibling masked the stall" + ) + finally: + stop.set() + finally: + grow_stop.set() + child.terminate() + child.wait(timeout = 5) + + def test_get_state_empty_cache(hf_cache): assert xf.get_hf_download_state([REPO]) == (0, False) @@ -1246,6 +1330,21 @@ def test_request_can_include_weights_path_qualified(): ) is True +def test_request_can_include_weights_string_form(): + """Hugging Face accepts allow / ignore patterns as a bare string; it must be treated as + one pattern, not iterated character by character (which would misclassify a subfolder + weight request as weightless).""" + assert hcs.request_can_include_weights("checkpoint-10/*", None) is True + assert hcs.request_can_include_weights("*.safetensors", None) is True + assert hcs.request_can_include_weights("config.json", None) is False + assert hcs.request_can_include_weights(None, "*.safetensors") is True # ignore as str + # A string ignore that drops every weight format leaves the request weightless. + assert hcs.request_can_include_weights( + "config.json", ["*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", + "*.h5", "*.msgpack", "*.ckpt", "*.onnx", "*.pdparams"] + ) is False + + def test_prepare_for_http_spares_active_sibling_partial(hf_cache): """The generic HTTP-prep purge must not unlink a concurrent download's still-active .incomplete temp file: only stale (old-mtime) partials are removed, so a sibling diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 537581ae4..4f1e74d37 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -374,6 +374,17 @@ def _has_glob(text: str) -> bool: return any(ch in text for ch in _GLOB_CHARS) +def _as_pattern_list(patterns: "Optional[object]") -> "Optional[list]": + """Normalize an allow / ignore pattern argument to a list. Hugging Face accepts a bare + ``str`` as well as a list, and iterating the ``str`` form would walk it character by + character (so ``"checkpoint-10/*"`` would never match), misclassifying the request.""" + if patterns is None: + return None + if isinstance(patterns, str): + return [patterns] + return list(patterns) + + def _pattern_basename_targets_weight(pattern: str) -> bool: """True if *pattern*'s final path component looks like it selects a weight file: a catch-all (``*`` / ``**``) or a name / glob ending in a recognized weight suffix. @@ -401,7 +412,12 @@ def request_can_include_weights( ``["checkpoint-10/*"]`` or ``["models/*.safetensors"]`` probe the canonical weight names re-rooted under that directory, and a bare non-first shard like ``["model-00002-of-00005.safetensors"]`` is probed verbatim, so a request that does - target weights inside a subfolder / at a specific shard is not misread as weightless.""" + target weights inside a subfolder / at a specific shard is not misread as weightless. + + *allow_patterns* / *ignore_patterns* accept the ``str`` or ``list[str]`` forms that + Hugging Face itself accepts.""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) if not allow_patterns and not ignore_patterns: return True try: diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index a58d9fad5..a1c7d794e 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -249,6 +249,46 @@ def _active_incomplete_blob_sizes( return sizes +def _child_open_incomplete_blobs(pid: int) -> Optional[set]: + """Basenames of the ``*.incomplete`` blob files the download child *pid* currently has + open. + + This pinpoints exactly the partial THIS child is writing -- including a resumed prior + partial that reuses the same blob-hash filename (which Hugging Face does on a retry), so + a hung resume is still detected -- without confusing it for a concurrent sibling + download's partial (held open by a different pid). Returns ``None`` when it cannot be + determined (no ``psutil`` and no ``/proc``, or the process is gone), so the caller falls + back to a coarser measure; an empty set means the child is running but not yet writing a + partial (connect / metadata phase). + """ + # Cross-platform (Linux / macOS / Windows) when psutil is available. + try: + import psutil # type: ignore + except ImportError: + psutil = None # type: ignore + if psutil is not None: + try: + files = psutil.Process(pid).open_files() + except Exception: + return None + return {os.path.basename(f.path) for f in files if f.path.endswith(INCOMPLETE_SUFFIX)} + # Linux fallback: read the open fds directly from /proc. + fd_dir = f"/proc/{pid}/fd" + try: + entries = os.listdir(fd_dir) + except OSError: + return None # no /proc (non-Linux) or the process is already gone + open_blobs: set = set() + for fd in entries: + try: + target = os.readlink(os.path.join(fd_dir, fd)) + except OSError: + continue + if target.endswith(INCOMPLETE_SUFFIX): + open_blobs.add(os.path.basename(target)) + return open_blobs + + def get_hf_download_state( repo_ids: Optional[list[str]] = None, *, @@ -308,6 +348,7 @@ def start_watchdog( on_heartbeat: Optional[Callable[[str], None]] = None, watch_new_partials_only: bool = False, baseline_incomplete_blobs: Optional[set] = None, + child_pid: Optional[int] = None, ) -> threading.Event: """Start a daemon thread that fires ``on_stall(message)`` exactly once iff a ``*.incomplete`` is present AND the on-disk size is unchanged for @@ -316,11 +357,13 @@ def start_watchdog( download targets a caller-supplied cache, else the active ``HF_HUB_CACHE``. Returns a stop event the caller sets when the download phase ends. - When *watch_new_partials_only* is set (single-file downloads), progress is measured - only over ``*.incomplete`` partials that were NOT present in *baseline_incomplete_blobs* - (captured before the child started) -- i.e. the child's own partials. This keeps a - concurrent sibling download of a different file in the same repo from resetting the - stall timer with its progress, which would otherwise keep a hung child alive forever. + When *watch_new_partials_only* is set (single-file downloads), progress is measured only + over the child's own partial, so a concurrent sibling download of a different file in the + same repo cannot reset the stall timer with its progress (which would keep a hung child + alive forever). The child's partial is identified, in order of preference, by the + ``*.incomplete`` blobs the *child_pid* process actually has open (precise across a + resumed download that reuses a prior blob-hash filename), else by the partials that did + NOT already exist in *baseline_incomplete_blobs* (captured before the child started). Snapshot downloads keep the repo-wide measurement (every blob is part of the one pull). """ stop = threading.Event() @@ -332,7 +375,14 @@ def start_watchdog( def _measure() -> Optional[tuple[int, bool]]: if watch_new_partials_only: sizes = _active_incomplete_blob_sizes(repo_type, single_repo_id, cache_dir) - owned = {name: n for name, n in sizes.items() if name not in baseline} + open_names = _child_open_incomplete_blobs(child_pid) if child_pid else None + if open_names is not None: + # Precise: only the partials this child holds open (handles a resumed + # partial that reuses a baseline blob-hash name, and excludes siblings). + owned = {name: n for name, n in sizes.items() if name in open_names} + else: + # Fallback (no psutil / no /proc): follow only newly-created partials. + owned = {name: n for name, n in sizes.items() if name not in baseline} return (sum(owned.values()), len(owned) > 0) return get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) @@ -647,6 +697,7 @@ def _run_download_attempt( on_heartbeat = on_status, watch_new_partials_only = (kind == "file"), baseline_incomplete_blobs = baseline_partials, + child_pid = proc.pid, ) result: Optional[dict] = None From e2eba7c485fbfc3a2addd35afe3b603438db12fd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 27 Jun 2026 02:52:28 +0000 Subject: [PATCH 23/82] Probe adapter / consolidated weight globs; bump version for the helper request_can_include_weights (hf_cache_state): a weight-selecting basename glob whose stem is not the canonical "model" -- adapter_model.* (PEFT), consolidated.* (original / consolidated checkpoints), diffusion_pytorch_model.* (diffusers) -- matched none of the canonical probes and was misread as weightless, letting the fast path accept a stale snapshot missing the requested adapter / consolidated weights instead of spawning the guarded download. The probe set now carries one representative name per naming convention, so those globs resolve to a probe. Bumps __version__ to 2026.6.8: this is the unsloth_zoo release that first ships the hf_xet_fallback / hf_cache_state helpers, so a distinct version exists for unsloth's dependency pin to require (an env on the prior 2026.6.7 lacks them). Adds coverage for adapter / consolidated / diffusers weight globs and a non-weight basename glob staying weightless. --- tests/test_hf_xet_fallback.py | 14 ++++++++++++++ unsloth_zoo/__init__.py | 2 +- unsloth_zoo/hf_cache_state.py | 18 +++++++++++++----- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index eb7ac5409..45b2185f5 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1330,6 +1330,20 @@ def test_request_can_include_weights_path_qualified(): ) is True +def test_request_can_include_weights_weight_selecting_globs(): + """Weight-selecting basename globs whose stem is not the canonical 'model' -- PEFT + adapters, consolidated / original checkpoints, diffusers -- must read as including + weights, so a stale snapshot missing them is not accepted on the weightless path.""" + assert hcs.request_can_include_weights(["adapter_model.*"], None) is True + assert hcs.request_can_include_weights(["adapter_model.safetensors"], None) is True + assert hcs.request_can_include_weights(["consolidated.*"], None) is True + assert hcs.request_can_include_weights(["consolidated.00.pth"], None) is True + assert hcs.request_can_include_weights(["diffusion_pytorch_model.*"], None) is True + assert hcs.request_can_include_weights(["adapter*.safetensors"], None) is True + # A non-weight basename glob stays weightless. + assert hcs.request_can_include_weights(["tokenizer.*"], None) is False + + def test_request_can_include_weights_string_form(): """Hugging Face accepts allow / ignore patterns as a bare string; it must be treated as one pattern, not iterated character by character (which would misclassify a subfolder diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 6c0f1d93e..364c3afe3 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2026.6.7" +__version__ = "2026.6.8" import os import warnings diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 4f1e74d37..f5b4cc91e 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -347,16 +347,24 @@ def snapshot_dir_is_complete(snapshot_dir: Path) -> bool: return has_weight -# One canonical filename per recognized weight format (plus a representative first shard); -# the probe set for "can this request include a weight file". The shard *index* sidecars -# (``*.safetensors.index.json`` / ``*.bin.index.json``) are intentionally absent: they are -# JSON metadata, not weights, so a metadata-only request such as ``allow_patterns=["*.json"]`` -# (or ``["*.index.json"]``) must read as weightless rather than as a full weight load. +# Representative loadable-weight filenames -- the probe set for "can this request include a +# weight file". One per recognized format and naming convention (full model, sharded, PEFT +# adapter, consolidated / original checkpoint, diffusers), so a weight-selecting glob like +# ``adapter_model.*`` or ``consolidated.*`` matches a probe and is not misread as weightless. +# The shard *index* sidecars (``*.safetensors.index.json`` / ``*.bin.index.json``) are +# intentionally absent: they are JSON metadata, not weights, so a metadata-only request such +# as ``allow_patterns=["*.json"]`` (or ``["*.index.json"]``) must read as weightless. _WEIGHT_PROBE_NAMES = ( "model.safetensors", "model-00001-of-00002.safetensors", "pytorch_model.bin", "pytorch_model-00001-of-00002.bin", + "adapter_model.safetensors", + "adapter_model.bin", + "consolidated.00.pth", + "consolidated.safetensors", + "diffusion_pytorch_model.safetensors", + "diffusion_pytorch_model.bin", "tf_model.h5", "flax_model.msgpack", "model.gguf", From 5d72351f0ea6c26e82954903ec6b67ea73723bb0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 27 Jun 2026 05:54:23 +0000 Subject: [PATCH 24/82] Validate the requested weight; treat no-slash directory globs as weight requests snapshot_dir_is_complete now takes allow_patterns / ignore_patterns and requires a weight the request actually selects, not just any weight on disk. Before, a request for a specific subset (adapter_model.safetensors, or a checkpoint shard) was satisfied by an unrelated weight the snapshot already carried, so the fast path accepted a stale snapshot missing the requested adapter / checkpoint weights and skipped the killable child. _snapshot_is_acceptable threads the patterns through, so the completeness check matches the cached weight files against them. request_can_include_weights now recognizes a no-slash directory glob such as checkpoint-* or global_step*: Hugging Face's fnmatch "*" spans "/", so these match nested weights (checkpoint-10/model.safetensors) yet only re-rooted probes for slash-containing patterns existed, leaving such a request classified as weightless and a clean stale snapshot acceptable. The canonical probes are now also re-rooted under a concretized form of a no-slash, no-extension glob. A no-slash file glob that carries an extension (tokenizer.*, *.json) stays weightless. Adds coverage for the requested-weight check (root and subfolder) and the no-slash directory-glob classification. --- tests/test_hf_xet_fallback.py | 54 +++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 88 +++++++++++++++++++++++++++++----- unsloth_zoo/hf_xet_fallback.py | 10 +++- 3 files changed, 139 insertions(+), 13 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 45b2185f5..7ec98fde9 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1172,6 +1172,41 @@ def test_snapshot_dir_is_complete_ignores_trainer_artifacts(tmp_path): assert hcs.snapshot_dir_is_complete(snap) is True # real weight present +def test_snapshot_dir_is_complete_requires_the_requested_weight(tmp_path): + """A patterned request is satisfied only by a weight it actually selects: a snapshot that + carries some other weight but not the requested one (e.g. adapter / checkpoint) reads as + incomplete, so the guarded download still runs.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(blob) # base weight only + # Requesting the adapter weight: the base weight does not satisfy it. + assert hcs.snapshot_dir_is_complete( + snap, allow_patterns = ["adapter_model.safetensors"] + ) is False + # No patterns: any loadable weight suffices. + assert hcs.snapshot_dir_is_complete(snap) is True + # Once the requested adapter weight is present, the request is satisfied. + (snap / "adapter_model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete( + snap, allow_patterns = ["adapter_model.safetensors"] + ) is True + + +def test_snapshot_dir_is_complete_requires_requested_subfolder_weight(tmp_path): + """A subfolder request is satisfied only by a weight under that subfolder, not by a + root-level weight the snapshot also carries.""" + snap = tmp_path / "snap" + (snap / "checkpoint-10").mkdir(parents = True) + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(blob) # root weight only + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is False + (snap / "checkpoint-10" / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is True + + def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): """HF's local_files_only returns a config-only snapshot (e.g. left by an earlier AutoConfig fetch) without checking weights. The fast path must reject it and complete @@ -1330,6 +1365,25 @@ def test_request_can_include_weights_path_qualified(): ) is True +def test_request_can_include_weights_no_slash_dir_glob(): + """A no-slash directory glob (checkpoint-*, global_step*) matches nested weights via HF's + fnmatch '*'-spans-'/' rule, so it must read as weight-including; a no-slash file glob with + an extension (tokenizer.*, *.json) stays weightless.""" + assert hcs.request_can_include_weights(["checkpoint-*"], None) is True + assert hcs.request_can_include_weights(["epoch-*"], None) is True + assert hcs.request_can_include_weights(["global_step*"], None) is True + assert hcs.request_can_include_weights(["*"], None) is True + # ignore_patterns that drop every weight format still wins over the dir glob. + assert hcs.request_can_include_weights( + ["checkpoint-*"], + ["*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", + "*.h5", "*.msgpack", "*.ckpt", "*.onnx", "*.pdparams"], + ) is False + # File globs with an extension are not directory globs. + assert hcs.request_can_include_weights(["tokenizer.*"], None) is False + assert hcs.request_can_include_weights(["*.json"], None) is False + + def test_request_can_include_weights_weight_selecting_globs(): """Weight-selecting basename globs whose stem is not the canonical 'model' -- PEFT adapters, consolidated / original checkpoints, diffusers -- must read as including diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index f5b4cc91e..ca596fb4e 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -309,8 +309,13 @@ def _weight_shard_index_complete(index_path: Path) -> bool: return True -def snapshot_dir_is_complete(snapshot_dir: Path) -> bool: - """Best-effort check that a cached snapshot actually holds its model weights. +def snapshot_dir_is_complete( + snapshot_dir: Path, + *, + allow_patterns: "Optional[object]" = None, + ignore_patterns: "Optional[object]" = None, +) -> bool: + """Best-effort check that a cached snapshot actually holds the requested model weights. ``snapshot_download(local_files_only=True)`` returns a snapshot dir whenever ``refs/`` and ``snapshots/`` exist, even one left by a prior interrupted @@ -324,14 +329,19 @@ def snapshot_dir_is_complete(snapshot_dir: Path) -> bool: all its members on disk (even with no index sidecar), and it contains at least one weight file. This does NOT assert that every non-weight file is present (no offline manifest exists for that); the killable child completes anything else still missing. - The aim is simply to never short-circuit a snapshot whose weights are not on disk.""" + The aim is simply to never short-circuit a snapshot whose weights are not on disk. + + When *allow_patterns* / *ignore_patterns* are given, the weight that must be present is + one the request actually selects: a request for ``adapter_model.safetensors`` (or a + specific checkpoint shard) is satisfied only by that weight on disk, not by some other + weight the snapshot happens to also carry. With no patterns, any loadable weight does.""" if snapshot_dir_has_broken_symlinks(snapshot_dir): return False try: entries = list(snapshot_dir.rglob("*")) except OSError: return False - has_weight = False + weight_rels: list = [] for entry in entries: name = entry.name if name.endswith((".safetensors.index.json", ".bin.index.json")): @@ -339,12 +349,32 @@ def snapshot_dir_is_complete(snapshot_dir: Path) -> bool: continue if not _weight_shard_index_complete(entry): return False - has_weight = True elif _is_loadable_weight_file(name) and _safe_is_file(entry): if not _numbered_shard_set_present(entry): return False - has_weight = True - return has_weight + try: + weight_rels.append(entry.relative_to(snapshot_dir).as_posix()) + except ValueError: + weight_rels.append(name) + if not weight_rels: + return False + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + if not allow_patterns and not ignore_patterns: + return True + # A patterned request must find a weight it actually selects on disk, not just any + # weight: the snapshot can carry an unrelated weight while the requested one is missing. + try: + from huggingface_hub.utils import filter_repo_objects + + matched = list( + filter_repo_objects( + weight_rels, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) + ) + except Exception: + return True # cannot evaluate the filter -> do not reject a snapshot that has weights + return len(matched) > 0 # Representative loadable-weight filenames -- the probe set for "can this request include a @@ -404,6 +434,29 @@ def _pattern_basename_targets_weight(pattern: str) -> bool: return base.endswith(_WEIGHT_FILE_SUFFIXES) +def _concretize_glob(pattern: str) -> str: + """Replace glob wildcards in *pattern* with a literal filler so it can stand in as a + concrete directory name (e.g. ``checkpoint-*`` -> ``checkpoint-x``). A ``[...]`` class + collapses to one filler char. Used to probe weights nested under a no-slash directory + glob, since Hugging Face's ``fnmatch`` ``*`` spans ``/``.""" + out = [] + i = 0 + n = len(pattern) + while i < n: + ch = pattern[i] + if ch in ("*", "?"): + out.append("x") + i += 1 + elif ch == "[": + j = pattern.find("]", i + 1) + out.append("x") + i = (j + 1) if j != -1 else (i + 1) + else: + out.append(ch) + i += 1 + return "".join(out) + + def request_can_include_weights( allow_patterns: "Optional[list]" = None, ignore_patterns: "Optional[list]" = None ) -> bool: @@ -448,10 +501,23 @@ def request_can_include_weights( # Concrete parent dir: re-root the canonical weight probes under it so a # path-qualified request is checked inside that directory, not only at the root. probes.extend(f"{prefix}/{name}" for name in _WEIGHT_PROBE_NAMES) - # A bare concrete weight filename (e.g. a specific non-first shard, or a - # subfolder-qualified weight) is itself a probe the filter can match verbatim. - if not _has_glob(pat) and pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): - probes.append(pat) + # A subfolder-qualified concrete weight filename is also a probe verbatim. + if not _has_glob(pat) and pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): + probes.append(pat) + elif not _has_glob(pat): + # A bare concrete weight filename (e.g. a specific non-first shard). + if pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): + probes.append(pat) + elif "." not in pat: + # A no-slash directory glob (e.g. "checkpoint-*"). HF's fnmatch "*" spans "/", so + # it matches nested weights like checkpoint-10/model.safetensors. Probe the + # canonical weights re-rooted under a concretized form of the glob, so the request + # is recognized as weight-including (still subject to ignore_patterns). A no-slash + # glob that carries an extension (e.g. "*.json", "adapter_model.*") is a file glob + # handled by the canonical / representative probes, so it is excluded here -- which + # keeps "tokenizer.*" / "*.json" correctly weightless. + concrete = _concretize_glob(pat) + probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) try: kept = list( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index a1c7d794e..fa635f201 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -761,9 +761,15 @@ def _snapshot_is_acceptable( A PATTERNED or non-model snapshot that legitimately holds only a subset -- a dataset, or a model repo fetched with ``allow_patterns=["config.json"]`` or ``ignore_patterns`` that drop all weights -- would be wrongly rejected by a weight requirement, so for those it is - enough that no symlink dangles (every file the snapshot references is on disk).""" + enough that no symlink dangles (every file the snapshot references is on disk). + + The completeness check is scoped to the requested patterns, so a request for a specific + weight (e.g. ``allow_patterns=["adapter_model.safetensors"]`` or a checkpoint shard) is + satisfied only when THAT weight is on disk, not by some other weight already cached.""" if repo_type == "model" and request_can_include_weights(allow_patterns, ignore_patterns): - return snapshot_dir_is_complete(snapshot_dir) + return snapshot_dir_is_complete( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) return not snapshot_dir_has_broken_symlinks(snapshot_dir) From 7bb19ab09b3cef513d2f363390bb1b4945bd887c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 27 Jun 2026 07:03:43 +0000 Subject: [PATCH 25/82] Harden weight detection for wildcard parents, shard subsets, and silent crashes request_can_include_weights: - Wildcard-parent weight globs (checkpoint-*/adapter_model.*, */model.*) now read as weight-including. The globbed-parent branch no longer early-returns: it re-roots the canonical weight probes under a concretized parent and lets the final filter decide, so a weight-targeting basename is recognized AND ignore_patterns are applied (an early return skipped the ignores, e.g. allow=["checkpoint-*/*"] with a weight-dropping ignore). - Dotted no-slash directory globs whose stem names a checkpoint dir (checkpoint-v1.*, global_step100.*) are recognized as weight-including, while file globs sharing the .* shape (tokenizer.*, *.json) stay weightless, so a tokenizer-only request is not forced into a strict weight check it can never satisfy. snapshot_dir_is_complete now scopes the shard-set check to the request: a deliberate single-shard request (allow_patterns=["model-00002-of-00005.safetensors"]) is satisfied by that one shard and is no longer rejected for missing the rest. The full -of-NNNNN set and the shard index are enforced only for an unpatterned full warm. hf_xet_fallback: a child that exits without enqueuing a result (a process-level crash such as a native hf_xet abort, with no captured Hub exception) is reported as "crashed" and retried over HTTP, rather than raised immediately as a deterministic error -- the other transport may still succeed. Adds coverage for each: wildcard-parent / dotted-checkpoint globs, ignore application under a wildcard parent, single-shard completeness, and the crash-then-HTTP-retry path. --- tests/test_hf_xet_fallback.py | 92 ++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 169 ++++++++++++++++++++++----------- unsloth_zoo/hf_xet_fallback.py | 21 +++- 3 files changed, 225 insertions(+), 57 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 7ec98fde9..76ac214c4 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -704,6 +704,24 @@ def test_nonstall_error_propagates_without_fallback(monkeypatch): assert fake.calls[0].disable_xet is False +def test_crashed_child_retries_over_http(monkeypatch): + """A silent process-level crash (child exits without a result, e.g. a native hf_xet + abort) is not a deterministic error, so it retries over HTTP; a clean second result is + accepted.""" + fake = _install(monkeypatch, [("crashed", "exited without a result"), ("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/cache/x" + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_crashed_child_on_both_transports_raises(monkeypatch): + """If the child crashes on Xet AND on HTTP, surface a hard error after both attempts.""" + fake = _install(monkeypatch, [("crashed", "boom"), ("crashed", "boom")]) + with pytest.raises(RuntimeError, match = "boom"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert [c.disable_xet for c in fake.calls] == [False, True] + + def test_immediate_success_uses_xet_only(monkeypatch): prepared = [] monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a, **k: prepared.append(a)) @@ -871,6 +889,40 @@ def test_xet_attempt_does_not_force_disable_before_spawn(monkeypatch): assert rec["disable_xet"] is None +class _EmptyQueue: + def get(self, timeout = None): + import queue as _queue + raise _queue.Empty + + def get_nowait(self): + import queue as _queue + raise _queue.Empty + + def put(self, item): + pass + + +def test_run_attempt_no_result_is_crashed(monkeypatch): + """A child that exits without enqueuing a result maps to 'crashed' (a process-level + crash that HTTP may still recover), not a deterministic 'error'.""" + rec: dict = {} + + class _Ctx: + def Process(self, *, target = None, kwargs = None, daemon = None): + return _FakeProc(rec) + + def Queue(self): + return _EmptyQueue() + + monkeypatch.setattr(xf, "_CTX", _Ctx()) + kind_result, _ = xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + assert kind_result == "crashed" + + def test_child_skips_gpu_init_env_set_before_spawn_and_restored(monkeypatch): """The download child inherits UNSLOTH_ZOO_DISABLE_GPU_INIT=1 at spawn (so its fresh unsloth_zoo import skips heavy torch/transformers init), and the parent's @@ -1207,6 +1259,22 @@ def test_snapshot_dir_is_complete_requires_requested_subfolder_weight(tmp_path): assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is True +def test_snapshot_dir_is_complete_single_shard_request(tmp_path): + """A deliberate single-shard request is satisfied by that one shard; the full -of-NNNNN + set is required only for an unpatterned full warm.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model-00002-of-00005.safetensors").symlink_to(blob) + # Just the requested shard present -> complete for that request. + assert hcs.snapshot_dir_is_complete( + snap, allow_patterns = ["model-00002-of-00005.safetensors"] + ) is True + # An unpatterned full warm requires the whole set -> incomplete (4 shards missing). + assert hcs.snapshot_dir_is_complete(snap) is False + + def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): """HF's local_files_only returns a config-only snapshot (e.g. left by an earlier AutoConfig fetch) without checking weights. The fast path must reject it and complete @@ -1382,6 +1450,30 @@ def test_request_can_include_weights_no_slash_dir_glob(): # File globs with an extension are not directory globs. assert hcs.request_can_include_weights(["tokenizer.*"], None) is False assert hcs.request_can_include_weights(["*.json"], None) is False + # A dotted no-slash glob whose stem names a checkpoint DIRECTORY still includes weights. + assert hcs.request_can_include_weights(["checkpoint-v1.*"], None) is True + assert hcs.request_can_include_weights(["global_step100.*"], None) is True + + +def test_request_can_include_weights_wildcard_parent(): + """A wildcard parent dir with a weight basename glob (checkpoint-*/adapter_model.*, + */model.*) must read as weight-including, and ignore_patterns must still be applied to a + wildcard-parent request rather than bypassed by an early return.""" + assert hcs.request_can_include_weights(["checkpoint-*/adapter_model.*"], None) is True + assert hcs.request_can_include_weights(["*/model.*"], None) is True + assert hcs.request_can_include_weights(["checkpoint-*/*.safetensors"], None) is True + # ignore_patterns applies under a wildcard parent: dropping every weight format -> weightless. + assert hcs.request_can_include_weights( + ["checkpoint-*/*"], + ["*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", + "*.h5", "*.msgpack", "*.ckpt", "*.onnx", "*.pdparams"], + ) is False + # Dropping only some formats leaves the request able to include the others. + assert hcs.request_can_include_weights( + ["checkpoint-*/*"], ["*.safetensors", "*.bin"] + ) is True + # A non-weight basename under a wildcard parent stays weightless. + assert hcs.request_can_include_weights(["checkpoint-*/tokenizer.json"], None) is False def test_request_can_include_weights_weight_selecting_globs(): diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index ca596fb4e..dc9503350 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -249,16 +249,48 @@ def _is_loadable_weight_file(name: str) -> bool: ) -def _numbered_shard_set_present(entry: Path) -> bool: +def _filter_paths( + paths: list, + allow_patterns: "Optional[list]" = None, + ignore_patterns: "Optional[list]" = None, +) -> list: + """Filter repo-relative *paths* by Hugging Face allow / ignore patterns, mirroring how + ``snapshot_download`` selects files. On any failure, treat all paths as selected so a + snapshot that does hold weights is never rejected for an unevaluable filter.""" + try: + from huggingface_hub.utils import filter_repo_objects + + return list( + filter_repo_objects( + paths, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) + ) + except Exception: + return list(paths) + + +def _numbered_shard_set_present( + entry: Path, + *, + snapshot_dir: "Optional[Path]" = None, + allow_patterns: "Optional[list]" = None, + ignore_patterns: "Optional[list]" = None, +) -> bool: """For a numbered weight shard (``model-00001-of-00002.safetensors``), True only when - every shard in its ``-of-NNNNN`` set is present in the same directory. + every shard in its ``-of-NNNNN`` set that the request selects is present in the same + directory. A leftover single shard from an interrupted multi-shard download reads as a weight file on its own, so without this an incomplete pull (one shard on disk, the rest never fetched) would short-circuit as a warm cache. This catches that even when the shard *index* sidecar was never cached (so ``_weight_shard_index_complete`` has nothing to check). A non-numbered / single-file weight matches no shard pattern and - is trivially satisfied.""" + is trivially satisfied. + + When *allow_patterns* / *ignore_patterns* are given, a sibling shard is required only + if the request actually selects it: a deliberate single-shard request + (``allow_patterns=["model-00002-of-00005.safetensors"]``) is satisfied by that one shard + and must not demand the rest.""" match = _NUMBERED_SHARD_RE.match(entry.name) if match is None: return True @@ -273,10 +305,18 @@ def _numbered_shard_set_present(entry: Path) -> bool: suffix = match.group("suffix") width = len(total_str) base = entry.parent + scoped = bool(allow_patterns or ignore_patterns) and snapshot_dir is not None for i in range(1, total + 1): - shard_name = f"{prefix}-{i:0{width}d}-of-{total_str}{suffix}" + shard_path = base / f"{prefix}-{i:0{width}d}-of-{total_str}{suffix}" + if scoped: + try: + rel = shard_path.relative_to(snapshot_dir).as_posix() + except ValueError: + rel = shard_path.name + if not _filter_paths([rel], allow_patterns, ignore_patterns): + continue # this sibling is not part of the request -> do not require it try: - if not (base / shard_name).exists(): + if not shard_path.exists(): return False except OSError: return False @@ -334,47 +374,61 @@ def snapshot_dir_is_complete( When *allow_patterns* / *ignore_patterns* are given, the weight that must be present is one the request actually selects: a request for ``adapter_model.safetensors`` (or a specific checkpoint shard) is satisfied only by that weight on disk, not by some other - weight the snapshot happens to also carry. With no patterns, any loadable weight does.""" + weight the snapshot happens to also carry. A deliberate single-shard request likewise + requires only that shard, not its whole ``-of-NNNNN`` set. With no patterns, any loadable + weight does, and every numbered shard set present must be complete.""" if snapshot_dir_has_broken_symlinks(snapshot_dir): return False try: entries = list(snapshot_dir.rglob("*")) except OSError: return False - weight_rels: list = [] + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + has_patterns = bool(allow_patterns or ignore_patterns) + + index_entries: list = [] + weight_entries: list = [] # (entry, repo-relative path) for entry in entries: name = entry.name if name.endswith((".safetensors.index.json", ".bin.index.json")): - if not _safe_is_file(entry): - continue - if not _weight_shard_index_complete(entry): - return False + if _safe_is_file(entry): + index_entries.append(entry) elif _is_loadable_weight_file(name) and _safe_is_file(entry): - if not _numbered_shard_set_present(entry): - return False try: - weight_rels.append(entry.relative_to(snapshot_dir).as_posix()) + rel = entry.relative_to(snapshot_dir).as_posix() except ValueError: - weight_rels.append(name) - if not weight_rels: + rel = name + weight_entries.append((entry, rel)) + + # The weights the request selects that are present on disk (any present weight when the + # request is unpatterned). The snapshot can carry an unrelated weight while the requested + # one is missing, so a patterned request must find one it actually selects. + if has_patterns: + selected = set(_filter_paths([rel for _, rel in weight_entries], allow_patterns, ignore_patterns)) + else: + selected = {rel for _, rel in weight_entries} + if not selected: return False - allow_patterns = _as_pattern_list(allow_patterns) - ignore_patterns = _as_pattern_list(ignore_patterns) - if not allow_patterns and not ignore_patterns: - return True - # A patterned request must find a weight it actually selects on disk, not just any - # weight: the snapshot can carry an unrelated weight while the requested one is missing. - try: - from huggingface_hub.utils import filter_repo_objects - matched = list( - filter_repo_objects( - weight_rels, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns - ) - ) - except Exception: - return True # cannot evaluate the filter -> do not reject a snapshot that has weights - return len(matched) > 0 + # Every selected numbered shard needs the sibling shards the request also selects (the + # whole set when unpatterned), so an interrupted multi-shard pull is not read as warm. + for entry, rel in weight_entries: + if rel not in selected: + continue + if not _numbered_shard_set_present( + entry, snapshot_dir = snapshot_dir, + allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + ): + return False + + # A full (unpatterned) warm also validates any shard index ships all its shards; a + # patterned request may legitimately want only a subset, so the index is not enforced. + if not has_patterns: + for index_entry in index_entries: + if not _weight_shard_index_complete(index_entry): + return False + return True # Representative loadable-weight filenames -- the probe set for "can this request include a @@ -423,15 +477,18 @@ def _as_pattern_list(patterns: "Optional[object]") -> "Optional[list]": return list(patterns) -def _pattern_basename_targets_weight(pattern: str) -> bool: - """True if *pattern*'s final path component looks like it selects a weight file: a - catch-all (``*`` / ``**``) or a name / glob ending in a recognized weight suffix. - Used only when the pattern's parent directory is itself globbed, so no concrete probe - path can be formed.""" - base = pattern.rsplit("/", 1)[-1].lower() - if base in ("*", "**"): - return True - return base.endswith(_WEIGHT_FILE_SUFFIXES) +# Stems that, by convention, name a per-checkpoint DIRECTORY (whose weights live inside), +# not a file. Used to disambiguate a dotted no-slash glob like ``checkpoint-v1.*`` (a +# checkpoint directory, weights nested) from a file glob like ``tokenizer.*`` -- both are +# structurally ``.*`` but only the former can include weights. +_CHECKPOINT_DIR_PREFIXES = ( + "checkpoint", "ckpt", "epoch", "step", "global_step", "iter", "iteration", +) + + +def _looks_like_checkpoint_dir(pattern: str) -> bool: + lowered = pattern.lower() + return any(lowered.startswith(prefix) for prefix in _CHECKPOINT_DIR_PREFIXES) def _concretize_glob(pattern: str) -> str: @@ -491,12 +548,14 @@ def request_can_include_weights( if "/" in pat: prefix = pat.rsplit("/", 1)[0] if _has_glob(prefix): - # Globbed parent dir (e.g. "checkpoint-*/*.safetensors"): no concrete path - # to test, so decide from the basename. Only a weight-targeting basename - # flips this on, so config/tokenizer globs under a wildcard dir stay - # weightless and are not forced into a strict weight check. - if _pattern_basename_targets_weight(pat): - return True + # Globbed parent dir (e.g. "checkpoint-*/*.safetensors" or + # "checkpoint-*/adapter_model.*"): re-root the canonical weight probes under a + # concretized form of the parent and let the final filter decide. This both + # recognizes a weight-targeting basename and still applies ignore_patterns -- + # an early True here would skip the ignores and wrongly require weights for, + # e.g., allow=["checkpoint-*/*"] + ignore that drops every weight format. + concrete = _concretize_glob(prefix) + probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) continue # Concrete parent dir: re-root the canonical weight probes under it so a # path-qualified request is checked inside that directory, not only at the root. @@ -508,14 +567,14 @@ def request_can_include_weights( # A bare concrete weight filename (e.g. a specific non-first shard). if pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): probes.append(pat) - elif "." not in pat: - # A no-slash directory glob (e.g. "checkpoint-*"). HF's fnmatch "*" spans "/", so - # it matches nested weights like checkpoint-10/model.safetensors. Probe the - # canonical weights re-rooted under a concretized form of the glob, so the request - # is recognized as weight-including (still subject to ignore_patterns). A no-slash - # glob that carries an extension (e.g. "*.json", "adapter_model.*") is a file glob - # handled by the canonical / representative probes, so it is excluded here -- which - # keeps "tokenizer.*" / "*.json" correctly weightless. + elif "." not in pat or _looks_like_checkpoint_dir(pat): + # A no-slash DIRECTORY glob (e.g. "checkpoint-*", "global_step*", or the dotted + # "checkpoint-v1.*"). HF's fnmatch "*" spans "/", so it matches nested weights like + # checkpoint-10/model.safetensors. Probe the canonical weights re-rooted under a + # concretized form of the glob, so the request is recognized as weight-including + # (still subject to ignore_patterns). A no-slash FILE glob with an extension + # ("*.json", "adapter_model.*", "tokenizer.*") is handled by the canonical / + # representative probes, so it is excluded here and stays correctly weightless. concrete = _concretize_glob(pat) probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index fa635f201..9bfb384fd 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -735,8 +735,12 @@ def _run_download_attempt( _terminate_process_group(proc, grace_period) if result is None: + # The child exited without enqueuing a result: a process-level crash (e.g. a native + # hf_xet abort / segfault), NOT a captured Hub exception. No deterministic error was + # observed, so the other transport may still succeed -- report it as "crashed" so the + # caller can retry over HTTP rather than surfacing a hard error. return ( - "error", + "crashed", f"download process for '{repo_id}' exited " f"(code={proc.exitcode}) without a result", ) @@ -888,7 +892,20 @@ def _download_with_xet_fallback( if kind_result == "cancelled": raise RuntimeError("Cancelled") if kind_result == "error": - # Deterministic failure: the other transport would fail identically. + # Deterministic failure (a captured Hub exception): the other transport would + # fail identically, so do not retry. + raise RuntimeError(payload) + if kind_result == "crashed": + # A process-level crash with no captured exception: HTTP may still succeed, so + # retry over it once before surfacing a hard error (mirrors the stall path). + if not disable_xet: + logger.warning( + "Download process for '%s' crashed without a result -- " + "retrying with HF_HUB_DISABLE_XET=1", label + ) + _safe_status(on_status, f"{label}: download crashed, retrying over HTTP") + disable_xet = True + continue raise RuntimeError(payload) # kind_result == "stall" if not disable_xet: From 2cb8d89c04416ded47fb24c96de6958ecd40ad5d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 27 Jun 2026 08:31:55 +0000 Subject: [PATCH 26/82] Recognize custom weight globs, bracket globs, and every named exact weight Three weight-detection gaps in hf_cache_state could let _snapshot_is_acceptable short-circuit a stale snapshot that lacks the requested weights, sending the in-process load into an unguarded Xet fetch. request_can_include_weights: a no-slash file glob whose basename is not a canonical probe name (lora_*.safetensors, *.bin, model-*.safetensors) added no probe and read as weightless. The no-slash branch now also adds a concretized self-probe whenever the glob's suffix is a weight suffix, so a custom weight glob is recognized while a non-weight glob (*.json, tokenizer.*) and the canonical adapter_model.* path are unchanged. ignore_patterns is still applied by the final filter. _concretize_glob: a [...] class was replaced with a literal x, which does not match the caller's own pattern (checkpoint-[0-9] -> checkpoint-x fails [0-9]), dropping every probe. It now picks a member the class actually matches ([0-9] -> 0, [a-z] -> a, a negated [!0-9] -> a non-excluded filler), so the probe satisfies the pattern. snapshot_dir_is_complete: a request naming multiple exact weights was accepted when only one was present, because the selected set was built from present weights and only checked non-empty. A new require_named_weights flag requires each explicitly named exact weight on disk. It is set only on the pre-download cache short-circuit (where accepting a stale base-only snapshot is the bug); the post-download check stays lenient, so an "either format" list (pytorch_model.bin + model.safetensors) against a safetensors-only repo is never turned into a spurious incomplete-snapshot failure. A glob still selects a subset, and a name the ignore filter drops is not required. Tests cover each: custom weight-suffix and bracket globs, the bracket concretizer, and the pre-download-strict / post-download-lenient named-weight behavior at both the snapshot_dir_is_complete and snapshot_download_with_xet_fallback levels. --- tests/test_hf_xet_fallback.py | 101 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 91 ++++++++++++++++++++++++----- unsloth_zoo/hf_xet_fallback.py | 16 +++++- 3 files changed, 190 insertions(+), 18 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 76ac214c4..e475662b0 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1275,6 +1275,40 @@ def test_snapshot_dir_is_complete_single_shard_request(tmp_path): assert hcs.snapshot_dir_is_complete(snap) is False +def test_snapshot_dir_is_complete_requires_each_named_weight(tmp_path): + """require_named_weights makes a request naming multiple exact weights (base + adapter) + need EACH on disk, so the pre-download cache probe does not short-circuit a stale snapshot + holding only the base. Off (the post-download check) it stays lenient, so an "either + format" list (pytorch_model.bin + model.safetensors) against a safetensors-only repo is not + turned into a spurious incomplete-snapshot failure.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(blob) # base only; adapter missing + pair = ["model.safetensors", "adapter_model.safetensors"] + # Strict (pre-download): adapter missing -> incomplete -> do not short-circuit a stale cache. + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is False + # Lenient (post-download default): a present selected weight suffices. + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = False) is True + # Either-format list, safetensors-only repo: strict still won't short-circuit, but the + # lenient check must NOT reject it (no error-forever on a name that doesn't exist). + either = ["pytorch_model.bin", "model.safetensors"] + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = either, require_named_weights = True) is False + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = either, require_named_weights = False) is True + # Both present -> strict is satisfied. + (snap / "adapter_model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is True + # A named weight the ignore filter drops is not actually requested, so it is not required. + (snap / "adapter_model.safetensors").unlink() + assert hcs.snapshot_dir_is_complete( + snap, allow_patterns = pair, ignore_patterns = ["adapter_model.safetensors"], + require_named_weights = True, + ) is True + # A glob may legitimately select a subset, so it is never forced to be exhaustive. + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"], require_named_weights = True) is True + + def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): """HF's local_files_only returns a config-only snapshot (e.g. left by an earlier AutoConfig fetch) without checking weights. The fast path must reject it and complete @@ -1291,6 +1325,44 @@ def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): assert out == "/cache/snap-fresh" and len(fake.calls) == 1 +def test_fast_path_requires_each_named_weight(hf_cache, monkeypatch): + """The pre-download cache short-circuit must not accept a stale snapshot holding only one + of several explicitly named weights: base + adapter requested, only the base cached -> the + guarded child still runs to fetch the rest (Codex #829).""" + blobs = _blobs_dir(hf_cache, DL_REPO) + snap = blobs.parent / "snapshots" / "sha" + snap.mkdir(parents = True) + base_blob = blobs / "w" + base_blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(base_blob) # base only; adapter missing + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + fake = _install(monkeypatch, [("ok", "/cache/snap-fresh")]) + out = xf.snapshot_download_with_xet_fallback( + DL_REPO, token = None, + allow_patterns = ["model.safetensors", "adapter_model.safetensors"], + ) + assert out == "/cache/snap-fresh" and len(fake.calls) == 1 + + +def test_fast_path_either_format_not_failed_post_download(hf_cache, monkeypatch): + """An "either format" request (pytorch_model.bin + model.safetensors) against a repo that + only ships safetensors must not error: the child's safetensors-only snapshot is accepted + post-download, since the strict named-weight rule is pre-download only (no spurious + incomplete-snapshot failure for a name that does not exist, Codex #829).""" + blobs = _blobs_dir(hf_cache, DL_REPO) + child = blobs.parent / "snapshots" / "fresh" + child.mkdir(parents = True) + w = blobs / "w" + w.write_bytes(b"x") + (child / "model.safetensors").symlink_to(w) # safetensors only; no pytorch_model.bin + fake = _install(monkeypatch, [("ok", str(child))]) + out = xf.snapshot_download_with_xet_fallback( + DL_REPO, token = None, force_download = True, + allow_patterns = ["pytorch_model.bin", "model.safetensors"], + ) + assert out == str(child) and len(fake.calls) == 1 + + def test_child_broken_snapshot_retries_over_http(monkeypatch, tmp_path): """A real but broken child snapshot result (HF offline-fallback returning a dir with dangling symlinks) is rejected on the Xet attempt and retried over HTTP; a clean @@ -1490,6 +1562,35 @@ def test_request_can_include_weights_weight_selecting_globs(): assert hcs.request_can_include_weights(["tokenizer.*"], None) is False +def test_request_can_include_weights_custom_weight_suffix_globs(): + """A no-slash FILE glob whose stem matches no canonical probe but whose suffix is a weight + suffix (lora_*.safetensors, *.bin, model-*.safetensors, custom_*.pt) must read as + weight-including, so a stale snapshot missing it is not accepted on the weightless path. A + non-weight-suffix file glob (*.json) stays weightless, and ignore_patterns still wins.""" + assert hcs.request_can_include_weights(["lora_*.safetensors"], None) is True + assert hcs.request_can_include_weights(["*.bin"], None) is True + assert hcs.request_can_include_weights(["model-*.safetensors"], None) is True + assert hcs.request_can_include_weights(["my_custom_*.pt"], None) is True + assert hcs.request_can_include_weights(["*.json"], None) is False + # ignore_patterns dropping that very format wins over the weight-suffix glob. + assert hcs.request_can_include_weights(["lora_*.safetensors"], ["*.safetensors"]) is False + + +def test_request_can_include_weights_bracket_globs(): + """A bracket / range glob (checkpoint-[0-9]/*.safetensors, model-[0-9].safetensors) is + concretized to a member the class actually matches, so the weight probe still satisfies the + caller's own pattern and the request reads as weight-including, not misclassified + weightless.""" + assert hcs.request_can_include_weights(["checkpoint-[0-9]/*.safetensors"], None) is True + assert hcs.request_can_include_weights(["model-[0-9].safetensors"], None) is True + assert hcs.request_can_include_weights(["ckpt-[0-9][0-9]/*"], None) is True + # The concretizer picks an in-class member, not a literal 'x' that the class would reject. + assert hcs._concretize_glob("checkpoint-[0-9]") == "checkpoint-0" + assert hcs._concretize_glob("layer_[a-f]") == "layer_a" + # A negated class yields a filler the class does not exclude (here a non-digit). + assert not hcs._concretize_glob("checkpoint-[!0-9]").endswith(tuple("0123456789")) + + def test_request_can_include_weights_string_form(): """Hugging Face accepts allow / ignore patterns as a bare string; it must be treated as one pattern, not iterated character by character (which would misclassify a subfolder diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index dc9503350..4c7349407 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -354,6 +354,7 @@ def snapshot_dir_is_complete( *, allow_patterns: "Optional[object]" = None, ignore_patterns: "Optional[object]" = None, + require_named_weights: bool = False, ) -> bool: """Best-effort check that a cached snapshot actually holds the requested model weights. @@ -376,7 +377,12 @@ def snapshot_dir_is_complete( specific checkpoint shard) is satisfied only by that weight on disk, not by some other weight the snapshot happens to also carry. A deliberate single-shard request likewise requires only that shard, not its whole ``-of-NNNNN`` set. With no patterns, any loadable - weight does, and every numbered shard set present must be complete.""" + weight does, and every numbered shard set present must be complete. + + *require_named_weights* additionally requires every explicitly named exact weight in + *allow_patterns* (e.g. ``["model.safetensors", "adapter_model.safetensors"]``) to be on + disk, so a stale cache holding only one of them is not treated as complete. Off by default + (used by the pre-download cache probe); a glob still selects a subset freely.""" if snapshot_dir_has_broken_symlinks(snapshot_dir): return False try: @@ -411,6 +417,24 @@ def snapshot_dir_is_complete( if not selected: return False + # A request that explicitly names exact weight files (e.g. a base model plus a PEFT + # adapter, ["model.safetensors", "adapter_model.safetensors"]) needs EACH of them, not + # just one: a stale cache holding only the first must not be accepted. Enforced only when + # the caller asks (the pre-download cache short-circuit), so the post-download check stays + # lenient and never errors when a named weight simply does not exist in the repo (an + # "either format" list like ["pytorch_model.bin", "model.safetensors"] against a + # safetensors-only repo). A glob may legitimately select a subset, so only concrete + # filenames are required, and one the ignore filter drops is not actually requested. + if require_named_weights and allow_patterns: + present_rels = [rel for _, rel in weight_entries] + for pat in allow_patterns: + if _has_glob(pat) or not str(pat).lower().endswith(_WEIGHT_FILE_SUFFIXES): + continue + if ignore_patterns and not _filter_paths([pat], None, ignore_patterns): + continue + if not _filter_paths(present_rels, [pat], None): + return False + # Every selected numbered shard needs the sibling shards the request also selects (the # whole set when unpatterned), so an interrupted multi-shard pull is not read as warm. for entry, rel in weight_entries: @@ -491,11 +515,35 @@ def _looks_like_checkpoint_dir(pattern: str) -> bool: return any(lowered.startswith(prefix) for prefix in _CHECKPOINT_DIR_PREFIXES) +def _bracket_member(content: str) -> str: + """A single character that a glob ``[...]`` class *matches*, for concretizing a bracket + expression into a probe that still satisfies the caller's own pattern. ``[0-9]`` -> ``0``, + ``[a-z]`` -> ``a``; a negated class (``[!...]`` / ``[^...]``) -> a filler the class does + not exclude. Replacing the class with a non-member (a literal ``x`` for ``[0-9]``) would + make the probe fail the caller's pattern and misread the request as weightless.""" + negated = content[:1] in ("!", "^") + if not negated: + # The first listed item is a member: a literal char, or the low end of a leading range. + return content[0] if content else "x" + # Negated: pick a filler the class does not exclude (fnmatch mirrors HF's matcher). + try: + import fnmatch + + cls = "[" + content + "]" + for cand in ("x", "0", "a", "z", "9", "_", "-", "A"): + if fnmatch.fnmatch(cand, cls): + return cand + except Exception: + pass + return "x" + + def _concretize_glob(pattern: str) -> str: """Replace glob wildcards in *pattern* with a literal filler so it can stand in as a concrete directory name (e.g. ``checkpoint-*`` -> ``checkpoint-x``). A ``[...]`` class - collapses to one filler char. Used to probe weights nested under a no-slash directory - glob, since Hugging Face's ``fnmatch`` ``*`` spans ``/``.""" + collapses to one member char it actually matches (so the probe still satisfies the + pattern). Used to probe weights nested under a no-slash directory glob, since Hugging + Face's ``fnmatch`` ``*`` spans ``/``.""" out = [] i = 0 n = len(pattern) @@ -506,8 +554,12 @@ def _concretize_glob(pattern: str) -> str: i += 1 elif ch == "[": j = pattern.find("]", i + 1) - out.append("x") - i = (j + 1) if j != -1 else (i + 1) + if j != -1: + out.append(_bracket_member(pattern[i + 1 : j])) + i = j + 1 + else: + out.append("x") # unterminated class: treat "[" as a literal filler + i += 1 else: out.append(ch) i += 1 @@ -567,16 +619,25 @@ def request_can_include_weights( # A bare concrete weight filename (e.g. a specific non-first shard). if pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): probes.append(pat) - elif "." not in pat or _looks_like_checkpoint_dir(pat): - # A no-slash DIRECTORY glob (e.g. "checkpoint-*", "global_step*", or the dotted - # "checkpoint-v1.*"). HF's fnmatch "*" spans "/", so it matches nested weights like - # checkpoint-10/model.safetensors. Probe the canonical weights re-rooted under a - # concretized form of the glob, so the request is recognized as weight-including - # (still subject to ignore_patterns). A no-slash FILE glob with an extension - # ("*.json", "adapter_model.*", "tokenizer.*") is handled by the canonical / - # representative probes, so it is excluded here and stays correctly weightless. - concrete = _concretize_glob(pat) - probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) + else: + # A no-slash glob. Two weight-including shapes: + # - a DIRECTORY glob ("checkpoint-*", "global_step*", or the dotted + # "checkpoint-v1.*"): HF's fnmatch "*" spans "/", so it matches nested weights + # like checkpoint-10/model.safetensors. Probe the canonical weights re-rooted + # under a concretized form of the glob. + # - a FILE glob naming a weight by suffix ("lora_*.safetensors", "*.bin", + # "model-*.safetensors"): its stem need not match any canonical probe name, so + # add a concretized self-probe so it is recognized rather than misread as + # weightless. + # A plain extension file glob with a non-weight suffix ("*.json", "tokenizer.*") + # matches no probe here and stays correctly weightless; "adapter_model.*" still + # rides the canonical adapter probe. Everything is subject to the final + # ignore_patterns filter below. + if "." not in pat or _looks_like_checkpoint_dir(pat): + concrete = _concretize_glob(pat) + probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) + if pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): + probes.append(_concretize_glob(pat)) try: kept = list( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 9bfb384fd..0526eb827 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -750,7 +750,8 @@ def _run_download_attempt( def _snapshot_is_acceptable( - snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any + snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, + require_named_weights: bool = False, ) -> bool: """Whether a cached / downloaded snapshot dir is complete enough to use, scoped to the caller's intent. @@ -769,10 +770,16 @@ def _snapshot_is_acceptable( The completeness check is scoped to the requested patterns, so a request for a specific weight (e.g. ``allow_patterns=["adapter_model.safetensors"]`` or a checkpoint shard) is - satisfied only when THAT weight is on disk, not by some other weight already cached.""" + satisfied only when THAT weight is on disk, not by some other weight already cached. + + ``require_named_weights`` makes a request that explicitly names multiple exact weights + require each of them on disk (set on the pre-download cache probe so a stale snapshot + missing one is not short-circuited; left off post-download so a named weight that simply + does not exist in the repo never turns a finished download into a spurious failure).""" if repo_type == "model" and request_can_include_weights(allow_patterns, ignore_patterns): return snapshot_dir_is_complete( - snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + require_named_weights = require_named_weights, ) return not snapshot_dir_has_broken_symlinks(snapshot_dir) @@ -1094,6 +1101,9 @@ def snapshot_download_with_xet_fallback( repo_type = repo_type, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + # Pre-download short-circuit: require each explicitly named weight so a stale + # snapshot missing one (base present, adapter not) is completed, not accepted. + require_named_weights = True, ): return cached_dir logger.debug("Cached snapshot for %s is incomplete; downloading.", repo_id) From 8c0fd24a37797b651b2420d2fc219159e4e27246 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 27 Jun 2026 09:02:51 +0000 Subject: [PATCH 27/82] Probe path-qualified weight globs, skip trainer artifacts, retry transient Xet errors Three more edge cases in the weight-detection and transport logic. request_can_include_weights: a path-qualified custom weight glob like checkpoint-10/lora_*.safetensors (or a globbed parent, checkpoint-*/lora_*.bin) only got canonical probes re-rooted under the parent, none of which match a lora_ basename, so it read as weightless and a stale snapshot without those weights could be accepted. The slash branch now also adds a concretized self-probe for the full pattern. The self-probe is shared by the bare-filename and no-slash-glob branches via a new _weight_self_probe helper. _weight_self_probe gates on _is_loadable_weight_file, so a request that names a trainer artifact (optimizer.pt, training_args.bin, scheduler.pt, rng_state_*.pth) no longer reads as weight-including just because the suffix is weight-like. That matches snapshot_dir_is_complete, which filters those basenames out as non-weights; previously a child that fetched exactly the requested artifact was retried over HTTP and could then fail as an incomplete snapshot. hf_xet_fallback: a transient Xet transport failure (an hf_xet / CAS error, connection reset, timeout, HTTP 5xx / 429) was raised immediately like a deterministic Hub error (401 / 404 / disk full), so the HTTP fallback was never tried even though disabling Xet could succeed. The child now classifies the captured exception (where the status code, errno and type are intact) and reports a transient failure as "retryable_error"; the parent retries it over HTTP once, then surfaces it if HTTP also fails. Deterministic and unknown errors stay non-retried, so a real repeatable failure is not looped between transports. Tests cover each: path-qualified custom and bracket globs, trainer-artifact requests staying weightless and acceptable, the transient-vs-deterministic classifier, and the retry-then-HTTP / raise-on-both paths. --- tests/test_hf_xet_fallback.py | 87 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 83 +++++++++++++++++--------------- unsloth_zoo/hf_xet_fallback.py | 87 +++++++++++++++++++++++++++++++--- 3 files changed, 212 insertions(+), 45 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index e475662b0..144c15fd2 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -12,6 +12,7 @@ from __future__ import annotations +import errno import importlib.util import json import os @@ -722,6 +723,60 @@ def test_crashed_child_on_both_transports_raises(monkeypatch): assert [c.disable_xet for c in fake.calls] == [False, True] +def test_retryable_xet_error_retries_over_http(monkeypatch): + """A transient Xet transport failure (CAS timeout / 5xx) is not a deterministic Hub error, + so it retries over HTTP; a clean HTTP result is accepted (Codex #829).""" + fake = _install(monkeypatch, [("retryable_error", "CasClientError: request timed out"), ("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/cache/x" + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_retryable_xet_error_on_both_transports_raises(monkeypatch): + """A transient error on Xet AND on HTTP has no other transport left, so it surfaces after + both attempts rather than looping (Codex #829).""" + fake = _install(monkeypatch, [("retryable_error", "503 Server Error"), ("retryable_error", "503 Server Error")]) + with pytest.raises(RuntimeError, match = "503"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_is_retryable_download_error_classification(): + """Transient transport failures (hf_xet / CAS, timeout, reset, HTTP 5xx / 429) are + retryable; deterministic Hub errors (not-found, gated, 4xx, disk full) and unknown errors + are not, so a real repeatable failure is surfaced rather than looped (Codex #829).""" + f = xf._is_retryable_download_error + + # Transient transport failures -> retryable. + assert f(Exception("hf_xet download failed: data processing error")) is True + assert f(TimeoutError("connection reset by peer")) is True + assert f(Exception("CasClientError: deadline exceeded")) is True + + class _Resp503(Exception): + response = type("R", (), {"status_code": 503})() + + assert f(_Resp503("server error")) is True + + class _Status429(Exception): + status_code = 429 + + assert f(_Status429("Too Many Requests")) is True + + # Deterministic Hub errors -> not retryable (matched by class name or 4xx status). + class RepositoryNotFoundError(Exception): + pass + + assert f(RepositoryNotFoundError("404 Client Error")) is False + + class _Resp404(Exception): + response = type("R", (), {"status_code": 404})() + + assert f(_Resp404("not found")) is False + assert f(OSError(errno.ENOSPC, "No space left on device")) is False + # An unknown / generic error stays deterministic -> surfaced, not looped over transports. + assert f(ValueError("unexpected response payload")) is False + + def test_immediate_success_uses_xet_only(monkeypatch): prepared = [] monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a, **k: prepared.append(a)) @@ -1591,6 +1646,38 @@ def test_request_can_include_weights_bracket_globs(): assert not hcs._concretize_glob("checkpoint-[!0-9]").endswith(tuple("0123456789")) +def test_request_can_include_weights_path_qualified_custom_globs(): + """A path-qualified custom weight glob (checkpoint-10/lora_*.safetensors, with a globbed + parent too) names a weight whose basename matches no canonical probe; it must read as + weight-including via a concretized self-probe, not weightless (Codex #829).""" + assert hcs.request_can_include_weights(["checkpoint-10/lora_*.safetensors"], None) is True + assert hcs.request_can_include_weights(["checkpoint-*/lora_*.bin"], None) is True + assert hcs.request_can_include_weights(["models/custom_*.pt"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/model-[0-9].safetensors"], None) is True + # A non-weight basename under a subfolder stays weightless. + assert hcs.request_can_include_weights(["checkpoint-10/tokenizer.json"], None) is False + + +def test_request_can_include_weights_trainer_artifacts_weightless(tmp_path): + """A trainer / optimizer artifact (optimizer.pt, training_args.bin, rng_state_*.pth) carries + a weight suffix but is not a loadable weight: request_can_include_weights must read it as + weightless, matching snapshot_dir_is_complete -- otherwise a child that fetches exactly that + artifact loops as an 'incomplete snapshot' (Codex #829).""" + for art in ["optimizer.pt", "training_args.bin", "scheduler.pt", "rng_state_0.pth", + "checkpoint-10/optimizer.pt"]: + assert hcs.request_can_include_weights([art], None) is False, art + # Consistency: a snapshot holding only the requested artifact is acceptable (weightless + # path), so the guarded download is not retried into an incomplete-snapshot failure. + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "optimizer.pt").symlink_to(blob) + assert xf._snapshot_is_acceptable( + snap, repo_type = "model", allow_patterns = ["optimizer.pt"], ignore_patterns = None + ) is True + + def test_request_can_include_weights_string_form(): """Hugging Face accepts allow / ignore patterns as a bare string; it must be treated as one pattern, not iterated character by character (which would misclassify a subfolder diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 4c7349407..db276f3a9 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -566,6 +566,24 @@ def _concretize_glob(pattern: str) -> str: return "".join(out) +def _weight_self_probe(pattern: str) -> "Optional[str]": + """A concretized stand-in for *pattern* when it names a loadable model weight by suffix + (``lora_*.safetensors`` -> ``lora_x.safetensors``, ``checkpoint-10/lora_*.bin`` -> + ``checkpoint-10/lora_x.bin``, a bare ``model-00002-of-00005.safetensors``), so a custom + weight basename that matches no canonical probe is still recognized. Returns None when the + suffix is not a weight suffix, or when the (concretized) basename is a known trainer / + optimizer artifact (``optimizer.pt``, ``training_args.bin``, ``rng_state_*.pth``): those + carry weight suffixes but the snapshot completeness check filters them out as non-weights, + so classifying such a request as weight-including would loop the guarded download.""" + if not pattern.lower().endswith(_WEIGHT_FILE_SUFFIXES): + return None + concrete = _concretize_glob(pattern) + basename = concrete.rsplit("/", 1)[-1] + if not _is_loadable_weight_file(basename): + return None + return concrete + + def request_can_include_weights( allow_patterns: "Optional[list]" = None, ignore_patterns: "Optional[list]" = None ) -> bool: @@ -597,47 +615,34 @@ def request_can_include_weights( probes = list(_WEIGHT_PROBE_NAMES) for pat in (allow_patterns or ()): + # A concretized stand-in when the pattern itself names a loadable weight by suffix + # (lora_*.safetensors, checkpoint-10/lora_*.bin, a bare non-first shard). None for a + # non-weight suffix and for a known trainer artifact (optimizer.pt, training_args.bin), + # keeping this consistent with the snapshot completeness check. + self_probe = _weight_self_probe(pat) if "/" in pat: + # Path-qualified: re-root the canonical weight probes under the parent dir + # (concretized when the parent itself is globbed) so the request is checked inside + # that directory, not only at the root. No early return -- the final filter still + # applies ignore_patterns (e.g. allow=["checkpoint-*/*"] + an all-weights ignore). prefix = pat.rsplit("/", 1)[0] - if _has_glob(prefix): - # Globbed parent dir (e.g. "checkpoint-*/*.safetensors" or - # "checkpoint-*/adapter_model.*"): re-root the canonical weight probes under a - # concretized form of the parent and let the final filter decide. This both - # recognizes a weight-targeting basename and still applies ignore_patterns -- - # an early True here would skip the ignores and wrongly require weights for, - # e.g., allow=["checkpoint-*/*"] + ignore that drops every weight format. - concrete = _concretize_glob(prefix) - probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) - continue - # Concrete parent dir: re-root the canonical weight probes under it so a - # path-qualified request is checked inside that directory, not only at the root. - probes.extend(f"{prefix}/{name}" for name in _WEIGHT_PROBE_NAMES) - # A subfolder-qualified concrete weight filename is also a probe verbatim. - if not _has_glob(pat) and pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): - probes.append(pat) - elif not _has_glob(pat): - # A bare concrete weight filename (e.g. a specific non-first shard). - if pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): - probes.append(pat) - else: - # A no-slash glob. Two weight-including shapes: - # - a DIRECTORY glob ("checkpoint-*", "global_step*", or the dotted - # "checkpoint-v1.*"): HF's fnmatch "*" spans "/", so it matches nested weights - # like checkpoint-10/model.safetensors. Probe the canonical weights re-rooted - # under a concretized form of the glob. - # - a FILE glob naming a weight by suffix ("lora_*.safetensors", "*.bin", - # "model-*.safetensors"): its stem need not match any canonical probe name, so - # add a concretized self-probe so it is recognized rather than misread as - # weightless. - # A plain extension file glob with a non-weight suffix ("*.json", "tokenizer.*") - # matches no probe here and stays correctly weightless; "adapter_model.*" still - # rides the canonical adapter probe. Everything is subject to the final - # ignore_patterns filter below. - if "." not in pat or _looks_like_checkpoint_dir(pat): - concrete = _concretize_glob(pat) - probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) - if pat.lower().endswith(_WEIGHT_FILE_SUFFIXES): - probes.append(_concretize_glob(pat)) + concrete_parent = _concretize_glob(prefix) if _has_glob(prefix) else prefix + probes.extend(f"{concrete_parent}/{name}" for name in _WEIGHT_PROBE_NAMES) + elif _has_glob(pat) and ("." not in pat or _looks_like_checkpoint_dir(pat)): + # A no-slash DIRECTORY glob ("checkpoint-*", "global_step*", the dotted + # "checkpoint-v1.*"): HF's fnmatch "*" spans "/", so it matches nested weights like + # checkpoint-10/model.safetensors. Probe the canonical weights re-rooted under a + # concretized form of the glob. A plain extension file glob ("*.json", "tokenizer.*") + # is not a directory glob and stays weightless unless it names a weight (self_probe). + concrete = _concretize_glob(pat) + probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) + # A pattern that itself names a loadable weight -- a bare filename, a path-qualified + # name, or a weight-suffix glob whose stem matches no canonical probe (lora_*.safetensors, + # checkpoint-*/lora_*.bin) -- is recognized via its self-probe. "adapter_model.*" rides + # the canonical adapter probe instead, and a trainer artifact yields no self-probe and + # stays weightless. Everything is subject to the final ignore_patterns filter below. + if self_probe is not None: + probes.append(self_probe) try: kept = list( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 0526eb827..95ad9d828 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -33,6 +33,7 @@ from __future__ import annotations +import errno import importlib.util import multiprocessing as mp import logging @@ -439,6 +440,52 @@ def _scrub_in_child(text: str, token: Optional[str]) -> str: return _default_scrub_secrets(text, hf_token = token) +# Deterministic Hub failures that recur identically over either transport, so switching from +# Xet to HTTP is pointless: surface them. Matched by exception class name so the parent need +# not import huggingface_hub's error classes. +_DETERMINISTIC_ERROR_NAMES = frozenset({ + "RepositoryNotFoundError", + "RevisionNotFoundError", + "EntryNotFoundError", + "GatedRepoError", + "DisabledRepoError", + "LocalEntryNotFoundError", + "BadRequestError", +}) +# Substrings that mark a transient transport failure (hf_xet / CAS error, timeout, reset, +# HTTP 5xx / 429) that disabling Xet and retrying over HTTP may recover. +_TRANSIENT_ERROR_HINTS = ( + "xet", "casclient", "cas_", "timeout", "timed out", "connection", "reset by peer", + "temporarily", "try again", "incompleteread", "protocolerror", "remotedisconnected", + "broken pipe", "ssl", "eof occurred", "502", "503", "504", "500 server", "429", + "too many requests", "service unavailable", "bad gateway", "gateway time", + "connection aborted", +) + + +def _is_retryable_download_error(exc: BaseException) -> bool: + """True when a captured download exception looks like a transient transport failure (an + ``hf_xet`` / CAS error, connection reset, timeout, HTTP 5xx / 429) that the OTHER transport + may recover, vs a deterministic Hub error (auth, not-found, gated, disk-full) that would + fail identically. Unknown errors are treated as deterministic, so a real repeatable failure + is surfaced rather than looped between transports.""" + name = type(exc).__name__ + if name in _DETERMINISTIC_ERROR_NAMES: + return False + # Disk full / quota: a different transport cannot help. + if isinstance(exc, OSError) and getattr(exc, "errno", None) in (errno.ENOSPC, errno.EDQUOT): + return False + # An HTTP status (HfHubHTTPError carries a requests / httpx response): 5xx and 429 are + # transient; other 4xx (401 / 403 / 404 / 416) are deterministic. + status = getattr(getattr(exc, "response", None), "status_code", None) + if not isinstance(status, int): + status = getattr(exc, "status_code", None) + if isinstance(status, int): + return status >= 500 or status == 429 + text = f"{name}: {exc}".lower() + return any(hint in text for hint in _TRANSIENT_ERROR_HINTS) + + def _child_download(*, kind: str, params: dict, token: Optional[str], repo_type: str) -> str: """Run the actual HF download for one attempt inside the spawn child.""" if kind == "snapshot": @@ -533,7 +580,14 @@ def _download_child_entry( path = _child_download(kind = kind, params = params, token = token, repo_type = repo_type) result_queue.put({"ok": True, "path": path}) except BaseException as e: # noqa: BLE001 - report every failure to the parent - result_queue.put({"ok": False, "error": _scrub_in_child(f"{type(e).__name__}: {e}", token)}) + # Classify here, where the exception object (status code, errno, type) is intact, so the + # parent can retry a transient Xet transport failure over HTTP and still surface a + # deterministic Hub error without a pointless second attempt. + result_queue.put({ + "ok": False, + "error": _scrub_in_child(f"{type(e).__name__}: {e}", token), + "retryable": _is_retryable_download_error(e), + }) def _terminate_process_group(proc: "mp.process.BaseProcess", grace_period: float) -> None: @@ -585,8 +639,11 @@ def _run_download_attempt( ) -> tuple[str, Optional[str]]: """Run one download in a spawn child supervised by the no-progress watchdog. - Returns ``("ok", path)``, ``("stall", None)``, ``("cancelled", None)``, or - ``("error", message)``. This is the seam tests monkeypatch to avoid spawning. + Returns ``("ok", path)``, ``("stall", None)``, ``("cancelled", None)``, + ``("crashed", message)`` (process-level crash, no captured exception), + ``("retryable_error", message)`` (a transient Xet transport failure worth an HTTP retry), + or ``("error", message)`` (a deterministic Hub error). This is the seam tests monkeypatch + to avoid spawning. """ # A single-file download scopes its stall detection to its own child's partials. # Capture the partials already on disk for this repo BEFORE spawning, so the watchdog @@ -746,7 +803,11 @@ def _run_download_attempt( ) if result.get("ok"): return ("ok", result["path"]) - return ("error", result.get("error") or "unknown download error") + message = result.get("error") or "unknown download error" + if result.get("retryable"): + # A transient transport failure the child flagged as worth another transport. + return ("retryable_error", message) + return ("error", message) def _snapshot_is_acceptable( @@ -899,8 +960,22 @@ def _download_with_xet_fallback( if kind_result == "cancelled": raise RuntimeError("Cancelled") if kind_result == "error": - # Deterministic failure (a captured Hub exception): the other transport would - # fail identically, so do not retry. + # Deterministic failure (a captured Hub exception: auth, not-found, gated, disk + # full): the other transport would fail identically, so do not retry. + raise RuntimeError(payload) + if kind_result == "retryable_error": + # A transient transport failure (hf_xet CAS timeout, 5xx, connection reset) rather + # than a deterministic Hub error: disabling Xet and retrying over HTTP may recover, + # so try the other transport once before surfacing it (mirrors the crash / stall + # paths). If HTTP also failed, there is no other transport left -- raise. + if not disable_xet: + logger.warning( + "Download for '%s' hit a transient Xet transport error -- retrying " + "with HF_HUB_DISABLE_XET=1: %s", label, payload + ) + _safe_status(on_status, f"{label}: transient Xet error, retrying over HTTP") + disable_xet = True + continue raise RuntimeError(payload) if kind_result == "crashed": # A process-level crash with no captured exception: HTTP may still succeed, so From f612061272e5d78380650a40c9cd833468d87f7e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 27 Jun 2026 09:52:48 +0000 Subject: [PATCH 28/82] Harden weight detection for empty filters and subfolders; clarify the single-file watchdog Addresses a fresh review round (three findings) plus a parallel deep review of the same code. request_can_include_weights / snapshot_dir_is_complete: an empty allow list is now a real filter. Hugging Face treats allow_patterns=[] as selecting NO objects, so the request is weightless; both functions previously collapsed [] with None (an unfiltered warmup) and would reject a legitimately empty snapshot, looping the guarded download. They now distinguish allow_patterns is None from []. Path-qualified subfolder requests are reclassified conservatively. A bare catch-all under a subfolder (component/quant dirs such as unet/*, transformer/*, original/*, mp_rank_*/*, BF16/*) holds weights, so it must stay weight-including: reading an unknown subfolder as weightless would accept a stale config-only cache and re-open the silent Xet hang this module prevents. The re-rooting of the canonical weight probes now defaults ON and is skipped only when the request is clearly non-weight: a non-weight basename glob (*.json, tokenizer.*, *.txt) or a catch-all under a KNOWN auxiliary dir (tokenizer/, runs/, logs/, wandb/) that does not itself target a weight. tokenizer/* stays weightless; unet/* and unknown subfolders do not. The strict named-weight presence check (require_named_weights, the pre-download short-circuit) now uses a direct membership test instead of the fail-open _filter_paths helper, so an unevaluable filter requires the guarded download rather than silently accepting a stale cache. hf_xet_fallback: the single-file watchdog keeps its precise per-child fd scoping. An empty child open-set is treated as "the child owns no partial yet" (the connect / metadata phase), not as a helper-owned partial: hf_xet writes in-process and holds the .incomplete fd continuously, so attributing a concurrent sibling's post-baseline partial to a still-connecting child would false-stall and needlessly retry it. A transient HTTP 408 now retries over HTTP alongside 5xx / 429 (416 and other 4xx stay deterministic). The download child's result queue is closed deterministically, and a result that raced in during the same poll window the watchdog fired is preferred over a late stall. hf_xet_fallback also gains an __all__ for its Studio-facing import-* surface. Tests cover each: empty allow/ignore lists, weight-bearing vs auxiliary subfolders (with an end-to-end stale config-only BF16/ cache that must not short-circuit), the basename weight-classification helpers, the empty open-set not firing on a sibling partial, and the 408 / 416 boundary. --- tests/test_hf_xet_fallback.py | 115 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 106 ++++++++++++++++++++++++++---- unsloth_zoo/hf_xet_fallback.py | 45 +++++++++++-- 3 files changed, 250 insertions(+), 16 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 144c15fd2..d0c3dd825 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -329,6 +329,30 @@ def _grow(): child.wait(timeout = 5) +def test_file_watchdog_empty_open_set_ignores_sibling(hf_cache, monkeypatch): + """hf_xet writes in-process and holds its .incomplete fd continuously, so an EMPTY child + open-set means the child owns no partial YET (the connect / metadata phase), NOT that a + helper process owns one. A concurrent sibling's post-baseline partial must therefore NOT be + attributed to a still-connecting child -- otherwise a stalled sibling would kill it and force + a needless HTTP retry. The precise empty-set branch owns nothing, so no stall fires.""" + blobs = _blobs_dir(hf_cache) + # A sibling partial created after baseline (not name-excluded), constant (stalled). + (blobs / "sibling.incomplete").write_bytes(b"\0" * 4096) + monkeypatch.setattr(xf, "_child_open_incomplete_blobs", lambda pid: set()) # child owns none + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.2, + watch_new_partials_only = True, baseline_incomplete_blobs = set(), + child_pid = 4242, # non-None so the precise child-open path is taken + ) + try: + time.sleep(0.8) + assert calls == [], "watchdog falsely fired on a sibling partial the child does not own" + finally: + stop.set() + + def test_get_state_empty_cache(hf_cache): assert xf.get_hf_download_state([REPO]) == (0, False) @@ -762,7 +786,16 @@ class _Status429(Exception): assert f(_Status429("Too Many Requests")) is True + class _Status408(Exception): + status_code = 408 + + assert f(_Status408("Request Timeout")) is True # 408 is transient + # Deterministic Hub errors -> not retryable (matched by class name or 4xx status). + class _Status416(Exception): + status_code = 416 + + assert f(_Status416("Range Not Satisfiable")) is False # a retry would fail identically class RepositoryNotFoundError(Exception): pass @@ -1678,6 +1711,88 @@ def test_request_can_include_weights_trainer_artifacts_weightless(tmp_path): ) is True +def test_request_can_include_weights_empty_allow_list(tmp_path): + """allow_patterns=[] is a real filter that selects NO objects (Hugging Face semantics), so + the request is weightless -- it must not collapse with None (an unfiltered warmup) and + reject a legitimately empty snapshot (Codex #829). ignore_patterns=[] ignores nothing, so + it stays weight-including.""" + assert hcs.request_can_include_weights([], None) is False + assert hcs.request_can_include_weights(None, None) is True + assert hcs.request_can_include_weights(None, []) is True + assert hcs.request_can_include_weights([], []) is False + # snapshot_dir_is_complete agrees: allow=[] is a scoped (select-nothing) request, not a full + # warmup, so a snapshot carrying an unrelated weight is not read as complete for it. + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = []) is False + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = None) is True + + +def test_request_can_include_weights_non_weight_subfolders(): + """A generic glob under a plain (non-checkpoint) subfolder such as tokenizer/* or runs/* + must read as weightless -- the unconditional canonical re-rooting would otherwise add a + synthetic tokenizer/model.safetensors probe and misclassify a tokenizer-only download as a + model warmup (Codex #829). A checkpoint/weight dir or a weight-targeting basename still + includes weights.""" + assert hcs.request_can_include_weights(["tokenizer/*"], None) is False + assert hcs.request_can_include_weights(["runs/*"], None) is False + assert hcs.request_can_include_weights(["logs/*.txt"], None) is False + # Weight-bearing cases stay weight-including. + assert hcs.request_can_include_weights(["checkpoint-10/*"], None) is True + assert hcs.request_can_include_weights(["*/model.*"], None) is True + assert hcs.request_can_include_weights(["models/*.safetensors"], None) is True + # A weight-suffix basename under a plain subfolder is still recognized (self-probe). + assert hcs.request_can_include_weights(["tokenizer/*.safetensors"], None) is True + # A checkpoint dir nested anywhere in the parent path counts. + assert hcs.request_can_include_weights(["runs/checkpoint-5/*"], None) is True + + +def test_request_can_include_weights_weight_bearing_subfolders(tmp_path): + """A component / quant subfolder (unet/, transformer/, original/, BF16/, Q8_0/) holds + weights, so a bare catch-all under it must stay weight-including -- reading an unknown + subfolder as weightless would accept a stale config-only cache and re-open the silent Xet + hang. Only KNOWN auxiliary dirs (tokenizer/, runs/) are weightless (Codex #829).""" + for d in ["unet/*", "transformer/*", "text_encoder/*", "vae/*", "original/*", + "mp_rank_00/*", "BF16/*", "Q8_0/*", "Q4_K_M/*", "unknown_component/*"]: + assert hcs.request_can_include_weights([d], None) is True, d + # End to end: a stale config-only BF16/ snapshot (weight missing, no dangling symlink) must + # NOT be short-circuited as warm -- the guarded download still runs. + snap = tmp_path / "snap" + (snap / "BF16").mkdir(parents = True) + (snap / "BF16" / "config.json").write_text("{}") # config only, no weight + assert xf._snapshot_is_acceptable( + snap, repo_type = "model", allow_patterns = ["BF16/*"], ignore_patterns = None, + require_named_weights = True, + ) is False + # Once the weight is present, it is acceptable. + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "BF16" / "model.safetensors").symlink_to(blob) + assert xf._snapshot_is_acceptable( + snap, repo_type = "model", allow_patterns = ["BF16/*"], ignore_patterns = None, + ) is True + + +def test_basename_weight_classification_helpers(): + """Lock the catch-all-vs-weight distinction the subfolder gating rests on: a weight-stem + basename targets a weight, a config / tokenizer glob is clearly non-weight, and a bare + catch-all ('*') is NEITHER (so it defaults to weight-including under an unknown dir).""" + tw, nw = hcs._basename_targets_weight, hcs._basename_is_non_weight + assert tw("model.*") is True and nw("model.*") is False + assert tw("*.safetensors") is True and nw("*.safetensors") is False + assert tw("adapter_model.*") is True + assert tw("*.json") is False and nw("*.json") is True + assert tw("tokenizer.*") is False and nw("tokenizer.*") is True + assert tw("config.json") is False and nw("config.json") is True + # A catch-all matches both a weight and a non-weight representative -> neither classifier. + assert tw("*") is False and nw("*") is False + # *.bin matches a weight (pytorch_model.bin) AND a non-weight (training_args.bin) -> neither. + assert tw("*.bin") is False and nw("*.bin") is False + + def test_request_can_include_weights_string_form(): """Hugging Face accepts allow / ignore patterns as a bare string; it must be treated as one pattern, not iterated character by character (which would misclassify a subfolder diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index db276f3a9..0b21d3b62 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -23,6 +23,7 @@ from __future__ import annotations +import fnmatch import re import sys from pathlib import Path @@ -391,7 +392,10 @@ def snapshot_dir_is_complete( return False allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) - has_patterns = bool(allow_patterns or ignore_patterns) + # An empty allow list is a real (select-nothing) filter, not "unpatterned": treat any + # non-None patterns as a scoped request so allow_patterns=[] does not fall into the full + # warmup branch (consistent with request_can_include_weights). + has_patterns = allow_patterns is not None or ignore_patterns is not None index_entries: list = [] weight_entries: list = [] # (entry, repo-relative path) @@ -426,13 +430,17 @@ def snapshot_dir_is_complete( # safetensors-only repo). A glob may legitimately select a subset, so only concrete # filenames are required, and one the ignore filter drops is not actually requested. if require_named_weights and allow_patterns: - present_rels = [rel for _, rel in weight_entries] + present_rels = set(rel for _, rel in weight_entries) for pat in allow_patterns: if _has_glob(pat) or not str(pat).lower().endswith(_WEIGHT_FILE_SUFFIXES): continue if ignore_patterns and not _filter_paths([pat], None, ignore_patterns): continue - if not _filter_paths(present_rels, [pat], None): + # pat is a concrete (glob-free) weight path, so presence is an exact match. A direct + # membership test (not _filter_paths, which fails OPEN by returning all paths on a + # filter error) keeps this strict check fail-SAFE: an unevaluable case requires the + # guarded download rather than silently accepting a stale cache as warm. + if pat not in present_rels: return False # Every selected numbered shard needs the sibling shards the request also selects (the @@ -527,8 +535,6 @@ def _bracket_member(content: str) -> str: return content[0] if content else "x" # Negated: pick a filler the class does not exclude (fnmatch mirrors HF's matcher). try: - import fnmatch - cls = "[" + content + "]" for cand in ("x", "0", "a", "z", "9", "_", "-", "A"): if fnmatch.fnmatch(cand, cls): @@ -566,6 +572,69 @@ def _concretize_glob(pattern: str) -> str: return "".join(out) +# Representative NON-weight files a catch-all ("*") or a config / tokenizer glob ("*.json") +# would also match -- used to tell a weight-specific basename (model.*, *.safetensors) from a +# catch-all when deciding whether a path-qualified request under a plain subfolder targets +# weights. Not exhaustive; just enough common names to disqualify a non-weight glob. +_NON_WEIGHT_PROBE_NAMES = ( + "config.json", + "tokenizer.json", + "tokenizer.model", + "tokenizer_config.json", + "special_tokens_map.json", + "generation_config.json", + "preprocessor_config.json", + "vocab.json", + "merges.txt", + "readme.md", + "training_args.bin", + "optimizer.pt", +) + + +# Subfolders that, by convention, hold only auxiliary / telemetry files -- never model weights. +# A catch-all glob under one of these (tokenizer/*, runs/*) is read as weightless. Kept +# deliberately narrow: an unknown subfolder (unet/, transformer/, original/, a new arch's +# component dir) must stay weight-including, so a weight-bearing dir is never misread as +# weightless (that would re-open the silent-Xet-hang accept-stale this module exists to prevent). +_NON_WEIGHT_DIR_NAMES = frozenset({ + "tokenizer", "runs", "run", "logs", "log", "samples", "sample", "tensorboard", "tb", + "events", "eval", "evals", "evaluation", "metrics", "wandb", "assets", "images", "media", +}) + + +def _basename_targets_weight(basename: str) -> bool: + """True when a path-qualified pattern's basename specifically selects a model weight + (``model.*``, ``adapter_model.*``, ``*.safetensors``), so the request is weight-including + even under a non-weight parent dir. A catch-all (``*``) matches both weights and non-weights + and a non-weight glob (``*.json``) matches no weight, so neither counts.""" + base = basename.lower() + if not any(fnmatch.fnmatchcase(name, base) for name in _WEIGHT_PROBE_NAMES): + return False + return not any(fnmatch.fnmatchcase(name, base) for name in _NON_WEIGHT_PROBE_NAMES) + + +def _basename_is_non_weight(basename: str) -> bool: + """True when a path-qualified pattern's basename clearly selects only non-weight files + (``*.json``, ``config.json``, ``tokenizer.*``, ``*.txt``): it matches a known non-weight + representative but no weight name. A catch-all (``*``) matches a weight too, so it is NOT + clearly non-weight and stays weight-including (the parent dir may hold weights).""" + base = basename.lower() + if not any(fnmatch.fnmatchcase(name, base) for name in _NON_WEIGHT_PROBE_NAMES): + return False + return not any(fnmatch.fnmatchcase(name, base) for name in _WEIGHT_PROBE_NAMES) + + +def _parent_is_non_weight_dir(prefix: str) -> bool: + """True when *prefix* is a known auxiliary / telemetry dir (tokenizer/, runs/, logs/) and no + component looks like a checkpoint / weight dir, so a catch-all glob under it holds no weights. + An unknown subfolder returns False (stays weight-including) to avoid accept-stale.""" + parts = [p.lower() for p in prefix.split("/") if p] + if any(_looks_like_checkpoint_dir(p) for p in parts): + return False + return any(p in _NON_WEIGHT_DIR_NAMES for p in parts) + + def _weight_self_probe(pattern: str) -> "Optional[str]": """A concretized stand-in for *pattern* when it names a loadable model weight by suffix (``lora_*.safetensors`` -> ``lora_x.safetensors``, ``checkpoint-10/lora_*.bin`` -> @@ -606,7 +675,11 @@ def request_can_include_weights( Hugging Face itself accepts.""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) - if not allow_patterns and not ignore_patterns: + # Only a truly unfiltered request (both None) is an unconditional weight warmup. An empty + # allow list is NOT None: Hugging Face's filter_repo_objects treats allow_patterns=[] as + # selecting NO objects, so the request is weightless -- collapsing [] with None here would + # reject a legitimately empty snapshot and loop the guarded download. + if allow_patterns is None and ignore_patterns is None: return True try: from huggingface_hub.utils import filter_repo_objects @@ -622,12 +695,21 @@ def request_can_include_weights( self_probe = _weight_self_probe(pat) if "/" in pat: # Path-qualified: re-root the canonical weight probes under the parent dir - # (concretized when the parent itself is globbed) so the request is checked inside - # that directory, not only at the root. No early return -- the final filter still - # applies ignore_patterns (e.g. allow=["checkpoint-*/*"] + an all-weights ignore). - prefix = pat.rsplit("/", 1)[0] - concrete_parent = _concretize_glob(prefix) if _has_glob(prefix) else prefix - probes.extend(f"{concrete_parent}/{name}" for name in _WEIGHT_PROBE_NAMES) + # (concretized when the parent is globbed) so the request is checked inside that + # directory. Default to re-rooting (weight-including), because an unknown subfolder + # (unet/, transformer/, original/, mp_rank_*/) may hold weights and reading it as + # weightless would accept a stale config-only cache -> the silent Xet hang. Skip the + # re-root only when the request is clearly non-weight: a non-weight basename glob + # (*.json, tokenizer.*, *.txt), or a catch-all under a known auxiliary dir + # (tokenizer/*, runs/*) that does not itself target a weight. A weight-suffix + # basename is still recognized by self_probe below; the final filter applies ignores. + prefix, base = pat.rsplit("/", 1) + clearly_weightless = _basename_is_non_weight(base) or ( + _parent_is_non_weight_dir(prefix) and not _basename_targets_weight(base) + ) + if not clearly_weightless: + concrete_parent = _concretize_glob(prefix) if _has_glob(prefix) else prefix + probes.extend(f"{concrete_parent}/{name}" for name in _WEIGHT_PROBE_NAMES) elif _has_glob(pat) and ("." not in pat or _looks_like_checkpoint_dir(pat)): # A no-slash DIRECTORY glob ("checkpoint-*", "global_step*", the dotted # "checkpoint-v1.*"): HF's fnmatch "*" spans "/", so it matches nested weights like diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 95ad9d828..3f24f5698 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -60,6 +60,19 @@ logger = logging.getLogger(__name__) +# Public surface (Studio imports from this module, including a `import *` re-export shim), so +# an explicit list keeps the stdlib imports (os, re, signal, errno, ...) out of `import *`. +__all__ = [ + "DownloadStallError", + "hf_hub_download_with_xet_fallback", + "snapshot_download_with_xet_fallback", + "start_watchdog", + "get_hf_download_state", + "is_hf_xet_available", + "xet_force_disabled", + "child_should_disable_xet", +] + _CTX = mp.get_context("spawn") # Defaults match the existing Studio inference watchdog and hub shutdown deadline. @@ -378,11 +391,17 @@ def _measure() -> Optional[tuple[int, bool]]: sizes = _active_incomplete_blob_sizes(repo_type, single_repo_id, cache_dir) open_names = _child_open_incomplete_blobs(child_pid) if child_pid else None if open_names is not None: - # Precise: only the partials this child holds open (handles a resumed - # partial that reuses a baseline blob-hash name, and excludes siblings). + # Precise: only the partials this child process holds open (handles a resumed + # partial that reuses a baseline blob-hash name, and excludes siblings). hf_xet + # writes in-process and holds the .incomplete fd continuously, so an EMPTY set + # here means the child owns no partial YET (the connect / metadata phase), NOT + # that a helper process owns one -- it must own nothing this tick, so a stalled + # sibling's post-baseline partial cannot be misattributed and kill a connecting + # child. owned = {name: n for name, n in sizes.items() if name in open_names} else: - # Fallback (no psutil / no /proc): follow only newly-created partials. + # Cannot inspect the child (no psutil / no /proc): best-effort fall back to + # following only newly-created partials (not in the pre-spawn baseline). owned = {name: n for name, n in sizes.items() if name not in baseline} return (sum(owned.values()), len(owned) > 0) return get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) @@ -481,7 +500,9 @@ def _is_retryable_download_error(exc: BaseException) -> bool: if not isinstance(status, int): status = getattr(exc, "status_code", None) if isinstance(status, int): - return status >= 500 or status == 429 + # 5xx server errors, 429 rate-limit, 408 request-timeout are transient; other 4xx + # (401 / 403 / 404 / 416) are deterministic and would fail identically over HTTP. + return status >= 500 or status in (408, 429) text = f"{name}: {exc}".lower() return any(hint in text for hint in _TRANSIENT_ERROR_HINTS) @@ -764,6 +785,14 @@ def _run_download_attempt( _terminate_process_group(proc, grace_period) return ("cancelled", None) if stalled.is_set(): + # Prefer a result the child enqueued in the same ~interval window the watchdog + # fired in over a late stall, so a download that just succeeded is not killed and + # needlessly retried over HTTP. + try: + result = result_queue.get_nowait() + break + except queue.Empty: + pass _terminate_process_group(proc, grace_period) return ("stall", None) try: @@ -790,6 +819,14 @@ def _run_download_attempt( # cancel/stall branch already terminated it is a harmless no-op. if proc.is_alive(): _terminate_process_group(proc, grace_period) + # Release the queue's pipe fds deterministically rather than waiting for GC (which is + # fragile when the child was killed mid-put). The result, if any, is already extracted, + # and a killed child has nothing more to flush, so cancel the feeder join before close. + try: + result_queue.cancel_join_thread() + result_queue.close() + except Exception: + pass if result is None: # The child exited without enqueuing a result: a process-level crash (e.g. a native From 69a777e4c0d5dd6547fd3bf58ac563392ef34098 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 27 Jun 2026 12:44:33 +0000 Subject: [PATCH 29/82] Treat Diffusers config-only components as weightless; reject checkpoint-only snapshots as a warm root model Two more weight-detection edge cases from a fresh review round. request_can_include_weights: a catch-all under a Diffusers pipeline component that ships only config / vocab files (scheduler/*, feature_extractor/*, the extra tokenizers tokenizer_2/ tokenizer_3/) was given synthetic weight probes and read as weight-including, so a child that correctly fetched only scheduler/scheduler_config.json was then rejected by _snapshot_is_acceptable for lacking weights and looped the guarded download. These component dirs are added to the auxiliary-dir set, so a catch-all under them reads as weightless. The weight-bearing pipeline dirs (unet/, transformer/, vae/, text_encoder*/, image_encoder/, safety_checker/) are deliberately left out, so a catch-all under them stays weight-including. snapshot_dir_is_complete: an unpatterned warm is a bare from_pretrained, which reads ROOT model weights. A weight that lives only inside a per-checkpoint dir (checkpoint-500/model.safetensors, global_step1000/pytorch_model.bin, left behind by a prior allow_patterns=["checkpoint-500/*"] pull) is not a root weight, so it must not make a checkpoint-only snapshot read as a warm root model. The unpatterned branch now excludes checkpoint-dir weights (and their shard indexes) from the completeness check via a new _path_under_checkpoint_dir helper, so the guarded download still runs rather than handing from_pretrained a snapshot whose root weights are missing. A patterned checkpoint request (allow_patterns=["checkpoint-500/*"]) still resolves against that checkpoint's own weights, and a root weight present alongside the checkpoint dirs still completes the warm. Tests cover both: scheduler/ and the other config-only components reading as weightless (with an end-to-end config-only scheduler/ snapshot that is acceptable), the weight-bearing components staying weight-including, a checkpoint-only cache not completing an unpatterned warm while a patterned checkpoint request does, a global_step dir treated the same, and an incomplete checkpoint shard index not gating a complete root warm. --- tests/test_hf_xet_fallback.py | 130 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 47 +++++++++--- 2 files changed, 168 insertions(+), 9 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index d0c3dd825..bec402742 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1363,6 +1363,51 @@ def test_snapshot_dir_is_complete_single_shard_request(tmp_path): assert hcs.snapshot_dir_is_complete(snap) is False +def test_snapshot_dir_is_complete_checkpoint_only_not_warm_root(tmp_path): + """An unpatterned (root-model) warm is not satisfied by a weight that lives only inside a + per-checkpoint dir. A cache left by a prior allow_patterns=["checkpoint-10/*"] pull holds + checkpoint-10/model.safetensors but no root weight; reading it as a warm root model would let + the guarded download be skipped and hand from_pretrained a snapshot whose root weights are + missing (Codex #829). A root weight (or a patterned checkpoint request) still completes.""" + snap = tmp_path / "snap" + (snap / "checkpoint-10").mkdir(parents = True) + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "checkpoint-10" / "model.safetensors").symlink_to(blob) # checkpoint weight only + (snap / "config.json").write_text("{}") + # Unpatterned root warm: the checkpoint weight does not count -> incomplete. + assert hcs.snapshot_dir_is_complete(snap) is False + # A patterned request for that checkpoint IS satisfied by it (not a root warm). + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is True + # Once a root weight is present, the unpatterned warm completes (checkpoint weight ignored). + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + # A DeepSpeed-style global_step dir is treated the same way. + snap2 = tmp_path / "snap2" + (snap2 / "global_step500").mkdir(parents = True) + (snap2 / "global_step500" / "pytorch_model.bin").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap2) is False + + +def test_snapshot_dir_is_complete_checkpoint_index_does_not_gate_root(tmp_path): + """A per-checkpoint shard index with missing shards must not fail an unpatterned root warm: + the root weights are what the load reads, so an incomplete checkpoint index is irrelevant to + root completeness (and a complete root weight set is enough).""" + snap = tmp_path / "snap" + (snap / "checkpoint-7").mkdir(parents = True) + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(blob) # complete root weight + # An incomplete checkpoint shard index (shard 2 missing) lives under checkpoint-7/. + (snap / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "checkpoint-7" / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}}) + ) + # Root warm is complete: the checkpoint index is skipped, the root weight suffices. + assert hcs.snapshot_dir_is_complete(snap) is True + + def test_snapshot_dir_is_complete_requires_each_named_weight(tmp_path): """require_named_weights makes a request naming multiple exact weights (base + adapter) need EACH on disk, so the pre-download cache probe does not short-circuit a stale snapshot @@ -1776,6 +1821,91 @@ def test_request_can_include_weights_weight_bearing_subfolders(tmp_path): ) is True +def test_request_can_include_weights_diffusers_config_only_components(tmp_path): + """The Diffusers pipeline components that ship only *_config.json / vocab files + (scheduler/, feature_extractor/, tokenizer_2/, tokenizer_3/) must read as weightless, so a + catch-all like scheduler/* is not given synthetic weight probes and a child that correctly + fetches only scheduler/scheduler_config.json is not rejected for lacking weights (Codex + #829). The weight-bearing pipeline dirs stay weight-including.""" + for d in ["scheduler/*", "feature_extractor/*", "tokenizer_2/*", "tokenizer_3/*"]: + assert hcs.request_can_include_weights([d], None) is False, d + # The weight-bearing pipeline components are NOT in the weightless set. + for d in ["unet/*", "transformer/*", "vae/*", "text_encoder/*", "text_encoder_2/*", + "image_encoder/*", "safety_checker/*"]: + assert hcs.request_can_include_weights([d], None) is True, d + # End to end: a config-only scheduler/ snapshot is acceptable for a scheduler/* request + # (no weight expected there), so the guarded download is not looped on it. + snap = tmp_path / "snap" + (snap / "scheduler").mkdir(parents = True) + (snap / "scheduler" / "scheduler_config.json").write_text("{}") + assert xf._snapshot_is_acceptable( + snap, repo_type = "model", allow_patterns = ["scheduler/*"], ignore_patterns = None, + require_named_weights = True, + ) is True + + +def test_consumer_pattern_lists_accepted_end_to_end(tmp_path): + """Lock the cross-repo contract: the EXACT allow / ignore lists unsloth's + maybe_prefetch_hf_snapshot emits must be judged correctly by this module's acceptance, so a + future drift between the two repos cannot silently loop the guarded download. These lists + mirror unsloth's _ADAPTER_PREFETCH_PATTERNS / _ROOT_AUX_PREFETCH_PATTERNS / _SUBDIR_WEIGHT_ + IGNORE_PATTERNS; if unsloth changes them, this is where the mismatch surfaces.""" + blob = tmp_path / "blob" + blob.write_bytes(b"x") + + # --- adapter_only: allow = adapter files + root aux, ignore = None --- + root_aux = [ + "config.json", "generation_config.json", "tokenizer_config.json", "tokenizer.json", + "tokenizer.model", "special_tokens_map.json", "added_tokens.json", "vocab.json", + "vocab.txt", "merges.txt", "spiece.model", "chat_template.jinja", "chat_template.json", + "preprocessor_config.json", "processor_config.json", "configuration_*.py", "modeling_*.py", + "tokenization_*.py", "processing_*.py", "image_processing_*.py", "feature_extraction_*.py", + "video_processing_*.py", "*.tiktoken", + ] + adapter_allow = ["adapter_config.json", "adapter_model*", *root_aux] + assert hcs.request_can_include_weights(adapter_allow, None) is True + snap = tmp_path / "adapter" + snap.mkdir() + (snap / "adapter_config.json").write_text("{}") + (snap / "adapter_model.safetensors").symlink_to(blob) + (snap / "config.json").write_text("{}") + (snap / "tokenizer.json").write_text("{}") + # A merged full-model weight the adapter warm never requested is present but irrelevant. + (snap / "model.safetensors").symlink_to(blob) + assert xf._snapshot_is_acceptable( + snap, repo_type = "model", allow_patterns = adapter_allow, ignore_patterns = None, + require_named_weights = True, + ) is True + # An adapter cache missing its weight (config only) is NOT acceptable -> guarded download. + snap_bad = tmp_path / "adapter_bad" + snap_bad.mkdir() + (snap_bad / "adapter_config.json").write_text("{}") + assert xf._snapshot_is_acceptable( + snap_bad, repo_type = "model", allow_patterns = adapter_allow, ignore_patterns = None, + ) is False + + # --- weights_at_root: allow = None, ignore = static skips + subdir-weight excludes --- + root_ignore = ["*.onnx", "onnx/*", "*.gguf", "checkpoint-*/*", "*/*.safetensors", "*/*.bin"] + assert hcs.request_can_include_weights(None, root_ignore) is True + rsnap = tmp_path / "root" + rsnap.mkdir() + (rsnap / "config.json").write_text("{}") + (rsnap / "model.safetensors").symlink_to(blob) # root weight present + (rsnap / "fp16").mkdir() + (rsnap / "fp16" / "model.safetensors").symlink_to(blob) # subdir weight (unread by root load) + assert xf._snapshot_is_acceptable( + rsnap, repo_type = "model", allow_patterns = None, ignore_patterns = root_ignore, + ) is True + # A subdir-only cache (no root weight) is NOT acceptable for a root load. + rsnap_bad = tmp_path / "root_bad" + (rsnap_bad / "fp16").mkdir(parents = True) + (rsnap_bad / "config.json").write_text("{}") + (rsnap_bad / "fp16" / "model.safetensors").symlink_to(blob) + assert xf._snapshot_is_acceptable( + rsnap_bad, repo_type = "model", allow_patterns = None, ignore_patterns = root_ignore, + ) is False + + def test_basename_weight_classification_helpers(): """Lock the catch-all-vs-weight distinction the subfolder gating rests on: a weight-stem basename targets a weight, a config / tokenizer glob is clearly non-weight, and a bare diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 0b21d3b62..3a8cdb109 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -417,7 +417,12 @@ def snapshot_dir_is_complete( if has_patterns: selected = set(_filter_paths([rel for _, rel in weight_entries], allow_patterns, ignore_patterns)) else: - selected = {rel for _, rel in weight_entries} + # Unpatterned warm = a bare from_pretrained, which reads ROOT model weights. A weight that + # lives only inside a per-checkpoint dir (checkpoint-500/model.safetensors, left behind by + # a prior allow_patterns=["checkpoint-500/*"] pull) is not a root weight, so it must not + # make a checkpoint-only snapshot read as a warm root model -- that would let the guarded + # download be skipped and hand from_pretrained a snapshot whose root weights are missing. + selected = {rel for _, rel in weight_entries if not _path_under_checkpoint_dir(rel)} if not selected: return False @@ -455,9 +460,17 @@ def snapshot_dir_is_complete( return False # A full (unpatterned) warm also validates any shard index ships all its shards; a - # patterned request may legitimately want only a subset, so the index is not enforced. + # patterned request may legitimately want only a subset, so the index is not enforced. A + # per-checkpoint index (checkpoint-500/model.safetensors.index.json) does not gate a root + # warm for the same reason its weights do not, so it is skipped here too. if not has_patterns: for index_entry in index_entries: + try: + index_rel = index_entry.relative_to(snapshot_dir).as_posix() + except ValueError: + index_rel = index_entry.name + if _path_under_checkpoint_dir(index_rel): + continue if not _weight_shard_index_complete(index_entry): return False return True @@ -523,6 +536,17 @@ def _looks_like_checkpoint_dir(pattern: str) -> bool: return any(lowered.startswith(prefix) for prefix in _CHECKPOINT_DIR_PREFIXES) +def _path_under_checkpoint_dir(rel: str) -> bool: + """True when a repo-relative *rel* lives inside a per-checkpoint directory + (``checkpoint-500/model.safetensors``, ``global_step1000/pytorch_model.bin``). Only the + PARENT components are checked -- the final component is the filename itself. Used to keep a + checkpoint-dir weight from satisfying an unpatterned (root-model) warmup: such a weight is + what a prior ``allow_patterns=["checkpoint-500/*"]`` pull leaves behind, not the root weight + a bare ``from_pretrained`` reads.""" + parts = rel.split("/") + return any(_looks_like_checkpoint_dir(p) for p in parts[:-1] if p) + + def _bracket_member(content: str) -> str: """A single character that a glob ``[...]`` class *matches*, for concretizing a bracket expression into a probe that still satisfies the caller's own pattern. ``[0-9]`` -> ``0``, @@ -592,14 +616,19 @@ def _concretize_glob(pattern: str) -> str: ) -# Subfolders that, by convention, hold only auxiliary / telemetry files -- never model weights. -# A catch-all glob under one of these (tokenizer/*, runs/*) is read as weightless. Kept -# deliberately narrow: an unknown subfolder (unet/, transformer/, original/, a new arch's -# component dir) must stay weight-including, so a weight-bearing dir is never misread as -# weightless (that would re-open the silent-Xet-hang accept-stale this module exists to prevent). +# Subfolders that, by convention, hold only auxiliary / telemetry / config files -- never model +# weights. A catch-all glob under one of these (tokenizer/*, runs/*, scheduler/*) is read as +# weightless. Kept deliberately narrow: an unknown subfolder (unet/, transformer/, original/, a +# new arch's component dir) must stay weight-including, so a weight-bearing dir is never misread +# as weightless (that would re-open the silent-Xet-hang accept-stale this module exists to +# prevent). The Diffusers pipeline components listed here (scheduler/, feature_extractor/, the +# extra tokenizers) ship only *_config.json / vocab files; the weight-bearing pipeline dirs +# (unet/, transformer/, vae/, text_encoder*/, image_encoder/, safety_checker/) are deliberately +# absent so a catch-all under them stays weight-including. _NON_WEIGHT_DIR_NAMES = frozenset({ - "tokenizer", "runs", "run", "logs", "log", "samples", "sample", "tensorboard", "tb", - "events", "eval", "evals", "evaluation", "metrics", "wandb", "assets", "images", "media", + "tokenizer", "tokenizer_2", "tokenizer_3", "runs", "run", "logs", "log", "samples", "sample", + "tensorboard", "tb", "events", "eval", "evals", "evaluation", "metrics", "wandb", "assets", + "images", "media", "scheduler", "feature_extractor", }) From ae398c609ef31b989596532e4117c89839c40a20 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 03:06:19 +0000 Subject: [PATCH 30/82] Require named non-weight files in completeness; honor cancel before the warm-cache shortcut Two more review findings on the shared Xet fallback. snapshot_dir_is_complete: an exact-file request (no globs) that names a non-weight alongside a weight -- ["model.safetensors", "tokenizer.json"] -- now requires the non-weight on disk too. The check previously validated only the weights, so a stale cache holding just the weight short- circuited past the guarded download that should still fetch the explicitly named tokenizer / config. WHICH names are required depends on the request shape: an exact-file list names every file it wants, so each concrete name (weight or non-weight) must be present; a list containing ANY glob is a broad "warm what matches" selection where named aux files are best-effort (an optional vocab.txt / spiece.model the repo may simply lack), so only its concrete weight names are required. Requiring every optional aux file there would defeat the warm-cache short-circuit for normal repos (unsloth's adapter / tokenizer / subfolder warms always carry globs, so they keep short-circuiting on a warm cache). Enforced only at the pre-download probe, so the post-download check stays lenient. hf_xet_fallback: both wrappers now honor an already-set cancel_event BEFORE their offline and warm-cache short-circuits. Those short-circuits return without reaching _download_with_xet_fallback, which held the only other cancellation check, so a request cancelled before a cached probe (a Studio / FastModel flow where cancellation can arrive before the warm-cache step) previously resolved and handed back the cached file / snapshot instead of cancelling. Tests cover the named non-weight requirement (exact list requires the tokenizer, a glob-bearing list does not), and cancellation being honored even when the file / snapshot is already cached. --- tests/test_hf_xet_fallback.py | 57 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 43 ++++++++++++++++++------- unsloth_zoo/hf_xet_fallback.py | 12 +++++++ 3 files changed, 100 insertions(+), 12 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index bec402742..6912a4096 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -721,6 +721,37 @@ def test_cancel_before_start_raises_no_attempt(monkeypatch): assert fake.calls == [] +def test_cancel_honored_even_when_file_cached(monkeypatch, tmp_path): + """A cancel_event set before the call must raise even when the file is ALREADY cached: the + warm-cache short-circuit returns without reaching _download_with_xet_fallback (the other + cancel check), so it must honor cancellation first rather than hand back the cached path + (Codex #829).""" + cached = tmp_path / "cached.gguf" + cached.write_bytes(b"\0" * 8) + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: str(cached)) + fake = _install(monkeypatch, []) # the attempt must never run + ev = threading.Event() + ev.set() + with pytest.raises(RuntimeError, match = "Cancelled"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cancel_event = ev) + assert fake.calls == [] + + +def test_snapshot_cancel_honored_even_when_cached(monkeypatch, tmp_path): + """The snapshot wrapper must also honor a pre-set cancel before its warm-cache short-circuit, + so a cancelled request does not resolve a cached snapshot (Codex #829).""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "model.safetensors").write_bytes(b"x") + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + fake = _install(monkeypatch, []) # the attempt must never run + ev = threading.Event() + ev.set() + with pytest.raises(RuntimeError, match = "Cancelled"): + xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, cancel_event = ev) + assert fake.calls == [] + + def test_nonstall_error_propagates_without_fallback(monkeypatch): fake = _install(monkeypatch, [("error", "RepositoryNotFoundError: 404 not found")]) with pytest.raises(RuntimeError, match = "RepositoryNotFoundError"): @@ -1442,6 +1473,32 @@ def test_snapshot_dir_is_complete_requires_each_named_weight(tmp_path): assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"], require_named_weights = True) is True +def test_snapshot_dir_is_complete_requires_named_non_weight_exact_only(tmp_path): + """An EXACT-file request (no globs) naming a non-weight alongside a weight requires the + non-weight on disk too: a stale cache holding only the weight must not short-circuit past the + guarded download that should still fetch the explicitly named tokenizer / config (Codex #829). + A request containing ANY glob is instead a broad selection where aux files are best-effort, so + only its concrete weights are required -- keeping unsloth's glob-bearing adapter / tokenizer + warms able to short-circuit on a warm cache rather than re-downloading on every load.""" + blob = tmp_path / "blob" + blob.write_bytes(b"x") + snap = tmp_path / "snap" + snap.mkdir() + (snap / "model.safetensors").symlink_to(blob) # weight present; tokenizer.json missing + pair = ["model.safetensors", "tokenizer.json"] + # Strict (pre-download): the named tokenizer.json is missing -> incomplete -> guarded download. + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is False + # Lenient (post-download): a present selected weight suffices (no error on a possibly-absent name). + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = False) is True + # Once the named non-weight is on disk, strict is satisfied. + (snap / "tokenizer.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is True + # A list containing ANY glob is a broad warm: optional aux names are NOT required, only weights. + (snap / "tokenizer.json").unlink() + globbed = ["model.safetensors", "tokenizer.json", "modeling_*.py"] + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globbed, require_named_weights = True) is True + + def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): """HF's local_files_only returns a config-only snapshot (e.g. left by an earlier AutoConfig fetch) without checking weights. The fast path must reject it and complete diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 3a8cdb109..815fa87c1 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -426,26 +426,45 @@ def snapshot_dir_is_complete( if not selected: return False - # A request that explicitly names exact weight files (e.g. a base model plus a PEFT - # adapter, ["model.safetensors", "adapter_model.safetensors"]) needs EACH of them, not - # just one: a stale cache holding only the first must not be accepted. Enforced only when - # the caller asks (the pre-download cache short-circuit), so the post-download check stays - # lenient and never errors when a named weight simply does not exist in the repo (an - # "either format" list like ["pytorch_model.bin", "model.safetensors"] against a - # safetensors-only repo). A glob may legitimately select a subset, so only concrete - # filenames are required, and one the ignore filter drops is not actually requested. + # A request that explicitly names exact files needs EACH of them on disk, not just one, so a + # stale cache holding a subset is not short-circuited past the guarded download. WHICH names + # are required depends on the request shape: + # * An exact-file request (no globs) -- ["model.safetensors", "tokenizer.json"], or a base + # plus a PEFT adapter ["model.safetensors", "adapter_model.safetensors"] -- names every + # file it wants, so each concrete name (weight OR non-weight) must be present. A cache with + # just the weight must not accept-warm while the named tokenizer / config is still missing. + # * A request containing ANY glob is a broad "warm what matches" selection where named aux + # files are best-effort (an optional vocab.txt / spiece.model the repo may simply lack), so + # only its concrete WEIGHT names are required -- demanding every optional aux file there + # would defeat the warm-cache short-circuit for normal repos. + # Enforced only at the pre-download probe (require_named_weights), so the post-download check + # stays lenient and never errors on an "either format" name (["pytorch_model.bin", + # "model.safetensors"] against a safetensors-only repo) that does not exist in the repo. A name + # the ignore filter drops is not actually requested. if require_named_weights and allow_patterns: - present_rels = set(rel for _, rel in weight_entries) + exact_only = not any(_has_glob(p) for p in allow_patterns) + if exact_only: + present = set() + for entry in entries: + if _safe_is_file(entry): + try: + present.add(entry.relative_to(snapshot_dir).as_posix()) + except ValueError: + present.add(entry.name) + else: + present = set(rel for _, rel in weight_entries) for pat in allow_patterns: - if _has_glob(pat) or not str(pat).lower().endswith(_WEIGHT_FILE_SUFFIXES): + if _has_glob(pat): + continue + if not exact_only and not str(pat).lower().endswith(_WEIGHT_FILE_SUFFIXES): continue if ignore_patterns and not _filter_paths([pat], None, ignore_patterns): continue - # pat is a concrete (glob-free) weight path, so presence is an exact match. A direct + # pat is a concrete (glob-free) path, so presence is an exact match. A direct # membership test (not _filter_paths, which fails OPEN by returning all paths on a # filter error) keeps this strict check fail-SAFE: an unevaluable case requires the # guarded download rather than silently accepting a stale cache as warm. - if pat not in present_rels: + if pat not in present: return False # Every selected numbered shard needs the sibling shards the request also selects (the diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 3f24f5698..da85b4d35 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1081,6 +1081,12 @@ def hf_hub_download_with_xet_fallback( # dirs are normalized too, since HF accepts pathlib.Path. if isinstance(cache_dir, (str, os.PathLike)): cache_dir = os.path.expanduser(os.fspath(cache_dir)) + # Honor an already-set cancellation before any cache probe or network work. The offline and + # warm-cache short-circuits below return without reaching _download_with_xet_fallback (which + # holds the only other cancel check), so a request cancelled before this point must not + # resolve and hand back a cached file. + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") # Offline: resolve purely from the local cache, never reaching the network. HF # raises LocalEntryNotFoundError if it is not cached; let that propagate. if local_files_only: @@ -1169,6 +1175,12 @@ def snapshot_download_with_xet_fallback( # resolve to the same on-disk cache location. if isinstance(cache_dir, (str, os.PathLike)): cache_dir = os.path.expanduser(os.fspath(cache_dir)) + # Honor an already-set cancellation before any cache probe or network work. The offline and + # warm-cache short-circuits below return without reaching _download_with_xet_fallback (which + # holds the only other cancel check), so a request cancelled before this point must not + # resolve and hand back a snapshot. + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") # Offline: resolve purely from the local cache, never reaching the network. HF # raises if the snapshot is not cached; let that propagate. if local_files_only: From aaf7a2d33dbdbc0fe7eaa1fb5c96d9798e860cde Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 03:56:55 +0000 Subject: [PATCH 31/82] Harden the Xet fallback child lifecycle and malformed-index handling Address the latest review round on the shared helper: - Reap the download child if start_watchdog raises after the child has already spawned, so a thread/FD-exhaustion failure mid-supervision can no longer leak a process. start_watchdog now runs inside the try whose finally reaps the child, and stop_watchdog stays None until it succeeds. - Send the post-grace SIGKILL only while the process-group leader is still alive. Once join() has reaped a leader that exited on SIGTERM, its pid is free and a busy host can recycle it into an unrelated setsid group within the grace window, so a killpg on it would signal the wrong group. hf_xet 1.5.x writes in process and spawns no helper procs, so a reaped leader leaves nothing to clean up. - On a watchdog stall, drain the result queue with a short timeout instead of get_nowait, so a download that just succeeded in the same interval the watchdog fired in is not killed and needlessly retried over HTTP. - Tolerate a non-string (and unhashable) weight_map value in a weight-shard index: filter to strings before de-duplicating, so arbitrary JSON cannot crash the completeness probe. The string shards still gate, so a real missing shard is still detected. - Widen the snapshot-payload path guard and the cache-dir expanduser guard so a malformed input degrades to a safe default instead of raising. - Annotate has_active_incomplete_blobs repo_type as Optional[str]. Tests: child reaped when start_watchdog raises; exported Xet knobs (is_hf_xet_available, xet_force_disabled, child_should_disable_xet); weight-shard index completeness tolerates a non-string weight_map value. --- tests/test_hf_xet_fallback.py | 120 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 16 +++-- unsloth_zoo/hf_xet_fallback.py | 55 ++++++++------- 3 files changed, 164 insertions(+), 27 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 6912a4096..8b1a10608 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2090,3 +2090,123 @@ def test_default_leaves_xet_enabled(): f"without the env var, constants.HF_HUB_DISABLE_XET was not False " f"(rc={proc.returncode}): {proc.stderr}" ) + + +# --------------------------------------------------------------------------- # +# Exported Xet knobs + child-leak safety + malformed-index resilience. +# --------------------------------------------------------------------------- # +def test_xet_availability_and_disable_helpers(monkeypatch): + """The exported Xet knobs: child_should_disable_xet reads the per-worker config flag; + xet_force_disabled honors every documented env knob; is_hf_xet_available reflects the + importability of hf_xet and treats a probe error as 'unavailable'.""" + assert xf.child_should_disable_xet({"disable_xet": True}) is True + assert xf.child_should_disable_xet({"disable_xet": False}) is False + assert xf.child_should_disable_xet({}) is False + + for knob in ("UNSLOTH_DISABLE_XET", "UNSLOTH_STABLE_DOWNLOADS", "HF_HUB_DISABLE_XET"): + for k in ("UNSLOTH_DISABLE_XET", "UNSLOTH_STABLE_DOWNLOADS", "HF_HUB_DISABLE_XET"): + monkeypatch.delenv(k, raising = False) + assert xf.xet_force_disabled() is False + monkeypatch.setenv(knob, "1") + assert xf.xet_force_disabled() is True, knob + + monkeypatch.setattr(xf.importlib.util, "find_spec", lambda name: object()) + assert xf.is_hf_xet_available() is True + monkeypatch.setattr(xf.importlib.util, "find_spec", lambda name: None) + assert xf.is_hf_xet_available() is False + + def _raise(name): + raise ImportError("boom") + + monkeypatch.setattr(xf.importlib.util, "find_spec", _raise) + assert xf.is_hf_xet_available() is False # a probe exception -> treated as unavailable + + +def test_run_attempt_terminates_child_if_watchdog_start_raises(monkeypatch): + """If start_watchdog raises (e.g. thread/FD exhaustion: 'can't start new thread') AFTER the + download child has already spawned, the child must STILL be reaped -- no leaked process. The + error then propagates (a watchdog-start failure is not a transport fault to retry over HTTP).""" + rec = {"terminated": False} + + class _AliveProc: + def __init__(self): + self.pid = None # None -> _terminate_process_group skips killpg, uses terminate() + self.exitcode = None + self._alive = True + + def start(self): + pass + + def is_alive(self): + return self._alive + + def terminate(self): + rec["terminated"] = True + self._alive = False + + def kill(self): + rec["terminated"] = True + self._alive = False + + def join(self, timeout = None): + pass + + class _Ctx: + def Process(self, *, target = None, kwargs = None, daemon = None): + return _AliveProc() + + def Queue(self): + return _FakeQueue({"ok": True, "path": "/cache/x"}) + + monkeypatch.setattr(xf, "_CTX", _Ctx()) + + def _boom(*a, **k): + raise RuntimeError("can't start new thread") + + monkeypatch.setattr(xf, "start_watchdog", _boom) + + with pytest.raises(RuntimeError, match = "can't start new thread"): + xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.05, on_status = None, + ) + assert rec["terminated"] is True # child reaped despite the watchdog-start failure + + +def test_snapshot_dir_is_complete_tolerates_non_string_shard(tmp_path): + """A weight index whose ``weight_map`` carries a non-string value (malformed / arbitrary JSON) + must not crash the completeness probe: the bad entry is skipped, the string shards still + gate, so a real missing shard is still detected (Codex #829). Uses a non-numbered weight name + so only the index path -- not the numbered-shard-name expansion -- is exercised.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps( + { + "weight_map": { + "a": "model.safetensors", + "b": ["not", "a", "string"], # malformed entry -> skipped, no crash + } + } + ) + ) + # The one concrete file it names is present; the malformed entry is ignored, so no crash and + # the snapshot reads as complete (only demonstrably-missing string shards reject). + assert hcs.snapshot_dir_is_complete(snap) is True + # A genuinely missing string shard still gates, with the malformed entry still skipped. + (snap / "model.safetensors.index.json").write_text( + json.dumps( + { + "weight_map": { + "a": "model.safetensors", + "b": "absent-extra.safetensors", + "c": {"bad": "object"}, + } + } + ) + ) + assert hcs.snapshot_dir_is_complete(snap) is False # 'absent-extra' missing, bad entry skipped diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 815fa87c1..8f13446b0 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -79,7 +79,12 @@ def hf_cache_root(*, create: bool = False, cache_dir: "Optional[str | Path]" = N if cache_dir is not None: # Match huggingface_hub, which expands ~ before writing; scanning the # literal path would otherwise miss a partial under e.g. ~/hf-cache. - root = Path(cache_dir).expanduser() + # Path.expanduser() raises RuntimeError when no home can be resolved (a restricted + # container with HOME unset); fall back to the literal path rather than crash the probe. + try: + root = Path(cache_dir).expanduser() + except (RuntimeError, OSError): + root = Path(cache_dir) else: try: from huggingface_hub import constants as hf_constants @@ -339,9 +344,12 @@ def _weight_shard_index_complete(index_path: Path) -> bool: weight_map = data.get("weight_map") if isinstance(data, dict) else None if not isinstance(weight_map, dict): return True - # weight_map values are filenames relative to the index file's own directory. + # weight_map values are filenames relative to the index file's own directory. They come from + # arbitrary JSON: a non-string (e.g. list/dict) value is both unhashable -- so it would break + # set() -- and invalid for ``base / shard``, so filter to strings BEFORE de-duplicating rather + # than crash (consistent with the fail-open parse handling above). base = index_path.parent - for shard in set(weight_map.values()): + for shard in {s for s in weight_map.values() if isinstance(s, str)}: try: if not (base / shard).exists(): return False @@ -864,7 +872,7 @@ def repo_cache_dir_has_incomplete_blobs(repo_dir: Path) -> bool: def has_active_incomplete_blobs( - repo_type: str, repo_id: str, *, cache_dir: "Optional[str | Path]" = None + repo_type: "Optional[str]", repo_id: str, *, cache_dir: "Optional[str | Path]" = None ) -> bool: for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): if repo_cache_dir_has_incomplete_blobs(entry): diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index da85b4d35..b15b0f532 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -635,12 +635,13 @@ def _signal_group(sig: int) -> None: _signal_group(getattr(signal, "SIGTERM", signal.SIGINT)) proc.join(timeout = grace_period) - # Always send the post-grace SIGKILL to the whole group, even if the Python - # leader already exited on SIGTERM: a Xet helper left in the group can keep - # the stalled writer alive while the parent starts HTTP cleanup. killpg on an - # already-dead group is a no-op (ProcessLookupError is caught in _signal_group). - _signal_group(getattr(signal, "SIGKILL", signal.SIGTERM)) + # Post-grace SIGKILL only while the leader is still alive, so its pid (== pgid after setsid) is + # a live target. Once proc.join() reaps a leader that exited on SIGTERM, that pid is free and a + # busy host can recycle it into an unrelated setsid'd group within the grace window -- a + # killpg(pid) would then signal the WRONG group. hf_xet 1.5.x writes in-process and spawns no + # helper procs, so a reaped leader leaves nothing in the group to clean up. if proc.is_alive(): + _signal_group(getattr(signal, "SIGKILL", signal.SIGTERM)) proc.join(timeout = 5.0) @@ -764,22 +765,25 @@ def _run_download_attempt( pass stalled = threading.Event() - stop_watchdog = start_watchdog( - repo_ids = [repo_id], - on_stall = lambda msg: stalled.set(), - repo_type = repo_type, - cache_dir = params.get("cache_dir"), - interval = interval, - stall_timeout = stall_timeout, - xet_disabled = disable_xet, - on_heartbeat = on_status, - watch_new_partials_only = (kind == "file"), - baseline_incomplete_blobs = baseline_partials, - child_pid = proc.pid, - ) - + # start_watchdog creates and starts a thread; if that raises (e.g. "can't start new thread" + # under thread/FD exhaustion), the child already started above must STILL be terminated. So it + # runs inside the try whose finally reaps the child; stop_watchdog stays None until it succeeds. + stop_watchdog = None result: Optional[dict] = None try: + stop_watchdog = start_watchdog( + repo_ids = [repo_id], + on_stall = lambda msg: stalled.set(), + repo_type = repo_type, + cache_dir = params.get("cache_dir"), + interval = interval, + stall_timeout = stall_timeout, + xet_disabled = disable_xet, + on_heartbeat = on_status, + watch_new_partials_only = (kind == "file"), + baseline_incomplete_blobs = baseline_partials, + child_pid = proc.pid, + ) while proc.is_alive(): if cancel_event is not None and cancel_event.is_set(): _terminate_process_group(proc, grace_period) @@ -787,9 +791,11 @@ def _run_download_attempt( if stalled.is_set(): # Prefer a result the child enqueued in the same ~interval window the watchdog # fired in over a late stall, so a download that just succeeded is not killed and - # needlessly retried over HTTP. + # needlessly retried over HTTP. A spawn Queue has a child-side feeder thread, so a + # result put microseconds earlier is not yet readable by get_nowait(); use a short + # timeout (matching the process-exit drain below) to let the pipe flush. try: - result = result_queue.get_nowait() + result = result_queue.get(timeout = 1.0) break except queue.Empty: pass @@ -810,7 +816,8 @@ def _run_download_attempt( except queue.Empty: result = None finally: - stop_watchdog.set() + if stop_watchdog is not None: + stop_watchdog.set() proc.join(timeout = grace_period) # Any exit from the loop -- normal completion, cancel/stall, or an # unexpected exception (e.g. KeyboardInterrupt) -- must not leak the child. @@ -893,7 +900,9 @@ def _snapshot_payload_incomplete( request.""" try: path = Path(payload) - except TypeError: + except (TypeError, ValueError, OSError): + # Non-path payload (unit-test sentinel) or, on Windows, a path with invalid characters + # (ValueError / OSError): trust it rather than reject -- production always returns a real dir. return False try: if not path.is_dir(): From 5ed4779281a3ba6e79c4132327d0cb129560c519 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 04:59:47 +0000 Subject: [PATCH 32/82] Scope snapshot completeness to the requested files and preserve Hub error types Address the latest Codex review round on the shared helper: - Scope the broken-symlink and weight-presence checks to the files a request actually selects. A dangling symlink for an EXCLUDED file -- a stale root model.safetensors under an allow_patterns=["adapter_model.safetensors"] probe whose adapter weight is on disk -- no longer rejects the whole snapshot; a dangle for a REQUESTED file still does. - Treat an ignore-only root warm (ignore_patterns set, no allow_patterns) the same as an unpatterned one: a weight that lives only in a per-checkpoint dir never satisfies it, so a checkpoint-only cache is not read as a warm root model and the guarded download is not skipped. - Treat a no-slash metadata glob (tokenizer*, config*, vocab*, special_tokens*) as a FILE glob, not a weight-bearing directory glob, so a tokenizer*-only warm that fetched tokenizer.json is not rejected for lacking a weight. model* / pytorch_model* / checkpoint-* stay weight-including. - Require an EXACT-named weightless request (allow_patterns=["tokenizer.json"], no globs) to find its named files on disk before the snapshot is treated as warm, so a config-only snapshot dir HF hands back is not accepted for it. A glob-bearing list stays best-effort (an optional vocab.txt the repo may lack must not fail it). - Re-raise a deterministic Hub failure preserving its original exception TYPE (RepositoryNotFoundError / GatedRepoError / OSError ...) across the spawn-process boundary, reconstructed from the child's ": " report, so a caller's typed except clause still matches. Unknown types fall back to RuntimeError. - Stop the test module from permanently replacing sys.modules["unsloth_zoo"] with its loader placeholder, which shadowed the real package for the rest of the pytest process; restore sys.modules after loading the two files under test. Tests: ignore-only checkpoint exclusion; dangling symlink scoped to the request; metadata globs read as weightless; exact-named weightless presence; deterministic OSError type preserved (and unknown -> RuntimeError); cross-test package isolation. --- tests/test_hf_xet_fallback.py | 138 ++++++++++++++++++++++++++++++++- unsloth_zoo/hf_cache_state.py | 123 +++++++++++++++++++++++++---- unsloth_zoo/hf_xet_fallback.py | 58 +++++++++++++- 3 files changed, 298 insertions(+), 21 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 8b1a10608..6fd0fddba 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -39,8 +39,17 @@ def _load(name: str, filename: str): return module -# A package placeholder so ``from unsloth_zoo.hf_cache_state import ...`` inside -# hf_xet_fallback resolves to the file we load below, not the installed package. +# A package placeholder so ``from unsloth_zoo.hf_cache_state import ...`` inside hf_xet_fallback +# resolves to the file we load below, not the installed package. RESTORE sys.modules afterwards: +# leaving the placeholder (and the two submodule entries _load installs) in sys.modules would shadow +# the REAL unsloth_zoo for the rest of the pytest process -- its __init__ never runs -- so a later +# test importing unsloth_zoo (e.g. unsloth_zoo.FORCE_FLOAT32) would fail. The two loaded modules keep +# their own bound references (their intra-package import resolved during exec), so they work after +# the placeholder is removed (Codex #829). +_saved_modules = { + name: sys.modules.get(name) + for name in ("unsloth_zoo", "unsloth_zoo.hf_cache_state", "unsloth_zoo.hf_xet_fallback") +} if "unsloth_zoo" not in sys.modules: _pkg = _types.ModuleType("unsloth_zoo") _pkg.__path__ = [str(_ZOO_DIR)] @@ -49,6 +58,12 @@ def _load(name: str, filename: str): hcs = _load("unsloth_zoo.hf_cache_state", "hf_cache_state.py") xf = _load("unsloth_zoo.hf_xet_fallback", "hf_xet_fallback.py") +for _name, _mod in _saved_modules.items(): + if _mod is None: + sys.modules.pop(_name, None) + else: + sys.modules[_name] = _mod + # Real prep impl, captured before the autouse fixture stubs the module attribute. _REAL_DEFAULT_PREPARE = xf._default_prepare_for_http @@ -754,7 +769,12 @@ def test_snapshot_cancel_honored_even_when_cached(monkeypatch, tmp_path): def test_nonstall_error_propagates_without_fallback(monkeypatch): fake = _install(monkeypatch, [("error", "RepositoryNotFoundError: 404 not found")]) - with pytest.raises(RuntimeError, match = "RepositoryNotFoundError"): + # A deterministic Hub error is re-raised with its ORIGINAL type preserved across the spawn + # boundary (not flattened to a bare RuntimeError), so a caller's typed except clause still + # matches (Codex #829). The parent reconstructs the class from the child's ": ..." prefix. + expected_cls = xf._resolve_exception_class("RepositoryNotFoundError") + assert expected_cls is not None and expected_cls is not RuntimeError + with pytest.raises(expected_cls, match = "RepositoryNotFoundError"): xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) assert len(fake.calls) == 1, "deterministic error must not trigger an HTTP fallback" assert fake.calls[0].disable_xet is False @@ -2210,3 +2230,115 @@ def test_snapshot_dir_is_complete_tolerates_non_string_shard(tmp_path): ) ) assert hcs.snapshot_dir_is_complete(snap) is False # 'absent-extra' missing, bad entry skipped + + +# --------------------------------------------------------------------------- # +# Codex review round: scoped completeness, weightless named files, type preservation. +# --------------------------------------------------------------------------- # +def test_snapshot_complete_ignore_only_root_excludes_checkpoint(tmp_path): + """An IGNORE-ONLY root warm (no allow_patterns, e.g. ignore=['*.onnx']) is still a bare + from_pretrained reading ROOT weights, so a snapshot whose only weight lives in a checkpoint dir + must read as INCOMPLETE rather than short-circuit the guarded download (Codex #829).""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "config.json").write_text("{}") + (snap / "checkpoint-500").mkdir() + (snap / "checkpoint-500" / "model.safetensors").symlink_to(blob) + # ignore-only -> has_patterns True but allow_patterns None: checkpoint weight must not satisfy it. + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.onnx"]) is False + # A real root weight makes the same ignore-only request complete. + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.onnx"]) is True + + +def test_snapshot_complete_ignores_dangling_symlink_outside_request(tmp_path): + """A dangling symlink for a file the request does NOT select must not reject the snapshot: an + allow_patterns=['adapter_model.safetensors'] probe whose adapter weight is on disk stays complete + even with a stale dangling root model.safetensors. A dangle for the REQUESTED file still rejects + (Codex #829).""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "adapter_model.safetensors").symlink_to(blob) + (snap / "model.safetensors").symlink_to(tmp_path / "missing-blob") # dangling, NOT requested + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["adapter_model.safetensors"]) is True + # When the dangling file IS the requested one, the snapshot is incomplete. + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["model.safetensors"]) is False + + +def test_request_can_include_weights_metadata_glob_is_weightless(): + """A no-slash metadata glob (tokenizer*, config*, vocab*, special_tokens*) is a FILE glob, not a + weight-bearing directory glob, so a warm that fetched only tokenizer.json is not rejected for + lacking a weight. 'model*' / 'pytorch_model*' / 'checkpoint-*' stay weight-including (Codex #829).""" + assert hcs.request_can_include_weights(allow_patterns = ["tokenizer*"]) is False + assert hcs.request_can_include_weights(allow_patterns = ["config*"]) is False + assert hcs.request_can_include_weights(allow_patterns = ["vocab*"]) is False + assert hcs.request_can_include_weights(allow_patterns = ["special_tokens*"]) is False + assert hcs.request_can_include_weights(allow_patterns = ["model*"]) is True + assert hcs.request_can_include_weights(allow_patterns = ["pytorch_model*"]) is True + assert hcs.request_can_include_weights(allow_patterns = ["checkpoint-*"]) is True + + +def test_requested_named_files_present_exact_request(tmp_path): + """An EXACT-named weightless request (allow=['tokenizer.json'], no glob) requires its named file + on disk; a config-only snapshot must not pass. A glob list or no allow_patterns is best-effort + (Codex #829).""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "config.json").write_text("{}") + assert hcs.requested_named_files_present(snap, allow_patterns = ["tokenizer.json"]) is False + (snap / "tokenizer.json").write_text("{}") + assert hcs.requested_named_files_present(snap, allow_patterns = ["tokenizer.json"]) is True + # A glob list is best-effort: a missing optional file does not fail it. + assert hcs.requested_named_files_present(snap, allow_patterns = ["tokenizer*", "vocab.txt"]) is True + # No allow_patterns -> trivially satisfied. + assert hcs.requested_named_files_present(snap) is True + # An ignore-filtered name is not actually requested, so its absence does not fail. + assert hcs.requested_named_files_present( + snap, allow_patterns = ["tokenizer.json", "absent.json"], ignore_patterns = ["absent.json"] + ) is True + + +def test_snapshot_acceptable_weightless_requires_named_file(tmp_path): + """End-to-end: _snapshot_is_acceptable for a weightless exact-named request rejects a config-only + cache missing the requested tokenizer.json, so the guarded download is not skipped (Codex #829).""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "config.json").write_text("{}") + assert xf._snapshot_is_acceptable( + snap, repo_type = "model", allow_patterns = ["tokenizer.json"], ignore_patterns = None + ) is False + (snap / "tokenizer.json").write_text("{}") + assert xf._snapshot_is_acceptable( + snap, repo_type = "model", allow_patterns = ["tokenizer.json"], ignore_patterns = None + ) is True + + +def test_deterministic_oserror_type_preserved(monkeypatch): + """A deterministic disk-full OSError is re-raised as OSError (not flattened to RuntimeError), so a + caller's `except OSError` cleanup still runs across the spawn boundary (Codex #829).""" + fake = _install(monkeypatch, [("error", "OSError: [Errno 28] No space left on device")]) + with pytest.raises(OSError, match = "No space left"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" + + +def test_unknown_error_falls_back_to_runtimeerror(monkeypatch): + """An unrecognized error class name still surfaces (as RuntimeError, the prior behavior) without + a transport fallback -- only KNOWN deterministic Hub / OS types are reconstructed (Codex #829).""" + fake = _install(monkeypatch, [("error", "SomeWeirdError: kaboom")]) + with pytest.raises(RuntimeError, match = "kaboom"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1 + + +def test_resolve_exception_class_maps_known_names(): + """The reconstruction map resolves the documented deterministic Hub error names + OSError, and + returns None (-> RuntimeError) for an unknown name (Codex #829).""" + assert xf._resolve_exception_class("OSError") is OSError + cls = xf._resolve_exception_class("RepositoryNotFoundError") + assert cls is not None and issubclass(cls, BaseException) + assert xf._resolve_exception_class("NotARealErrorType") is None diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 8f13446b0..7b009e7fa 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -358,6 +358,42 @@ def _weight_shard_index_complete(index_path: Path) -> bool: return True +def _broken_symlink_rel_paths(snapshot_dir: Path) -> list: + """Repo-relative posix paths of every dangling symlink in *snapshot_dir* -- a referenced file + whose blob is missing or still an ``.incomplete`` partial (an interrupted download). Empty when + none. Lets a completeness check scope the interrupted-download signal to the files a request + actually selects, rather than rejecting the whole snapshot for a dangle outside the request.""" + out: list = [] + try: + for entry in snapshot_dir.rglob("*"): + try: + if entry.is_symlink() and not entry.exists(): + try: + out.append(entry.relative_to(snapshot_dir).as_posix()) + except ValueError: + out.append(entry.name) + except OSError: + continue + except OSError: + return out + return out + + +def _requested_scope_filter( + rels: list, allow_patterns: "Optional[list]", ignore_patterns: "Optional[list]" +) -> list: + """The subset of repo-relative *rels* a request selects. Applies the allow / ignore filter, and + when there is no ``allow_patterns`` (an UNPATTERNED or IGNORE-ONLY request -- a bare + ``from_pretrained`` that reads ROOT weights) also drops per-checkpoint-dir paths the root load + never reads, so a checkpoint-dir file neither satisfies the warm nor (as a dangling symlink) + blocks it. An explicit ``allow_patterns`` is trusted as-is: a caller that names a checkpoint + path opts back into it.""" + kept = _filter_paths(list(rels), allow_patterns, ignore_patterns) + if allow_patterns is None: + kept = [r for r in kept if not _path_under_checkpoint_dir(r)] + return kept + + def snapshot_dir_is_complete( snapshot_dir: Path, *, @@ -392,8 +428,6 @@ def snapshot_dir_is_complete( *allow_patterns* (e.g. ``["model.safetensors", "adapter_model.safetensors"]``) to be on disk, so a stale cache holding only one of them is not treated as complete. Off by default (used by the pre-download cache probe); a glob still selects a subset freely.""" - if snapshot_dir_has_broken_symlinks(snapshot_dir): - return False try: entries = list(snapshot_dir.rglob("*")) except OSError: @@ -405,6 +439,16 @@ def snapshot_dir_is_complete( # warmup branch (consistent with request_can_include_weights). has_patterns = allow_patterns is not None or ignore_patterns is not None + # A dangling symlink marks an interrupted download, but only one for a file the request + # actually selects should reject the snapshot. A stale dangling root model.safetensors must + # not fail an allow_patterns=["adapter_model.safetensors"] probe whose adapter weight IS on + # disk, so scope the broken-symlink check to the requested files (and, for a root warm with no + # allow_patterns, drop checkpoint-dir paths the bare load never reads) -- the same selection + # _requested_scope_filter applies to the weights below. + broken = _broken_symlink_rel_paths(snapshot_dir) + if broken and _requested_scope_filter(broken, allow_patterns, ignore_patterns): + return False + index_entries: list = [] weight_entries: list = [] # (entry, repo-relative path) for entry in entries: @@ -419,18 +463,14 @@ def snapshot_dir_is_complete( rel = name weight_entries.append((entry, rel)) - # The weights the request selects that are present on disk (any present weight when the - # request is unpatterned). The snapshot can carry an unrelated weight while the requested - # one is missing, so a patterned request must find one it actually selects. - if has_patterns: - selected = set(_filter_paths([rel for _, rel in weight_entries], allow_patterns, ignore_patterns)) - else: - # Unpatterned warm = a bare from_pretrained, which reads ROOT model weights. A weight that - # lives only inside a per-checkpoint dir (checkpoint-500/model.safetensors, left behind by - # a prior allow_patterns=["checkpoint-500/*"] pull) is not a root weight, so it must not - # make a checkpoint-only snapshot read as a warm root model -- that would let the guarded - # download be skipped and hand from_pretrained a snapshot whose root weights are missing. - selected = {rel for _, rel in weight_entries if not _path_under_checkpoint_dir(rel)} + # The weights the request selects that are present on disk (any present root weight when the + # request is unpatterned). The snapshot can carry an unrelated weight while the requested one + # is missing, so a patterned request must find one it actually selects. _requested_scope_filter + # also excludes per-checkpoint-dir weights (checkpoint-500/model.safetensors, left behind by a + # prior allow_patterns=["checkpoint-500/*"] pull) whenever there is no allow_patterns -- an + # UNPATTERNED *or* IGNORE-ONLY root warm (e.g. ignore_patterns=["*.onnx"]) is still a bare + # from_pretrained reading ROOT weights, so a checkpoint-only snapshot must not read as warm. + selected = set(_requested_scope_filter([rel for _, rel in weight_entries], allow_patterns, ignore_patterns)) if not selected: return False @@ -766,12 +806,21 @@ def request_can_include_weights( if not clearly_weightless: concrete_parent = _concretize_glob(prefix) if _has_glob(prefix) else prefix probes.extend(f"{concrete_parent}/{name}" for name in _WEIGHT_PROBE_NAMES) - elif _has_glob(pat) and ("." not in pat or _looks_like_checkpoint_dir(pat)): + elif ( + _has_glob(pat) + and ("." not in pat or _looks_like_checkpoint_dir(pat)) + and not _basename_is_non_weight(pat) + ): # A no-slash DIRECTORY glob ("checkpoint-*", "global_step*", the dotted # "checkpoint-v1.*"): HF's fnmatch "*" spans "/", so it matches nested weights like # checkpoint-10/model.safetensors. Probe the canonical weights re-rooted under a # concretized form of the glob. A plain extension file glob ("*.json", "tokenizer.*") # is not a directory glob and stays weightless unless it names a weight (self_probe). + # A no-slash glob whose stem is a known metadata family ("tokenizer*", "config*", + # "vocab*", "special_tokens*") is a FILE glob, not a directory: _basename_is_non_weight + # excludes it so a tokenizer*-only warm that fetched tokenizer.json is not rejected for + # lacking a weight ("model*" / "pytorch_model*" stay weight-including -- they match a + # weight probe, so _basename_is_non_weight is False for them). concrete = _concretize_glob(pat) probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) # A pattern that itself names a loadable weight -- a bare filename, a path-qualified @@ -793,6 +842,50 @@ def request_can_include_weights( return len(kept) > 0 +def requested_named_files_present( + snapshot_dir: Path, + *, + allow_patterns: "Optional[object]" = None, + ignore_patterns: "Optional[object]" = None, +) -> bool: + """For a request that names EXACT files (every ``allow_patterns`` entry is glob-free), True only + when each named file the ignore filter keeps is on disk. + + ``snapshot_download(local_files_only=True)`` returns a snapshot dir whenever the revision folder + exists -- even a config-only one left by a prior ``AutoConfig`` fetch -- so for a weightless + request like ``allow_patterns=["tokenizer.json"]`` a dangling-symlink check alone would accept a + cache that does not actually contain the requested file. This makes that request require its + named file before the snapshot is treated as warm. + + A request with ANY glob, or with no ``allow_patterns``, is a best-effort "warm what matches" and + cannot be turned into an exact manifest (an optional ``vocab.txt`` the repo may simply lack would + wrongly fail it), so it is trivially satisfied here -- the weight-bearing requests are gated by + ``snapshot_dir_is_complete`` instead.""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + if not allow_patterns or any(_has_glob(p) for p in allow_patterns): + return True + try: + entries = list(snapshot_dir.rglob("*")) + except OSError: + return True # cannot enumerate -> do not reject on an unreadable dir + present = set() + for entry in entries: + if _safe_is_file(entry): + try: + present.add(entry.relative_to(snapshot_dir).as_posix()) + except ValueError: + present.add(entry.name) + for pat in allow_patterns: + # A named file the ignore filter drops is not actually requested. _filter_paths fails OPEN + # (returns all on error), so an unevaluable filter keeps the strict presence check. + if ignore_patterns and not _filter_paths([pat], None, ignore_patterns): + continue + if pat not in present: + return False + return True + + def _iter_snapshot_dirs(repo_dir: Path) -> Iterator[Path]: snapshots_dir = repo_dir / "snapshots" try: diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index b15b0f532..9657ee434 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -54,6 +54,7 @@ hf_cache_root, iter_active_repo_cache_dirs, request_can_include_weights, + requested_named_files_present, snapshot_dir_has_broken_symlinks, snapshot_dir_is_complete, ) @@ -482,6 +483,44 @@ def _scrub_in_child(text: str, token: Optional[str]) -> str: ) +def _resolve_exception_class(type_name: str) -> "Optional[type]": + """Map a deterministic Hub / OS error class NAME (as captured in the child) back to its class, + so the parent can re-raise the original type rather than a generic RuntimeError. Best-effort: an + unknown name returns None. Imports are local so the helper stays import-light when no error + occurs and never hard-depends on a specific huggingface_hub layout.""" + if type_name == "OSError": + return OSError + if type_name not in _DETERMINISTIC_ERROR_NAMES: + return None + for module_name in ("huggingface_hub.errors", "huggingface_hub.utils"): + try: + module = importlib.import_module(module_name) + except Exception: + continue + cls = getattr(module, type_name, None) + if isinstance(cls, type) and issubclass(cls, BaseException): + return cls + return None + + +def _raise_child_error(message: str) -> None: + """Re-raise a deterministic child download error, preserving its original exception TYPE when it + is a known Hub / OS error, so callers that catch ``RepositoryNotFoundError`` / ``GatedRepoError`` + / ``OSError`` (auth prompts, offline handling, disk cleanup) still see those types across the + spawn-process boundary. The child reports the failure as ``": "``, so the + type is reconstructed from that prefix; anything unrecognized falls back to ``RuntimeError`` (the + prior behavior). A class whose constructor rejects a lone string also degrades to RuntimeError.""" + type_name = message.split(":", 1)[0].strip() if ":" in message else "" + exc_cls = _resolve_exception_class(type_name) + if exc_cls is None: + raise RuntimeError(message) + try: + exc = exc_cls(message) + except Exception: + raise RuntimeError(message) + raise exc + + def _is_retryable_download_error(exc: BaseException) -> bool: """True when a captured download exception looks like a transient transport failure (an ``hf_xet`` / CAS error, connection reset, timeout, HTTP 5xx / 429) that the OTHER transport @@ -886,7 +925,17 @@ def _snapshot_is_acceptable( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, require_named_weights = require_named_weights, ) - return not snapshot_dir_has_broken_symlinks(snapshot_dir) + # Weightless / non-model request (a dataset, or a model repo whose patterns drop every weight + # format, e.g. a tokenizer-only allow list): no weight is expected, so completeness is "no + # dangling symlink". But an EXACT-named weightless request (allow_patterns=["tokenizer.json"], + # no globs) must still find its named files on disk -- HF can hand back a config-only snapshot + # dir that simply does not contain the requested file. A glob-bearing list stays best-effort. + return ( + not snapshot_dir_has_broken_symlinks(snapshot_dir) + and requested_named_files_present( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) + ) def _snapshot_payload_incomplete( @@ -1007,8 +1056,11 @@ def _download_with_xet_fallback( raise RuntimeError("Cancelled") if kind_result == "error": # Deterministic failure (a captured Hub exception: auth, not-found, gated, disk - # full): the other transport would fail identically, so do not retry. - raise RuntimeError(payload) + # full): the other transport would fail identically, so do not retry. Re-raise + # preserving the original exception type (RepositoryNotFoundError / GatedRepoError / + # OSError ...) where known, so callers' typed except clauses still match across the + # spawn boundary; unknown errors fall back to RuntimeError. + _raise_child_error(payload) if kind_result == "retryable_error": # A transient transport failure (hf_xet CAS timeout, 5xx, connection reset) rather # than a deterministic Hub error: disabling Xet and retrying over HTTP may recover, From bd9278607049dee6b6f6166551860d832efc438f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 05:41:32 +0000 Subject: [PATCH 33/82] Preserve Hub error types across the spawn boundary regardless of constructor signature The deterministic-error reconstruction instantiated the resolved Hub class with only the scrubbed message string. Hub errors (RepositoryNotFoundError, GatedRepoError, ...) subclass HfHubHTTPError, whose response argument is keyword-only and required on newer huggingface_hub versions, so exc_cls(message) raises TypeError there and the handler silently fell back to RuntimeError -- dropping the type a caller's except clause relies on. Reconstruct robustly: try the normal constructors first (best fidelity: they default response / server_message), then bypass __init__ via __new__ + BaseException.__init__ so the original TYPE and the message survive even when no constructor accepts a lone string. Only a class that cannot be instantiated at all still degrades to RuntimeError. Tests: a Hub error class whose constructor requires a keyword-only response is re-raised with its type preserved (not RuntimeError); direct coverage of the layered constructor / __new__ reconstruction paths. --- tests/test_hf_xet_fallback.py | 37 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 33 +++++++++++++++++++++++++----- 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 6fd0fddba..7a01b9d50 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2342,3 +2342,40 @@ def test_resolve_exception_class_maps_known_names(): cls = xf._resolve_exception_class("RepositoryNotFoundError") assert cls is not None and issubclass(cls, BaseException) assert xf._resolve_exception_class("NotARealErrorType") is None + + +def test_error_type_preserved_when_constructor_needs_kwarg(monkeypatch): + """A Hub error class whose constructor rejects a lone positional string (newer huggingface_hub + makes HfHubHTTPError's `response` required / keyword-only) must STILL be re-raised with its type + preserved -- via an __init__-bypassing reconstruction -- not silently downgraded to RuntimeError + (Codex #829).""" + class PickyHubError(Exception): + def __init__(self, message, *, response): # response required + keyword-only + super().__init__(message) + self.response = response + + monkeypatch.setattr( + xf, "_resolve_exception_class", + lambda name: PickyHubError if name == "PickyHubError" else None, + ) + fake = _install(monkeypatch, [("error", "PickyHubError: kaboom")]) + with pytest.raises(PickyHubError, match = "kaboom"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" + + +def test_instantiate_preserving_type_paths(): + """Direct coverage of the layered reconstruction: a normal constructor is used when it accepts a + string; a keyword-only-required constructor falls through to the __new__ bypass; both yield an + instance of the requested type carrying the message (Codex #829).""" + class Plain(Exception): + pass + + class Picky(Exception): + def __init__(self, message, *, response): + super().__init__(message) + + for cls in (Plain, Picky): + exc = xf._instantiate_preserving_type(cls, "the message") + assert isinstance(exc, cls) + assert "the message" in str(exc) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 9657ee434..bede13939 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -503,20 +503,43 @@ def _resolve_exception_class(type_name: str) -> "Optional[type]": return None +def _instantiate_preserving_type(exc_cls: type, message: str) -> "Optional[BaseException]": + """Build an *exc_cls* instance carrying *message*, robust to a finicky constructor. Hub error + classes (``RepositoryNotFoundError`` ...) subclass ``HfHubHTTPError``, whose ``response`` arg is + keyword-only -- and required on some huggingface_hub versions -- so a plain ``exc_cls(message)`` + can raise ``TypeError``. Try the normal constructors first (best fidelity: they default + ``response`` / ``server_message``), then BYPASS ``__init__`` via ``__new__`` so the TYPE and the + message survive even when no constructor accepts a lone string. Returns None only if even + ``__new__`` fails, so the caller can fall back to ``RuntimeError``.""" + for build in ( + lambda: exc_cls(message), + lambda: exc_cls(message, response = None), + ): + try: + return build() + except Exception: + continue + try: + exc = exc_cls.__new__(exc_cls) + BaseException.__init__(exc, message) + return exc + except Exception: + return None + + def _raise_child_error(message: str) -> None: """Re-raise a deterministic child download error, preserving its original exception TYPE when it is a known Hub / OS error, so callers that catch ``RepositoryNotFoundError`` / ``GatedRepoError`` / ``OSError`` (auth prompts, offline handling, disk cleanup) still see those types across the spawn-process boundary. The child reports the failure as ``": "``, so the - type is reconstructed from that prefix; anything unrecognized falls back to ``RuntimeError`` (the - prior behavior). A class whose constructor rejects a lone string also degrades to RuntimeError.""" + type is reconstructed from that prefix; anything unrecognized -- or a class that cannot be + instantiated at all -- falls back to ``RuntimeError`` (the prior behavior).""" type_name = message.split(":", 1)[0].strip() if ":" in message else "" exc_cls = _resolve_exception_class(type_name) if exc_cls is None: raise RuntimeError(message) - try: - exc = exc_cls(message) - except Exception: + exc = _instantiate_preserving_type(exc_cls, message) + if exc is None: raise RuntimeError(message) raise exc From 76e7507856400327a42185c9af97b1e8c6283cfb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 06:42:11 +0000 Subject: [PATCH 34/82] Group weight formats, enforce named files post-download, handle dir patterns and errno Address the latest Codex review round on the shared helper: - Treat a trailing-slash directory allow pattern (unet/, checkpoint-10/) as a wildcard, not an exact filename. Hugging Face's filter_repo_objects expands it to match the directory contents, but _has_glob saw no glob char and the strict named-file checks then looked for a literal "unet/" entry -- rejecting a fully cached component directory and forcing an unnecessary network / offline-failing re-fetch. - Require explicitly named files in the POST-download acceptance check too, not just the pre-download probe, so a finished download that handed back a stale snapshot missing a named file (a base + adapter list where only the base materialized, or a weight plus a named tokenizer.json) is retried instead of returned with files missing. Named weights are grouped by LOGICAL weight (format / shard variants share a key), so an "either format" list stays satisfied by whichever variant the repo actually ships and never errors forever on a format that does not exist. - Preserve errno when reconstructing a deterministic child OSError. A disk-full (ENOSPC) or quota (EDQUOT) error reported as "[Errno N] ..." now rebuilds OSError(errno, message) rather than OSError(message), so a caller's except OSError can still branch on exc.errno across the spawn-process boundary. Tests: dir/ pattern read as a wildcard; logical-weight grouping (either-format = one group, base + adapter = two); post-download rejects a stale named weight and retries while an either-format result stays accepted; errno parsed and preserved on the rethrown OSError. --- tests/test_hf_xet_fallback.py | 114 +++++++++++++++++++++++++++++---- unsloth_zoo/hf_cache_state.py | 87 ++++++++++++++++++------- unsloth_zoo/hf_xet_fallback.py | 39 +++++++++-- 3 files changed, 200 insertions(+), 40 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 7a01b9d50..66da14ee2 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1460,27 +1460,27 @@ def test_snapshot_dir_is_complete_checkpoint_index_does_not_gate_root(tmp_path): def test_snapshot_dir_is_complete_requires_each_named_weight(tmp_path): - """require_named_weights makes a request naming multiple exact weights (base + adapter) - need EACH on disk, so the pre-download cache probe does not short-circuit a stale snapshot - holding only the base. Off (the post-download check) it stays lenient, so an "either - format" list (pytorch_model.bin + model.safetensors) against a safetensors-only repo is not - turned into a spurious incomplete-snapshot failure.""" + """require_named_weights makes a request naming multiple exact weights (base + adapter) need + EACH logical weight on disk -- so a stale snapshot holding only the base is rejected -- while + grouping format variants of one logical weight so an "either format" list (pytorch_model.bin + + model.safetensors) is satisfied by whichever format the repo actually ships (no error-forever on + a name that doesn't exist).""" snap = tmp_path / "snap" snap.mkdir() blob = tmp_path / "blob" blob.write_bytes(b"x") (snap / "model.safetensors").symlink_to(blob) # base only; adapter missing pair = ["model.safetensors", "adapter_model.safetensors"] - # Strict (pre-download): adapter missing -> incomplete -> do not short-circuit a stale cache. + # base + adapter are two LOGICAL weights -> both required; the adapter is missing -> incomplete. assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is False - # Lenient (post-download default): a present selected weight suffices. + # Lenient (require_named_weights off): a present selected weight suffices. assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = False) is True - # Either-format list, safetensors-only repo: strict still won't short-circuit, but the - # lenient check must NOT reject it (no error-forever on a name that doesn't exist). + # Either-format list = ONE logical weight: whichever format is present satisfies it, under both + # the strict and lenient checks (no spurious failure on the absent format). either = ["pytorch_model.bin", "model.safetensors"] - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = either, require_named_weights = True) is False + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = either, require_named_weights = True) is True assert hcs.snapshot_dir_is_complete(snap, allow_patterns = either, require_named_weights = False) is True - # Both present -> strict is satisfied. + # Both present -> the base + adapter request is satisfied. (snap / "adapter_model.safetensors").symlink_to(blob) assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is True # A named weight the ignore filter drops is not actually requested, so it is not required. @@ -2379,3 +2379,95 @@ def __init__(self, message, *, response): exc = xf._instantiate_preserving_type(cls, "the message") assert isinstance(exc, cls) assert "the message" in str(exc) + + +# --------------------------------------------------------------------------- # +# Codex round: dir/ wildcard, logical-weight grouping post-download, errno preservation. +# --------------------------------------------------------------------------- # +def test_dir_pattern_treated_as_wildcard(tmp_path): + """A trailing-slash directory allow pattern (unet/) is a wildcard -- HF's filter_repo_objects + expands it to unet/* -- not an exact filename, so the strict named-file checks must not reject a + fully cached component directory by looking for a literal 'unet/' entry (Codex #829).""" + assert hcs._has_glob("unet/") is True + assert hcs._has_glob("checkpoint-10/") is True + assert hcs._has_glob("config.json") is False + snap = tmp_path / "snap" + snap.mkdir() + (snap / "unet").mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob) + (snap / "unet" / "config.json").write_text("{}") + # A dir/ pattern is best-effort (glob), never an exact-name requirement. + assert hcs.requested_named_files_present(snap, allow_patterns = ["unet/"]) is True + # The component-dir weight satisfies the request; not rejected for a literal 'unet/'. + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["unet/"]) is True + + +def test_weight_logical_key_groups_formats(): + """Format / shard variants of one logical weight share a key; different stems or subdirs do + not, so an either-format list is one group while base + adapter are two (Codex #829).""" + k = hcs._weight_logical_key + assert k("pytorch_model.bin") == k("model.safetensors") # either-format -> 1 group + assert k("model-00001-of-00002.safetensors") == k("model.safetensors") # shard -> same group + assert k("model.safetensors") != k("adapter_model.safetensors") # base vs adapter + assert k("unet/model.safetensors") != k("vae/model.safetensors") # different subdirs + + +def test_post_download_rejects_stale_named_weight(hf_cache, monkeypatch): + """A finished child snapshot missing an explicitly named LOGICAL weight (base present, adapter + missing) is now treated as incomplete post-download and retried over HTTP, instead of being + returned with the adapter still missing (Codex #829).""" + blobs = _blobs_dir(hf_cache, DL_REPO) + base_only = blobs.parent / "snapshots" / "xet" + base_only.mkdir(parents = True) + w = blobs / "w" + w.write_bytes(b"x") + (base_only / "model.safetensors").symlink_to(w) # base only; adapter missing + complete = blobs.parent / "snapshots" / "http" + complete.mkdir(parents = True) + (complete / "model.safetensors").symlink_to(w) + (complete / "adapter_model.safetensors").symlink_to(w) + fake = _install(monkeypatch, [("ok", str(base_only)), ("ok", str(complete))]) + out = xf.snapshot_download_with_xet_fallback( + DL_REPO, token = None, force_download = True, + allow_patterns = ["model.safetensors", "adapter_model.safetensors"], + ) + assert out == str(complete) + assert [c.disable_xet for c in fake.calls] == [False, True] # the stale base-only result retried + + +def test_post_download_either_format_still_accepted(hf_cache, monkeypatch): + """An either-format list against a single-format child snapshot stays accepted post-download + (the formats group to one logical weight), so require_named_weights does not error-forever on a + format the repo never ships (Codex #829).""" + blobs = _blobs_dir(hf_cache, DL_REPO) + child = blobs.parent / "snapshots" / "only-st" + child.mkdir(parents = True) + w = blobs / "w" + w.write_bytes(b"x") + (child / "model.safetensors").symlink_to(w) # safetensors only; no pytorch_model.bin + fake = _install(monkeypatch, [("ok", str(child))]) + out = xf.snapshot_download_with_xet_fallback( + DL_REPO, token = None, force_download = True, + allow_patterns = ["pytorch_model.bin", "model.safetensors"], + ) + assert out == str(child) and len(fake.calls) == 1 + + +def test_parse_errno(): + assert xf._parse_errno("OSError: [Errno 28] No space left on device") == 28 + assert xf._parse_errno("OSError: [Errno 122] Disk quota exceeded") == 122 + assert xf._parse_errno("OSError: some message with no errno") is None + + +def test_oserror_errno_preserved(monkeypatch): + """A disk-full child OSError keeps its errno (ENOSPC) across the spawn boundary, so a caller's + `except OSError` cleanup can still branch on exc.errno -- not see errno=None (Codex #829).""" + import errno as _errno + + fake = _install(monkeypatch, [("error", "OSError: [Errno 28] No space left on device")]) + with pytest.raises(OSError) as excinfo: + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert excinfo.value.errno == _errno.ENOSPC + assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 7b009e7fa..2c8e3529e 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -474,21 +474,20 @@ def snapshot_dir_is_complete( if not selected: return False - # A request that explicitly names exact files needs EACH of them on disk, not just one, so a - # stale cache holding a subset is not short-circuited past the guarded download. WHICH names - # are required depends on the request shape: - # * An exact-file request (no globs) -- ["model.safetensors", "tokenizer.json"], or a base - # plus a PEFT adapter ["model.safetensors", "adapter_model.safetensors"] -- names every - # file it wants, so each concrete name (weight OR non-weight) must be present. A cache with - # just the weight must not accept-warm while the named tokenizer / config is still missing. - # * A request containing ANY glob is a broad "warm what matches" selection where named aux - # files are best-effort (an optional vocab.txt / spiece.model the repo may simply lack), so - # only its concrete WEIGHT names are required -- demanding every optional aux file there - # would defeat the warm-cache short-circuit for normal repos. - # Enforced only at the pre-download probe (require_named_weights), so the post-download check - # stays lenient and never errors on an "either format" name (["pytorch_model.bin", - # "model.safetensors"] against a safetensors-only repo) that does not exist in the repo. A name - # the ignore filter drops is not actually requested. + # A request that explicitly names exact files needs them on disk before a stale cache is + # short-circuited (pre-download) or accepted (post-download) past the guarded download. WHICH + # names are required depends on the request shape: + # * Each named NON-WEIGHT file (tokenizer.json, config.json) must be present -- but only for an + # exact-file request (no globs). A glob-bearing list treats aux names as best-effort (an + # optional vocab.txt / spiece.model the repo may lack), so unsloth's glob warms still + # short-circuit on a warm cache rather than re-downloading on every load. + # * Named WEIGHT files are grouped by LOGICAL weight: format / shard variants of the same + # weight share a key, and each group needs at least ONE variant present. So an "either + # format" list (["pytorch_model.bin", "model.safetensors"]) is satisfied by whichever the + # repo actually ships -- never an error-forever on a name that does not exist -- while a + # base + adapter list (["model.safetensors", "adapter_model.safetensors"]) is TWO groups and + # needs both, so a stale cache holding only the base is rejected. + # A name the ignore filter drops is not actually requested. if require_named_weights and allow_patterns: exact_only = not any(_has_glob(p) for p in allow_patterns) if exact_only: @@ -501,18 +500,25 @@ def snapshot_dir_is_complete( present.add(entry.name) else: present = set(rel for _, rel in weight_entries) + weight_groups: dict = {} for pat in allow_patterns: if _has_glob(pat): continue - if not exact_only and not str(pat).lower().endswith(_WEIGHT_FILE_SUFFIXES): - continue if ignore_patterns and not _filter_paths([pat], None, ignore_patterns): - continue - # pat is a concrete (glob-free) path, so presence is an exact match. A direct - # membership test (not _filter_paths, which fails OPEN by returning all paths on a - # filter error) keeps this strict check fail-SAFE: an unevaluable case requires the - # guarded download rather than silently accepting a stale cache as warm. - if pat not in present: + continue # a name the ignore filter drops is not actually requested + if str(pat).lower().endswith(_WEIGHT_FILE_SUFFIXES): + weight_groups.setdefault(_weight_logical_key(pat), []).append(pat) + elif exact_only: + # A named non-weight (tokenizer.json, config.json) is required as-is. A direct + # membership test (not _filter_paths, which fails OPEN by returning all paths on a + # filter error) keeps this fail-SAFE: an unevaluable case requires the guarded + # download rather than silently accepting a stale cache as warm. + if pat not in present: + return False + # Each logical weight group needs at least one of its named format / shard variants on disk + # (an interrupted shard SET is caught separately by the numbered-shard check below). + for names in weight_groups.values(): + if not any(n in present for n in names): return False # Every selected numbered shard needs the sibling shards the request also selects (the @@ -575,7 +581,40 @@ def snapshot_dir_is_complete( def _has_glob(text: str) -> bool: - return any(ch in text for ch in _GLOB_CHARS) + # A trailing-slash directory pattern ("unet/", "checkpoint-10/") is NOT an exact filename: + # Hugging Face's filter_repo_objects expands it to match everything under that directory (as + # if "unet/*"). Treat it as a wildcard so the strict exact-name checks do not look for a + # literal "unet/" entry and wrongly reject a fully cached directory / component download. + return text.endswith("/") or any(ch in text for ch in _GLOB_CHARS) + + +# Weight stems that are format-family variants of the SAME logical weight (Transformers reads one): +# the PyTorch / TF / Flax / safetensors "model" forms collapse to one key, so an "either format" +# named request is satisfied by whichever variant the repo actually ships. +_WEIGHT_FORMAT_FAMILY = { + "pytorch_model": "model", + "tf_model": "model", + "flax_model": "model", + "model": "model", +} + + +def _weight_logical_key(name: str) -> tuple: + """A grouping key for a named weight file so format / shard variants of the SAME logical weight + share it. Keyed by (directory, normalized stem): the weight suffix and any ``-NNNNN-of-NNNNN`` + shard suffix are stripped, and the pytorch_model / tf_model / flax_model / model family collapses + to ``model``. So ``["pytorch_model.bin", "model.safetensors"]`` is ONE group (either format + satisfies it) while ``["model.safetensors", "adapter_model.safetensors"]`` -- or the same stem in + two different subdirs -- are separate groups, each independently required.""" + norm = name.replace("\\", "/") + dirname, _, base = norm.rpartition("/") + base = base.lower() + for suf in _WEIGHT_FILE_SUFFIXES: + if base.endswith(suf): + base = base[: -len(suf)] + break + base = re.sub(r"-\d+-of-\d+$", "", base) + return (dirname, _WEIGHT_FORMAT_FAMILY.get(base, base)) def _as_pattern_list(patterns: "Optional[object]") -> "Optional[list]": diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index bede13939..b09ed47ab 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -527,6 +527,19 @@ def _instantiate_preserving_type(exc_cls: type, message: str) -> "Optional[BaseE return None +def _parse_errno(message: str) -> "Optional[int]": + """Pull the errno out of a stringified OSError. CPython formats it as ``[Errno 28] ...``, so a + disk-full (ENOSPC) / quota (EDQUOT) error keeps its code across the spawn boundary when the + parent reconstructs the OSError, letting callers branch on ``exc.errno``.""" + match = re.search(r"\[Errno (\d+)\]", message) + if match is None: + return None + try: + return int(match.group(1)) + except ValueError: + return None + + def _raise_child_error(message: str) -> None: """Re-raise a deterministic child download error, preserving its original exception TYPE when it is a known Hub / OS error, so callers that catch ``RepositoryNotFoundError`` / ``GatedRepoError`` @@ -538,6 +551,13 @@ def _raise_child_error(message: str) -> None: exc_cls = _resolve_exception_class(type_name) if exc_cls is None: raise RuntimeError(message) + if exc_cls is OSError: + # Preserve errno (ENOSPC / EDQUOT ...) so a caller's `except OSError` cleanup can still + # branch on exc.errno; OSError(message) alone would leave errno = None. + errno_val = _parse_errno(message) + if errno_val is not None: + raise OSError(errno_val, message) + raise OSError(message) exc = _instantiate_preserving_type(exc_cls, message) if exc is None: raise RuntimeError(message) @@ -939,10 +959,12 @@ def _snapshot_is_acceptable( weight (e.g. ``allow_patterns=["adapter_model.safetensors"]`` or a checkpoint shard) is satisfied only when THAT weight is on disk, not by some other weight already cached. - ``require_named_weights`` makes a request that explicitly names multiple exact weights - require each of them on disk (set on the pre-download cache probe so a stale snapshot - missing one is not short-circuited; left off post-download so a named weight that simply - does not exist in the repo never turns a finished download into a spurious failure).""" + ``require_named_weights`` makes a request that explicitly names files require them on disk + (each named non-weight, and at least one format/shard variant of each named LOGICAL weight), + so a stale snapshot missing one is neither short-circuited (pre-download) nor accepted + (post-download). Format variants of one weight are grouped, so an "either format" name list + against a single-format repo is satisfied by whichever variant exists -- never an error-forever + on a name that does not exist in the repo.""" if repo_type == "model" and request_can_include_weights(allow_patterns, ignore_patterns): return snapshot_dir_is_complete( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, @@ -981,8 +1003,15 @@ def _snapshot_payload_incomplete( return False except OSError: return False + # require_named_weights so a finished download that handed back a stale snapshot missing an + # explicitly named file -- a base + adapter list (["model.safetensors", + # "adapter_model.safetensors"]) where only the base materialized, or a weight + named + # tokenizer.json -- is still treated as incomplete and retried, not returned with files + # missing. Format / shard variants of one logical weight are grouped, so an "either format" + # list stays satisfied by whichever variant the repo actually ships (no error-forever). return not _snapshot_is_acceptable( - path, repo_type = repo_type, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + path, repo_type = repo_type, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, require_named_weights = True, ) From b1e81bf1fbbba8b6e92fa1c38f7055b57f57ce9e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 07:11:22 +0000 Subject: [PATCH 35/82] Scope the weightless broken-symlink check and recognize processor metadata globs Address the latest Codex review round on the shared helper: - Scope the broken-symlink check in the weightless / non-model acceptance branch to the REQUESTED files, matching snapshot_dir_is_complete. A dangling symlink for an EXCLUDED weight left by an earlier interrupted pull no longer rejects a complete config / tokenizer subset (e.g. allow_patterns=["config.json"] whose config is on disk); only a dangling REQUESTED file rejects. Adds snapshot_has_requested_broken_symlinks for the scoped check. - Recognize a processor-metadata glob as weightless. allow_patterns=["processor*"] selects processor_config.json and no weight, but the non-weight probe list had no processor_config.json representative, so the request was read as weight-including and a processor-only snapshot was wrongly rejected for lacking weights. Add processor_config.json (and video_preprocessor_config.json) to the non-weight probes. Tests: a weightless request accepts a snapshot with a dangling EXCLUDED weight but rejects a dangling REQUESTED file; a processor* glob reads as weightless while a model* glob stays weight-including. --- tests/test_hf_xet_fallback.py | 33 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 24 ++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 13 +++++++++---- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 66da14ee2..2e14f430a 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2471,3 +2471,36 @@ def test_oserror_errno_preserved(monkeypatch): xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) assert excinfo.value.errno == _errno.ENOSPC assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" + + +# --------------------------------------------------------------------------- # +# Codex round: weightless broken-symlink scoping + processor metadata glob. +# --------------------------------------------------------------------------- # +def test_weightless_broken_symlink_scoped_to_request(tmp_path): + """A weightless request (allow=['config.json']) must accept a snapshot whose config is present + even when an EXCLUDED weight left a dangling symlink from an earlier interrupted pull -- only a + dangling REQUESTED file rejects it (Codex #829).""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "config.json").write_text("{}") + (snap / "model.safetensors").symlink_to(tmp_path / "missing-blob") # dangling, NOT requested + assert hcs.snapshot_has_requested_broken_symlinks(snap, allow_patterns = ["config.json"]) is False + assert xf._snapshot_is_acceptable( + snap, repo_type = "model", allow_patterns = ["config.json"], ignore_patterns = None + ) is True + # A dangling REQUESTED file does reject. + (snap / "config.json").unlink() + (snap / "config.json").symlink_to(tmp_path / "missing-cfg") + assert hcs.snapshot_has_requested_broken_symlinks(snap, allow_patterns = ["config.json"]) is True + assert xf._snapshot_is_acceptable( + snap, repo_type = "model", allow_patterns = ["config.json"], ignore_patterns = None + ) is False + + +def test_processor_glob_is_weightless(): + """A processor-only warm (allow=['processor*']) selects processor_config.json and no weight, so + it must read as weightless rather than be rejected for lacking weights (Codex #829).""" + assert hcs._basename_is_non_weight("processor*") is True + assert hcs.request_can_include_weights(allow_patterns = ["processor*"]) is False + # control: a real weight glob stays weight-including + assert hcs.request_can_include_weights(allow_patterns = ["model*"]) is True diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 2c8e3529e..4a3c67850 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -394,6 +394,25 @@ def _requested_scope_filter( return kept +def snapshot_has_requested_broken_symlinks( + snapshot_dir: Path, + *, + allow_patterns: "Optional[object]" = None, + ignore_patterns: "Optional[object]" = None, +) -> bool: + """True iff a dangling symlink in *snapshot_dir* is for a file the request actually SELECTS. + + A dangling symlink marks an interrupted download, but for a scoped request only one for a + requested file should reject the snapshot: a dangling root ``model.safetensors`` left by an + earlier interrupted pull must not fail a weightless ``allow_patterns=["config.json"]`` request + whose config is on disk. Mirrors the scoped broken-symlink handling inside + ``snapshot_dir_is_complete`` so the weightless / non-model path is scoped the same way.""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + broken = _broken_symlink_rel_paths(snapshot_dir) + return bool(broken and _requested_scope_filter(broken, allow_patterns, ignore_patterns)) + + def snapshot_dir_is_complete( snapshot_dir: Path, *, @@ -714,6 +733,11 @@ def _concretize_glob(pattern: str) -> str: "special_tokens_map.json", "generation_config.json", "preprocessor_config.json", + # Processor metadata: a processor-only warm (allow_patterns=["processor*"]) selects these and no + # weight, so a representative must be here for _basename_is_non_weight to read the glob as + # metadata-only (else the snapshot is wrongly rejected for lacking a weight). + "processor_config.json", + "video_preprocessor_config.json", "vocab.json", "merges.txt", "readme.md", diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index b09ed47ab..2b548e31d 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -57,6 +57,7 @@ requested_named_files_present, snapshot_dir_has_broken_symlinks, snapshot_dir_is_complete, + snapshot_has_requested_broken_symlinks, ) logger = logging.getLogger(__name__) @@ -972,11 +973,15 @@ def _snapshot_is_acceptable( ) # Weightless / non-model request (a dataset, or a model repo whose patterns drop every weight # format, e.g. a tokenizer-only allow list): no weight is expected, so completeness is "no - # dangling symlink". But an EXACT-named weightless request (allow_patterns=["tokenizer.json"], - # no globs) must still find its named files on disk -- HF can hand back a config-only snapshot - # dir that simply does not contain the requested file. A glob-bearing list stays best-effort. + # dangling symlink among the REQUESTED files". The broken-symlink check is scoped to the request + # (like snapshot_dir_is_complete), so a dangling EXCLUDED weight left by an earlier interrupted + # pull does not reject a complete config/tokenizer subset. An EXACT-named weightless request + # (allow_patterns=["tokenizer.json"], no globs) must still find its named files on disk -- HF can + # hand back a config-only snapshot dir that simply lacks the requested file. Globs stay best-effort. return ( - not snapshot_dir_has_broken_symlinks(snapshot_dir) + not snapshot_has_requested_broken_symlinks( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) and requested_named_files_present( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns ) From 316fe69b40753458e07adf4afb3e538ace8c22a9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 07:32:28 +0000 Subject: [PATCH 36/82] Reject a diffusers pipeline snapshot missing a declared sub-model A full pipeline warm (no allow_patterns) killed mid-download can leave one component fully cached and another entirely absent. snapshot_dir_is_complete only checked that SOME loadable weight was present, so such a half-warmed pipeline read as complete -- and the in-process pipeline load would then fetch the missing component over unprotected Xet, the silent-hang case the fallback exists to prevent. A diffusers pipeline lists its sub-models in a root model_index.json (each non-_ key maps to a [library, class] pair). Add _diffusion_pipeline_complete: for an unpatterned warm, every declared non-null component's subfolder must exist with files, and every weight-bearing component (unet / transformer / vae / text_encoder / ...) must carry a weight. A repo with no readable model_index.json (plain transformers / GGUF) is unaffected, and a scoped subfolder request still validates only its own subset. Tests: a unet+vae warm missing the declared text_encoder reads incomplete until text_encoder lands; a vae subfolder holding only a config (no weight) reads incomplete; a scoped allow_patterns=["unet/*"] request reads complete regardless of the other components. --- tests/test_hf_xet_fallback.py | 87 +++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 86 ++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 2e14f430a..1029115c6 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1298,6 +1298,93 @@ def test_snapshot_dir_is_complete_unit(tmp_path): assert hcs.snapshot_dir_is_complete(snap) is True +def _make_diffusion_component(snap, blob, name, weight_filename = None): + """Create a diffusers pipeline subfolder with a config and (optionally) a weight symlink.""" + comp = snap / name + comp.mkdir() + (comp / "config.json").write_text("{}") + if weight_filename is not None: + (comp / weight_filename).symlink_to(blob) + return comp + + +def test_snapshot_dir_is_complete_diffusion_missing_component(tmp_path): + """A full pipeline warm killed with one component absent reads as incomplete. model_index.json + declares unet / vae / text_encoder; the snapshot warmed unet + vae but never started + text_encoder, so the in-process pipeline load would fetch it over unprotected Xet.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"weights") + (snap / "model_index.json").write_text( + json.dumps( + { + "_class_name": "StableDiffusionPipeline", + "_diffusers_version": "0.30.0", + "unet": ["diffusers", "UNet2DConditionModel"], + "vae": ["diffusers", "AutoencoderKL"], + "text_encoder": ["transformers", "CLIPTextModel"], + "scheduler": ["diffusers", "PNDMScheduler"], + "safety_checker": [None, None], + } + ) + ) + _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") + _make_diffusion_component(snap, blob, "vae", "diffusion_pytorch_model.safetensors") + (snap / "scheduler").mkdir() + (snap / "scheduler" / "scheduler_config.json").write_text("{}") + # text_encoder subfolder never created -> interrupted pipeline warm + assert hcs.snapshot_dir_is_complete(snap) is False + # Once the missing component is on disk the pipeline reads complete. + _make_diffusion_component(snap, blob, "text_encoder", "model.safetensors") + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_snapshot_dir_is_complete_diffusion_partial_weight_component(tmp_path): + """A weight-bearing component whose subfolder holds only a config (no weight) reads as + incomplete: the component started but its weight was never fetched.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"weights") + (snap / "model_index.json").write_text( + json.dumps( + { + "_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], + "vae": ["diffusers", "AutoencoderKL"], + } + ) + ) + _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") + _make_diffusion_component(snap, blob, "vae", weight_filename = None) # config only, no weight + assert hcs.snapshot_dir_is_complete(snap) is False + (snap / "vae" / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_snapshot_dir_is_complete_diffusion_scoped_request_not_blocked(tmp_path): + """A scoped subfolder request (allow_patterns=["unet/*"]) targets its own subset, so the + whole-pipeline completeness rule does not apply: a unet-only snapshot reads complete even + though the pipeline declares more components.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"weights") + (snap / "model_index.json").write_text( + json.dumps( + { + "_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], + "vae": ["diffusers", "AutoencoderKL"], + "text_encoder": ["transformers", "CLIPTextModel"], + } + ) + ) + _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["unet/*"]) is True + + def test_snapshot_dir_is_complete_broken_symlink(tmp_path): """A dangling weight symlink reads as incomplete.""" snap = tmp_path / "snap" diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 4a3c67850..c765aa8a7 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -358,6 +358,82 @@ def _weight_shard_index_complete(index_path: Path) -> bool: return True +# Diffusers pipeline subfolders that carry loadable WEIGHTS (every other declared component -- +# scheduler, tokenizer, feature_extractor, processor -- is config-only). A weight-bearing +# component whose subfolder exists but holds no weight is a partially fetched component, so the +# in-process pipeline load would still fetch the weight in-process over Xet. +_WEIGHT_BEARING_PIPELINE_DIRS = frozenset({ + "unet", + "transformer", + "vae", + "vqvae", + "movq", + "prior", + "decoder", + "text_encoder", + "text_encoder_2", + "text_encoder_3", + "image_encoder", + "safety_checker", + "controlnet", +}) + + +def _dir_has_any_file(path: Path) -> bool: + """True if *path* contains at least one regular file (recursively). A dangling symlink left by + an interrupted blob fetch is NOT a regular file (``is_file()`` follows the link and returns + False), so a component subfolder that only ever received pointer symlinks reads as having no + files -- i.e. as an unfinished component.""" + try: + for entry in path.rglob("*"): + if _safe_is_file(entry): + return True + except OSError: + return False + return False + + +def _diffusion_pipeline_complete(snapshot_dir: Path, weight_dirs: set) -> bool: + """True unless a diffusers pipeline snapshot is missing a declared sub-model. A diffusers + pipeline lists its components in a root ``model_index.json`` where each non-``_`` key maps to a + ``[library, class]`` pair; a warm killed mid-pipeline can leave one component fully cached and + another entirely absent, and the in-process pipeline load would then fetch the missing + component over unprotected Xet (the silent-hang risk). Require every declared (non-null) + component's subfolder to exist with files, and every weight-bearing component + (unet / transformer / vae / text_encoder / ...) to carry a weight in *weight_dirs*. + + Returns True (do not block) when there is no readable ``model_index.json`` -- a plain + transformers / GGUF snapshot, or a non-diffusion repo -- so only an actual pipeline warm is + affected. Intended for a FULL pipeline warm (no allow_patterns); a scoped subfolder request is + already validated by its own selection.""" + import json + + index_path = snapshot_dir / "model_index.json" + if not _safe_is_file(index_path): + return True # not a diffusers pipeline (or an older layout) -- nothing pipeline-specific + try: + with open(index_path, "r", encoding = "utf-8") as f: + data = json.load(f) + except (OSError, ValueError): + return True # unreadable index: defer to the generic checks rather than over-reject + if not isinstance(data, dict): + return True + for key, value in data.items(): + if isinstance(key, str) and key.startswith("_"): + continue # _class_name / _diffusers_version metadata, not a component + if not (isinstance(value, (list, tuple)) and len(value) == 2): + continue # not a [library, class] component spec + library, class_name = value + if library is None or class_name is None: + continue # an explicitly absent component (e.g. a disabled safety_checker) + component_dir = snapshot_dir / key + if not _safe_is_dir(component_dir) or not _dir_has_any_file(component_dir): + return False # a declared component's subfolder is missing / empty -- interrupted warm + if key in _WEIGHT_BEARING_PIPELINE_DIRS and key not in weight_dirs: + return False # the component dir exists but carries no weight -- partial component + return True + + def _broken_symlink_rel_paths(snapshot_dir: Path) -> list: """Repo-relative posix paths of every dangling symlink in *snapshot_dir* -- a referenced file whose blob is missing or still an ``.incomplete`` partial (an interrupted download). Empty when @@ -565,6 +641,16 @@ def snapshot_dir_is_complete( continue if not _weight_shard_index_complete(index_entry): return False + + # A FULL pipeline warm (no allow_patterns) must carry every sub-model a diffusers + # model_index.json declares: a warm killed mid-pipeline can leave one component cached and + # another entirely absent, which the in-process pipeline load would then fetch over + # unprotected Xet. A scoped (allow_patterns) request targets its own subset, so the + # whole-pipeline rule does not apply -- only enforce it for the unpatterned warm. + if allow_patterns is None: + weight_dirs = {rel.split("/", 1)[0] for _, rel in weight_entries if "/" in rel} + if not _diffusion_pipeline_complete(snapshot_dir, weight_dirs): + return False return True From 24b03b44804ba47673bacfe031a3d3c849cb38e7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 08:16:44 +0000 Subject: [PATCH 37/82] Scope catch-all warms to root weights and treat processor subfolders as weightless Address the latest Codex review round on the shared completeness check: - Keep a pure catch-all allow list scoped to root weights. allow_patterns=["*"] selects the whole repo exactly like an unpatterned warm (HF's fnmatch "*" spans "/"), so a root from_pretrained still reads ROOT weights. The checkpoint-dir exclusion previously applied only when allow_patterns was None, so a cache left by a prior allow_patterns=["checkpoint-10/*"] pull (only checkpoint-10/model.safetensors, no root weight) read as complete for ["*"] and let the guarded download be skipped. _is_pure_catchall now treats ["*"] / ["**"] like an unpatterned root warm for both the checkpoint-dir drop and the diffusers pipeline check; a path-bearing pattern (checkpoint-10/*, model.safetensors) is still trusted as a deliberate selection. - Recognize a processor / image_processor subfolder as weightless. allow_patterns=["processor/*"] selects only *_config.json / vocab files, but the auxiliary-dir set omitted processor, so the synthetic processor/model.safetensors weight probe made the request look weight-bearing and a processor-only snapshot was wrongly rejected for lacking a weight. Add processor and image_processor alongside scheduler / feature_extractor; a weight name under the subfolder is still read as weight-including. Tests: a checkpoint-only cache reads incomplete for ["*"] / ["**"] but complete for ["checkpoint-10/*"]; processor/* and image_processor/* read as weightless while processor/model.safetensors stays weight-including. --- tests/test_hf_xet_fallback.py | 35 +++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 53 ++++++++++++++++++++++++----------- 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 1029115c6..0cbdbac10 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1527,6 +1527,41 @@ def test_snapshot_dir_is_complete_checkpoint_only_not_warm_root(tmp_path): assert hcs.snapshot_dir_is_complete(snap2) is False +def test_snapshot_dir_is_complete_catchall_not_warm_root(tmp_path): + """A pure catch-all allow list (["*"]) selects the whole repo just like an unpatterned warm, so + a root from_pretrained still reads ROOT weights. A checkpoint-only cache (left by a prior + allow_patterns=["checkpoint-10/*"] pull) must NOT read as complete for ["*"] -- HF's fnmatch + "*" spans "/" and would otherwise count the checkpoint weight as satisfying the catch-all + (Codex #829). A path-bearing pattern that names the checkpoint is still trusted.""" + snap = tmp_path / "snap" + (snap / "checkpoint-10").mkdir(parents = True) + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "checkpoint-10" / "model.safetensors").symlink_to(blob) # checkpoint weight only + (snap / "config.json").write_text("{}") + # Catch-all is treated like an unpatterned root warm: the checkpoint weight does not count. + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*"]) is False + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["**"]) is False + # A path-bearing checkpoint pattern IS satisfied by it (deliberate checkpoint selection). + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is True + # Once a root weight is present, the catch-all completes. + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*"]) is True + + +def test_request_can_include_weights_processor_subfolder(): + """A processor / image_processor subfolder ships only *_config.json + vocab files (no weights), + so a catch-all warm under it (processor/*) reads as WEIGHTLESS. Without this, the synthetic + processor/model.safetensors weight probe makes the request look weight-bearing and a + processor-only snapshot is wrongly rejected for lacking a weight (Codex #829). A weight under + the same subfolder is still recognized as weight-including.""" + assert hcs.request_can_include_weights(["processor/*"], None) is False + assert hcs.request_can_include_weights(["image_processor/*"], None) is False + assert hcs.request_can_include_weights(["processor/"], None) is False + # A weight name under the subfolder still reads as weight-including (no accept-stale). + assert hcs.request_can_include_weights(["processor/model.safetensors"], None) is True + + def test_snapshot_dir_is_complete_checkpoint_index_does_not_gate_root(tmp_path): """A per-checkpoint shard index with missing shards must not fail an unpatterned root warm: the root weights are what the load reads, so an incomplete checkpoint index is irrelevant to diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index c765aa8a7..287c6be46 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -455,17 +455,37 @@ def _broken_symlink_rel_paths(snapshot_dir: Path) -> list: return out +# Catch-all allow patterns that select the WHOLE repo, exactly like an unpatterned warm. HF's +# fnmatch ``*`` spans ``/``, so a bare ``*`` (or ``**``) matches every path including checkpoint +# subdirs -- but a root ``from_pretrained`` still reads ROOT weights, so such a request must be +# treated like an unpatterned root warm (drop checkpoint-dir paths), not trusted as a deliberate +# checkpoint selection. +_CATCHALL_ALLOW_PATTERNS = frozenset({"*", "**"}) + + +def _is_pure_catchall(allow_patterns: "Optional[list]") -> bool: + """True when *allow_patterns* is a non-empty list whose every entry is a bare catch-all + (``*`` / ``**``). Such a list selects the whole repo just like an unpatterned warm, so a root + load still reads ROOT weights and a checkpoint-dir-only cache must not satisfy it. A list with + any path-bearing or name-specific pattern (``checkpoint-10/*``, ``model.safetensors``) is + trusted as-is -- a caller that names a checkpoint path opts back into it.""" + if not allow_patterns: + return False + return all(isinstance(p, str) and p.strip() in _CATCHALL_ALLOW_PATTERNS for p in allow_patterns) + + def _requested_scope_filter( rels: list, allow_patterns: "Optional[list]", ignore_patterns: "Optional[list]" ) -> list: """The subset of repo-relative *rels* a request selects. Applies the allow / ignore filter, and when there is no ``allow_patterns`` (an UNPATTERNED or IGNORE-ONLY request -- a bare - ``from_pretrained`` that reads ROOT weights) also drops per-checkpoint-dir paths the root load - never reads, so a checkpoint-dir file neither satisfies the warm nor (as a dangling symlink) - blocks it. An explicit ``allow_patterns`` is trusted as-is: a caller that names a checkpoint - path opts back into it.""" + ``from_pretrained`` that reads ROOT weights) OR the allow list is a pure catch-all + (``["*"]``, which selects the whole repo just like an unpatterned warm) also drops + per-checkpoint-dir paths the root load never reads, so a checkpoint-dir file neither satisfies + the warm nor (as a dangling symlink) blocks it. A path-bearing ``allow_patterns`` is trusted + as-is: a caller that names a checkpoint path opts back into it.""" kept = _filter_paths(list(rels), allow_patterns, ignore_patterns) - if allow_patterns is None: + if allow_patterns is None or _is_pure_catchall(allow_patterns): kept = [r for r in kept if not _path_under_checkpoint_dir(r)] return kept @@ -642,12 +662,12 @@ def snapshot_dir_is_complete( if not _weight_shard_index_complete(index_entry): return False - # A FULL pipeline warm (no allow_patterns) must carry every sub-model a diffusers - # model_index.json declares: a warm killed mid-pipeline can leave one component cached and - # another entirely absent, which the in-process pipeline load would then fetch over - # unprotected Xet. A scoped (allow_patterns) request targets its own subset, so the - # whole-pipeline rule does not apply -- only enforce it for the unpatterned warm. - if allow_patterns is None: + # A FULL pipeline warm (no allow_patterns, or a pure catch-all ``["*"]`` that selects the + # whole repo the same way) must carry every sub-model a diffusers model_index.json declares: + # a warm killed mid-pipeline can leave one component cached and another entirely absent, which + # the in-process pipeline load would then fetch over unprotected Xet. A scoped (path-bearing) + # request targets its own subset, so the whole-pipeline rule does not apply there. + if allow_patterns is None or _is_pure_catchall(allow_patterns): weight_dirs = {rel.split("/", 1)[0] for _, rel in weight_entries if "/" in rel} if not _diffusion_pipeline_complete(snapshot_dir, weight_dirs): return False @@ -837,14 +857,15 @@ def _concretize_glob(pattern: str) -> str: # weightless. Kept deliberately narrow: an unknown subfolder (unet/, transformer/, original/, a # new arch's component dir) must stay weight-including, so a weight-bearing dir is never misread # as weightless (that would re-open the silent-Xet-hang accept-stale this module exists to -# prevent). The Diffusers pipeline components listed here (scheduler/, feature_extractor/, the -# extra tokenizers) ship only *_config.json / vocab files; the weight-bearing pipeline dirs -# (unet/, transformer/, vae/, text_encoder*/, image_encoder/, safety_checker/) are deliberately -# absent so a catch-all under them stays weight-including. +# prevent). The Diffusers / multimodal preprocessing components listed here (scheduler/, +# feature_extractor/, processor/, image_processor/, the extra tokenizers) ship only +# *_config.json / vocab files; the weight-bearing pipeline dirs (unet/, transformer/, vae/, +# text_encoder*/, image_encoder/, safety_checker/) are deliberately absent so a catch-all under +# them stays weight-including. _NON_WEIGHT_DIR_NAMES = frozenset({ "tokenizer", "tokenizer_2", "tokenizer_3", "runs", "run", "logs", "log", "samples", "sample", "tensorboard", "tb", "events", "eval", "evals", "evaluation", "metrics", "wandb", "assets", - "images", "media", "scheduler", "feature_extractor", + "images", "media", "scheduler", "feature_extractor", "processor", "image_processor", }) From 588c60fea755bb6be2059317125da7f33e6b58db Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 09:10:44 +0000 Subject: [PATCH 38/82] Require root-level weights for root warms and validate pipelines for file globs Address the latest Codex review round: a root warm expressed without a path-bearing pattern read as complete when only a SUBFOLDER weight was on disk, so the guarded download was skipped and the in-process load fetched the missing root weight over unprotected Xet. Unify the "this is a full / root warm" decision behind _targets_root_only: a request reads from the repo ROOT when it has no allow_patterns at all, or every allow pattern is a no-slash name / glob (HF's fnmatch "*" spans "/", so such a glob still matches nested subdir files, but a bare from_pretrained reads only root-level files). A path-bearing pattern (checkpoint-10/*, BF16/*, unet/*) deliberately targets a subfolder and is trusted. - Require a root-level weight for a root warm of a NON-pipeline model. _requested_scope_filter gains a root_weights_only mode that drops EVERY subfolder weight, not just checkpoint dirs, so a prior allow_patterns=["BF16/*"] / ["fp16/*"] pull (BF16/model.safetensors, no root weight) no longer satisfies an unpatterned or no-slash-glob (["*.safetensors"]) root warm. A diffusers pipeline (model_index.json present) keeps the narrower checkpoint-dir scoping, since its component weights legitimately live in subfolders. The same scoping applies to the broken-symlink check, so a stale dangling subfolder weight does not block a root warm. - Validate diffusers components for a full-file-glob warm. The model_index.json component check now runs for any root warm (_targets_root_only), so a full pipeline warm written as ["*.safetensors", "*.json"] -- which HF spreads across nested component files -- is no longer treated as a scoped subset: a stale snapshot with only unet/ but a declared vae / text_encoder reads as incomplete. Tests: a checkpoint-only or BF16/-only cache reads incomplete for an unpatterned / ["*.safetensors"] root warm but complete for the matching path-bearing request; a diffusers pipeline missing a declared component reads incomplete for a ["*.safetensors", "*.json"] warm. --- tests/test_hf_xet_fallback.py | 70 +++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 90 +++++++++++++++++++++++++---------- 2 files changed, 136 insertions(+), 24 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 0cbdbac10..e2b540d76 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1549,6 +1549,76 @@ def test_snapshot_dir_is_complete_catchall_not_warm_root(tmp_path): assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*"]) is True +def test_snapshot_dir_is_complete_root_file_glob_not_checkpoint(tmp_path): + """A no-slash root weight glob (["*.safetensors"]) reads from the repo ROOT, but HF's fnmatch + "*" spans "/" and also matches checkpoint-10/model.safetensors. A checkpoint-only cache must + NOT read as complete for the root safetensors warm (Codex #829) -- the root model.safetensors + from_pretrained reads is still missing. A path-bearing pattern naming the checkpoint is + trusted.""" + snap = tmp_path / "snap" + (snap / "checkpoint-10").mkdir(parents = True) + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "checkpoint-10" / "model.safetensors").symlink_to(blob) + (snap / "config.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is False + # A path-bearing checkpoint pattern IS satisfied by the checkpoint weight. + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is True + # A root weight completes the root glob. + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is True + + +def test_snapshot_dir_is_complete_root_warm_ignores_subfolder_weight(tmp_path): + """A root warm (allow_patterns=None, or a no-slash glob) is not satisfied by a weight that + lives only in a NON-checkpoint subfolder such as a precision variant (BF16/, fp16/). A prior + allow_patterns=["BF16/*"] pull leaves BF16/model.safetensors but no root weight; a root + from_pretrained never reads that subfolder, so the snapshot must read as incomplete (Codex + #829). A root weight, or an explicit BF16/ request, still completes.""" + snap = tmp_path / "snap" + (snap / "BF16").mkdir(parents = True) + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "BF16" / "model.safetensors").symlink_to(blob) + (snap / "config.json").write_text("{}") + # Unpatterned and no-slash-glob root warms both ignore the subfolder weight. + assert hcs.snapshot_dir_is_complete(snap) is False + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is False + # An explicit subfolder request IS satisfied by it (deliberate subfolder selection). + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["BF16/*"]) is True + # A root weight completes the root warm (the subfolder variant is ignored). + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_snapshot_dir_is_complete_diffusion_file_glob_validates_components(tmp_path): + """A full diffusers warm expressed as file globs (["*.safetensors", "*.json"]) still selects + nested component files, so the model_index.json component check must run: a stale snapshot with + only unet/ but a declared vae / text_encoder reads as incomplete (Codex #829), not accepted as + a scoped subset.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"weights") + (snap / "model_index.json").write_text( + json.dumps( + { + "_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], + "vae": ["diffusers", "AutoencoderKL"], + "text_encoder": ["transformers", "CLIPTextModel"], + } + ) + ) + _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") + globs = ["*.safetensors", "*.json"] + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is False + # Once the other declared components land, the file-glob warm completes. + _make_diffusion_component(snap, blob, "vae", "diffusion_pytorch_model.safetensors") + _make_diffusion_component(snap, blob, "text_encoder", "model.safetensors") + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is True + + def test_request_can_include_weights_processor_subfolder(): """A processor / image_processor subfolder ships only *_config.json + vocab files (no weights), so a catch-all warm under it (processor/*) reads as WEIGHTLESS. Without this, the synthetic diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 287c6be46..4afbac3a0 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -474,18 +474,46 @@ def _is_pure_catchall(allow_patterns: "Optional[list]") -> bool: return all(isinstance(p, str) and p.strip() in _CATCHALL_ALLOW_PATTERNS for p in allow_patterns) +def _targets_root_only(allow_patterns: "Optional[list]") -> bool: + """True when a request reads from the repo ROOT rather than a specific subfolder: there is no + ``allow_patterns`` at all, or every allow pattern is a no-slash name / glob (``*``, + ``*.safetensors``, ``model*``, ``config.json``). HF's fnmatch ``*`` spans ``/``, so such a glob + also matches nested ``subdir/...`` files, but the LOAD a bare ``from_pretrained`` performs reads + only root-level files. A path-bearing pattern (``checkpoint-10/*``, ``BF16/*``, ``unet/*``) + deliberately targets a subfolder and is trusted as-is.""" + if allow_patterns is None: + return True + if not allow_patterns: + return False # allow_patterns=[] selects nothing -- a scoped (empty) request, not a root warm + return all(isinstance(p, str) and "/" not in p for p in allow_patterns) + + def _requested_scope_filter( - rels: list, allow_patterns: "Optional[list]", ignore_patterns: "Optional[list]" + rels: list, + allow_patterns: "Optional[list]", + ignore_patterns: "Optional[list]", + *, + root_weights_only: bool = False, ) -> list: - """The subset of repo-relative *rels* a request selects. Applies the allow / ignore filter, and - when there is no ``allow_patterns`` (an UNPATTERNED or IGNORE-ONLY request -- a bare - ``from_pretrained`` that reads ROOT weights) OR the allow list is a pure catch-all - (``["*"]``, which selects the whole repo just like an unpatterned warm) also drops - per-checkpoint-dir paths the root load never reads, so a checkpoint-dir file neither satisfies - the warm nor (as a dangling symlink) blocks it. A path-bearing ``allow_patterns`` is trusted - as-is: a caller that names a checkpoint path opts back into it.""" + """The subset of repo-relative *rels* a request selects. Applies the allow / ignore filter, then + drops paths a root load never reads: + + * *root_weights_only* (a root-level warm of a NON-pipeline model) drops EVERY subfolder path: + a bare ``from_pretrained`` reads only repo-ROOT files, so a weight that lives solely in a + subfolder (``BF16/model.safetensors``, ``fp16/``, ``checkpoint-500/``, an ``onnx/`` export) is + an alternate the load never reads and must neither satisfy the warm nor (as a dangling + symlink) block it. + * otherwise, when there is no ``allow_patterns`` (an UNPATTERNED or IGNORE-ONLY request) or the + allow list is a pure catch-all (``["*"]``), drop only per-checkpoint-dir paths -- this is the + diffusers-pipeline / path-trusting case where genuine component subfolders (``unet/``, + ``vae/``) must survive. + + A path-bearing ``allow_patterns`` is otherwise trusted as-is: a caller that names a subfolder + path opts back into it.""" kept = _filter_paths(list(rels), allow_patterns, ignore_patterns) - if allow_patterns is None or _is_pure_catchall(allow_patterns): + if root_weights_only: + kept = [r for r in kept if "/" not in r] + elif allow_patterns is None or _is_pure_catchall(allow_patterns): kept = [r for r in kept if not _path_under_checkpoint_dir(r)] return kept @@ -554,14 +582,24 @@ def snapshot_dir_is_complete( # warmup branch (consistent with request_can_include_weights). has_patterns = allow_patterns is not None or ignore_patterns is not None + # A root-level warm (no path-bearing allow pattern) of a NON-pipeline model reads only repo-ROOT + # files, so a weight under any subfolder (BF16/, fp16/, a checkpoint dir) is an alternate the + # load never reads. A diffusers pipeline is the exception -- its component weights legitimately + # live in subfolders (unet/, vae/, text_encoder/), validated by _diffusion_pipeline_complete -- + # so it keeps the (narrower) checkpoint-dir scoping instead. + is_pipeline = _safe_is_file(snapshot_dir / "model_index.json") + root_weights_only = _targets_root_only(allow_patterns) and not is_pipeline + # A dangling symlink marks an interrupted download, but only one for a file the request # actually selects should reject the snapshot. A stale dangling root model.safetensors must # not fail an allow_patterns=["adapter_model.safetensors"] probe whose adapter weight IS on - # disk, so scope the broken-symlink check to the requested files (and, for a root warm with no - # allow_patterns, drop checkpoint-dir paths the bare load never reads) -- the same selection + # disk, so scope the broken-symlink check to the requested files (and, for a root warm, drop the + # subfolder / checkpoint-dir paths the bare load never reads) -- the same selection # _requested_scope_filter applies to the weights below. broken = _broken_symlink_rel_paths(snapshot_dir) - if broken and _requested_scope_filter(broken, allow_patterns, ignore_patterns): + if broken and _requested_scope_filter( + broken, allow_patterns, ignore_patterns, root_weights_only = root_weights_only + ): return False index_entries: list = [] @@ -580,12 +618,15 @@ def snapshot_dir_is_complete( # The weights the request selects that are present on disk (any present root weight when the # request is unpatterned). The snapshot can carry an unrelated weight while the requested one - # is missing, so a patterned request must find one it actually selects. _requested_scope_filter - # also excludes per-checkpoint-dir weights (checkpoint-500/model.safetensors, left behind by a - # prior allow_patterns=["checkpoint-500/*"] pull) whenever there is no allow_patterns -- an - # UNPATTERNED *or* IGNORE-ONLY root warm (e.g. ignore_patterns=["*.onnx"]) is still a bare - # from_pretrained reading ROOT weights, so a checkpoint-only snapshot must not read as warm. - selected = set(_requested_scope_filter([rel for _, rel in weight_entries], allow_patterns, ignore_patterns)) + # is missing, so a patterned request must find one it actually selects. For a root-level warm of + # a non-pipeline model, _requested_scope_filter drops EVERY subfolder weight (BF16/, fp16/, + # checkpoint-500/, an onnx/ export, left behind by a prior patterned pull): an UNPATTERNED, + # IGNORE-ONLY, or no-slash-glob (["*.safetensors"]) root warm is still a bare from_pretrained + # reading ROOT weights, so a subfolder-only snapshot must not read as warm. + selected = set(_requested_scope_filter( + [rel for _, rel in weight_entries], allow_patterns, ignore_patterns, + root_weights_only = root_weights_only, + )) if not selected: return False @@ -662,12 +703,13 @@ def snapshot_dir_is_complete( if not _weight_shard_index_complete(index_entry): return False - # A FULL pipeline warm (no allow_patterns, or a pure catch-all ``["*"]`` that selects the - # whole repo the same way) must carry every sub-model a diffusers model_index.json declares: - # a warm killed mid-pipeline can leave one component cached and another entirely absent, which - # the in-process pipeline load would then fetch over unprotected Xet. A scoped (path-bearing) - # request targets its own subset, so the whole-pipeline rule does not apply there. - if allow_patterns is None or _is_pure_catchall(allow_patterns): + # A FULL pipeline warm -- no allow_patterns, a pure catch-all ``["*"]``, or a no-slash file glob + # (``["*.safetensors", "*.json"]``) that HF's matcher spreads across every nested component -- + # must carry every sub-model a diffusers model_index.json declares: a warm killed mid-pipeline + # can leave one component cached and another entirely absent, which the in-process pipeline load + # would then fetch over unprotected Xet. A scoped (path-bearing) request targets its own subset, + # so the whole-pipeline rule does not apply there. + if _targets_root_only(allow_patterns): weight_dirs = {rel.split("/", 1)[0] for _, rel in weight_entries if "/" in rel} if not _diffusion_pipeline_complete(snapshot_dir, weight_dirs): return False From 66cb43e7406017089c5f6141f669f8f6912ef1d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 09:42:03 +0000 Subject: [PATCH 39/82] Validate shard indexes for all root-wide warms and require the sharded index sidecar Address the latest Codex review round, reusing the _targets_root_only abstraction to close a class of shard-index gaps rather than one pattern at a time: - Validate weight-shard indexes for every root-wide warm, not just the unpatterned one. The index check was gated on "no patterns at all", so a root-wide PATTERNED warm (ignore_patterns=["*.onnx"], allow_patterns=["*.safetensors", "*.json"]) skipped it. A stale cache whose model.safetensors.index.json lists NON-numbered shards (foo.safetensors, bar.safetensors) with one missing then read as complete -- the numbered-shard check cannot catch a non-numbered name. The check now runs for every _targets_root_only request, scoped to root-level indexes for a non-pipeline root load. - Require the index sidecar for a sharded root warm. transformers' local from_pretrained resolves a directory by probing model.safetensors then model.safetensors.index.json (then the .bin pair) and never globs model-*-of-*.safetensors, so a cache with every numbered shard but no model.safetensors.index.json raises rather than loads and would fetch the index in-process over Xet. A full (unpatterned or glob-bearing) root-level warm now requires the matching root index for each numbered-shard set; a deliberate exact single-shard request is exempt (it wants only that file). Updates test_snapshot_dir_is_complete_missing_shard_without_index to the verified transformers behavior (all shards present but no index reads incomplete until the index lands). - Recognize a chat-template glob as weightless. allow_patterns=["chat_template*"] selects only chat_template.json / .jinja and no weight; add those (and added_tokens.json) to the non-weight probes alongside processor metadata, so a template-only snapshot is not rejected for lacking a weight. Tests: a non-numbered shard index missing a shard rejects an ["*.safetensors", "*.json"] / ignore-only root warm; a complete numbered-shard set without its index rejects a glob warm but a single-shard exact request still completes; chat_template* / added_tokens.json read as weightless. --- tests/test_hf_xet_fallback.py | 79 +++++++++++++++++++++++++++++++++-- unsloth_zoo/hf_cache_state.py | 50 +++++++++++++++++++--- 2 files changed, 120 insertions(+), 9 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index e2b540d76..e428a3b2e 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1416,9 +1416,13 @@ def test_snapshot_dir_is_complete_missing_shard(tmp_path): def test_snapshot_dir_is_complete_missing_shard_without_index(tmp_path): - """A leftover single numbered shard with NO index sidecar (an interrupted multi-shard - pull where the index was never cached) must read as incomplete: the shard name itself - states the full set, so the missing siblings are detectable without a manifest.""" + """An interrupted multi-shard pull with NO index sidecar reads as incomplete. While the shards + are partial, the numbered shard name itself states the full set, so missing siblings are + detectable without a manifest. But even with EVERY shard on disk, a full warm is still + incomplete until model.safetensors.index.json is present: transformers' local from_pretrained + resolves a directory by probing model.safetensors then model.safetensors.index.json (never by + globbing model-*-of-*.safetensors), so a sharded checkpoint without its index raises rather than + loads, and the missing index would otherwise be fetched in-process over Xet.""" snap = tmp_path / "snap" snap.mkdir() blob = tmp_path / "blob" @@ -1428,7 +1432,19 @@ def test_snapshot_dir_is_complete_missing_shard_without_index(tmp_path): (snap / "model-00002-of-00003.safetensors").symlink_to(blob) assert hcs.snapshot_dir_is_complete(snap) is False # shard 3 still missing (snap / "model-00003-of-00003.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap) is True + assert hcs.snapshot_dir_is_complete(snap) is False # all shards present but no index sidecar + (snap / "model.safetensors.index.json").write_text( + json.dumps( + { + "weight_map": { + "a": "model-00001-of-00003.safetensors", + "b": "model-00002-of-00003.safetensors", + "c": "model-00003-of-00003.safetensors", + } + } + ) + ) + assert hcs.snapshot_dir_is_complete(snap) is True # index present -> loadable def test_snapshot_dir_is_complete_ignores_trainer_artifacts(tmp_path): @@ -1619,6 +1635,61 @@ def test_snapshot_dir_is_complete_diffusion_file_glob_validates_components(tmp_p assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is True +def test_request_can_include_weights_chat_template_glob(): + """A chat-template glob (["chat_template*"]) selects only chat_template.json / .jinja and no + weight, so it reads as WEIGHTLESS. Without a representative in the non-weight probes the glob is + misread as a weight directory and a template-only snapshot is wrongly rejected (Codex #829).""" + assert hcs.request_can_include_weights(["chat_template*"], None) is False + assert hcs.request_can_include_weights(["chat_template.jinja"], None) is False + assert hcs.request_can_include_weights(["added_tokens.json"], None) is False + # A weight glob is still weight-including. + assert hcs.request_can_include_weights(["*.safetensors"], None) is True + + +def test_snapshot_dir_is_complete_root_glob_validates_shard_index(tmp_path): + """A root-wide patterned warm (allow_patterns=["*.safetensors", "*.json"], or + ignore_patterns=["*.onnx"]) still warms the root model, so a shard index whose shards are not + all on disk must reject it -- even for NON-numbered shard names the numbered-shard check cannot + catch (Codex #829). The index validation now runs for every _targets_root_only request, not + only the unpatterned one.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "foo.safetensors").symlink_to(blob) # one non-numbered shard present + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "foo.safetensors", "b": "bar.safetensors"}}) + ) + globs = ["*.safetensors", "*.json"] + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is False # bar missing + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.onnx"]) is False # bar missing + (snap / "bar.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is True + + +def test_snapshot_dir_is_complete_sharded_glob_requires_index(tmp_path): + """A full warm expressed as a weight glob (["*.safetensors"]) over a complete numbered-shard set + is still incomplete without the index sidecar (transformers cannot load a local sharded + checkpoint without it). A deliberate exact single-shard request is exempt -- it wants only that + file, not the whole model.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is False + # An exact single-shard request is satisfied by that shard alone (no index required). + assert hcs.snapshot_dir_is_complete( + snap, allow_patterns = ["model-00001-of-00002.safetensors"] + ) is True + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}}) + ) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is True + + def test_request_can_include_weights_processor_subfolder(): """A processor / image_processor subfolder ships only *_config.json + vocab files (no weights), so a catch-all warm under it (processor/*) reads as WEIGHTLESS. Without this, the synthetic diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 4afbac3a0..c05868b6f 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -688,11 +688,15 @@ def snapshot_dir_is_complete( ): return False - # A full (unpatterned) warm also validates any shard index ships all its shards; a - # patterned request may legitimately want only a subset, so the index is not enforced. A - # per-checkpoint index (checkpoint-500/model.safetensors.index.json) does not gate a root - # warm for the same reason its weights do not, so it is skipped here too. - if not has_patterns: + # A root-wide warm -- no allow_patterns, a catch-all, OR a no-slash glob such as + # ignore_patterns=["*.onnx"] / allow_patterns=["*.safetensors", "*.json"] (all _targets_root_only) + # -- validates that every weight-shard index it would read ships all its shards: an index whose + # shards (numbered OR arbitrarily named) are not all on disk is an interrupted pull the + # in-process load would finish over Xet, which the numbered-shard check alone cannot catch for + # non-numbered shard names. A PATH-BEARING (scoped) request may legitimately want only a subset, + # so the index is not enforced there. A per-checkpoint index, and -- for a non-pipeline root load + # -- any subfolder index, is not what the root load reads, so it is skipped. + if _targets_root_only(allow_patterns): for index_entry in index_entries: try: index_rel = index_entry.relative_to(snapshot_dir).as_posix() @@ -700,9 +704,39 @@ def snapshot_dir_is_complete( index_rel = index_entry.name if _path_under_checkpoint_dir(index_rel): continue + if root_weights_only and "/" in index_rel: + continue # a subfolder index a bare root from_pretrained never reads if not _weight_shard_index_complete(index_entry): return False + # A sharded checkpoint is loadable locally ONLY through its index sidecar: transformers' + # from_pretrained resolves a local directory by probing model.safetensors then + # model.safetensors.index.json (then the .bin pair) -- it never globs model-*-of-*.safetensors -- + # so a cache holding every numbered shard but missing model.safetensors.index.json raises + # "no file named ..." or fetches the index in-process over Xet. For a root-level (non-pipeline) + # FULL warm (unpatterned or glob-bearing -- never a deliberate exact single-shard request, which + # wants only that file), require the index sidecar of each root-level numbered-shard set. + if root_weights_only and ( + allow_patterns is None or any(_has_glob(p) for p in allow_patterns) + ): + root_index_names = set() + for index_entry in index_entries: + try: + irel = index_entry.relative_to(snapshot_dir).as_posix() + except ValueError: + irel = index_entry.name + if "/" not in irel: + root_index_names.add(irel) + for _, rel in weight_entries: + if "/" in rel: + continue + shard_match = _NUMBERED_SHARD_RE.match(rel) + if shard_match is None: + continue + index_name = f"{shard_match.group('prefix')}{shard_match.group('suffix')}.index.json" + if index_name not in root_index_names: + return False + # A FULL pipeline warm -- no allow_patterns, a pure catch-all ``["*"]``, or a no-slash file glob # (``["*.safetensors", "*.json"]``) that HF's matcher spreads across every nested component -- # must carry every sub-model a diffusers model_index.json declares: a warm killed mid-pipeline @@ -886,6 +920,12 @@ def _concretize_glob(pattern: str) -> str: # metadata-only (else the snapshot is wrongly rejected for lacking a weight). "processor_config.json", "video_preprocessor_config.json", + # Chat-template metadata: a template-only warm (allow_patterns=["chat_template*"]) selects these + # and no weight, so a representative must be here -- otherwise the glob is misread as a weight + # directory and a template-only snapshot is wrongly rejected for lacking a weight. + "chat_template.json", + "chat_template.jinja", + "added_tokens.json", "vocab.json", "merges.txt", "readme.md", From 74d9179a7cbf64690a90c2a0fab6e0940ae6b226 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 10:05:09 +0000 Subject: [PATCH 40/82] Make the sharded-index requirement variant-aware The shard-index requirement added in the previous commit derived a fixed index name (model.safetensors.index.json) and matched only that suffix, which is wrong for a variant checkpoint. transformers' _add_variant inserts the variant token before the trailing extension, so a variant index is model.safetensors.index.fp16.json and its shards are model.fp16-00001-of-00002.safetensors (variant in the regex prefix). A complete variant sharded cache -- shards on disk and its variant index present -- was therefore falsely rejected as index-less, and that path is reachable: unsloth forwards variant into the warm. Replace the exact-name derivation with _has_root_shard_index, which recognizes a root-level shard index in either form (canonical model.safetensors.index.json / pytorch_model.bin.index.json and the variant model.safetensors.index.fp16.json) via a ".index." infix + ".json" test. A root warm with root numbered shards still requires SOME root index to be present (the no-index case the check exists for), while a variant sharded set with its variant index now reads complete. The numbered-shard SET check already validates variant shard completeness, since the variant shard name carries the variant in the prefix and matches the numbered-shard regex. Tests: a variant sharded set (model.fp16-*-of-*.safetensors) reads incomplete with no index and complete once model.safetensors.index.fp16.json is present. --- tests/test_hf_xet_fallback.py | 24 +++++++++++++++ unsloth_zoo/hf_cache_state.py | 57 ++++++++++++++++++++++------------- 2 files changed, 60 insertions(+), 21 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index e428a3b2e..37e2bc9db 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1690,6 +1690,30 @@ def test_snapshot_dir_is_complete_sharded_glob_requires_index(tmp_path): assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is True +def test_snapshot_dir_is_complete_variant_sharded_index(tmp_path): + """A variant sharded checkpoint must not be falsely rejected for lacking an index. Transformers' + _add_variant names the variant index model.safetensors.index.fp16.json (variant before the + trailing .json) and the shards model.fp16-00001-of-00002.safetensors (variant in the regex + prefix). The index-sidecar requirement recognizes the variant index, so a complete variant + sharded set with its index reads complete, while the same set without ANY index reads + incomplete.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model.fp16-00002-of-00002.safetensors").symlink_to(blob) + # Every shard present but no index of any kind -> incomplete. + assert hcs.snapshot_dir_is_complete(snap) is False + # The variant index (note: token before the trailing .json) makes it loadable. + (snap / "model.safetensors.index.fp16.json").write_text( + json.dumps({"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}}) + ) + assert hcs.snapshot_dir_is_complete(snap) is True + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is True + + def test_request_can_include_weights_processor_subfolder(): """A processor / image_processor subfolder ships only *_config.json + vocab files (no weights), so a catch-all warm under it (processor/*) reads as WEIGHTLESS. Without this, the synthetic diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index c05868b6f..9f9c478e9 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -358,6 +358,29 @@ def _weight_shard_index_complete(index_path: Path) -> bool: return True +def _has_root_shard_index(snapshot_dir: Path, entries: list) -> bool: + """True if a root-level weight-shard index sidecar is present on disk. Matches the canonical + ``model.safetensors.index.json`` / ``pytorch_model.bin.index.json`` AND the variant form + ``model.safetensors.index.fp16.json`` -- transformers' ``_add_variant`` inserts the variant + token before the trailing ``.json``, so a plain ``*.index.json`` suffix test would miss it and + a variant sharded checkpoint (whose shards ARE on disk and whose variant index IS present) would + be wrongly judged index-less. A subfolder index does not count -- a bare root load never reads + it. *entries* is the already-collected ``rglob`` listing, reused to avoid a second walk.""" + for entry in entries: + name = entry.name + if ".index." not in name or not name.endswith(".json"): + continue + try: + rel = entry.relative_to(snapshot_dir).as_posix() + except ValueError: + rel = name + if "/" in rel: + continue # a subfolder index the bare root load never reads + if _safe_is_file(entry): + return True + return False + + # Diffusers pipeline subfolders that carry loadable WEIGHTS (every other declared component -- # scheduler, tokenizer, feature_extractor, processor -- is config-only). A weight-bearing # component whose subfolder exists but holds no weight is a partially fetched component, so the @@ -712,30 +735,22 @@ def snapshot_dir_is_complete( # A sharded checkpoint is loadable locally ONLY through its index sidecar: transformers' # from_pretrained resolves a local directory by probing model.safetensors then # model.safetensors.index.json (then the .bin pair) -- it never globs model-*-of-*.safetensors -- - # so a cache holding every numbered shard but missing model.safetensors.index.json raises - # "no file named ..." or fetches the index in-process over Xet. For a root-level (non-pipeline) - # FULL warm (unpatterned or glob-bearing -- never a deliberate exact single-shard request, which - # wants only that file), require the index sidecar of each root-level numbered-shard set. + # so a cache holding every numbered shard but missing the index raises "no file named ..." or + # fetches the index in-process over Xet. For a root-level (non-pipeline) FULL warm (unpatterned + # or glob-bearing -- never a deliberate exact single-shard request, which wants only that file), + # require a root-level shard index when root numbered shards are present. _has_root_shard_index + # matches the variant form too (model.safetensors.index.fp16.json), so a variant sharded cache -- + # whose shards (model.fp16-00001-of-00002.safetensors) carry the variant in the regex prefix -- + # is not falsely rejected. if root_weights_only and ( allow_patterns is None or any(_has_glob(p) for p in allow_patterns) ): - root_index_names = set() - for index_entry in index_entries: - try: - irel = index_entry.relative_to(snapshot_dir).as_posix() - except ValueError: - irel = index_entry.name - if "/" not in irel: - root_index_names.add(irel) - for _, rel in weight_entries: - if "/" in rel: - continue - shard_match = _NUMBERED_SHARD_RE.match(rel) - if shard_match is None: - continue - index_name = f"{shard_match.group('prefix')}{shard_match.group('suffix')}.index.json" - if index_name not in root_index_names: - return False + has_root_numbered_shard = any( + "/" not in rel and _NUMBERED_SHARD_RE.match(rel) is not None + for _, rel in weight_entries + ) + if has_root_numbered_shard and not _has_root_shard_index(snapshot_dir, entries): + return False # A FULL pipeline warm -- no allow_patterns, a pure catch-all ``["*"]``, or a no-slash file glob # (``["*.safetensors", "*.json"]``) that HF's matcher spreads across every nested component -- From e625b585d46af6acd87eeefd89c86439871637e4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 10:50:04 +0000 Subject: [PATCH 41/82] Recognize SentencePiece tokenizer globs and pin the parent transport before spawn Address the latest Codex review round: - Recognize a SentencePiece / slow-tokenizer vocab glob as weightless. A tokenizer-only warm with a no-slash glob (allow_patterns=["spiece*"], ["sentencepiece*"], ["spm*"]) selects only SentencePiece assets and no weight, but the non-weight probe list had no representative, so the glob read as a weight directory and a tokenizer-only snapshot was wrongly rejected for lacking a weight. Add spiece.model, sentencepiece.bpe.model, spm.model, source.spm, target.spm, bpe.codes, vocab.bpe and normalizer.json alongside the processor / chat-template probes. - Pin huggingface_hub's transport constant in the PARENT before the spawn env window. The spawn briefly sets the child-only HF_HUB_DISABLE_XET=1 in os.environ so the child inherits it at creation (Hub reads it into a module constant at import time). A concurrent thread doing its FIRST import of huggingface_hub inside that window could cache the disabled-Xet value in the parent and silently route later in-process downloads over HTTP. Import huggingface_hub.constants in the parent before mutating the env (under the existing spawn lock) so its transport flags are cached from the real environment first; a concurrent import in the window then re-reads nothing. Tests: spiece* / sentencepiece* / spm* read as weightless while a model* glob stays weight-including. --- tests/test_hf_xet_fallback.py | 13 +++++++++++++ unsloth_zoo/hf_cache_state.py | 12 ++++++++++++ unsloth_zoo/hf_xet_fallback.py | 10 ++++++++++ 3 files changed, 35 insertions(+) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 37e2bc9db..ee475b3dd 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1635,6 +1635,19 @@ def test_snapshot_dir_is_complete_diffusion_file_glob_validates_components(tmp_p assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is True +def test_request_can_include_weights_sentencepiece_glob(): + """A SentencePiece / slow-tokenizer vocab glob (["spiece*"], ["sentencepiece*"], ["spm*"]) + selects only tokenizer assets and no weight, so it reads as WEIGHTLESS. Without a representative + the no-slash glob is misread as a weight directory and a tokenizer-only snapshot is wrongly + rejected for lacking a weight (Codex #829).""" + assert hcs.request_can_include_weights(["spiece*"], None) is False + assert hcs.request_can_include_weights(["sentencepiece*"], None) is False + assert hcs.request_can_include_weights(["spm*"], None) is False + assert hcs.request_can_include_weights(["spiece.model", "tokenizer*"], None) is False + # A weight glob is still weight-including (no accept-stale). + assert hcs.request_can_include_weights(["model*"], None) is True + + def test_request_can_include_weights_chat_template_glob(): """A chat-template glob (["chat_template*"]) selects only chat_template.json / .jinja and no weight, so it reads as WEIGHTLESS. Without a representative in the non-weight probes the glob is diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 9f9c478e9..dfc9fc33e 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -941,6 +941,18 @@ def _concretize_glob(pattern: str) -> str: "chat_template.json", "chat_template.jinja", "added_tokens.json", + # SentencePiece / slow-tokenizer vocab assets a tokenizer-only warm selects with a no-slash glob + # (allow_patterns=["spiece*"], ["sentencepiece*"], ["spm*"]). Without a representative the glob + # reads as a weight directory and a tokenizer-only snapshot is wrongly rejected for lacking a + # weight. tokenizer.model is already listed above. + "spiece.model", + "sentencepiece.bpe.model", + "spm.model", + "source.spm", + "target.spm", + "bpe.codes", + "vocab.bpe", + "normalizer.json", "vocab.json", "merges.txt", "readme.md", diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 2b548e31d..a03861c54 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -788,6 +788,16 @@ def _run_download_attempt( child_env["HF_HUB_DISABLE_XET"] = "1" child_env["HF_HUB_ENABLE_HF_TRANSFER"] = "0" with _SPAWN_ENV_LOCK: + # Cache huggingface_hub's transport constants in the PARENT from the REAL environment NOW, + # before the child-only env (HF_HUB_DISABLE_XET=1) is briefly set below. Hub reads + # HF_HUB_DISABLE_XET into a module constant at import time; without this, a concurrent thread + # doing its FIRST `import huggingface_hub` inside the spawn window could cache the child-only + # disabled-Xet value in the parent and silently route later in-process downloads over HTTP. + # Once imported it is a no-op, so a concurrent import in the window then re-reads nothing. + try: + import huggingface_hub.constants # noqa: F401 + except Exception: + pass saved_env = {k: os.environ.get(k) for k in child_env} # multiprocessing 'spawn' reconstructs __main__ in the child from # __main__.__file__. If that is a pseudo-path ('', a notebook) the From e130ca725af4c99b84dd9a410512803ad373537e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 11:24:55 +0000 Subject: [PATCH 42/82] Recognize variant weight globs, diffusers weight-only warms, and dataset broken symlinks Address the latest Codex review round (one finding rejected as factually incorrect, three fixed): - Treat a variant weight prefix glob as weight-bearing. allow_patterns=["model.fp16*"] / ["pytorch_model.fp16*"] selects a real variant weight (model.fp16.safetensors) but ends in the wildcard, not a weight suffix, so request_can_include_weights read it as weightless and the fast path could accept a config-only snapshot without the variant weights. _weight_self_probe now also matches a trailing-wildcard glob whose wildcard would absorb a weight suffix, while a clearly non-weight glob (tokenizer*, config*, spiece*) stays weightless. - Recognize a diffusers pipeline from its component weights, not only model_index.json. A weights-only root glob (allow_patterns=["*.safetensors"]) on a diffusers repo downloads the component weights (unet/, vae/, ...) but not model_index.json, so is_pipeline read False, the request was treated as a non-pipeline root load, its subfolder weights were dropped, and the complete download was reported incomplete. is_pipeline is now also true when a recognized weight-bearing component subfolder holds a weight; an arbitrary precision/checkpoint subfolder (BF16/, checkpoint-10/) is still not mistaken for a pipeline. - Keep dataset checkpoint paths in the broken-symlink check. snapshot_has_requested_broken_symlinks dropped checkpoint-dir paths for every repo type, so a dataset snapshot with a dangling checkpoint-10/data.parquet symlink read as complete and a broken cache was returned. The checkpoint-dir drop (a root-model-load notion) now applies only to repo_type="model"; a dataset rejects any dangling requested file. Note on the rejected finding: Codex claimed transformers' _add_variant places the variant AFTER the shard count (model-00001-of-00002.fp16.safetensors). Verified against transformers 4.57.6 and 5.12.1 by saving a sharded variant checkpoint: the real form is model.fp16-00001-of-00002.safetensors (variant in the prefix), which _NUMBERED_SHARD_RE already matches, so the sibling-shard and index requirements do apply. No change needed. Tests: model.fp16* / pytorch_model.fp16* read as weight-including while tokenizer* stays weightless; a diffusers *.safetensors warm without model_index.json reads complete while a BF16/-only snapshot stays incomplete; a dangling checkpoint-dir symlink blocks a dataset snapshot but not a model warm. --- tests/test_hf_xet_fallback.py | 52 +++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 77 ++++++++++++++++++++++++++++------ unsloth_zoo/hf_xet_fallback.py | 3 +- 3 files changed, 119 insertions(+), 13 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index ee475b3dd..22e7c074f 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1648,6 +1648,58 @@ def test_request_can_include_weights_sentencepiece_glob(): assert hcs.request_can_include_weights(["model*"], None) is True +def test_request_can_include_weights_variant_prefix_glob(): + """A variant weight prefix glob (["model.fp16*"], ["pytorch_model.fp16*"]) selects a real variant + weight (model.fp16.safetensors) even though the pattern ends in the wildcard, not a weight + suffix. It must read as weight-including so the fast path does not accept a config-only snapshot + without the requested variant weights (Codex #829). A metadata prefix glob stays weightless.""" + assert hcs.request_can_include_weights(["model.fp16*"], None) is True + assert hcs.request_can_include_weights(["pytorch_model.fp16*"], None) is True + assert hcs.request_can_include_weights(["model.bf16*"], None) is True + # Metadata / non-weight prefix globs stay weightless (no over-reject of a tokenizer-only warm). + assert hcs.request_can_include_weights(["tokenizer*"], None) is False + assert hcs.request_can_include_weights(["config*"], None) is False + assert hcs.request_can_include_weights(["spiece*"], None) is False + + +def test_snapshot_dir_is_complete_diffusion_weights_only_glob(tmp_path): + """A weights-only root glob (["*.safetensors"]) on a diffusers repo downloads the component + weights (unet/, vae/) but NOT model_index.json, so the pipeline layout must be recognized from + the component weights themselves -- otherwise the snapshot is misread as a non-pipeline root + load, its subfolder weights dropped, and a complete download reported incomplete (Codex #829). + A genuine non-pipeline subfolder-only snapshot (BF16/) is still rejected for a root glob.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"w") + _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") + _make_diffusion_component(snap, blob, "vae", "diffusion_pytorch_model.safetensors") + # No model_index.json on disk (a *.safetensors warm would not have selected it). + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is True + # A precision-variant subfolder (not a pipeline component) must NOT be mistaken for a pipeline. + snap2 = tmp_path / "snap2" + (snap2 / "BF16").mkdir(parents = True) + (snap2 / "BF16" / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap2, allow_patterns = ["*.safetensors"]) is False + + +def test_snapshot_has_requested_broken_symlinks_dataset_vs_model(tmp_path): + """A dangling checkpoint-dir symlink blocks a DATASET snapshot (every requested file matters) + but not a MODEL root warm (a root load never reads a checkpoint-dir file), so the checkpoint-dir + drop applies only to models (Codex #829).""" + snap = tmp_path / "snap" + (snap / "checkpoint-10").mkdir(parents = True) + (snap / "checkpoint-10" / "data.parquet").symlink_to(tmp_path / "missing") # dangling + assert hcs.snapshot_has_requested_broken_symlinks(snap, repo_type = "model") is False + assert hcs.snapshot_has_requested_broken_symlinks(snap, repo_type = "dataset") is True + # A dangling ROOT file blocks both (it is a requested interrupted file in either case). + snap2 = tmp_path / "snap2" + snap2.mkdir() + (snap2 / "data.parquet").symlink_to(tmp_path / "missing") + assert hcs.snapshot_has_requested_broken_symlinks(snap2, repo_type = "model") is True + assert hcs.snapshot_has_requested_broken_symlinks(snap2, repo_type = "dataset") is True + + def test_request_can_include_weights_chat_template_glob(): """A chat-template glob (["chat_template*"]) selects only chat_template.json / .jinja and no weight, so it reads as WEIGHTLESS. Without a representative in the non-weight probes the glob is diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index dfc9fc33e..118d57531 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -457,6 +457,30 @@ def _diffusion_pipeline_complete(snapshot_dir: Path, weight_dirs: set) -> bool: return True +def _has_pipeline_component_weight(snapshot_dir: Path, entries: list) -> bool: + """True if a recognized diffusers component subfolder (``unet/``, ``vae/``, ``text_encoder/``, + ...) holds a loadable weight. A weights-only warm (``allow_patterns=["*.safetensors"]``, or an + ignore list that drops ``model_index.json``) downloads a diffusers pipeline's component weights + but NOT ``model_index.json`` -- so the pipeline layout has to be recognized from the component + weights themselves, else the snapshot is misread as a non-pipeline root load and its (legitimate) + subfolder weights are dropped. Kept to the known weight-bearing component names so an arbitrary + precision/checkpoint subfolder (``BF16/``, ``checkpoint-10/``) is NOT mistaken for a pipeline.""" + for entry in entries: + try: + rel = entry.relative_to(snapshot_dir).as_posix() + except ValueError: + continue + parts = rel.split("/") + if ( + len(parts) >= 2 + and parts[0] in _WEIGHT_BEARING_PIPELINE_DIRS + and _is_loadable_weight_file(parts[-1]) + and _safe_is_file(entry) + ): + return True + return False + + def _broken_symlink_rel_paths(snapshot_dir: Path) -> list: """Repo-relative posix paths of every dangling symlink in *snapshot_dir* -- a referenced file whose blob is missing or still an ``.incomplete`` partial (an interrupted download). Empty when @@ -546,18 +570,31 @@ def snapshot_has_requested_broken_symlinks( *, allow_patterns: "Optional[object]" = None, ignore_patterns: "Optional[object]" = None, + repo_type: "Optional[str]" = "model", ) -> bool: """True iff a dangling symlink in *snapshot_dir* is for a file the request actually SELECTS. A dangling symlink marks an interrupted download, but for a scoped request only one for a requested file should reject the snapshot: a dangling root ``model.safetensors`` left by an earlier interrupted pull must not fail a weightless ``allow_patterns=["config.json"]`` request - whose config is on disk. Mirrors the scoped broken-symlink handling inside - ``snapshot_dir_is_complete`` so the weightless / non-model path is scoped the same way.""" + whose config is on disk. + + For a MODEL repo the scoping mirrors ``snapshot_dir_is_complete``: a root load reads only root + files, so a dangling checkpoint-dir / subfolder symlink does not block it. A DATASET (or other + non-model) snapshot has no "root load reads only root files" notion -- every dangling symlink + for a selected path is an interrupted file that must reject the cache (e.g. a dangling + ``checkpoint-10/data.parquet`` in an unpatterned dataset pull), so the checkpoint-dir drop is + NOT applied there.""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) broken = _broken_symlink_rel_paths(snapshot_dir) - return bool(broken and _requested_scope_filter(broken, allow_patterns, ignore_patterns)) + if not broken: + return False + if repo_type == "model": + requested = _requested_scope_filter(broken, allow_patterns, ignore_patterns) + else: + requested = _filter_paths(broken, allow_patterns, ignore_patterns) + return bool(requested) def snapshot_dir_is_complete( @@ -610,7 +647,9 @@ def snapshot_dir_is_complete( # load never reads. A diffusers pipeline is the exception -- its component weights legitimately # live in subfolders (unet/, vae/, text_encoder/), validated by _diffusion_pipeline_complete -- # so it keeps the (narrower) checkpoint-dir scoping instead. - is_pipeline = _safe_is_file(snapshot_dir / "model_index.json") + is_pipeline = _safe_is_file(snapshot_dir / "model_index.json") or _has_pipeline_component_weight( + snapshot_dir, entries + ) root_weights_only = _targets_root_only(allow_patterns) and not is_pipeline # A dangling symlink marks an interrupted download, but only one for a file the request @@ -1018,14 +1057,28 @@ def _weight_self_probe(pattern: str) -> "Optional[str]": suffix is not a weight suffix, or when the (concretized) basename is a known trainer / optimizer artifact (``optimizer.pt``, ``training_args.bin``, ``rng_state_*.pth``): those carry weight suffixes but the snapshot completeness check filters them out as non-weights, - so classifying such a request as weight-including would loop the guarded download.""" - if not pattern.lower().endswith(_WEIGHT_FILE_SUFFIXES): - return None - concrete = _concretize_glob(pattern) - basename = concrete.rsplit("/", 1)[-1] - if not _is_loadable_weight_file(basename): - return None - return concrete + so classifying such a request as weight-including would loop the guarded download. + + Also recognizes a trailing-wildcard variant glob whose wildcard would ABSORB the weight suffix + (``model.fp16*`` -> ``model.fp16.safetensors``, ``pytorch_model.fp16*`` -> + ``pytorch_model.fp16.bin``): such a glob does not literally end in a weight suffix, but it does + select a real variant weight, so it must read as weight-including. A clearly non-weight glob + (``tokenizer*``, ``config*``, ``spiece*``) is excluded so a metadata-only warm stays weightless.""" + if pattern.lower().endswith(_WEIGHT_FILE_SUFFIXES): + concrete = _concretize_glob(pattern) + basename = concrete.rsplit("/", 1)[-1] + if not _is_loadable_weight_file(basename): + return None + return concrete + # Trailing-wildcard glob: try each weight suffix in place of the absorbing wildcard. + if "/" not in pattern and pattern.endswith(("*", "?")) and not _basename_is_non_weight(pattern): + stem = _concretize_glob(pattern.rstrip("*?")) + if stem: + for suffix in _WEIGHT_FILE_SUFFIXES: + candidate = stem + suffix + if fnmatch.fnmatchcase(candidate, pattern) and _is_loadable_weight_file(candidate): + return candidate + return None def request_can_include_weights( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index a03861c54..2e52a32d6 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -990,7 +990,8 @@ def _snapshot_is_acceptable( # hand back a config-only snapshot dir that simply lacks the requested file. Globs stay best-effort. return ( not snapshot_has_requested_broken_symlinks( - snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + repo_type = repo_type, ) and requested_named_files_present( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns From 573b97872f56acdf478849416b377d6ffbed3bff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 11:50:33 +0000 Subject: [PATCH 43/82] Harden the shared completeness check and spawn against the 3-reviewer round Address the findings from three independent deep reviews of the shared helper. All are real edges of the shared cache-state / spawn core (none reachable through unsloth main's current wiring, but each is a genuine correctness or resource bug for a direct caller). One finding was rejected as factually incorrect (see the note below). Completeness check (hf_cache_state.py): - Recognize the variant shard-index sidecar everywhere. A new _is_weight_shard_index treats both the canonical model.safetensors.index.json / pytorch_model.bin.index.json and the variant form model.safetensors.index.fp16.json (transformers' _add_variant inserts the variant token before the trailing .json) as a weight index. The old ".index." + ".json" / suffix tests missed the variant form, so a variant index went uncollected and its listed shards were left unvalidated. Used at both the index-entry collection and _has_root_shard_index. - Require a diffusers component's own index when it is sharded. A weight-bearing pipeline component (unet/, transformer/, ...) holding a complete numbered-shard set but no diffusion_pytorch_model.safetensors.index.json is now incomplete: diffusers, like transformers, cannot load a local sharded component without its index, so accepting it would warm a cache the in-process load then re-fetches over unprotected Xet. Mirrors the root index requirement. - Guard the trailing-wildcard weight self-probe. The absorbed-suffix probe (model.fp16* -> model.fp16.safetensors) now skips a stem that is a trainer artifact (scheduler*, rng_state*, scaler*, optimizer*) or an index / weight sidecar (model.safetensors.index*, model.safetensors*), so it no longer synthesizes a fake scheduler.safetensors / model.safetensors.index.safetensors weight and over-rejects an artifact-only or index-only warm. A real model.safetensors* glob stays weight-including through the canonical probe. - Re-root the self-probe under a subfolder. A path-qualified variant glob (unet/diffusion_pytorch_model.fp16*, text_encoder/model.fp16*) now resolves to its concrete weight under that directory, so a variant-component warm is read as weight-including instead of weightless. Previously such a request could accept a config-only component snapshot and let the load hang on the missing weights (accept-stale). Spawn / transport (hf_xet_fallback.py): - Close the result queue if the spawn fails. proc.start() can raise (OSError "can't start new process" under fd / thread exhaustion). The queue's OS pipe fds are allocated before the spawn, but the lifecycle try/finally that closes them is only entered after a successful start, so a failed spawn leaked the fds. A dedicated except around the spawn now closes the queue and re-raises, making the failure deterministic. - Read xet_force_disabled() under the spawn lock. The live os.environ["HF_HUB_DISABLE_XET"] read is now serialized against the window in which a concurrent download briefly sets that var in the parent env around its own spawn, so a download can no longer observe the other's child-only value and wrongly force itself onto HTTP from the first attempt. Note on the rejected finding: one reviewer claimed transformers' _add_variant places the variant after the shard count (model-00001-of-00002.fp16.safetensors). Verified against transformers 4.57.6 and 5.12.1 by saving a sharded variant checkpoint: the real form is model.fp16-00001-of-00002.safetensors (variant in the prefix), which _NUMBERED_SHARD_RE already matches. No change made. Tests (+6): subfolder variant component glob reads weight-including while a non-weight subfolder glob stays weightless; an index glob stays weightless while a weight glob is kept; the self-probe returns no synthetic weight for artifact / sidecar stems; a sharded diffusers component requires its index; a failed spawn closes the result queue; the disable-Xet read happens under the spawn lock. 149 pass; the 40k-layout safety-invariant fuzz stays at 0 violations. --- tests/test_hf_xet_fallback.py | 162 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 77 +++++++++++++--- unsloth_zoo/hf_xet_fallback.py | 21 ++++- 3 files changed, 244 insertions(+), 16 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 22e7c074f..8d6daa76c 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2856,3 +2856,165 @@ def test_processor_glob_is_weightless(): assert hcs.request_can_include_weights(allow_patterns = ["processor*"]) is False # control: a real weight glob stays weight-including assert hcs.request_can_include_weights(allow_patterns = ["model*"]) is True + + +# --------------------------------------------------------------------------- +# Regression tests for the 3-reviewer (Opus) review round on #829. +# Each maps to one accepted finding; the rejected "variant shard after the count" +# finding is covered (and disproven) by test_snapshot_dir_is_complete_variant_sharded_index. +# --------------------------------------------------------------------------- +def test_request_can_include_weights_subfolder_variant_component_glob(): + """R3-4: a path-qualified variant component glob (["unet/diffusion_pytorch_model.fp16*"], + ["text_encoder/model.fp16*"]) selects a real variant weight inside a pipeline subfolder. It does + not end in a weight suffix and the re-rooted canonical probes (unet/model.safetensors, ...) do not + match the variant name, so without a path-qualified self-probe it would read as WEIGHTLESS and the + fast path could accept a config-only component snapshot -> the silent Xet hang (accept-stale). It + must read as weight-including. A non-weight subfolder glob (unet/config*) stays weightless.""" + assert hcs.request_can_include_weights(["unet/diffusion_pytorch_model.fp16*"], None) is True + assert hcs.request_can_include_weights(["text_encoder/model.fp16*"], None) is True + assert hcs.request_can_include_weights(["transformer/diffusion_pytorch_model.bf16*"], None) is True + # A non-weight subfolder glob is not over-classified. + assert hcs.request_can_include_weights(["unet/config*"], None) is False + assert hcs.request_can_include_weights(["text_encoder/tokenizer*"], None) is False + # The self-probe re-roots the synthetic weight under the requested subfolder. + assert hcs._weight_self_probe("unet/diffusion_pytorch_model.fp16*") == \ + "unet/diffusion_pytorch_model.fp16.safetensors" + + +def test_request_can_include_weights_index_glob_weightless_weight_glob_kept(): + """R3-2: a trailing-wildcard shard-index glob (["model.safetensors.index*"], + ["model.bin.index*"]) selects only the index sidecar (no weight), so it reads as WEIGHTLESS -- the + synthetic-suffix branch must not turn the .index stem into a fake model.safetensors.index.safetensors + weight. A plain weight-stem glob (["model.safetensors*"], ["pytorch_model.bin*"]) still reads as + weight-including via the canonical weight probe (no accept-stale).""" + assert hcs.request_can_include_weights(["model.safetensors.index*"], None) is False + assert hcs.request_can_include_weights(["model.bin.index*"], None) is False + assert hcs.request_can_include_weights(["model.safetensors*"], None) is True + assert hcs.request_can_include_weights(["pytorch_model.bin*"], None) is True + + +def test_weight_self_probe_artifact_and_sidecar_stems(): + """R3-3: _weight_self_probe must not synthesize a fake weight for a trailing-wildcard glob whose + stem is a trainer artifact (scheduler*, rng_state*, scaler*, optimizer*) or an index / weight + sidecar (model.safetensors.index*, model.safetensors*) -- the absorbed-suffix branch would + otherwise return scheduler.safetensors / model.safetensors.index.safetensors and make the request + require a weight that does not exist (over-reject). A genuine variant weight stem still resolves.""" + for artifact in ("scheduler*", "rng_state*", "rng_state_*", "scaler*", "optimizer*", "training_args*"): + assert hcs._weight_self_probe(artifact) is None, artifact + for sidecar in ("model.safetensors.index*", "model.bin.index*", "model.safetensors*", + "pytorch_model.bin*"): + assert hcs._weight_self_probe(sidecar) is None, sidecar + # A real variant weight stem still resolves to its concrete weight name (the suffix loop tries + # .safetensors first, so a stem that admits either suffix resolves to the safetensors form). + assert hcs._weight_self_probe("model.fp16*") == "model.fp16.safetensors" + assert hcs._weight_self_probe("pytorch_model.fp16*") == "pytorch_model.fp16.safetensors" + + +def test_snapshot_dir_is_complete_diffusion_sharded_component_requires_index(tmp_path): + """R1: a diffusers pipeline whose weight-bearing component (transformer/, unet/) holds a complete + NUMBERED-shard set but no index sidecar is INCOMPLETE -- transformers cannot load a local sharded + component without its index, so reporting complete would warm a cache the in-process load then + re-fetches (the silent Xet hang). Adding the component index makes it complete.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"w") + (snap / "model_index.json").write_text( + json.dumps({"_class_name": "FluxPipeline", + "transformer": ["diffusers", "FluxTransformer2DModel"]}) + ) + comp = _make_diffusion_component(snap, blob, "transformer") + (comp / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) + (comp / "diffusion_pytorch_model-00002-of-00002.safetensors").symlink_to(blob) + # Complete shard set, NO index -> incomplete. + assert hcs.snapshot_dir_is_complete(snap) is False + (comp / "diffusion_pytorch_model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}}) + ) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_failed_spawn_closes_result_queue(monkeypatch): + """R2-2: if proc.start() raises (e.g. OSError "can't start new process" under fd / thread + exhaustion), the result_queue's OS pipe fds -- allocated before the spawn -- must be closed + rather than leaked. The lifecycle try/finally that closes them is only entered after a + successful start, so a dedicated except around the spawn must close the queue and re-raise.""" + closed = {"cancel_join": False, "close": False} + + class _FakeQueue: + def cancel_join_thread(self): + closed["cancel_join"] = True + + def close(self): + closed["close"] = True + + class _FakeProc: + def __init__(self, *a, **k): + self.pid = None + + def start(self): + raise OSError(errno.EAGAIN, "Resource temporarily unavailable") + + class _FakeCtx: + def Queue(self): + return _FakeQueue() + + def Process(self, *a, **k): + return _FakeProc() + + monkeypatch.setattr(xf, "_CTX", _FakeCtx()) + with pytest.raises(OSError): + xf._run_download_attempt( + "owner/repo", + kind = "snapshot", + params = {}, + token = None, + repo_type = "model", + disable_xet = False, + cancel_event = None, + stall_timeout = 1.0, + interval = 0.1, + grace_period = 0.1, + on_status = None, + ) + assert closed["close"] is True + assert closed["cancel_join"] is True + + +def test_disable_xet_read_under_spawn_lock(monkeypatch): + """R2-1: _download_with_xet_fallback must read xet_force_disabled() while holding + _SPAWN_ENV_LOCK. A concurrent download briefly sets the child-only HF_HUB_DISABLE_XET=1 in the + parent os.environ around its spawn (under the same lock); reading the live env outside the lock + could observe that value and wrongly force THIS download onto HTTP from the first attempt.""" + seen = {} + real = xf.xet_force_disabled + + def _spy(): + # A plain (non-reentrant) Lock cannot be re-acquired by its owner, so a non-blocking acquire + # FAILS iff the read is happening inside `with _SPAWN_ENV_LOCK:`. If the read were outside the + # lock, the acquire would succeed. + got = xf._SPAWN_ENV_LOCK.acquire(blocking = False) + if got: + xf._SPAWN_ENV_LOCK.release() + seen["held"] = not got + return real() + + monkeypatch.setattr(xf, "xet_force_disabled", _spy) + monkeypatch.setattr(xf, "_run_download_attempt", lambda *a, **k: ("ok", "/tmp/warm")) + out = xf._download_with_xet_fallback( + repo_id = "owner/repo", + label = "test", + kind = "snapshot", + params = {}, + token = None, + repo_type = "model", + cancel_event = None, + stall_timeout = 1.0, + interval = 0.1, + grace_period = 0.1, + on_status = None, + prepare_for_http_fn = None, + ) + assert out == "/tmp/warm" + assert seen.get("held") is True diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 118d57531..25f743252 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -231,6 +231,11 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: }) # Distributed trainer runs shard the RNG state as rng_state_0.pth, rng_state_1.pth, ... _NON_WEIGHT_BASENAME_PREFIXES = ("rng_state_",) +# The stems (basename without the suffix) of the trainer-artifact names above -- optimizer, +# scheduler, scaler, rng_state, training_args. A trailing-wildcard glob over one of these +# (``scheduler*``, ``rng_state*``) selects only a trainer artifact, so the synthetic-weight-suffix +# probe must NOT classify it as weight-including. +_NON_WEIGHT_STEMS = frozenset(name.rsplit(".", 1)[0] for name in _NON_WEIGHT_BASENAMES) def _is_loadable_weight_file(name: str) -> bool: @@ -329,6 +334,15 @@ def _numbered_shard_set_present( return True +def _is_weight_shard_index(name: str) -> bool: + """True if *name* is a weight-shard index sidecar: the canonical + ``model.safetensors.index.json`` / ``pytorch_model.bin.index.json`` AND the variant form + ``model.safetensors.index.fp16.json`` (transformers' ``_add_variant`` inserts the variant token + before the trailing ``.json``). A plain ``*.safetensors.index.json`` suffix test misses the + variant form, leaving its listed shards unvalidated.""" + return name.endswith(".json") and (".safetensors.index." in name or ".bin.index." in name) + + def _weight_shard_index_complete(index_path: Path) -> bool: """True if every shard a HF weight index (``model.safetensors.index.json`` / ``pytorch_model.bin.index.json``) lists is present next to the index. An unreadable @@ -367,13 +381,12 @@ def _has_root_shard_index(snapshot_dir: Path, entries: list) -> bool: be wrongly judged index-less. A subfolder index does not count -- a bare root load never reads it. *entries* is the already-collected ``rglob`` listing, reused to avoid a second walk.""" for entry in entries: - name = entry.name - if ".index." not in name or not name.endswith(".json"): + if not _is_weight_shard_index(entry.name): continue try: rel = entry.relative_to(snapshot_dir).as_posix() except ValueError: - rel = name + rel = entry.name if "/" in rel: continue # a subfolder index the bare root load never reads if _safe_is_file(entry): @@ -452,8 +465,26 @@ def _diffusion_pipeline_complete(snapshot_dir: Path, weight_dirs: set) -> bool: component_dir = snapshot_dir / key if not _safe_is_dir(component_dir) or not _dir_has_any_file(component_dir): return False # a declared component's subfolder is missing / empty -- interrupted warm - if key in _WEIGHT_BEARING_PIPELINE_DIRS and key not in weight_dirs: - return False # the component dir exists but carries no weight -- partial component + if key in _WEIGHT_BEARING_PIPELINE_DIRS: + if key not in weight_dirs: + return False # the component dir exists but carries no weight -- partial component + # A SHARDED component is loadable locally only via its in-component index sidecar + # (diffusers, like transformers, never globs shard files), so a component holding + # numbered shards but no diffusion_pytorch_model.safetensors.index.json is incomplete -- + # an interrupted warm that dropped the tiny index blob would make the pipeline load fetch + # it in-process over unprotected Xet. Mirrors the root transformers index requirement. + try: + comp_entries = list(component_dir.rglob("*")) + except OSError: + comp_entries = [] + has_numbered_shard = any( + _NUMBERED_SHARD_RE.match(e.name) is not None + and _is_loadable_weight_file(e.name) + and _safe_is_file(e) + for e in comp_entries + ) + if has_numbered_shard and not _has_root_shard_index(component_dir, comp_entries): + return False return True @@ -668,7 +699,7 @@ def snapshot_dir_is_complete( weight_entries: list = [] # (entry, repo-relative path) for entry in entries: name = entry.name - if name.endswith((".safetensors.index.json", ".bin.index.json")): + if _is_weight_shard_index(name): if _safe_is_file(entry): index_entries.append(entry) elif _is_loadable_weight_file(name) and _safe_is_file(entry): @@ -1070,14 +1101,32 @@ def _weight_self_probe(pattern: str) -> "Optional[str]": if not _is_loadable_weight_file(basename): return None return concrete - # Trailing-wildcard glob: try each weight suffix in place of the absorbing wildcard. - if "/" not in pattern and pattern.endswith(("*", "?")) and not _basename_is_non_weight(pattern): - stem = _concretize_glob(pattern.rstrip("*?")) - if stem: - for suffix in _WEIGHT_FILE_SUFFIXES: - candidate = stem + suffix - if fnmatch.fnmatchcase(candidate, pattern) and _is_loadable_weight_file(candidate): - return candidate + # Trailing-wildcard glob whose wildcard ABSORBS the weight suffix (``model.fp16*`` -> + # ``model.fp16.safetensors``, ``unet/diffusion_pytorch_model.fp16*`` -> the same re-rooted under + # ``unet/``): try each weight suffix in place of the trailing wildcard. Applies to a path-qualified + # basename too, so a subfolder variant-component glob is not misread as weightless. + if pattern.endswith(("*", "?")): + prefix, _, base = pattern.rpartition("/") + if not _basename_is_non_weight(base): + stem = _concretize_glob(base.rstrip("*?")) + stem_lower = stem.lower() + # A stem that is itself a trainer artifact (scheduler*, rng_state*, optimizer*) selects no + # weight; a stem already ending in ``.index`` or a weight suffix is an index sidecar / + # canonical-probe case the synthetic suffix would only corrupt (model.safetensors.index*, + # model.safetensors*). Skip both so the request is not over-classified weight-including. + is_artifact_stem = ( + stem_lower in _NON_WEIGHT_STEMS + or stem_lower.startswith(_NON_WEIGHT_BASENAME_PREFIXES) + ) + is_sidecar_stem = stem_lower.endswith(".index") or stem_lower.endswith(_WEIGHT_FILE_SUFFIXES) + if stem and not is_artifact_stem and not is_sidecar_stem: + for suffix in _WEIGHT_FILE_SUFFIXES: + candidate_base = stem + suffix + if fnmatch.fnmatchcase(candidate_base, base) and _is_loadable_weight_file(candidate_base): + if not prefix: + return candidate_base + concrete_prefix = _concretize_glob(prefix) if _has_glob(prefix) else prefix + return f"{concrete_prefix}/{candidate_base}" return None diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 2e52a32d6..f11edbce7 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -826,6 +826,18 @@ def _run_download_attempt( try: os.environ.update(child_env) proc.start() + except BaseException: + # proc.start() can raise (e.g. OSError "can't start new process" under fd / + # thread exhaustion). The result_queue's OS pipe fds were allocated above, but + # the lifecycle try/finally that closes them is only entered AFTER a successful + # start, so on a failed spawn that cleanup never runs and the fds leak. Close + # the queue here so a failed spawn is deterministic rather than fd-leaking. + try: + result_queue.cancel_join_thread() + result_queue.close() + except Exception: + pass + raise finally: for k, v in saved_env.items(): if v is None: @@ -1051,8 +1063,13 @@ def _download_with_xet_fallback( raise RuntimeError("Cancelled") cache_dir = params.get("cache_dir") - # The Unsloth/HF knobs can force HTTP from the very first attempt. - disable_xet = xet_force_disabled() + # The Unsloth/HF knobs can force HTTP from the very first attempt. xet_force_disabled() reads + # os.environ["HF_HUB_DISABLE_XET"] live, and a CONCURRENT download briefly sets that var in the + # parent env around its spawn (under _SPAWN_ENV_LOCK) so its child inherits it. Read under the + # same lock so this download cannot observe the other's child-only value and wrongly force itself + # onto HTTP from the start. + with _SPAWN_ENV_LOCK: + disable_xet = xet_force_disabled() for attempt in range(2): if disable_xet: From c288957c87974b0dc308863f2d6a8ce22b95b9ef Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 28 Jun 2026 13:56:27 +0000 Subject: [PATCH 44/82] Trim the cache-completeness checker to a conservative canonical-only fast-path hf_cache_state.py is only a fast-path optimization: it decides whether the local cache is already warm enough to SKIP the protective spawn child and load in-process. It had grown to 1355 lines of layout classification (variant shards, diffusers pipelines, glob/pattern weight inference, subfolder probes) because each review round added another exotic case. None of that is necessary: the watched snapshot_download child already does the authoritative manifest-vs-cache compare and resume, so the local gate only has to recognize the unambiguous common case and defer everything else to the child. Rewrite snapshot_dir_is_complete to a small conservative gate. It returns complete only for an unpatterned model request (allow_patterns is None) that is not a diffusers pipeline, has no dangling symlink, and whose canonical root weights are present: a root model.safetensors / pytorch_model.bin single file, or a root *.safetensors.index.json / *.bin.index.json whose every listed shard is on disk. Ignore patterns need no eligibility gate -- the canonical-weight presence check verifies what the in-process load actually reads is on disk, so the common bare from_pretrained ignore list (the *.onnx / *.gguf / *.pt / *.bin format excludes plus the */*.safetensors subdir drops) keeps the warm cache fast-path eligible. Everything else (weight variants, diffusers, datasets, any allow pattern, numbered shards without an index) returns False and is handled by the watched child. A false complete is the only dangerous error (it lets the in-process load fetch a missing weight over un-killable Xet); a false not-complete only spawns the cheap child, so the gate errs that way. request_can_include_weights is reduced to a small per-pattern selector (a wildcard / weight-suffix / directory basename can select a weight; a concrete non-weight name cannot), and an ignore-only request reads weightless only when it strips every weight format -- a partial strip stays weight-bearing, so the gate never skips the child on a cache whose surviving weight is a variant or a non-safetensors format. A tokenizer or config allow list stays weightless so its offline short-circuit survives. Split the acceptance wiring in hf_xet_fallback.py, because the check is used at two sites with opposite failure modes: - Pre-download (_cache_can_skip_download): may skip the child only when the conservative gate proves a complete canonical model cache, or, for a weightless / non-model request, when the requested subset is intact. A false accept here would hang the in-process load. - Post-download (_download_result_usable): after the child's snapshot_download (which already did the authoritative resume), accept the result unless there is positive breakage evidence -- a dangling requested symlink, or a weight-bearing model warm that came back with no weight at all. Lenient on purpose so a finished diffusers / variant / either-format download is not rejected and re-looped into a DownloadStallError. The watchdog, spawn child, process-group termination, and HTTP-retry loop are unchanged -- this only simplifies the cache-state optimization that decides whether the safety mechanism can be skipped. The public surface is preserved: hf_cache_state.__all__, the names hf_xet_fallback imports, and the snapshot_download_with_xet_fallback / hf_hub_download_with_xet_fallback / DownloadStallError / start_watchdog entrypoints the unsloth-main wiring and the Studio shim consume. unsloth's common bare from_pretrained warm (allow=None plus the subdir-scoped */*.safetensors ignores) stays fast-path eligible, so warm and offline loads are unaffected. Tests: drop the exotic-layout completeness tests; keep the watchdog / spawn / transport / stall recovery suite; add focused tests for the canonical gate (single-file, sharded-with-index, rejects config-only / diffusers / patterned / sharded-without-index / broken-symlink, eligible under subdir-scoped ignores), the trimmed request_can_include_weights, and the pre/post-download split (a diffusers warm is not fast-pathed but its complete result is accepted; a config-only model result is rejected). hf_cache_state.py drops from 1355 to ~665 lines. --- tests/test_hf_xet_fallback.py | 1212 +++++--------------------------- unsloth_zoo/hf_cache_state.py | 1055 +++++---------------------- unsloth_zoo/hf_xet_fallback.py | 143 ++-- 3 files changed, 460 insertions(+), 1950 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 8d6daa76c..bb0632730 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1298,93 +1298,6 @@ def test_snapshot_dir_is_complete_unit(tmp_path): assert hcs.snapshot_dir_is_complete(snap) is True -def _make_diffusion_component(snap, blob, name, weight_filename = None): - """Create a diffusers pipeline subfolder with a config and (optionally) a weight symlink.""" - comp = snap / name - comp.mkdir() - (comp / "config.json").write_text("{}") - if weight_filename is not None: - (comp / weight_filename).symlink_to(blob) - return comp - - -def test_snapshot_dir_is_complete_diffusion_missing_component(tmp_path): - """A full pipeline warm killed with one component absent reads as incomplete. model_index.json - declares unet / vae / text_encoder; the snapshot warmed unet + vae but never started - text_encoder, so the in-process pipeline load would fetch it over unprotected Xet.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"weights") - (snap / "model_index.json").write_text( - json.dumps( - { - "_class_name": "StableDiffusionPipeline", - "_diffusers_version": "0.30.0", - "unet": ["diffusers", "UNet2DConditionModel"], - "vae": ["diffusers", "AutoencoderKL"], - "text_encoder": ["transformers", "CLIPTextModel"], - "scheduler": ["diffusers", "PNDMScheduler"], - "safety_checker": [None, None], - } - ) - ) - _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") - _make_diffusion_component(snap, blob, "vae", "diffusion_pytorch_model.safetensors") - (snap / "scheduler").mkdir() - (snap / "scheduler" / "scheduler_config.json").write_text("{}") - # text_encoder subfolder never created -> interrupted pipeline warm - assert hcs.snapshot_dir_is_complete(snap) is False - # Once the missing component is on disk the pipeline reads complete. - _make_diffusion_component(snap, blob, "text_encoder", "model.safetensors") - assert hcs.snapshot_dir_is_complete(snap) is True - - -def test_snapshot_dir_is_complete_diffusion_partial_weight_component(tmp_path): - """A weight-bearing component whose subfolder holds only a config (no weight) reads as - incomplete: the component started but its weight was never fetched.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"weights") - (snap / "model_index.json").write_text( - json.dumps( - { - "_class_name": "StableDiffusionPipeline", - "unet": ["diffusers", "UNet2DConditionModel"], - "vae": ["diffusers", "AutoencoderKL"], - } - ) - ) - _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") - _make_diffusion_component(snap, blob, "vae", weight_filename = None) # config only, no weight - assert hcs.snapshot_dir_is_complete(snap) is False - (snap / "vae" / "diffusion_pytorch_model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap) is True - - -def test_snapshot_dir_is_complete_diffusion_scoped_request_not_blocked(tmp_path): - """A scoped subfolder request (allow_patterns=["unet/*"]) targets its own subset, so the - whole-pipeline completeness rule does not apply: a unet-only snapshot reads complete even - though the pipeline declares more components.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"weights") - (snap / "model_index.json").write_text( - json.dumps( - { - "_class_name": "StableDiffusionPipeline", - "unet": ["diffusers", "UNet2DConditionModel"], - "vae": ["diffusers", "AutoencoderKL"], - "text_encoder": ["transformers", "CLIPTextModel"], - } - ) - ) - _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["unet/*"]) is True - - def test_snapshot_dir_is_complete_broken_symlink(tmp_path): """A dangling weight symlink reads as incomplete.""" snap = tmp_path / "snap" @@ -1466,332 +1379,6 @@ def test_snapshot_dir_is_complete_ignores_trainer_artifacts(tmp_path): assert hcs.snapshot_dir_is_complete(snap) is True # real weight present -def test_snapshot_dir_is_complete_requires_the_requested_weight(tmp_path): - """A patterned request is satisfied only by a weight it actually selects: a snapshot that - carries some other weight but not the requested one (e.g. adapter / checkpoint) reads as - incomplete, so the guarded download still runs.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "model.safetensors").symlink_to(blob) # base weight only - # Requesting the adapter weight: the base weight does not satisfy it. - assert hcs.snapshot_dir_is_complete( - snap, allow_patterns = ["adapter_model.safetensors"] - ) is False - # No patterns: any loadable weight suffices. - assert hcs.snapshot_dir_is_complete(snap) is True - # Once the requested adapter weight is present, the request is satisfied. - (snap / "adapter_model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete( - snap, allow_patterns = ["adapter_model.safetensors"] - ) is True - - -def test_snapshot_dir_is_complete_requires_requested_subfolder_weight(tmp_path): - """A subfolder request is satisfied only by a weight under that subfolder, not by a - root-level weight the snapshot also carries.""" - snap = tmp_path / "snap" - (snap / "checkpoint-10").mkdir(parents = True) - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "model.safetensors").symlink_to(blob) # root weight only - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is False - (snap / "checkpoint-10" / "model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is True - - -def test_snapshot_dir_is_complete_single_shard_request(tmp_path): - """A deliberate single-shard request is satisfied by that one shard; the full -of-NNNNN - set is required only for an unpatterned full warm.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "model-00002-of-00005.safetensors").symlink_to(blob) - # Just the requested shard present -> complete for that request. - assert hcs.snapshot_dir_is_complete( - snap, allow_patterns = ["model-00002-of-00005.safetensors"] - ) is True - # An unpatterned full warm requires the whole set -> incomplete (4 shards missing). - assert hcs.snapshot_dir_is_complete(snap) is False - - -def test_snapshot_dir_is_complete_checkpoint_only_not_warm_root(tmp_path): - """An unpatterned (root-model) warm is not satisfied by a weight that lives only inside a - per-checkpoint dir. A cache left by a prior allow_patterns=["checkpoint-10/*"] pull holds - checkpoint-10/model.safetensors but no root weight; reading it as a warm root model would let - the guarded download be skipped and hand from_pretrained a snapshot whose root weights are - missing (Codex #829). A root weight (or a patterned checkpoint request) still completes.""" - snap = tmp_path / "snap" - (snap / "checkpoint-10").mkdir(parents = True) - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "checkpoint-10" / "model.safetensors").symlink_to(blob) # checkpoint weight only - (snap / "config.json").write_text("{}") - # Unpatterned root warm: the checkpoint weight does not count -> incomplete. - assert hcs.snapshot_dir_is_complete(snap) is False - # A patterned request for that checkpoint IS satisfied by it (not a root warm). - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is True - # Once a root weight is present, the unpatterned warm completes (checkpoint weight ignored). - (snap / "model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap) is True - # A DeepSpeed-style global_step dir is treated the same way. - snap2 = tmp_path / "snap2" - (snap2 / "global_step500").mkdir(parents = True) - (snap2 / "global_step500" / "pytorch_model.bin").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap2) is False - - -def test_snapshot_dir_is_complete_catchall_not_warm_root(tmp_path): - """A pure catch-all allow list (["*"]) selects the whole repo just like an unpatterned warm, so - a root from_pretrained still reads ROOT weights. A checkpoint-only cache (left by a prior - allow_patterns=["checkpoint-10/*"] pull) must NOT read as complete for ["*"] -- HF's fnmatch - "*" spans "/" and would otherwise count the checkpoint weight as satisfying the catch-all - (Codex #829). A path-bearing pattern that names the checkpoint is still trusted.""" - snap = tmp_path / "snap" - (snap / "checkpoint-10").mkdir(parents = True) - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "checkpoint-10" / "model.safetensors").symlink_to(blob) # checkpoint weight only - (snap / "config.json").write_text("{}") - # Catch-all is treated like an unpatterned root warm: the checkpoint weight does not count. - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*"]) is False - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["**"]) is False - # A path-bearing checkpoint pattern IS satisfied by it (deliberate checkpoint selection). - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is True - # Once a root weight is present, the catch-all completes. - (snap / "model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*"]) is True - - -def test_snapshot_dir_is_complete_root_file_glob_not_checkpoint(tmp_path): - """A no-slash root weight glob (["*.safetensors"]) reads from the repo ROOT, but HF's fnmatch - "*" spans "/" and also matches checkpoint-10/model.safetensors. A checkpoint-only cache must - NOT read as complete for the root safetensors warm (Codex #829) -- the root model.safetensors - from_pretrained reads is still missing. A path-bearing pattern naming the checkpoint is - trusted.""" - snap = tmp_path / "snap" - (snap / "checkpoint-10").mkdir(parents = True) - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "checkpoint-10" / "model.safetensors").symlink_to(blob) - (snap / "config.json").write_text("{}") - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is False - # A path-bearing checkpoint pattern IS satisfied by the checkpoint weight. - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["checkpoint-10/*"]) is True - # A root weight completes the root glob. - (snap / "model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is True - - -def test_snapshot_dir_is_complete_root_warm_ignores_subfolder_weight(tmp_path): - """A root warm (allow_patterns=None, or a no-slash glob) is not satisfied by a weight that - lives only in a NON-checkpoint subfolder such as a precision variant (BF16/, fp16/). A prior - allow_patterns=["BF16/*"] pull leaves BF16/model.safetensors but no root weight; a root - from_pretrained never reads that subfolder, so the snapshot must read as incomplete (Codex - #829). A root weight, or an explicit BF16/ request, still completes.""" - snap = tmp_path / "snap" - (snap / "BF16").mkdir(parents = True) - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "BF16" / "model.safetensors").symlink_to(blob) - (snap / "config.json").write_text("{}") - # Unpatterned and no-slash-glob root warms both ignore the subfolder weight. - assert hcs.snapshot_dir_is_complete(snap) is False - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is False - # An explicit subfolder request IS satisfied by it (deliberate subfolder selection). - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["BF16/*"]) is True - # A root weight completes the root warm (the subfolder variant is ignored). - (snap / "model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap) is True - - -def test_snapshot_dir_is_complete_diffusion_file_glob_validates_components(tmp_path): - """A full diffusers warm expressed as file globs (["*.safetensors", "*.json"]) still selects - nested component files, so the model_index.json component check must run: a stale snapshot with - only unet/ but a declared vae / text_encoder reads as incomplete (Codex #829), not accepted as - a scoped subset.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"weights") - (snap / "model_index.json").write_text( - json.dumps( - { - "_class_name": "StableDiffusionPipeline", - "unet": ["diffusers", "UNet2DConditionModel"], - "vae": ["diffusers", "AutoencoderKL"], - "text_encoder": ["transformers", "CLIPTextModel"], - } - ) - ) - _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") - globs = ["*.safetensors", "*.json"] - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is False - # Once the other declared components land, the file-glob warm completes. - _make_diffusion_component(snap, blob, "vae", "diffusion_pytorch_model.safetensors") - _make_diffusion_component(snap, blob, "text_encoder", "model.safetensors") - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is True - - -def test_request_can_include_weights_sentencepiece_glob(): - """A SentencePiece / slow-tokenizer vocab glob (["spiece*"], ["sentencepiece*"], ["spm*"]) - selects only tokenizer assets and no weight, so it reads as WEIGHTLESS. Without a representative - the no-slash glob is misread as a weight directory and a tokenizer-only snapshot is wrongly - rejected for lacking a weight (Codex #829).""" - assert hcs.request_can_include_weights(["spiece*"], None) is False - assert hcs.request_can_include_weights(["sentencepiece*"], None) is False - assert hcs.request_can_include_weights(["spm*"], None) is False - assert hcs.request_can_include_weights(["spiece.model", "tokenizer*"], None) is False - # A weight glob is still weight-including (no accept-stale). - assert hcs.request_can_include_weights(["model*"], None) is True - - -def test_request_can_include_weights_variant_prefix_glob(): - """A variant weight prefix glob (["model.fp16*"], ["pytorch_model.fp16*"]) selects a real variant - weight (model.fp16.safetensors) even though the pattern ends in the wildcard, not a weight - suffix. It must read as weight-including so the fast path does not accept a config-only snapshot - without the requested variant weights (Codex #829). A metadata prefix glob stays weightless.""" - assert hcs.request_can_include_weights(["model.fp16*"], None) is True - assert hcs.request_can_include_weights(["pytorch_model.fp16*"], None) is True - assert hcs.request_can_include_weights(["model.bf16*"], None) is True - # Metadata / non-weight prefix globs stay weightless (no over-reject of a tokenizer-only warm). - assert hcs.request_can_include_weights(["tokenizer*"], None) is False - assert hcs.request_can_include_weights(["config*"], None) is False - assert hcs.request_can_include_weights(["spiece*"], None) is False - - -def test_snapshot_dir_is_complete_diffusion_weights_only_glob(tmp_path): - """A weights-only root glob (["*.safetensors"]) on a diffusers repo downloads the component - weights (unet/, vae/) but NOT model_index.json, so the pipeline layout must be recognized from - the component weights themselves -- otherwise the snapshot is misread as a non-pipeline root - load, its subfolder weights dropped, and a complete download reported incomplete (Codex #829). - A genuine non-pipeline subfolder-only snapshot (BF16/) is still rejected for a root glob.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"w") - _make_diffusion_component(snap, blob, "unet", "diffusion_pytorch_model.safetensors") - _make_diffusion_component(snap, blob, "vae", "diffusion_pytorch_model.safetensors") - # No model_index.json on disk (a *.safetensors warm would not have selected it). - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is True - # A precision-variant subfolder (not a pipeline component) must NOT be mistaken for a pipeline. - snap2 = tmp_path / "snap2" - (snap2 / "BF16").mkdir(parents = True) - (snap2 / "BF16" / "model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap2, allow_patterns = ["*.safetensors"]) is False - - -def test_snapshot_has_requested_broken_symlinks_dataset_vs_model(tmp_path): - """A dangling checkpoint-dir symlink blocks a DATASET snapshot (every requested file matters) - but not a MODEL root warm (a root load never reads a checkpoint-dir file), so the checkpoint-dir - drop applies only to models (Codex #829).""" - snap = tmp_path / "snap" - (snap / "checkpoint-10").mkdir(parents = True) - (snap / "checkpoint-10" / "data.parquet").symlink_to(tmp_path / "missing") # dangling - assert hcs.snapshot_has_requested_broken_symlinks(snap, repo_type = "model") is False - assert hcs.snapshot_has_requested_broken_symlinks(snap, repo_type = "dataset") is True - # A dangling ROOT file blocks both (it is a requested interrupted file in either case). - snap2 = tmp_path / "snap2" - snap2.mkdir() - (snap2 / "data.parquet").symlink_to(tmp_path / "missing") - assert hcs.snapshot_has_requested_broken_symlinks(snap2, repo_type = "model") is True - assert hcs.snapshot_has_requested_broken_symlinks(snap2, repo_type = "dataset") is True - - -def test_request_can_include_weights_chat_template_glob(): - """A chat-template glob (["chat_template*"]) selects only chat_template.json / .jinja and no - weight, so it reads as WEIGHTLESS. Without a representative in the non-weight probes the glob is - misread as a weight directory and a template-only snapshot is wrongly rejected (Codex #829).""" - assert hcs.request_can_include_weights(["chat_template*"], None) is False - assert hcs.request_can_include_weights(["chat_template.jinja"], None) is False - assert hcs.request_can_include_weights(["added_tokens.json"], None) is False - # A weight glob is still weight-including. - assert hcs.request_can_include_weights(["*.safetensors"], None) is True - - -def test_snapshot_dir_is_complete_root_glob_validates_shard_index(tmp_path): - """A root-wide patterned warm (allow_patterns=["*.safetensors", "*.json"], or - ignore_patterns=["*.onnx"]) still warms the root model, so a shard index whose shards are not - all on disk must reject it -- even for NON-numbered shard names the numbered-shard check cannot - catch (Codex #829). The index validation now runs for every _targets_root_only request, not - only the unpatterned one.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "foo.safetensors").symlink_to(blob) # one non-numbered shard present - (snap / "model.safetensors.index.json").write_text( - json.dumps({"weight_map": {"a": "foo.safetensors", "b": "bar.safetensors"}}) - ) - globs = ["*.safetensors", "*.json"] - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is False # bar missing - assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.onnx"]) is False # bar missing - (snap / "bar.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globs) is True - - -def test_snapshot_dir_is_complete_sharded_glob_requires_index(tmp_path): - """A full warm expressed as a weight glob (["*.safetensors"]) over a complete numbered-shard set - is still incomplete without the index sidecar (transformers cannot load a local sharded - checkpoint without it). A deliberate exact single-shard request is exempt -- it wants only that - file, not the whole model.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "model-00001-of-00002.safetensors").symlink_to(blob) - (snap / "model-00002-of-00002.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is False - # An exact single-shard request is satisfied by that shard alone (no index required). - assert hcs.snapshot_dir_is_complete( - snap, allow_patterns = ["model-00001-of-00002.safetensors"] - ) is True - (snap / "model.safetensors.index.json").write_text( - json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", - "b": "model-00002-of-00002.safetensors"}}) - ) - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is True - - -def test_snapshot_dir_is_complete_variant_sharded_index(tmp_path): - """A variant sharded checkpoint must not be falsely rejected for lacking an index. Transformers' - _add_variant names the variant index model.safetensors.index.fp16.json (variant before the - trailing .json) and the shards model.fp16-00001-of-00002.safetensors (variant in the regex - prefix). The index-sidecar requirement recognizes the variant index, so a complete variant - sharded set with its index reads complete, while the same set without ANY index reads - incomplete.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) - (snap / "model.fp16-00002-of-00002.safetensors").symlink_to(blob) - # Every shard present but no index of any kind -> incomplete. - assert hcs.snapshot_dir_is_complete(snap) is False - # The variant index (note: token before the trailing .json) makes it loadable. - (snap / "model.safetensors.index.fp16.json").write_text( - json.dumps({"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", - "b": "model.fp16-00002-of-00002.safetensors"}}) - ) - assert hcs.snapshot_dir_is_complete(snap) is True - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is True - - -def test_request_can_include_weights_processor_subfolder(): - """A processor / image_processor subfolder ships only *_config.json + vocab files (no weights), - so a catch-all warm under it (processor/*) reads as WEIGHTLESS. Without this, the synthetic - processor/model.safetensors weight probe makes the request look weight-bearing and a - processor-only snapshot is wrongly rejected for lacking a weight (Codex #829). A weight under - the same subfolder is still recognized as weight-including.""" - assert hcs.request_can_include_weights(["processor/*"], None) is False - assert hcs.request_can_include_weights(["image_processor/*"], None) is False - assert hcs.request_can_include_weights(["processor/"], None) is False - # A weight name under the subfolder still reads as weight-including (no accept-stale). - assert hcs.request_can_include_weights(["processor/model.safetensors"], None) is True - - def test_snapshot_dir_is_complete_checkpoint_index_does_not_gate_root(tmp_path): """A per-checkpoint shard index with missing shards must not fail an unpatterned root warm: the root weights are what the load reads, so an incomplete checkpoint index is irrelevant to @@ -1811,66 +1398,6 @@ def test_snapshot_dir_is_complete_checkpoint_index_does_not_gate_root(tmp_path): assert hcs.snapshot_dir_is_complete(snap) is True -def test_snapshot_dir_is_complete_requires_each_named_weight(tmp_path): - """require_named_weights makes a request naming multiple exact weights (base + adapter) need - EACH logical weight on disk -- so a stale snapshot holding only the base is rejected -- while - grouping format variants of one logical weight so an "either format" list (pytorch_model.bin + - model.safetensors) is satisfied by whichever format the repo actually ships (no error-forever on - a name that doesn't exist).""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "model.safetensors").symlink_to(blob) # base only; adapter missing - pair = ["model.safetensors", "adapter_model.safetensors"] - # base + adapter are two LOGICAL weights -> both required; the adapter is missing -> incomplete. - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is False - # Lenient (require_named_weights off): a present selected weight suffices. - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = False) is True - # Either-format list = ONE logical weight: whichever format is present satisfies it, under both - # the strict and lenient checks (no spurious failure on the absent format). - either = ["pytorch_model.bin", "model.safetensors"] - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = either, require_named_weights = True) is True - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = either, require_named_weights = False) is True - # Both present -> the base + adapter request is satisfied. - (snap / "adapter_model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is True - # A named weight the ignore filter drops is not actually requested, so it is not required. - (snap / "adapter_model.safetensors").unlink() - assert hcs.snapshot_dir_is_complete( - snap, allow_patterns = pair, ignore_patterns = ["adapter_model.safetensors"], - require_named_weights = True, - ) is True - # A glob may legitimately select a subset, so it is never forced to be exhaustive. - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"], require_named_weights = True) is True - - -def test_snapshot_dir_is_complete_requires_named_non_weight_exact_only(tmp_path): - """An EXACT-file request (no globs) naming a non-weight alongside a weight requires the - non-weight on disk too: a stale cache holding only the weight must not short-circuit past the - guarded download that should still fetch the explicitly named tokenizer / config (Codex #829). - A request containing ANY glob is instead a broad selection where aux files are best-effort, so - only its concrete weights are required -- keeping unsloth's glob-bearing adapter / tokenizer - warms able to short-circuit on a warm cache rather than re-downloading on every load.""" - blob = tmp_path / "blob" - blob.write_bytes(b"x") - snap = tmp_path / "snap" - snap.mkdir() - (snap / "model.safetensors").symlink_to(blob) # weight present; tokenizer.json missing - pair = ["model.safetensors", "tokenizer.json"] - # Strict (pre-download): the named tokenizer.json is missing -> incomplete -> guarded download. - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is False - # Lenient (post-download): a present selected weight suffices (no error on a possibly-absent name). - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = False) is True - # Once the named non-weight is on disk, strict is satisfied. - (snap / "tokenizer.json").write_text("{}") - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = pair, require_named_weights = True) is True - # A list containing ANY glob is a broad warm: optional aux names are NOT required, only weights. - (snap / "tokenizer.json").unlink() - globbed = ["model.safetensors", "tokenizer.json", "modeling_*.py"] - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = globbed, require_named_weights = True) is True - - def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): """HF's local_files_only returns a config-only snapshot (e.g. left by an earlier AutoConfig fetch) without checking weights. The fast path must reject it and complete @@ -1906,25 +1433,6 @@ def test_fast_path_requires_each_named_weight(hf_cache, monkeypatch): assert out == "/cache/snap-fresh" and len(fake.calls) == 1 -def test_fast_path_either_format_not_failed_post_download(hf_cache, monkeypatch): - """An "either format" request (pytorch_model.bin + model.safetensors) against a repo that - only ships safetensors must not error: the child's safetensors-only snapshot is accepted - post-download, since the strict named-weight rule is pre-download only (no spurious - incomplete-snapshot failure for a name that does not exist, Codex #829).""" - blobs = _blobs_dir(hf_cache, DL_REPO) - child = blobs.parent / "snapshots" / "fresh" - child.mkdir(parents = True) - w = blobs / "w" - w.write_bytes(b"x") - (child / "model.safetensors").symlink_to(w) # safetensors only; no pytorch_model.bin - fake = _install(monkeypatch, [("ok", str(child))]) - out = xf.snapshot_download_with_xet_fallback( - DL_REPO, token = None, force_download = True, - allow_patterns = ["pytorch_model.bin", "model.safetensors"], - ) - assert out == str(child) and len(fake.calls) == 1 - - def test_child_broken_snapshot_retries_over_http(monkeypatch, tmp_path): """A real but broken child snapshot result (HF offline-fallback returning a dir with dangling symlinks) is rejected on the Xet attempt and retried over HTTP; a clean @@ -2067,92 +1575,6 @@ def test_request_can_include_weights_path_qualified(): ) is True -def test_request_can_include_weights_no_slash_dir_glob(): - """A no-slash directory glob (checkpoint-*, global_step*) matches nested weights via HF's - fnmatch '*'-spans-'/' rule, so it must read as weight-including; a no-slash file glob with - an extension (tokenizer.*, *.json) stays weightless.""" - assert hcs.request_can_include_weights(["checkpoint-*"], None) is True - assert hcs.request_can_include_weights(["epoch-*"], None) is True - assert hcs.request_can_include_weights(["global_step*"], None) is True - assert hcs.request_can_include_weights(["*"], None) is True - # ignore_patterns that drop every weight format still wins over the dir glob. - assert hcs.request_can_include_weights( - ["checkpoint-*"], - ["*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", - "*.h5", "*.msgpack", "*.ckpt", "*.onnx", "*.pdparams"], - ) is False - # File globs with an extension are not directory globs. - assert hcs.request_can_include_weights(["tokenizer.*"], None) is False - assert hcs.request_can_include_weights(["*.json"], None) is False - # A dotted no-slash glob whose stem names a checkpoint DIRECTORY still includes weights. - assert hcs.request_can_include_weights(["checkpoint-v1.*"], None) is True - assert hcs.request_can_include_weights(["global_step100.*"], None) is True - - -def test_request_can_include_weights_wildcard_parent(): - """A wildcard parent dir with a weight basename glob (checkpoint-*/adapter_model.*, - */model.*) must read as weight-including, and ignore_patterns must still be applied to a - wildcard-parent request rather than bypassed by an early return.""" - assert hcs.request_can_include_weights(["checkpoint-*/adapter_model.*"], None) is True - assert hcs.request_can_include_weights(["*/model.*"], None) is True - assert hcs.request_can_include_weights(["checkpoint-*/*.safetensors"], None) is True - # ignore_patterns applies under a wildcard parent: dropping every weight format -> weightless. - assert hcs.request_can_include_weights( - ["checkpoint-*/*"], - ["*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", - "*.h5", "*.msgpack", "*.ckpt", "*.onnx", "*.pdparams"], - ) is False - # Dropping only some formats leaves the request able to include the others. - assert hcs.request_can_include_weights( - ["checkpoint-*/*"], ["*.safetensors", "*.bin"] - ) is True - # A non-weight basename under a wildcard parent stays weightless. - assert hcs.request_can_include_weights(["checkpoint-*/tokenizer.json"], None) is False - - -def test_request_can_include_weights_weight_selecting_globs(): - """Weight-selecting basename globs whose stem is not the canonical 'model' -- PEFT - adapters, consolidated / original checkpoints, diffusers -- must read as including - weights, so a stale snapshot missing them is not accepted on the weightless path.""" - assert hcs.request_can_include_weights(["adapter_model.*"], None) is True - assert hcs.request_can_include_weights(["adapter_model.safetensors"], None) is True - assert hcs.request_can_include_weights(["consolidated.*"], None) is True - assert hcs.request_can_include_weights(["consolidated.00.pth"], None) is True - assert hcs.request_can_include_weights(["diffusion_pytorch_model.*"], None) is True - assert hcs.request_can_include_weights(["adapter*.safetensors"], None) is True - # A non-weight basename glob stays weightless. - assert hcs.request_can_include_weights(["tokenizer.*"], None) is False - - -def test_request_can_include_weights_custom_weight_suffix_globs(): - """A no-slash FILE glob whose stem matches no canonical probe but whose suffix is a weight - suffix (lora_*.safetensors, *.bin, model-*.safetensors, custom_*.pt) must read as - weight-including, so a stale snapshot missing it is not accepted on the weightless path. A - non-weight-suffix file glob (*.json) stays weightless, and ignore_patterns still wins.""" - assert hcs.request_can_include_weights(["lora_*.safetensors"], None) is True - assert hcs.request_can_include_weights(["*.bin"], None) is True - assert hcs.request_can_include_weights(["model-*.safetensors"], None) is True - assert hcs.request_can_include_weights(["my_custom_*.pt"], None) is True - assert hcs.request_can_include_weights(["*.json"], None) is False - # ignore_patterns dropping that very format wins over the weight-suffix glob. - assert hcs.request_can_include_weights(["lora_*.safetensors"], ["*.safetensors"]) is False - - -def test_request_can_include_weights_bracket_globs(): - """A bracket / range glob (checkpoint-[0-9]/*.safetensors, model-[0-9].safetensors) is - concretized to a member the class actually matches, so the weight probe still satisfies the - caller's own pattern and the request reads as weight-including, not misclassified - weightless.""" - assert hcs.request_can_include_weights(["checkpoint-[0-9]/*.safetensors"], None) is True - assert hcs.request_can_include_weights(["model-[0-9].safetensors"], None) is True - assert hcs.request_can_include_weights(["ckpt-[0-9][0-9]/*"], None) is True - # The concretizer picks an in-class member, not a literal 'x' that the class would reject. - assert hcs._concretize_glob("checkpoint-[0-9]") == "checkpoint-0" - assert hcs._concretize_glob("layer_[a-f]") == "layer_a" - # A negated class yields a filler the class does not exclude (here a non-digit). - assert not hcs._concretize_glob("checkpoint-[!0-9]").endswith(tuple("0123456789")) - - def test_request_can_include_weights_path_qualified_custom_globs(): """A path-qualified custom weight glob (checkpoint-10/lora_*.safetensors, with a globbed parent too) names a weight whose basename matches no canonical probe; it must read as @@ -2165,26 +1587,6 @@ def test_request_can_include_weights_path_qualified_custom_globs(): assert hcs.request_can_include_weights(["checkpoint-10/tokenizer.json"], None) is False -def test_request_can_include_weights_trainer_artifacts_weightless(tmp_path): - """A trainer / optimizer artifact (optimizer.pt, training_args.bin, rng_state_*.pth) carries - a weight suffix but is not a loadable weight: request_can_include_weights must read it as - weightless, matching snapshot_dir_is_complete -- otherwise a child that fetches exactly that - artifact loops as an 'incomplete snapshot' (Codex #829).""" - for art in ["optimizer.pt", "training_args.bin", "scheduler.pt", "rng_state_0.pth", - "checkpoint-10/optimizer.pt"]: - assert hcs.request_can_include_weights([art], None) is False, art - # Consistency: a snapshot holding only the requested artifact is acceptable (weightless - # path), so the guarded download is not retried into an incomplete-snapshot failure. - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "optimizer.pt").symlink_to(blob) - assert xf._snapshot_is_acceptable( - snap, repo_type = "model", allow_patterns = ["optimizer.pt"], ignore_patterns = None - ) is True - - def test_request_can_include_weights_empty_allow_list(tmp_path): """allow_patterns=[] is a real filter that selects NO objects (Hugging Face semantics), so the request is weightless -- it must not collapse with None (an unfiltered warmup) and @@ -2205,153 +1607,6 @@ def test_request_can_include_weights_empty_allow_list(tmp_path): assert hcs.snapshot_dir_is_complete(snap, allow_patterns = None) is True -def test_request_can_include_weights_non_weight_subfolders(): - """A generic glob under a plain (non-checkpoint) subfolder such as tokenizer/* or runs/* - must read as weightless -- the unconditional canonical re-rooting would otherwise add a - synthetic tokenizer/model.safetensors probe and misclassify a tokenizer-only download as a - model warmup (Codex #829). A checkpoint/weight dir or a weight-targeting basename still - includes weights.""" - assert hcs.request_can_include_weights(["tokenizer/*"], None) is False - assert hcs.request_can_include_weights(["runs/*"], None) is False - assert hcs.request_can_include_weights(["logs/*.txt"], None) is False - # Weight-bearing cases stay weight-including. - assert hcs.request_can_include_weights(["checkpoint-10/*"], None) is True - assert hcs.request_can_include_weights(["*/model.*"], None) is True - assert hcs.request_can_include_weights(["models/*.safetensors"], None) is True - # A weight-suffix basename under a plain subfolder is still recognized (self-probe). - assert hcs.request_can_include_weights(["tokenizer/*.safetensors"], None) is True - # A checkpoint dir nested anywhere in the parent path counts. - assert hcs.request_can_include_weights(["runs/checkpoint-5/*"], None) is True - - -def test_request_can_include_weights_weight_bearing_subfolders(tmp_path): - """A component / quant subfolder (unet/, transformer/, original/, BF16/, Q8_0/) holds - weights, so a bare catch-all under it must stay weight-including -- reading an unknown - subfolder as weightless would accept a stale config-only cache and re-open the silent Xet - hang. Only KNOWN auxiliary dirs (tokenizer/, runs/) are weightless (Codex #829).""" - for d in ["unet/*", "transformer/*", "text_encoder/*", "vae/*", "original/*", - "mp_rank_00/*", "BF16/*", "Q8_0/*", "Q4_K_M/*", "unknown_component/*"]: - assert hcs.request_can_include_weights([d], None) is True, d - # End to end: a stale config-only BF16/ snapshot (weight missing, no dangling symlink) must - # NOT be short-circuited as warm -- the guarded download still runs. - snap = tmp_path / "snap" - (snap / "BF16").mkdir(parents = True) - (snap / "BF16" / "config.json").write_text("{}") # config only, no weight - assert xf._snapshot_is_acceptable( - snap, repo_type = "model", allow_patterns = ["BF16/*"], ignore_patterns = None, - require_named_weights = True, - ) is False - # Once the weight is present, it is acceptable. - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "BF16" / "model.safetensors").symlink_to(blob) - assert xf._snapshot_is_acceptable( - snap, repo_type = "model", allow_patterns = ["BF16/*"], ignore_patterns = None, - ) is True - - -def test_request_can_include_weights_diffusers_config_only_components(tmp_path): - """The Diffusers pipeline components that ship only *_config.json / vocab files - (scheduler/, feature_extractor/, tokenizer_2/, tokenizer_3/) must read as weightless, so a - catch-all like scheduler/* is not given synthetic weight probes and a child that correctly - fetches only scheduler/scheduler_config.json is not rejected for lacking weights (Codex - #829). The weight-bearing pipeline dirs stay weight-including.""" - for d in ["scheduler/*", "feature_extractor/*", "tokenizer_2/*", "tokenizer_3/*"]: - assert hcs.request_can_include_weights([d], None) is False, d - # The weight-bearing pipeline components are NOT in the weightless set. - for d in ["unet/*", "transformer/*", "vae/*", "text_encoder/*", "text_encoder_2/*", - "image_encoder/*", "safety_checker/*"]: - assert hcs.request_can_include_weights([d], None) is True, d - # End to end: a config-only scheduler/ snapshot is acceptable for a scheduler/* request - # (no weight expected there), so the guarded download is not looped on it. - snap = tmp_path / "snap" - (snap / "scheduler").mkdir(parents = True) - (snap / "scheduler" / "scheduler_config.json").write_text("{}") - assert xf._snapshot_is_acceptable( - snap, repo_type = "model", allow_patterns = ["scheduler/*"], ignore_patterns = None, - require_named_weights = True, - ) is True - - -def test_consumer_pattern_lists_accepted_end_to_end(tmp_path): - """Lock the cross-repo contract: the EXACT allow / ignore lists unsloth's - maybe_prefetch_hf_snapshot emits must be judged correctly by this module's acceptance, so a - future drift between the two repos cannot silently loop the guarded download. These lists - mirror unsloth's _ADAPTER_PREFETCH_PATTERNS / _ROOT_AUX_PREFETCH_PATTERNS / _SUBDIR_WEIGHT_ - IGNORE_PATTERNS; if unsloth changes them, this is where the mismatch surfaces.""" - blob = tmp_path / "blob" - blob.write_bytes(b"x") - - # --- adapter_only: allow = adapter files + root aux, ignore = None --- - root_aux = [ - "config.json", "generation_config.json", "tokenizer_config.json", "tokenizer.json", - "tokenizer.model", "special_tokens_map.json", "added_tokens.json", "vocab.json", - "vocab.txt", "merges.txt", "spiece.model", "chat_template.jinja", "chat_template.json", - "preprocessor_config.json", "processor_config.json", "configuration_*.py", "modeling_*.py", - "tokenization_*.py", "processing_*.py", "image_processing_*.py", "feature_extraction_*.py", - "video_processing_*.py", "*.tiktoken", - ] - adapter_allow = ["adapter_config.json", "adapter_model*", *root_aux] - assert hcs.request_can_include_weights(adapter_allow, None) is True - snap = tmp_path / "adapter" - snap.mkdir() - (snap / "adapter_config.json").write_text("{}") - (snap / "adapter_model.safetensors").symlink_to(blob) - (snap / "config.json").write_text("{}") - (snap / "tokenizer.json").write_text("{}") - # A merged full-model weight the adapter warm never requested is present but irrelevant. - (snap / "model.safetensors").symlink_to(blob) - assert xf._snapshot_is_acceptable( - snap, repo_type = "model", allow_patterns = adapter_allow, ignore_patterns = None, - require_named_weights = True, - ) is True - # An adapter cache missing its weight (config only) is NOT acceptable -> guarded download. - snap_bad = tmp_path / "adapter_bad" - snap_bad.mkdir() - (snap_bad / "adapter_config.json").write_text("{}") - assert xf._snapshot_is_acceptable( - snap_bad, repo_type = "model", allow_patterns = adapter_allow, ignore_patterns = None, - ) is False - - # --- weights_at_root: allow = None, ignore = static skips + subdir-weight excludes --- - root_ignore = ["*.onnx", "onnx/*", "*.gguf", "checkpoint-*/*", "*/*.safetensors", "*/*.bin"] - assert hcs.request_can_include_weights(None, root_ignore) is True - rsnap = tmp_path / "root" - rsnap.mkdir() - (rsnap / "config.json").write_text("{}") - (rsnap / "model.safetensors").symlink_to(blob) # root weight present - (rsnap / "fp16").mkdir() - (rsnap / "fp16" / "model.safetensors").symlink_to(blob) # subdir weight (unread by root load) - assert xf._snapshot_is_acceptable( - rsnap, repo_type = "model", allow_patterns = None, ignore_patterns = root_ignore, - ) is True - # A subdir-only cache (no root weight) is NOT acceptable for a root load. - rsnap_bad = tmp_path / "root_bad" - (rsnap_bad / "fp16").mkdir(parents = True) - (rsnap_bad / "config.json").write_text("{}") - (rsnap_bad / "fp16" / "model.safetensors").symlink_to(blob) - assert xf._snapshot_is_acceptable( - rsnap_bad, repo_type = "model", allow_patterns = None, ignore_patterns = root_ignore, - ) is False - - -def test_basename_weight_classification_helpers(): - """Lock the catch-all-vs-weight distinction the subfolder gating rests on: a weight-stem - basename targets a weight, a config / tokenizer glob is clearly non-weight, and a bare - catch-all ('*') is NEITHER (so it defaults to weight-including under an unknown dir).""" - tw, nw = hcs._basename_targets_weight, hcs._basename_is_non_weight - assert tw("model.*") is True and nw("model.*") is False - assert tw("*.safetensors") is True and nw("*.safetensors") is False - assert tw("adapter_model.*") is True - assert tw("*.json") is False and nw("*.json") is True - assert tw("tokenizer.*") is False and nw("tokenizer.*") is True - assert tw("config.json") is False and nw("config.json") is True - # A catch-all matches both a weight and a non-weight representative -> neither classifier. - assert tw("*") is False and nw("*") is False - # *.bin matches a weight (pytorch_model.bin) AND a non-weight (training_args.bin) -> neither. - assert tw("*.bin") is False and nw("*.bin") is False - - def test_request_can_include_weights_string_form(): """Hugging Face accepts allow / ignore patterns as a bare string; it must be treated as one pattern, not iterated character by character (which would misclassify a subfolder @@ -2546,92 +1801,9 @@ def _boom(*a, **k): assert rec["terminated"] is True # child reaped despite the watchdog-start failure -def test_snapshot_dir_is_complete_tolerates_non_string_shard(tmp_path): - """A weight index whose ``weight_map`` carries a non-string value (malformed / arbitrary JSON) - must not crash the completeness probe: the bad entry is skipped, the string shards still - gate, so a real missing shard is still detected (Codex #829). Uses a non-numbered weight name - so only the index path -- not the numbered-shard-name expansion -- is exercised.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "model.safetensors").symlink_to(blob) - (snap / "model.safetensors.index.json").write_text( - json.dumps( - { - "weight_map": { - "a": "model.safetensors", - "b": ["not", "a", "string"], # malformed entry -> skipped, no crash - } - } - ) - ) - # The one concrete file it names is present; the malformed entry is ignored, so no crash and - # the snapshot reads as complete (only demonstrably-missing string shards reject). - assert hcs.snapshot_dir_is_complete(snap) is True - # A genuinely missing string shard still gates, with the malformed entry still skipped. - (snap / "model.safetensors.index.json").write_text( - json.dumps( - { - "weight_map": { - "a": "model.safetensors", - "b": "absent-extra.safetensors", - "c": {"bad": "object"}, - } - } - ) - ) - assert hcs.snapshot_dir_is_complete(snap) is False # 'absent-extra' missing, bad entry skipped - - # --------------------------------------------------------------------------- # # Codex review round: scoped completeness, weightless named files, type preservation. # --------------------------------------------------------------------------- # -def test_snapshot_complete_ignore_only_root_excludes_checkpoint(tmp_path): - """An IGNORE-ONLY root warm (no allow_patterns, e.g. ignore=['*.onnx']) is still a bare - from_pretrained reading ROOT weights, so a snapshot whose only weight lives in a checkpoint dir - must read as INCOMPLETE rather than short-circuit the guarded download (Codex #829).""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "config.json").write_text("{}") - (snap / "checkpoint-500").mkdir() - (snap / "checkpoint-500" / "model.safetensors").symlink_to(blob) - # ignore-only -> has_patterns True but allow_patterns None: checkpoint weight must not satisfy it. - assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.onnx"]) is False - # A real root weight makes the same ignore-only request complete. - (snap / "model.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.onnx"]) is True - - -def test_snapshot_complete_ignores_dangling_symlink_outside_request(tmp_path): - """A dangling symlink for a file the request does NOT select must not reject the snapshot: an - allow_patterns=['adapter_model.safetensors'] probe whose adapter weight is on disk stays complete - even with a stale dangling root model.safetensors. A dangle for the REQUESTED file still rejects - (Codex #829).""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "adapter_model.safetensors").symlink_to(blob) - (snap / "model.safetensors").symlink_to(tmp_path / "missing-blob") # dangling, NOT requested - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["adapter_model.safetensors"]) is True - # When the dangling file IS the requested one, the snapshot is incomplete. - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["model.safetensors"]) is False - - -def test_request_can_include_weights_metadata_glob_is_weightless(): - """A no-slash metadata glob (tokenizer*, config*, vocab*, special_tokens*) is a FILE glob, not a - weight-bearing directory glob, so a warm that fetched only tokenizer.json is not rejected for - lacking a weight. 'model*' / 'pytorch_model*' / 'checkpoint-*' stay weight-including (Codex #829).""" - assert hcs.request_can_include_weights(allow_patterns = ["tokenizer*"]) is False - assert hcs.request_can_include_weights(allow_patterns = ["config*"]) is False - assert hcs.request_can_include_weights(allow_patterns = ["vocab*"]) is False - assert hcs.request_can_include_weights(allow_patterns = ["special_tokens*"]) is False - assert hcs.request_can_include_weights(allow_patterns = ["model*"]) is True - assert hcs.request_can_include_weights(allow_patterns = ["pytorch_model*"]) is True - assert hcs.request_can_include_weights(allow_patterns = ["checkpoint-*"]) is True def test_requested_named_files_present_exact_request(tmp_path): @@ -2654,21 +1826,6 @@ def test_requested_named_files_present_exact_request(tmp_path): ) is True -def test_snapshot_acceptable_weightless_requires_named_file(tmp_path): - """End-to-end: _snapshot_is_acceptable for a weightless exact-named request rejects a config-only - cache missing the requested tokenizer.json, so the guarded download is not skipped (Codex #829).""" - snap = tmp_path / "snap" - snap.mkdir() - (snap / "config.json").write_text("{}") - assert xf._snapshot_is_acceptable( - snap, repo_type = "model", allow_patterns = ["tokenizer.json"], ignore_patterns = None - ) is False - (snap / "tokenizer.json").write_text("{}") - assert xf._snapshot_is_acceptable( - snap, repo_type = "model", allow_patterns = ["tokenizer.json"], ignore_patterns = None - ) is True - - def test_deterministic_oserror_type_preserved(monkeypatch): """A deterministic disk-full OSError is re-raised as OSError (not flattened to RuntimeError), so a caller's `except OSError` cleanup still runs across the spawn boundary (Codex #829).""" @@ -2736,75 +1893,6 @@ def __init__(self, message, *, response): # --------------------------------------------------------------------------- # # Codex round: dir/ wildcard, logical-weight grouping post-download, errno preservation. # --------------------------------------------------------------------------- # -def test_dir_pattern_treated_as_wildcard(tmp_path): - """A trailing-slash directory allow pattern (unet/) is a wildcard -- HF's filter_repo_objects - expands it to unet/* -- not an exact filename, so the strict named-file checks must not reject a - fully cached component directory by looking for a literal 'unet/' entry (Codex #829).""" - assert hcs._has_glob("unet/") is True - assert hcs._has_glob("checkpoint-10/") is True - assert hcs._has_glob("config.json") is False - snap = tmp_path / "snap" - snap.mkdir() - (snap / "unet").mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"x") - (snap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob) - (snap / "unet" / "config.json").write_text("{}") - # A dir/ pattern is best-effort (glob), never an exact-name requirement. - assert hcs.requested_named_files_present(snap, allow_patterns = ["unet/"]) is True - # The component-dir weight satisfies the request; not rejected for a literal 'unet/'. - assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["unet/"]) is True - - -def test_weight_logical_key_groups_formats(): - """Format / shard variants of one logical weight share a key; different stems or subdirs do - not, so an either-format list is one group while base + adapter are two (Codex #829).""" - k = hcs._weight_logical_key - assert k("pytorch_model.bin") == k("model.safetensors") # either-format -> 1 group - assert k("model-00001-of-00002.safetensors") == k("model.safetensors") # shard -> same group - assert k("model.safetensors") != k("adapter_model.safetensors") # base vs adapter - assert k("unet/model.safetensors") != k("vae/model.safetensors") # different subdirs - - -def test_post_download_rejects_stale_named_weight(hf_cache, monkeypatch): - """A finished child snapshot missing an explicitly named LOGICAL weight (base present, adapter - missing) is now treated as incomplete post-download and retried over HTTP, instead of being - returned with the adapter still missing (Codex #829).""" - blobs = _blobs_dir(hf_cache, DL_REPO) - base_only = blobs.parent / "snapshots" / "xet" - base_only.mkdir(parents = True) - w = blobs / "w" - w.write_bytes(b"x") - (base_only / "model.safetensors").symlink_to(w) # base only; adapter missing - complete = blobs.parent / "snapshots" / "http" - complete.mkdir(parents = True) - (complete / "model.safetensors").symlink_to(w) - (complete / "adapter_model.safetensors").symlink_to(w) - fake = _install(monkeypatch, [("ok", str(base_only)), ("ok", str(complete))]) - out = xf.snapshot_download_with_xet_fallback( - DL_REPO, token = None, force_download = True, - allow_patterns = ["model.safetensors", "adapter_model.safetensors"], - ) - assert out == str(complete) - assert [c.disable_xet for c in fake.calls] == [False, True] # the stale base-only result retried - - -def test_post_download_either_format_still_accepted(hf_cache, monkeypatch): - """An either-format list against a single-format child snapshot stays accepted post-download - (the formats group to one logical weight), so require_named_weights does not error-forever on a - format the repo never ships (Codex #829).""" - blobs = _blobs_dir(hf_cache, DL_REPO) - child = blobs.parent / "snapshots" / "only-st" - child.mkdir(parents = True) - w = blobs / "w" - w.write_bytes(b"x") - (child / "model.safetensors").symlink_to(w) # safetensors only; no pytorch_model.bin - fake = _install(monkeypatch, [("ok", str(child))]) - out = xf.snapshot_download_with_xet_fallback( - DL_REPO, token = None, force_download = True, - allow_patterns = ["pytorch_model.bin", "model.safetensors"], - ) - assert out == str(child) and len(fake.calls) == 1 def test_parse_errno(): @@ -2826,113 +1914,8 @@ def test_oserror_errno_preserved(monkeypatch): # --------------------------------------------------------------------------- # -# Codex round: weightless broken-symlink scoping + processor metadata glob. +# Spawn-safety regressions: failed-spawn queue cleanup + disable-Xet env-race lock. # --------------------------------------------------------------------------- # -def test_weightless_broken_symlink_scoped_to_request(tmp_path): - """A weightless request (allow=['config.json']) must accept a snapshot whose config is present - even when an EXCLUDED weight left a dangling symlink from an earlier interrupted pull -- only a - dangling REQUESTED file rejects it (Codex #829).""" - snap = tmp_path / "snap" - snap.mkdir() - (snap / "config.json").write_text("{}") - (snap / "model.safetensors").symlink_to(tmp_path / "missing-blob") # dangling, NOT requested - assert hcs.snapshot_has_requested_broken_symlinks(snap, allow_patterns = ["config.json"]) is False - assert xf._snapshot_is_acceptable( - snap, repo_type = "model", allow_patterns = ["config.json"], ignore_patterns = None - ) is True - # A dangling REQUESTED file does reject. - (snap / "config.json").unlink() - (snap / "config.json").symlink_to(tmp_path / "missing-cfg") - assert hcs.snapshot_has_requested_broken_symlinks(snap, allow_patterns = ["config.json"]) is True - assert xf._snapshot_is_acceptable( - snap, repo_type = "model", allow_patterns = ["config.json"], ignore_patterns = None - ) is False - - -def test_processor_glob_is_weightless(): - """A processor-only warm (allow=['processor*']) selects processor_config.json and no weight, so - it must read as weightless rather than be rejected for lacking weights (Codex #829).""" - assert hcs._basename_is_non_weight("processor*") is True - assert hcs.request_can_include_weights(allow_patterns = ["processor*"]) is False - # control: a real weight glob stays weight-including - assert hcs.request_can_include_weights(allow_patterns = ["model*"]) is True - - -# --------------------------------------------------------------------------- -# Regression tests for the 3-reviewer (Opus) review round on #829. -# Each maps to one accepted finding; the rejected "variant shard after the count" -# finding is covered (and disproven) by test_snapshot_dir_is_complete_variant_sharded_index. -# --------------------------------------------------------------------------- -def test_request_can_include_weights_subfolder_variant_component_glob(): - """R3-4: a path-qualified variant component glob (["unet/diffusion_pytorch_model.fp16*"], - ["text_encoder/model.fp16*"]) selects a real variant weight inside a pipeline subfolder. It does - not end in a weight suffix and the re-rooted canonical probes (unet/model.safetensors, ...) do not - match the variant name, so without a path-qualified self-probe it would read as WEIGHTLESS and the - fast path could accept a config-only component snapshot -> the silent Xet hang (accept-stale). It - must read as weight-including. A non-weight subfolder glob (unet/config*) stays weightless.""" - assert hcs.request_can_include_weights(["unet/diffusion_pytorch_model.fp16*"], None) is True - assert hcs.request_can_include_weights(["text_encoder/model.fp16*"], None) is True - assert hcs.request_can_include_weights(["transformer/diffusion_pytorch_model.bf16*"], None) is True - # A non-weight subfolder glob is not over-classified. - assert hcs.request_can_include_weights(["unet/config*"], None) is False - assert hcs.request_can_include_weights(["text_encoder/tokenizer*"], None) is False - # The self-probe re-roots the synthetic weight under the requested subfolder. - assert hcs._weight_self_probe("unet/diffusion_pytorch_model.fp16*") == \ - "unet/diffusion_pytorch_model.fp16.safetensors" - - -def test_request_can_include_weights_index_glob_weightless_weight_glob_kept(): - """R3-2: a trailing-wildcard shard-index glob (["model.safetensors.index*"], - ["model.bin.index*"]) selects only the index sidecar (no weight), so it reads as WEIGHTLESS -- the - synthetic-suffix branch must not turn the .index stem into a fake model.safetensors.index.safetensors - weight. A plain weight-stem glob (["model.safetensors*"], ["pytorch_model.bin*"]) still reads as - weight-including via the canonical weight probe (no accept-stale).""" - assert hcs.request_can_include_weights(["model.safetensors.index*"], None) is False - assert hcs.request_can_include_weights(["model.bin.index*"], None) is False - assert hcs.request_can_include_weights(["model.safetensors*"], None) is True - assert hcs.request_can_include_weights(["pytorch_model.bin*"], None) is True - - -def test_weight_self_probe_artifact_and_sidecar_stems(): - """R3-3: _weight_self_probe must not synthesize a fake weight for a trailing-wildcard glob whose - stem is a trainer artifact (scheduler*, rng_state*, scaler*, optimizer*) or an index / weight - sidecar (model.safetensors.index*, model.safetensors*) -- the absorbed-suffix branch would - otherwise return scheduler.safetensors / model.safetensors.index.safetensors and make the request - require a weight that does not exist (over-reject). A genuine variant weight stem still resolves.""" - for artifact in ("scheduler*", "rng_state*", "rng_state_*", "scaler*", "optimizer*", "training_args*"): - assert hcs._weight_self_probe(artifact) is None, artifact - for sidecar in ("model.safetensors.index*", "model.bin.index*", "model.safetensors*", - "pytorch_model.bin*"): - assert hcs._weight_self_probe(sidecar) is None, sidecar - # A real variant weight stem still resolves to its concrete weight name (the suffix loop tries - # .safetensors first, so a stem that admits either suffix resolves to the safetensors form). - assert hcs._weight_self_probe("model.fp16*") == "model.fp16.safetensors" - assert hcs._weight_self_probe("pytorch_model.fp16*") == "pytorch_model.fp16.safetensors" - - -def test_snapshot_dir_is_complete_diffusion_sharded_component_requires_index(tmp_path): - """R1: a diffusers pipeline whose weight-bearing component (transformer/, unet/) holds a complete - NUMBERED-shard set but no index sidecar is INCOMPLETE -- transformers cannot load a local sharded - component without its index, so reporting complete would warm a cache the in-process load then - re-fetches (the silent Xet hang). Adding the component index makes it complete.""" - snap = tmp_path / "snap" - snap.mkdir() - blob = tmp_path / "blob" - blob.write_bytes(b"w") - (snap / "model_index.json").write_text( - json.dumps({"_class_name": "FluxPipeline", - "transformer": ["diffusers", "FluxTransformer2DModel"]}) - ) - comp = _make_diffusion_component(snap, blob, "transformer") - (comp / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) - (comp / "diffusion_pytorch_model-00002-of-00002.safetensors").symlink_to(blob) - # Complete shard set, NO index -> incomplete. - assert hcs.snapshot_dir_is_complete(snap) is False - (comp / "diffusion_pytorch_model.safetensors.index.json").write_text( - json.dumps({"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", - "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}}) - ) - assert hcs.snapshot_dir_is_complete(snap) is True def test_failed_spawn_closes_result_queue(monkeypatch): @@ -3018,3 +2001,196 @@ def _spy(): ) assert out == "/tmp/warm" assert seen.get("held") is True + + +# --------------------------------------------------------------------------- # +# Conservative fast-path gate + pre/post-download acceptance split (PR #829 trim). +# The completeness gate is intentionally narrow: it fast-paths ONLY the unambiguous +# canonical model cache, deferring everything else to the watched snapshot_download +# child. The pre-download (skip the child?) and post-download (accept the result?) +# checks are deliberately asymmetric -- strict pre, lenient post. +# --------------------------------------------------------------------------- # +def _mk_snapshot(tmp_path, name): + blob = tmp_path / "_blob" + if not blob.exists(): + blob.write_bytes(b"w") + snap = tmp_path / name + snap.mkdir() + return snap, blob + + +def test_gate_fast_paths_canonical_single_file(tmp_path): + """A complete, unpatterned single-file model cache is fast-path eligible (skip the child).""" + snap, blob = _mk_snapshot(tmp_path, "single") + (snap / "model.safetensors").symlink_to(blob) + (snap / "config.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_gate_fast_paths_canonical_sharded_with_index(tmp_path): + """A complete sharded model with its index sidecar is fast-path eligible; without the index, or + with a listed shard missing, it is not (transformers cannot load a local sharded checkpoint + without a complete index).""" + snap, blob = _mk_snapshot(tmp_path, "shard") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is False # numbered shards, no index + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap) is True + snap2, _ = _mk_snapshot(tmp_path, "shard2") + (snap2 / "model-00001-of-00002.safetensors").symlink_to(blob) # shard 2 absent + (snap2 / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap2) is False + + +def test_gate_rejects_config_only(tmp_path): + snap, _ = _mk_snapshot(tmp_path, "cfg") + (snap / "config.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_gate_rejects_diffusers_marker(tmp_path): + """A diffusers pipeline (root model_index.json) is never fast-pathed -> defer to the child, + even when a root-level weight happens to be present.""" + snap, blob = _mk_snapshot(tmp_path, "diff") + (snap / "model_index.json").write_text("{}") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_gate_rejects_any_allow_pattern(tmp_path): + """Any allow_patterns makes the request non-trivial -> defer to the child (no fast-path).""" + snap, blob = _mk_snapshot(tmp_path, "pat") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is False + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["model.safetensors"]) is False + + +def test_gate_eligible_under_ignore_patterns(tmp_path): + """allow=None with ANY ignore patterns stays fast-path eligible: the canonical-weight presence + check verifies the surviving root weight is on disk, so ignores that drop other formats cannot + make an incomplete cache read complete. This covers the common bare from_pretrained warm, whose + real ignore list mixes root-level format excludes (*.onnx, *.gguf, *.pt, *.bin) with subdir + (*/*.safetensors) drops.""" + snap, blob = _mk_snapshot(tmp_path, "ign") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*/*.safetensors", "*/*.bin"]) is True + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.onnx"]) is True + # The actual unsloth bare-from_pretrained combined ignore list (root-level format excludes + + # subdir-weight drops) -- the warm root model.safetensors must still fast-path. + unsloth_ignore = [ + "*.onnx", "*.h5", "*.msgpack", "*.tflite", "*.mlmodel", "*.gguf", "*.pt", "*.pth", + "*.ckpt", "optimizer.*", "scheduler.*", "rng_state*", "trainer_state.json", + "events.out.tfevents*", "*.bin", + "*/*.safetensors", "*/*.bin", "*/*.h5", "*/*.msgpack", "*/*.pt", "*/*.pth", + ] + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = unsloth_ignore) is True + + +def test_gate_rejects_broken_symlink(tmp_path): + snap, _ = _mk_snapshot(tmp_path, "broken") + (snap / "model.safetensors").symlink_to(tmp_path / "_missing") + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_request_can_include_weights_trim_semantics(): + r = hcs.request_can_include_weights + assert r(None, None) is True # bare unpatterned + assert r(None, ["*/*.safetensors", "*/*.bin"]) is True # subdir ignore (common bare) + assert r(["*.safetensors"], None) is True # weight glob + assert r(["model.fp16.safetensors"], None) is True # variant exact weight + assert r(["unet/*"], None) is True # subfolder weight glob + assert r(["model.gguf"], None) is True # gguf is a weight + assert r(["config.json", "tokenizer.json", "*.py"], None) is False # tokenizer-only + assert r(["adapter_model*", "adapter_config.json"], None) is True # adapter + assert r([], None) is False # empty allow selects nothing + + +def test_request_can_include_weights_partial_ignore_strip_is_weight_bearing(): + """An ignore-only request is weightless ONLY when it strips EVERY weight format. A partial strip + -- only the canonical model.safetensors/pytorch_model.bin names while a variant survives, or only + some suffixes while a .pt / .gguf weight survives -- must read as weight-bearing, so the fast path + requires a real weight and never skips the protective child on a config-only cache (the Xet hang).""" + r = hcs.request_can_include_weights + assert r(None, ["model.safetensors", "pytorch_model.bin"]) is True # variant / other-format survives + assert r(None, ["*.safetensors", "*.bin"]) is True # .pt / .gguf / .pth / ... survive + # Only stripping EVERY weight format is weightless. + all_formats = ["*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", "*.ckpt", + "*.onnx", "*.msgpack", "*.h5", "*.pdparams"] + assert r(None, all_formats) is False + + +def test_pre_download_skips_complete_model(tmp_path): + snap, blob = _mk_snapshot(tmp_path, "m") + (snap / "model.safetensors").symlink_to(blob) + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_pre_download_does_not_skip_diffusers_but_post_accepts(tmp_path): + """The pre/post asymmetry: a diffusers warm is NOT fast-pathed (spawn the child), but the same + complete diffusers result IS accepted post-download (it has component weights), so a good + download is never re-looped into a stall error.""" + snap, blob = _mk_snapshot(tmp_path, "diff") + (snap / "model_index.json").write_text("{}") + comp = snap / "unet" + comp.mkdir() + (comp / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_rejects_config_only_model(tmp_path): + """A model warm that came back with NO weight (HF handed back a stale config-only snapshot) is + rejected post-download and retried over HTTP.""" + snap, _ = _mk_snapshot(tmp_path, "cfg") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_accepts_dataset_without_weight(tmp_path): + snap, blob = _mk_snapshot(tmp_path, "ds") + (snap / "data.parquet").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "dataset", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_accepts_either_format_single_present(tmp_path): + """An either-format named request (['pytorch_model.bin','model.safetensors']) against a repo that + ships only safetensors: the finished download has a weight, so it is accepted -- not re-looped + for the absent .bin the repo simply does not publish.""" + snap, blob = _mk_snapshot(tmp_path, "either") + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["pytorch_model.bin", "model.safetensors"], ignore_patterns = None) is True + + +def test_pre_download_skips_intact_tokenizer_only(tmp_path): + """A tokenizer-only (weightless) warm short-circuits offline: an intact requested subset is + enough, no weight required.""" + snap, _ = _mk_snapshot(tmp_path, "tok") + (snap / "tokenizer.json").write_text("{}") + (snap / "config.json").write_text("{}") + assert xf._cache_can_skip_download( + snap, repo_type = "model", + allow_patterns = ["tokenizer.json", "config.json"], ignore_patterns = None) is True + + +def test_pre_download_partial_ignore_does_not_skip_config_only(tmp_path): + """Over-accept guard (safety reviewer finding): an ignore-only request stripping only SOME weight + formats (ignore=['*.safetensors','*.bin']) on a config-only cache must NOT skip the child -- a + repo whose surviving weight is e.g. model.gguf / model.fp16.safetensors / a .pt checkpoint would + otherwise be fetched in-process over un-killable Xet (the hang).""" + snap, _ = _mk_snapshot(tmp_path, "cfgign") + (snap / "config.json").write_text("{}") + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, + ignore_patterns = ["*.safetensors", "*.bin"]) is False diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 25f743252..13d83dc13 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -16,6 +16,16 @@ whether an ``.incomplete`` partial is present. The no-progress download watchdog is built on exactly these two signals. +The completeness check here is intentionally a CONSERVATIVE fast-path gate, not an +authoritative snapshot verifier. It returns "complete" only for the unambiguous +canonical model-cache layouts whose local evidence proves an in-process load will +not fetch a weight. Everything else (diffusers pipelines, weight variants, +non-trivial allow/ignore patterns, datasets, any layout needing inference) returns +"not complete" so the caller runs the authoritative Hugging Face download/resume in +the watched child. Returning a false "complete" is the only dangerous error (it can +send an in-process load to fetch a missing weight over un-killable Xet); returning a +false "not complete" only spawns the cheap watched child, so the gate errs that way. + Only the single active cache root (``huggingface_hub.constants.HF_HUB_CACHE``) is scanned here; multi-root / legacy-cache enumeration and transport-marker logic are download-manager concerns that live in the consumer, not in this module. @@ -24,7 +34,6 @@ from __future__ import annotations import fnmatch -import re import sys from pathlib import Path from typing import Iterator, Optional @@ -200,6 +209,10 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: return False +# --------------------------------------------------------------------------- +# Weight-file recognition (small helpers the conservative completeness gate needs) +# --------------------------------------------------------------------------- + # Model weight file extensions. A snapshot with none of these is config/tokenizer # only (e.g. a prior AutoConfig fetch), so it is not a warm cache for a weight load. _WEIGHT_FILE_SUFFIXES = ( @@ -217,8 +230,7 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: # Trainer / optimizer state files carry weight suffixes (.bin / .pt / .pth) but are NOT # loadable model weights. A checkpoint dir or a patterned pull can leave only these behind, -# so they must not satisfy the "snapshot holds its weights" check (which would skip the -# killable download while from_pretrained still lacks real weights). +# so they must not count as a model weight on disk. _NON_WEIGHT_BASENAMES = frozenset({ "training_args.bin", "optimizer.bin", @@ -231,11 +243,6 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: }) # Distributed trainer runs shard the RNG state as rng_state_0.pth, rng_state_1.pth, ... _NON_WEIGHT_BASENAME_PREFIXES = ("rng_state_",) -# The stems (basename without the suffix) of the trainer-artifact names above -- optimizer, -# scheduler, scaler, rng_state, training_args. A trailing-wildcard glob over one of these -# (``scheduler*``, ``rng_state*``) selects only a trainer artifact, so the synthetic-weight-suffix -# probe must NOT classify it as weight-including. -_NON_WEIGHT_STEMS = frozenset(name.rsplit(".", 1)[0] for name in _NON_WEIGHT_BASENAMES) def _is_loadable_weight_file(name: str) -> bool: @@ -253,87 +260,6 @@ def _is_loadable_weight_file(name: str) -> bool: return True -# Numbered shard naming, e.g. ``model-00001-of-00002.safetensors`` or -# ``pytorch_model-00003-of-00004.bin``: prefix, 1-based index, total, suffix. -_NUMBERED_SHARD_RE = re.compile( - r"^(?P.+)-(?P\d+)-of-(?P\d+)(?P\.[^.]+)$" -) - - -def _filter_paths( - paths: list, - allow_patterns: "Optional[list]" = None, - ignore_patterns: "Optional[list]" = None, -) -> list: - """Filter repo-relative *paths* by Hugging Face allow / ignore patterns, mirroring how - ``snapshot_download`` selects files. On any failure, treat all paths as selected so a - snapshot that does hold weights is never rejected for an unevaluable filter.""" - try: - from huggingface_hub.utils import filter_repo_objects - - return list( - filter_repo_objects( - paths, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns - ) - ) - except Exception: - return list(paths) - - -def _numbered_shard_set_present( - entry: Path, - *, - snapshot_dir: "Optional[Path]" = None, - allow_patterns: "Optional[list]" = None, - ignore_patterns: "Optional[list]" = None, -) -> bool: - """For a numbered weight shard (``model-00001-of-00002.safetensors``), True only when - every shard in its ``-of-NNNNN`` set that the request selects is present in the same - directory. - - A leftover single shard from an interrupted multi-shard download reads as a weight - file on its own, so without this an incomplete pull (one shard on disk, the rest - never fetched) would short-circuit as a warm cache. This catches that even when the - shard *index* sidecar was never cached (so ``_weight_shard_index_complete`` has - nothing to check). A non-numbered / single-file weight matches no shard pattern and - is trivially satisfied. - - When *allow_patterns* / *ignore_patterns* are given, a sibling shard is required only - if the request actually selects it: a deliberate single-shard request - (``allow_patterns=["model-00002-of-00005.safetensors"]``) is satisfied by that one shard - and must not demand the rest.""" - match = _NUMBERED_SHARD_RE.match(entry.name) - if match is None: - return True - total_str = match.group("total") - try: - total = int(total_str) - except ValueError: - return True - if total <= 0: - return True - prefix = match.group("prefix") - suffix = match.group("suffix") - width = len(total_str) - base = entry.parent - scoped = bool(allow_patterns or ignore_patterns) and snapshot_dir is not None - for i in range(1, total + 1): - shard_path = base / f"{prefix}-{i:0{width}d}-of-{total_str}{suffix}" - if scoped: - try: - rel = shard_path.relative_to(snapshot_dir).as_posix() - except ValueError: - rel = shard_path.name - if not _filter_paths([rel], allow_patterns, ignore_patterns): - continue # this sibling is not part of the request -> do not require it - try: - if not shard_path.exists(): - return False - except OSError: - return False - return True - - def _is_weight_shard_index(name: str) -> bool: """True if *name* is a weight-shard index sidecar: the canonical ``model.safetensors.index.json`` / ``pytorch_model.bin.index.json`` AND the variant form @@ -372,150 +298,56 @@ def _weight_shard_index_complete(index_path: Path) -> bool: return True -def _has_root_shard_index(snapshot_dir: Path, entries: list) -> bool: - """True if a root-level weight-shard index sidecar is present on disk. Matches the canonical - ``model.safetensors.index.json`` / ``pytorch_model.bin.index.json`` AND the variant form - ``model.safetensors.index.fp16.json`` -- transformers' ``_add_variant`` inserts the variant - token before the trailing ``.json``, so a plain ``*.index.json`` suffix test would miss it and - a variant sharded checkpoint (whose shards ARE on disk and whose variant index IS present) would - be wrongly judged index-less. A subfolder index does not count -- a bare root load never reads - it. *entries* is the already-collected ``rglob`` listing, reused to avoid a second walk.""" - for entry in entries: - if not _is_weight_shard_index(entry.name): - continue - try: - rel = entry.relative_to(snapshot_dir).as_posix() - except ValueError: - rel = entry.name - if "/" in rel: - continue # a subfolder index the bare root load never reads - if _safe_is_file(entry): - return True - return False +# --------------------------------------------------------------------------- +# Pattern helpers (kept small: normalization + glob detection + HF filtering) +# --------------------------------------------------------------------------- - -# Diffusers pipeline subfolders that carry loadable WEIGHTS (every other declared component -- -# scheduler, tokenizer, feature_extractor, processor -- is config-only). A weight-bearing -# component whose subfolder exists but holds no weight is a partially fetched component, so the -# in-process pipeline load would still fetch the weight in-process over Xet. -_WEIGHT_BEARING_PIPELINE_DIRS = frozenset({ - "unet", - "transformer", - "vae", - "vqvae", - "movq", - "prior", - "decoder", - "text_encoder", - "text_encoder_2", - "text_encoder_3", - "image_encoder", - "safety_checker", - "controlnet", -}) +_GLOB_CHARS = ("*", "?", "[") -def _dir_has_any_file(path: Path) -> bool: - """True if *path* contains at least one regular file (recursively). A dangling symlink left by - an interrupted blob fetch is NOT a regular file (``is_file()`` follows the link and returns - False), so a component subfolder that only ever received pointer symlinks reads as having no - files -- i.e. as an unfinished component.""" - try: - for entry in path.rglob("*"): - if _safe_is_file(entry): - return True - except OSError: - return False - return False +def _has_glob(text: str) -> bool: + # A trailing-slash directory pattern ("unet/", "checkpoint-10/") is NOT an exact filename: + # Hugging Face's filter_repo_objects expands it to match everything under that directory (as + # if "unet/*"). Treat it as a wildcard so the strict exact-name checks do not look for a + # literal "unet/" entry and wrongly reject a fully cached directory / component download. + return text.endswith("/") or any(ch in text for ch in _GLOB_CHARS) -def _diffusion_pipeline_complete(snapshot_dir: Path, weight_dirs: set) -> bool: - """True unless a diffusers pipeline snapshot is missing a declared sub-model. A diffusers - pipeline lists its components in a root ``model_index.json`` where each non-``_`` key maps to a - ``[library, class]`` pair; a warm killed mid-pipeline can leave one component fully cached and - another entirely absent, and the in-process pipeline load would then fetch the missing - component over unprotected Xet (the silent-hang risk). Require every declared (non-null) - component's subfolder to exist with files, and every weight-bearing component - (unet / transformer / vae / text_encoder / ...) to carry a weight in *weight_dirs*. +def _as_pattern_list(patterns: "Optional[object]") -> "Optional[list]": + """Normalize an allow / ignore pattern argument to a list. Hugging Face accepts a bare + ``str`` as well as a list, and iterating the ``str`` form would walk it character by + character (so ``"checkpoint-10/*"`` would never match), misclassifying the request.""" + if patterns is None: + return None + if isinstance(patterns, str): + return [patterns] + return list(patterns) - Returns True (do not block) when there is no readable ``model_index.json`` -- a plain - transformers / GGUF snapshot, or a non-diffusion repo -- so only an actual pipeline warm is - affected. Intended for a FULL pipeline warm (no allow_patterns); a scoped subfolder request is - already validated by its own selection.""" - import json - index_path = snapshot_dir / "model_index.json" - if not _safe_is_file(index_path): - return True # not a diffusers pipeline (or an older layout) -- nothing pipeline-specific +def _filter_paths( + paths: list, + allow_patterns: "Optional[list]" = None, + ignore_patterns: "Optional[list]" = None, +) -> list: + """Filter repo-relative *paths* by Hugging Face allow / ignore patterns, mirroring how + ``snapshot_download`` selects files. On any failure, treat all paths as selected so a + snapshot that does hold weights is never rejected for an unevaluable filter.""" try: - with open(index_path, "r", encoding = "utf-8") as f: - data = json.load(f) - except (OSError, ValueError): - return True # unreadable index: defer to the generic checks rather than over-reject - if not isinstance(data, dict): - return True - for key, value in data.items(): - if isinstance(key, str) and key.startswith("_"): - continue # _class_name / _diffusers_version metadata, not a component - if not (isinstance(value, (list, tuple)) and len(value) == 2): - continue # not a [library, class] component spec - library, class_name = value - if library is None or class_name is None: - continue # an explicitly absent component (e.g. a disabled safety_checker) - component_dir = snapshot_dir / key - if not _safe_is_dir(component_dir) or not _dir_has_any_file(component_dir): - return False # a declared component's subfolder is missing / empty -- interrupted warm - if key in _WEIGHT_BEARING_PIPELINE_DIRS: - if key not in weight_dirs: - return False # the component dir exists but carries no weight -- partial component - # A SHARDED component is loadable locally only via its in-component index sidecar - # (diffusers, like transformers, never globs shard files), so a component holding - # numbered shards but no diffusion_pytorch_model.safetensors.index.json is incomplete -- - # an interrupted warm that dropped the tiny index blob would make the pipeline load fetch - # it in-process over unprotected Xet. Mirrors the root transformers index requirement. - try: - comp_entries = list(component_dir.rglob("*")) - except OSError: - comp_entries = [] - has_numbered_shard = any( - _NUMBERED_SHARD_RE.match(e.name) is not None - and _is_loadable_weight_file(e.name) - and _safe_is_file(e) - for e in comp_entries - ) - if has_numbered_shard and not _has_root_shard_index(component_dir, comp_entries): - return False - return True - + from huggingface_hub.utils import filter_repo_objects -def _has_pipeline_component_weight(snapshot_dir: Path, entries: list) -> bool: - """True if a recognized diffusers component subfolder (``unet/``, ``vae/``, ``text_encoder/``, - ...) holds a loadable weight. A weights-only warm (``allow_patterns=["*.safetensors"]``, or an - ignore list that drops ``model_index.json``) downloads a diffusers pipeline's component weights - but NOT ``model_index.json`` -- so the pipeline layout has to be recognized from the component - weights themselves, else the snapshot is misread as a non-pipeline root load and its (legitimate) - subfolder weights are dropped. Kept to the known weight-bearing component names so an arbitrary - precision/checkpoint subfolder (``BF16/``, ``checkpoint-10/``) is NOT mistaken for a pipeline.""" - for entry in entries: - try: - rel = entry.relative_to(snapshot_dir).as_posix() - except ValueError: - continue - parts = rel.split("/") - if ( - len(parts) >= 2 - and parts[0] in _WEIGHT_BEARING_PIPELINE_DIRS - and _is_loadable_weight_file(parts[-1]) - and _safe_is_file(entry) - ): - return True - return False + return list( + filter_repo_objects( + paths, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) + ) + except Exception: + return list(paths) def _broken_symlink_rel_paths(snapshot_dir: Path) -> list: """Repo-relative posix paths of every dangling symlink in *snapshot_dir* -- a referenced file whose blob is missing or still an ``.incomplete`` partial (an interrupted download). Empty when - none. Lets a completeness check scope the interrupted-download signal to the files a request + none. Lets the broken-symlink check scope the interrupted-download signal to the files a request actually selects, rather than rejecting the whole snapshot for a dangle outside the request.""" out: list = [] try: @@ -533,69 +365,6 @@ def _broken_symlink_rel_paths(snapshot_dir: Path) -> list: return out -# Catch-all allow patterns that select the WHOLE repo, exactly like an unpatterned warm. HF's -# fnmatch ``*`` spans ``/``, so a bare ``*`` (or ``**``) matches every path including checkpoint -# subdirs -- but a root ``from_pretrained`` still reads ROOT weights, so such a request must be -# treated like an unpatterned root warm (drop checkpoint-dir paths), not trusted as a deliberate -# checkpoint selection. -_CATCHALL_ALLOW_PATTERNS = frozenset({"*", "**"}) - - -def _is_pure_catchall(allow_patterns: "Optional[list]") -> bool: - """True when *allow_patterns* is a non-empty list whose every entry is a bare catch-all - (``*`` / ``**``). Such a list selects the whole repo just like an unpatterned warm, so a root - load still reads ROOT weights and a checkpoint-dir-only cache must not satisfy it. A list with - any path-bearing or name-specific pattern (``checkpoint-10/*``, ``model.safetensors``) is - trusted as-is -- a caller that names a checkpoint path opts back into it.""" - if not allow_patterns: - return False - return all(isinstance(p, str) and p.strip() in _CATCHALL_ALLOW_PATTERNS for p in allow_patterns) - - -def _targets_root_only(allow_patterns: "Optional[list]") -> bool: - """True when a request reads from the repo ROOT rather than a specific subfolder: there is no - ``allow_patterns`` at all, or every allow pattern is a no-slash name / glob (``*``, - ``*.safetensors``, ``model*``, ``config.json``). HF's fnmatch ``*`` spans ``/``, so such a glob - also matches nested ``subdir/...`` files, but the LOAD a bare ``from_pretrained`` performs reads - only root-level files. A path-bearing pattern (``checkpoint-10/*``, ``BF16/*``, ``unet/*``) - deliberately targets a subfolder and is trusted as-is.""" - if allow_patterns is None: - return True - if not allow_patterns: - return False # allow_patterns=[] selects nothing -- a scoped (empty) request, not a root warm - return all(isinstance(p, str) and "/" not in p for p in allow_patterns) - - -def _requested_scope_filter( - rels: list, - allow_patterns: "Optional[list]", - ignore_patterns: "Optional[list]", - *, - root_weights_only: bool = False, -) -> list: - """The subset of repo-relative *rels* a request selects. Applies the allow / ignore filter, then - drops paths a root load never reads: - - * *root_weights_only* (a root-level warm of a NON-pipeline model) drops EVERY subfolder path: - a bare ``from_pretrained`` reads only repo-ROOT files, so a weight that lives solely in a - subfolder (``BF16/model.safetensors``, ``fp16/``, ``checkpoint-500/``, an ``onnx/`` export) is - an alternate the load never reads and must neither satisfy the warm nor (as a dangling - symlink) block it. - * otherwise, when there is no ``allow_patterns`` (an UNPATTERNED or IGNORE-ONLY request) or the - allow list is a pure catch-all (``["*"]``), drop only per-checkpoint-dir paths -- this is the - diffusers-pipeline / path-trusting case where genuine component subfolders (``unet/``, - ``vae/``) must survive. - - A path-bearing ``allow_patterns`` is otherwise trusted as-is: a caller that names a subfolder - path opts back into it.""" - kept = _filter_paths(list(rels), allow_patterns, ignore_patterns) - if root_weights_only: - kept = [r for r in kept if "/" not in r] - elif allow_patterns is None or _is_pure_catchall(allow_patterns): - kept = [r for r in kept if not _path_under_checkpoint_dir(r)] - return kept - - def snapshot_has_requested_broken_symlinks( snapshot_dir: Path, *, @@ -608,619 +377,159 @@ def snapshot_has_requested_broken_symlinks( A dangling symlink marks an interrupted download, but for a scoped request only one for a requested file should reject the snapshot: a dangling root ``model.safetensors`` left by an earlier interrupted pull must not fail a weightless ``allow_patterns=["config.json"]`` request - whose config is on disk. - - For a MODEL repo the scoping mirrors ``snapshot_dir_is_complete``: a root load reads only root - files, so a dangling checkpoint-dir / subfolder symlink does not block it. A DATASET (or other - non-model) snapshot has no "root load reads only root files" notion -- every dangling symlink - for a selected path is an interrupted file that must reject the cache (e.g. a dangling - ``checkpoint-10/data.parquet`` in an unpatterned dataset pull), so the checkpoint-dir drop is - NOT applied there.""" + whose config is on disk. The allow / ignore filter mirrors ``snapshot_download`` selection, so a + dangle for an excluded file does not reject the cache. (``repo_type`` is accepted for signature + compatibility; the scoping is now purely the allow/ignore filter.)""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) broken = _broken_symlink_rel_paths(snapshot_dir) if not broken: return False - if repo_type == "model": - requested = _requested_scope_filter(broken, allow_patterns, ignore_patterns) - else: - requested = _filter_paths(broken, allow_patterns, ignore_patterns) - return bool(requested) - + return bool(_filter_paths(broken, allow_patterns, ignore_patterns)) -def snapshot_dir_is_complete( - snapshot_dir: Path, - *, - allow_patterns: "Optional[object]" = None, - ignore_patterns: "Optional[object]" = None, - require_named_weights: bool = False, -) -> bool: - """Best-effort check that a cached snapshot actually holds the requested model weights. - - ``snapshot_download(local_files_only=True)`` returns a snapshot dir whenever - ``refs/`` and ``snapshots/`` exist, even one left by a prior interrupted - or patterned download (a config-only snapshot from an ``AutoConfig`` fetch, or a - partial shard pull). A dangling-symlink check alone misses those: the missing files - were never symlinked, so nothing dangles. Treating such a snapshot as a warm cache - skips the killable child and lets the in-process load hit Xet on the absent weights. - - A snapshot is complete only when it has no dangling symlinks, every weight-shard - index it ships resolves all its shards on disk, every numbered shard set present has - all its members on disk (even with no index sidecar), and it contains at least one - weight file. This does NOT assert that every non-weight file is present (no offline - manifest exists for that); the killable child completes anything else still missing. - The aim is simply to never short-circuit a snapshot whose weights are not on disk. - - When *allow_patterns* / *ignore_patterns* are given, the weight that must be present is - one the request actually selects: a request for ``adapter_model.safetensors`` (or a - specific checkpoint shard) is satisfied only by that weight on disk, not by some other - weight the snapshot happens to also carry. A deliberate single-shard request likewise - requires only that shard, not its whole ``-of-NNNNN`` set. With no patterns, any loadable - weight does, and every numbered shard set present must be complete. - - *require_named_weights* additionally requires every explicitly named exact weight in - *allow_patterns* (e.g. ``["model.safetensors", "adapter_model.safetensors"]``) to be on - disk, so a stale cache holding only one of them is not treated as complete. Off by default - (used by the pre-download cache probe); a glob still selects a subset freely.""" - try: - entries = list(snapshot_dir.rglob("*")) - except OSError: - return False - allow_patterns = _as_pattern_list(allow_patterns) - ignore_patterns = _as_pattern_list(ignore_patterns) - # An empty allow list is a real (select-nothing) filter, not "unpatterned": treat any - # non-None patterns as a scoped request so allow_patterns=[] does not fall into the full - # warmup branch (consistent with request_can_include_weights). - has_patterns = allow_patterns is not None or ignore_patterns is not None - - # A root-level warm (no path-bearing allow pattern) of a NON-pipeline model reads only repo-ROOT - # files, so a weight under any subfolder (BF16/, fp16/, a checkpoint dir) is an alternate the - # load never reads. A diffusers pipeline is the exception -- its component weights legitimately - # live in subfolders (unet/, vae/, text_encoder/), validated by _diffusion_pipeline_complete -- - # so it keeps the (narrower) checkpoint-dir scoping instead. - is_pipeline = _safe_is_file(snapshot_dir / "model_index.json") or _has_pipeline_component_weight( - snapshot_dir, entries - ) - root_weights_only = _targets_root_only(allow_patterns) and not is_pipeline - - # A dangling symlink marks an interrupted download, but only one for a file the request - # actually selects should reject the snapshot. A stale dangling root model.safetensors must - # not fail an allow_patterns=["adapter_model.safetensors"] probe whose adapter weight IS on - # disk, so scope the broken-symlink check to the requested files (and, for a root warm, drop the - # subfolder / checkpoint-dir paths the bare load never reads) -- the same selection - # _requested_scope_filter applies to the weights below. - broken = _broken_symlink_rel_paths(snapshot_dir) - if broken and _requested_scope_filter( - broken, allow_patterns, ignore_patterns, root_weights_only = root_weights_only - ): - return False - - index_entries: list = [] - weight_entries: list = [] # (entry, repo-relative path) - for entry in entries: - name = entry.name - if _is_weight_shard_index(name): - if _safe_is_file(entry): - index_entries.append(entry) - elif _is_loadable_weight_file(name) and _safe_is_file(entry): - try: - rel = entry.relative_to(snapshot_dir).as_posix() - except ValueError: - rel = name - weight_entries.append((entry, rel)) - - # The weights the request selects that are present on disk (any present root weight when the - # request is unpatterned). The snapshot can carry an unrelated weight while the requested one - # is missing, so a patterned request must find one it actually selects. For a root-level warm of - # a non-pipeline model, _requested_scope_filter drops EVERY subfolder weight (BF16/, fp16/, - # checkpoint-500/, an onnx/ export, left behind by a prior patterned pull): an UNPATTERNED, - # IGNORE-ONLY, or no-slash-glob (["*.safetensors"]) root warm is still a bare from_pretrained - # reading ROOT weights, so a subfolder-only snapshot must not read as warm. - selected = set(_requested_scope_filter( - [rel for _, rel in weight_entries], allow_patterns, ignore_patterns, - root_weights_only = root_weights_only, - )) - if not selected: - return False - - # A request that explicitly names exact files needs them on disk before a stale cache is - # short-circuited (pre-download) or accepted (post-download) past the guarded download. WHICH - # names are required depends on the request shape: - # * Each named NON-WEIGHT file (tokenizer.json, config.json) must be present -- but only for an - # exact-file request (no globs). A glob-bearing list treats aux names as best-effort (an - # optional vocab.txt / spiece.model the repo may lack), so unsloth's glob warms still - # short-circuit on a warm cache rather than re-downloading on every load. - # * Named WEIGHT files are grouped by LOGICAL weight: format / shard variants of the same - # weight share a key, and each group needs at least ONE variant present. So an "either - # format" list (["pytorch_model.bin", "model.safetensors"]) is satisfied by whichever the - # repo actually ships -- never an error-forever on a name that does not exist -- while a - # base + adapter list (["model.safetensors", "adapter_model.safetensors"]) is TWO groups and - # needs both, so a stale cache holding only the base is rejected. - # A name the ignore filter drops is not actually requested. - if require_named_weights and allow_patterns: - exact_only = not any(_has_glob(p) for p in allow_patterns) - if exact_only: - present = set() - for entry in entries: - if _safe_is_file(entry): - try: - present.add(entry.relative_to(snapshot_dir).as_posix()) - except ValueError: - present.add(entry.name) - else: - present = set(rel for _, rel in weight_entries) - weight_groups: dict = {} - for pat in allow_patterns: - if _has_glob(pat): - continue - if ignore_patterns and not _filter_paths([pat], None, ignore_patterns): - continue # a name the ignore filter drops is not actually requested - if str(pat).lower().endswith(_WEIGHT_FILE_SUFFIXES): - weight_groups.setdefault(_weight_logical_key(pat), []).append(pat) - elif exact_only: - # A named non-weight (tokenizer.json, config.json) is required as-is. A direct - # membership test (not _filter_paths, which fails OPEN by returning all paths on a - # filter error) keeps this fail-SAFE: an unevaluable case requires the guarded - # download rather than silently accepting a stale cache as warm. - if pat not in present: - return False - # Each logical weight group needs at least one of its named format / shard variants on disk - # (an interrupted shard SET is caught separately by the numbered-shard check below). - for names in weight_groups.values(): - if not any(n in present for n in names): - return False - # Every selected numbered shard needs the sibling shards the request also selects (the - # whole set when unpatterned), so an interrupted multi-shard pull is not read as warm. - for entry, rel in weight_entries: - if rel not in selected: - continue - if not _numbered_shard_set_present( - entry, snapshot_dir = snapshot_dir, - allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, - ): - return False +# --------------------------------------------------------------------------- +# The conservative fast-path completeness gate +# --------------------------------------------------------------------------- - # A root-wide warm -- no allow_patterns, a catch-all, OR a no-slash glob such as - # ignore_patterns=["*.onnx"] / allow_patterns=["*.safetensors", "*.json"] (all _targets_root_only) - # -- validates that every weight-shard index it would read ships all its shards: an index whose - # shards (numbered OR arbitrarily named) are not all on disk is an interrupted pull the - # in-process load would finish over Xet, which the numbered-shard check alone cannot catch for - # non-numbered shard names. A PATH-BEARING (scoped) request may legitimately want only a subset, - # so the index is not enforced there. A per-checkpoint index, and -- for a non-pipeline root load - # -- any subfolder index, is not what the root load reads, so it is skipped. - if _targets_root_only(allow_patterns): - for index_entry in index_entries: - try: - index_rel = index_entry.relative_to(snapshot_dir).as_posix() - except ValueError: - index_rel = index_entry.name - if _path_under_checkpoint_dir(index_rel): - continue - if root_weights_only and "/" in index_rel: - continue # a subfolder index a bare root from_pretrained never reads - if not _weight_shard_index_complete(index_entry): - return False +# Canonical root weight filenames an in-process model load reads. Used to prove a warm cache (the +# file or its shard index is present). +_CANONICAL_SINGLE_WEIGHTS = ("model.safetensors", "pytorch_model.bin") - # A sharded checkpoint is loadable locally ONLY through its index sidecar: transformers' - # from_pretrained resolves a local directory by probing model.safetensors then - # model.safetensors.index.json (then the .bin pair) -- it never globs model-*-of-*.safetensors -- - # so a cache holding every numbered shard but missing the index raises "no file named ..." or - # fetches the index in-process over Xet. For a root-level (non-pipeline) FULL warm (unpatterned - # or glob-bearing -- never a deliberate exact single-shard request, which wants only that file), - # require a root-level shard index when root numbered shards are present. _has_root_shard_index - # matches the variant form too (model.safetensors.index.fp16.json), so a variant sharded cache -- - # whose shards (model.fp16-00001-of-00002.safetensors) carry the variant in the regex prefix -- - # is not falsely rejected. - if root_weights_only and ( - allow_patterns is None or any(_has_glob(p) for p in allow_patterns) - ): - has_root_numbered_shard = any( - "/" not in rel and _NUMBERED_SHARD_RE.match(rel) is not None - for _, rel in weight_entries - ) - if has_root_numbered_shard and not _has_root_shard_index(snapshot_dir, entries): - return False - # A FULL pipeline warm -- no allow_patterns, a pure catch-all ``["*"]``, or a no-slash file glob - # (``["*.safetensors", "*.json"]``) that HF's matcher spreads across every nested component -- - # must carry every sub-model a diffusers model_index.json declares: a warm killed mid-pipeline - # can leave one component cached and another entirely absent, which the in-process pipeline load - # would then fetch over unprotected Xet. A scoped (path-bearing) request targets its own subset, - # so the whole-pipeline rule does not apply there. - if _targets_root_only(allow_patterns): - weight_dirs = {rel.split("/", 1)[0] for _, rel in weight_entries if "/" in rel} - if not _diffusion_pipeline_complete(snapshot_dir, weight_dirs): +def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: + """True iff the ignore set provably excludes EVERY weight format: for each weight suffix there is + a pattern matching a representative filename with that suffix. Only then is an ignore-only request + weightless. A partial strip -- only some suffixes, or only the canonical ``model.safetensors`` / + ``pytorch_model.bin`` names while a variant (``model.fp16.safetensors``) or an other-format weight + (``model.gguf``, a ``*.pt`` checkpoint) survives -- is NOT weightless, so the request reads as + weight-bearing (conservative: never under-classify a request that could still pull a weight, which + would let the fast path skip the protective child on a config-only cache and hang on Xet).""" + for suffix in _WEIGHT_FILE_SUFFIXES: + probe = "weight" + suffix + if not any(isinstance(p, str) and fnmatch.fnmatchcase(probe, p) for p in ignore_patterns): return False return True -# Representative loadable-weight filenames -- the probe set for "can this request include a -# weight file". One per recognized format and naming convention (full model, sharded, PEFT -# adapter, consolidated / original checkpoint, diffusers), so a weight-selecting glob like -# ``adapter_model.*`` or ``consolidated.*`` matches a probe and is not misread as weightless. -# The shard *index* sidecars (``*.safetensors.index.json`` / ``*.bin.index.json``) are -# intentionally absent: they are JSON metadata, not weights, so a metadata-only request such -# as ``allow_patterns=["*.json"]`` (or ``["*.index.json"]``) must read as weightless. -_WEIGHT_PROBE_NAMES = ( - "model.safetensors", - "model-00001-of-00002.safetensors", - "pytorch_model.bin", - "pytorch_model-00001-of-00002.bin", - "adapter_model.safetensors", - "adapter_model.bin", - "consolidated.00.pth", - "consolidated.safetensors", - "diffusion_pytorch_model.safetensors", - "diffusion_pytorch_model.bin", - "tf_model.h5", - "flax_model.msgpack", - "model.gguf", - "model.pt", - "model.pth", - "model.ckpt", - "model.onnx", - "model.pdparams", -) - -_GLOB_CHARS = ("*", "?", "[") - - -def _has_glob(text: str) -> bool: - # A trailing-slash directory pattern ("unet/", "checkpoint-10/") is NOT an exact filename: - # Hugging Face's filter_repo_objects expands it to match everything under that directory (as - # if "unet/*"). Treat it as a wildcard so the strict exact-name checks do not look for a - # literal "unet/" entry and wrongly reject a fully cached directory / component download. - return text.endswith("/") or any(ch in text for ch in _GLOB_CHARS) - - -# Weight stems that are format-family variants of the SAME logical weight (Transformers reads one): -# the PyTorch / TF / Flax / safetensors "model" forms collapse to one key, so an "either format" -# named request is satisfied by whichever variant the repo actually ships. -_WEIGHT_FORMAT_FAMILY = { - "pytorch_model": "model", - "tf_model": "model", - "flax_model": "model", - "model": "model", -} - - -def _weight_logical_key(name: str) -> tuple: - """A grouping key for a named weight file so format / shard variants of the SAME logical weight - share it. Keyed by (directory, normalized stem): the weight suffix and any ``-NNNNN-of-NNNNN`` - shard suffix are stripped, and the pytorch_model / tf_model / flax_model / model family collapses - to ``model``. So ``["pytorch_model.bin", "model.safetensors"]`` is ONE group (either format - satisfies it) while ``["model.safetensors", "adapter_model.safetensors"]`` -- or the same stem in - two different subdirs -- are separate groups, each independently required.""" - norm = name.replace("\\", "/") - dirname, _, base = norm.rpartition("/") - base = base.lower() - for suf in _WEIGHT_FILE_SUFFIXES: - if base.endswith(suf): - base = base[: -len(suf)] - break - base = re.sub(r"-\d+-of-\d+$", "", base) - return (dirname, _WEIGHT_FORMAT_FAMILY.get(base, base)) - - -def _as_pattern_list(patterns: "Optional[object]") -> "Optional[list]": - """Normalize an allow / ignore pattern argument to a list. Hugging Face accepts a bare - ``str`` as well as a list, and iterating the ``str`` form would walk it character by - character (so ``"checkpoint-10/*"`` would never match), misclassifying the request.""" - if patterns is None: - return None - if isinstance(patterns, str): - return [patterns] - return list(patterns) - - -# Stems that, by convention, name a per-checkpoint DIRECTORY (whose weights live inside), -# not a file. Used to disambiguate a dotted no-slash glob like ``checkpoint-v1.*`` (a -# checkpoint directory, weights nested) from a file glob like ``tokenizer.*`` -- both are -# structurally ``.*`` but only the former can include weights. -_CHECKPOINT_DIR_PREFIXES = ( - "checkpoint", "ckpt", "epoch", "step", "global_step", "iter", "iteration", -) - - -def _looks_like_checkpoint_dir(pattern: str) -> bool: - lowered = pattern.lower() - return any(lowered.startswith(prefix) for prefix in _CHECKPOINT_DIR_PREFIXES) - - -def _path_under_checkpoint_dir(rel: str) -> bool: - """True when a repo-relative *rel* lives inside a per-checkpoint directory - (``checkpoint-500/model.safetensors``, ``global_step1000/pytorch_model.bin``). Only the - PARENT components are checked -- the final component is the filename itself. Used to keep a - checkpoint-dir weight from satisfying an unpatterned (root-model) warmup: such a weight is - what a prior ``allow_patterns=["checkpoint-500/*"]`` pull leaves behind, not the root weight - a bare ``from_pretrained`` reads.""" - parts = rel.split("/") - return any(_looks_like_checkpoint_dir(p) for p in parts[:-1] if p) - - -def _bracket_member(content: str) -> str: - """A single character that a glob ``[...]`` class *matches*, for concretizing a bracket - expression into a probe that still satisfies the caller's own pattern. ``[0-9]`` -> ``0``, - ``[a-z]`` -> ``a``; a negated class (``[!...]`` / ``[^...]``) -> a filler the class does - not exclude. Replacing the class with a non-member (a literal ``x`` for ``[0-9]``) would - make the probe fail the caller's pattern and misread the request as weightless.""" - negated = content[:1] in ("!", "^") - if not negated: - # The first listed item is a member: a literal char, or the low end of a leading range. - return content[0] if content else "x" - # Negated: pick a filler the class does not exclude (fnmatch mirrors HF's matcher). - try: - cls = "[" + content + "]" - for cand in ("x", "0", "a", "z", "9", "_", "-", "A"): - if fnmatch.fnmatch(cand, cls): - return cand - except Exception: - pass - return "x" - - -def _concretize_glob(pattern: str) -> str: - """Replace glob wildcards in *pattern* with a literal filler so it can stand in as a - concrete directory name (e.g. ``checkpoint-*`` -> ``checkpoint-x``). A ``[...]`` class - collapses to one member char it actually matches (so the probe still satisfies the - pattern). Used to probe weights nested under a no-slash directory glob, since Hugging - Face's ``fnmatch`` ``*`` spans ``/``.""" - out = [] - i = 0 - n = len(pattern) - while i < n: - ch = pattern[i] - if ch in ("*", "?"): - out.append("x") - i += 1 - elif ch == "[": - j = pattern.find("]", i + 1) - if j != -1: - out.append(_bracket_member(pattern[i + 1 : j])) - i = j + 1 - else: - out.append("x") # unterminated class: treat "[" as a literal filler - i += 1 - else: - out.append(ch) - i += 1 - return "".join(out) - - -# Representative NON-weight files a catch-all ("*") or a config / tokenizer glob ("*.json") -# would also match -- used to tell a weight-specific basename (model.*, *.safetensors) from a -# catch-all when deciding whether a path-qualified request under a plain subfolder targets -# weights. Not exhaustive; just enough common names to disqualify a non-weight glob. -_NON_WEIGHT_PROBE_NAMES = ( - "config.json", - "tokenizer.json", - "tokenizer.model", - "tokenizer_config.json", - "special_tokens_map.json", - "generation_config.json", - "preprocessor_config.json", - # Processor metadata: a processor-only warm (allow_patterns=["processor*"]) selects these and no - # weight, so a representative must be here for _basename_is_non_weight to read the glob as - # metadata-only (else the snapshot is wrongly rejected for lacking a weight). - "processor_config.json", - "video_preprocessor_config.json", - # Chat-template metadata: a template-only warm (allow_patterns=["chat_template*"]) selects these - # and no weight, so a representative must be here -- otherwise the glob is misread as a weight - # directory and a template-only snapshot is wrongly rejected for lacking a weight. - "chat_template.json", - "chat_template.jinja", - "added_tokens.json", - # SentencePiece / slow-tokenizer vocab assets a tokenizer-only warm selects with a no-slash glob - # (allow_patterns=["spiece*"], ["sentencepiece*"], ["spm*"]). Without a representative the glob - # reads as a weight directory and a tokenizer-only snapshot is wrongly rejected for lacking a - # weight. tokenizer.model is already listed above. - "spiece.model", - "sentencepiece.bpe.model", - "spm.model", - "source.spm", - "target.spm", - "bpe.codes", - "vocab.bpe", - "normalizer.json", - "vocab.json", - "merges.txt", - "readme.md", - "training_args.bin", - "optimizer.pt", -) - - -# Subfolders that, by convention, hold only auxiliary / telemetry / config files -- never model -# weights. A catch-all glob under one of these (tokenizer/*, runs/*, scheduler/*) is read as -# weightless. Kept deliberately narrow: an unknown subfolder (unet/, transformer/, original/, a -# new arch's component dir) must stay weight-including, so a weight-bearing dir is never misread -# as weightless (that would re-open the silent-Xet-hang accept-stale this module exists to -# prevent). The Diffusers / multimodal preprocessing components listed here (scheduler/, -# feature_extractor/, processor/, image_processor/, the extra tokenizers) ship only -# *_config.json / vocab files; the weight-bearing pipeline dirs (unet/, transformer/, vae/, -# text_encoder*/, image_encoder/, safety_checker/) are deliberately absent so a catch-all under -# them stays weight-including. -_NON_WEIGHT_DIR_NAMES = frozenset({ - "tokenizer", "tokenizer_2", "tokenizer_3", "runs", "run", "logs", "log", "samples", "sample", - "tensorboard", "tb", "events", "eval", "evals", "evaluation", "metrics", "wandb", "assets", - "images", "media", "scheduler", "feature_extractor", "processor", "image_processor", -}) - - -def _basename_targets_weight(basename: str) -> bool: - """True when a path-qualified pattern's basename specifically selects a model weight - (``model.*``, ``adapter_model.*``, ``*.safetensors``), so the request is weight-including - even under a non-weight parent dir. A catch-all (``*``) matches both weights and non-weights - and a non-weight glob (``*.json``) matches no weight, so neither counts.""" - base = basename.lower() - if not any(fnmatch.fnmatchcase(name, base) for name in _WEIGHT_PROBE_NAMES): - return False - return not any(fnmatch.fnmatchcase(name, base) for name in _NON_WEIGHT_PROBE_NAMES) - - -def _basename_is_non_weight(basename: str) -> bool: - """True when a path-qualified pattern's basename clearly selects only non-weight files - (``*.json``, ``config.json``, ``tokenizer.*``, ``*.txt``): it matches a known non-weight - representative but no weight name. A catch-all (``*``) matches a weight too, so it is NOT - clearly non-weight and stays weight-including (the parent dir may hold weights).""" - base = basename.lower() - if not any(fnmatch.fnmatchcase(name, base) for name in _NON_WEIGHT_PROBE_NAMES): - return False - return not any(fnmatch.fnmatchcase(name, base) for name in _WEIGHT_PROBE_NAMES) - - -def _parent_is_non_weight_dir(prefix: str) -> bool: - """True when *prefix* is a known auxiliary / telemetry dir (tokenizer/, runs/, logs/) and no - component looks like a checkpoint / weight dir, so a catch-all glob under it holds no weights. - An unknown subfolder returns False (stays weight-including) to avoid accept-stale.""" - parts = [p.lower() for p in prefix.split("/") if p] - if any(_looks_like_checkpoint_dir(p) for p in parts): - return False - return any(p in _NON_WEIGHT_DIR_NAMES for p in parts) - - -def _weight_self_probe(pattern: str) -> "Optional[str]": - """A concretized stand-in for *pattern* when it names a loadable model weight by suffix - (``lora_*.safetensors`` -> ``lora_x.safetensors``, ``checkpoint-10/lora_*.bin`` -> - ``checkpoint-10/lora_x.bin``, a bare ``model-00002-of-00005.safetensors``), so a custom - weight basename that matches no canonical probe is still recognized. Returns None when the - suffix is not a weight suffix, or when the (concretized) basename is a known trainer / - optimizer artifact (``optimizer.pt``, ``training_args.bin``, ``rng_state_*.pth``): those - carry weight suffixes but the snapshot completeness check filters them out as non-weights, - so classifying such a request as weight-including would loop the guarded download. - - Also recognizes a trailing-wildcard variant glob whose wildcard would ABSORB the weight suffix - (``model.fp16*`` -> ``model.fp16.safetensors``, ``pytorch_model.fp16*`` -> - ``pytorch_model.fp16.bin``): such a glob does not literally end in a weight suffix, but it does - select a real variant weight, so it must read as weight-including. A clearly non-weight glob - (``tokenizer*``, ``config*``, ``spiece*``) is excluded so a metadata-only warm stays weightless.""" - if pattern.lower().endswith(_WEIGHT_FILE_SUFFIXES): - concrete = _concretize_glob(pattern) - basename = concrete.rsplit("/", 1)[-1] - if not _is_loadable_weight_file(basename): - return None - return concrete - # Trailing-wildcard glob whose wildcard ABSORBS the weight suffix (``model.fp16*`` -> - # ``model.fp16.safetensors``, ``unet/diffusion_pytorch_model.fp16*`` -> the same re-rooted under - # ``unet/``): try each weight suffix in place of the trailing wildcard. Applies to a path-qualified - # basename too, so a subfolder variant-component glob is not misread as weightless. - if pattern.endswith(("*", "?")): - prefix, _, base = pattern.rpartition("/") - if not _basename_is_non_weight(base): - stem = _concretize_glob(base.rstrip("*?")) - stem_lower = stem.lower() - # A stem that is itself a trainer artifact (scheduler*, rng_state*, optimizer*) selects no - # weight; a stem already ending in ``.index`` or a weight suffix is an index sidecar / - # canonical-probe case the synthetic suffix would only corrupt (model.safetensors.index*, - # model.safetensors*). Skip both so the request is not over-classified weight-including. - is_artifact_stem = ( - stem_lower in _NON_WEIGHT_STEMS - or stem_lower.startswith(_NON_WEIGHT_BASENAME_PREFIXES) - ) - is_sidecar_stem = stem_lower.endswith(".index") or stem_lower.endswith(_WEIGHT_FILE_SUFFIXES) - if stem and not is_artifact_stem and not is_sidecar_stem: - for suffix in _WEIGHT_FILE_SUFFIXES: - candidate_base = stem + suffix - if fnmatch.fnmatchcase(candidate_base, base) and _is_loadable_weight_file(candidate_base): - if not prefix: - return candidate_base - concrete_prefix = _concretize_glob(prefix) if _has_glob(prefix) else prefix - return f"{concrete_prefix}/{candidate_base}" - return None +def _pattern_can_select_weight(pattern: "object") -> bool: + """Whether a single allow pattern could select a model weight file. Conservative: a wildcard + basename (``unet/*``, ``adapter_model*``) or a directory pattern (``unet/``) could absorb a + weight, as does a basename ending in a weight suffix (``*.safetensors``, ``model.fp16.safetensors``, + ``model.gguf``). Only a basename ending in a concrete NON-weight extension (``config.json``, + ``*.py``, ``tokenizer.model``, ``additional_chat_templates/*.jinja``) is treated as weightless, + so a tokenizer / config allow list keeps its offline short-circuit while a real (sub)weight glob + is never under-classified.""" + if not isinstance(pattern, str): + return True # unknown shape -> conservative + if pattern.endswith("/"): + return True # a bare directory pattern expands to everything under it, incl. weights + base = pattern.rsplit("/", 1)[-1] + if base.endswith(("*", "?")): + return True # a trailing wildcard could absorb a weight suffix + return base.endswith(_WEIGHT_FILE_SUFFIXES) def request_can_include_weights( - allow_patterns: "Optional[list]" = None, ignore_patterns: "Optional[list]" = None + allow_patterns: "Optional[object]" = None, ignore_patterns: "Optional[object]" = None ) -> bool: - """Whether a download restricted by *allow_patterns* / *ignore_patterns* can still - include a model weight file. - - Used to decide whether snapshot completeness should require weights: a request that - filters every weight format out (e.g. ``ignore_patterns`` covering ``*.safetensors`` - and ``*.bin`` to fetch only config / tokenizer files from a model repo) legitimately - yields a weightless snapshot, so requiring a weight there would reject a valid result. - An unfiltered request -- or one any weight filename survives -- includes weights. - - Path-qualified requests are handled too: ``allow_patterns`` such as - ``["checkpoint-10/*"]`` or ``["models/*.safetensors"]`` probe the canonical weight - names re-rooted under that directory, and a bare non-first shard like - ``["model-00002-of-00005.safetensors"]`` is probed verbatim, so a request that does - target weights inside a subfolder / at a specific shard is not misread as weightless. - - *allow_patterns* / *ignore_patterns* accept the ``str`` or ``list[str]`` forms that - Hugging Face itself accepts.""" + """Whether a request restricted by *allow_patterns* / *ignore_patterns* can still include a model + weight. Used to pick the weight-requiring vs weightless branch of the acceptance check. + + Conservative by design: when uncertain it returns True (treat the request as weight-bearing), so + the acceptance check requires a weight and never short-circuits a config-only cache for a real + weight load. It returns False only when the request is clearly weightless (a tokenizer / config + allow list that matches no weight name, or an ignore list that drops every weight format), which + preserves the offline short-circuit for a genuine tokenizer-only warm.""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) - # Only a truly unfiltered request (both None) is an unconditional weight warmup. An empty - # allow list is NOT None: Hugging Face's filter_repo_objects treats allow_patterns=[] as - # selecting NO objects, so the request is weightless -- collapsing [] with None here would - # reject a legitimately empty snapshot and loop the guarded download. if allow_patterns is None and ignore_patterns is None: return True - try: - from huggingface_hub.utils import filter_repo_objects - except Exception: - return True # cannot evaluate the filter -> assume weights are expected - - probes = list(_WEIGHT_PROBE_NAMES) - for pat in (allow_patterns or ()): - # A concretized stand-in when the pattern itself names a loadable weight by suffix - # (lora_*.safetensors, checkpoint-10/lora_*.bin, a bare non-first shard). None for a - # non-weight suffix and for a known trainer artifact (optimizer.pt, training_args.bin), - # keeping this consistent with the snapshot completeness check. - self_probe = _weight_self_probe(pat) - if "/" in pat: - # Path-qualified: re-root the canonical weight probes under the parent dir - # (concretized when the parent is globbed) so the request is checked inside that - # directory. Default to re-rooting (weight-including), because an unknown subfolder - # (unet/, transformer/, original/, mp_rank_*/) may hold weights and reading it as - # weightless would accept a stale config-only cache -> the silent Xet hang. Skip the - # re-root only when the request is clearly non-weight: a non-weight basename glob - # (*.json, tokenizer.*, *.txt), or a catch-all under a known auxiliary dir - # (tokenizer/*, runs/*) that does not itself target a weight. A weight-suffix - # basename is still recognized by self_probe below; the final filter applies ignores. - prefix, base = pat.rsplit("/", 1) - clearly_weightless = _basename_is_non_weight(base) or ( - _parent_is_non_weight_dir(prefix) and not _basename_targets_weight(base) - ) - if not clearly_weightless: - concrete_parent = _concretize_glob(prefix) if _has_glob(prefix) else prefix - probes.extend(f"{concrete_parent}/{name}" for name in _WEIGHT_PROBE_NAMES) - elif ( - _has_glob(pat) - and ("." not in pat or _looks_like_checkpoint_dir(pat)) - and not _basename_is_non_weight(pat) - ): - # A no-slash DIRECTORY glob ("checkpoint-*", "global_step*", the dotted - # "checkpoint-v1.*"): HF's fnmatch "*" spans "/", so it matches nested weights like - # checkpoint-10/model.safetensors. Probe the canonical weights re-rooted under a - # concretized form of the glob. A plain extension file glob ("*.json", "tokenizer.*") - # is not a directory glob and stays weightless unless it names a weight (self_probe). - # A no-slash glob whose stem is a known metadata family ("tokenizer*", "config*", - # "vocab*", "special_tokens*") is a FILE glob, not a directory: _basename_is_non_weight - # excludes it so a tokenizer*-only warm that fetched tokenizer.json is not rejected for - # lacking a weight ("model*" / "pytorch_model*" stay weight-including -- they match a - # weight probe, so _basename_is_non_weight is False for them). - concrete = _concretize_glob(pat) - probes.extend(f"{concrete}/{name}" for name in _WEIGHT_PROBE_NAMES) - # A pattern that itself names a loadable weight -- a bare filename, a path-qualified - # name, or a weight-suffix glob whose stem matches no canonical probe (lora_*.safetensors, - # checkpoint-*/lora_*.bin) -- is recognized via its self-probe. "adapter_model.*" rides - # the canonical adapter probe instead, and a trainer artifact yields no self-probe and - # stays weightless. Everything is subject to the final ignore_patterns filter below. - if self_probe is not None: - probes.append(self_probe) + if allow_patterns is None: + # Ignore-only request: weight-bearing unless the ignore list strips every weight format. + return not _ignore_strips_all_weights(ignore_patterns or []) + if not allow_patterns: + # allow_patterns=[] selects nothing -> no weight (HF filter selects no objects). + return False + # An allow list includes weights iff SOME pattern could select a weight (wildcard basename, + # weight-suffix basename, or a bare directory pattern). A list of only concrete non-weight names + # (a tokenizer / config warm) is weightless and keeps its offline short-circuit. + return any(_pattern_can_select_weight(pat) for pat in allow_patterns) + + +def _canonical_root_weights_complete(snapshot_dir: Path, entries: list) -> bool: + """True iff the snapshot holds a complete canonical ROOT model weight set: a root + ``model.safetensors`` / ``pytorch_model.bin`` single file, OR a root weight-shard index whose + every listed shard is present. Numbered shard files without a valid index, or weights that live + only in a subfolder, do NOT count -- those are deferred to the watched child.""" + root_files: set = set() + root_indices: list = [] + for entry in entries: + try: + rel = entry.relative_to(snapshot_dir).as_posix() + except ValueError: + rel = entry.name + if "/" in rel: + continue # a bare from_pretrained reads ROOT files only + if _is_weight_shard_index(entry.name): + if _safe_is_file(entry): + root_indices.append(entry) + elif _safe_is_file(entry): + root_files.add(entry.name) + # Sharded: a canonical root index whose every listed shard is on disk. + for index_entry in root_indices: + if _weight_shard_index_complete(index_entry): + return True + # Single-file canonical weight. + return any(name in root_files for name in _CANONICAL_SINGLE_WEIGHTS) + +def snapshot_dir_is_complete( + snapshot_dir: Path, + *, + allow_patterns: "Optional[object]" = None, + ignore_patterns: "Optional[object]" = None, + require_named_weights: bool = False, +) -> bool: + """Conservative fast-path gate: True only when *snapshot_dir* is an unambiguously complete + canonical ROOT model cache, so an in-process load will not fetch any weight. + + This is intentionally NOT an authoritative snapshot verifier. It returns True only for: + - an UNPATTERNED request (allow_patterns is None; ignore_patterns are fine), + - that is not a diffusers pipeline (no root ``model_index.json``), + - with no dangling symlink (interrupted blob), + - whose canonical root weights are present (single file, or a shard index with every shard). + Every other layout -- variants, diffusers, datasets, any allow pattern, sharded weights without + an index -- returns False, deferring to the watched ``snapshot_download`` child (the authoritative + manifest compare + resume). A false True risks a silent un-killable Xet fetch during the in-process + load; a false False only spawns the cheap child. ``require_named_weights`` is accepted for signature + compatibility (a named-weight request is non-trivially patterned and so is never fast-pathed here). + + ``ignore_patterns`` need no eligibility gate: the canonical-weight presence check below verifies + what the in-process load actually reads (root ``model.safetensors`` / ``pytorch_model.bin`` or a + complete shard index) is on disk, so an ignore that dropped some weight format (the common + ``*.onnx`` / ``*.gguf`` / ``*.pt`` / ``*.bin`` prefetch ignores, or the subdir ``*/*.safetensors`` + drops) cannot make an incomplete cache read complete -- the surviving canonical weight is what is + checked. This keeps the common warm ``from_pretrained`` cache fast-path eligible.""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + # 1. Only an UNPATTERNED request is eligible. Any allow list scopes the on-disk set to a subset + # whose relationship to the in-process load is not locally provable -> defer to the child. + if allow_patterns is not None: + return False try: - kept = list( - filter_repo_objects( - probes, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns - ) - ) - except Exception: - return True - return len(kept) > 0 + entries = list(snapshot_dir.rglob("*")) + except OSError: + return False + # 2. A diffusers pipeline (root model_index.json) needs component-completeness reasoning we do + # not fast-path -> defer to the child. + if _safe_is_file(snapshot_dir / "model_index.json"): + return False + # 3. A dangling symlink = an interrupted blob (missing or still .incomplete) -> not complete. + if snapshot_dir_has_broken_symlinks(snapshot_dir): + return False + # 4. Canonical root weights present and complete. + return _canonical_root_weights_complete(snapshot_dir, entries) def requested_named_files_present( @@ -1267,6 +576,10 @@ def requested_named_files_present( return True +# --------------------------------------------------------------------------- +# Active-cache enumeration primitives (download-manager / watchdog support) +# --------------------------------------------------------------------------- + def _iter_snapshot_dirs(repo_dir: Path) -> Iterator[Path]: snapshots_dir = repo_dir / "snapshots" try: diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index f11edbce7..a9e5b1edf 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -49,13 +49,13 @@ from unsloth_zoo.hf_cache_state import ( INCOMPLETE_SUFFIX, + _is_loadable_weight_file, blob_bytes_present, has_active_incomplete_blobs, hf_cache_root, iter_active_repo_cache_dirs, request_can_include_weights, requested_named_files_present, - snapshot_dir_has_broken_symlinks, snapshot_dir_is_complete, snapshot_has_requested_broken_symlinks, ) @@ -959,47 +959,13 @@ def _run_download_attempt( return ("error", message) -def _snapshot_is_acceptable( +def _intact_subset( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, - require_named_weights: bool = False, ) -> bool: - """Whether a cached / downloaded snapshot dir is complete enough to use, scoped to the - caller's intent. - - Weight files are required only when the request can actually include them: a model - repo (``repo_type == "model"``) whose ``allow_patterns`` / ``ignore_patterns`` do not - filter every weight format out. This wrapper exists to warm those weights before an - in-process load, so a result with no weights then means the download did not finish (HF - silently returns a stale local snapshot on an offline / timed-out request rather than - raising). - - A PATTERNED or non-model snapshot that legitimately holds only a subset -- a dataset, or - a model repo fetched with ``allow_patterns=["config.json"]`` or ``ignore_patterns`` that - drop all weights -- would be wrongly rejected by a weight requirement, so for those it is - enough that no symlink dangles (every file the snapshot references is on disk). - - The completeness check is scoped to the requested patterns, so a request for a specific - weight (e.g. ``allow_patterns=["adapter_model.safetensors"]`` or a checkpoint shard) is - satisfied only when THAT weight is on disk, not by some other weight already cached. - - ``require_named_weights`` makes a request that explicitly names files require them on disk - (each named non-weight, and at least one format/shard variant of each named LOGICAL weight), - so a stale snapshot missing one is neither short-circuited (pre-download) nor accepted - (post-download). Format variants of one weight are grouped, so an "either format" name list - against a single-format repo is satisfied by whichever variant exists -- never an error-forever - on a name that does not exist in the repo.""" - if repo_type == "model" and request_can_include_weights(allow_patterns, ignore_patterns): - return snapshot_dir_is_complete( - snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, - require_named_weights = require_named_weights, - ) - # Weightless / non-model request (a dataset, or a model repo whose patterns drop every weight - # format, e.g. a tokenizer-only allow list): no weight is expected, so completeness is "no - # dangling symlink among the REQUESTED files". The broken-symlink check is scoped to the request - # (like snapshot_dir_is_complete), so a dangling EXCLUDED weight left by an earlier interrupted - # pull does not reject a complete config/tokenizer subset. An EXACT-named weightless request - # (allow_patterns=["tokenizer.json"], no globs) must still find its named files on disk -- HF can - # hand back a config-only snapshot dir that simply lacks the requested file. Globs stay best-effort. + """No interrupted-download evidence for the files the request SELECTS: no dangling requested + symlink, and every EXACT-named requested file present. Used for a weightless / non-model request + (a dataset, a tokenizer-only allow list) and as the breakage check for a finished download. A + dangling EXCLUDED weight from an earlier interrupted pull does not reject a complete subset.""" return ( not snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, @@ -1011,15 +977,78 @@ def _snapshot_is_acceptable( ) +def _has_any_weight(snapshot_dir: Path) -> bool: + """True if the snapshot holds at least one loadable model weight anywhere (root or a component + subfolder). Lenient on purpose: it only distinguishes a real model warm from the config-only + stale snapshot HF can hand back on an offline / timed-out request, without classifying layout.""" + try: + for entry in snapshot_dir.rglob("*"): + if _is_loadable_weight_file(entry.name): + try: + if entry.is_file(): + return True + except OSError: + continue + except OSError: + return False + return False + + +def _cache_can_skip_download( + snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, +) -> bool: + """PRE-download: whether a locally cached snapshot is complete enough that the in-process load + will not fetch anything, so the protective child can be skipped. + + STRICT for a weight-bearing model request: only the conservative canonical fast-path + (``snapshot_dir_is_complete``) may skip the child; anything uncertain (diffusers, variants, + non-trivial patterns, sharded-without-index) returns False -> spawn the child. A false True here + would let the in-process load fetch a missing weight over un-killable Xet (the hang). A weightless + / non-model request has no weight to hang on, so an intact requested subset is enough -- this + preserves the offline short-circuit for a tokenizer-only / dataset warm.""" + if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): + return snapshot_dir_is_complete( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + ) + return _intact_subset( + snapshot_dir, repo_type = repo_type, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, + ) + + +def _download_result_usable( + snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, +) -> bool: + """POST-download: whether the child's ``snapshot_download`` result is usable, or should be retried + over HTTP. snapshot_download already did the authoritative manifest compare + resume, so accept + unless there is POSITIVE evidence of a silent-Xet partial: a dangling REQUESTED symlink (a blob + that is missing or still ``.incomplete``), or a weight-bearing model warm that came back with NO + weight at all (HF handed back a stale config-only snapshot on an offline / timed-out request). + LENIENT otherwise -- a finished diffusers / variant / either-format download passes, and a named + file simply absent from the repo is not treated as missing -- so a good download is never failed + and re-looped into a ``DownloadStallError``.""" + if snapshot_has_requested_broken_symlinks( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + repo_type = repo_type, + ): + return False + if ( + repo_type in (None, "model") + and request_can_include_weights(allow_patterns, ignore_patterns) + and not _has_any_weight(snapshot_dir) + ): + return False + return True + + def _snapshot_payload_incomplete( payload: Any, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any ) -> bool: - """True when a snapshot download returned a real directory that is not acceptable for - the request (see ``_snapshot_is_acceptable``). Guarded to an existing directory so a - mocked / non-path payload (unit tests) or an unexpected return is trusted rather than - rejected; in production the child always returns a real snapshot dir, where this - catches HF handing back an existing partial snapshot on an offline / timed-out - request.""" + """True when a snapshot download returned a real directory that is not usable for the request + (see ``_download_result_usable``). Guarded to an existing directory so a mocked / non-path + payload (unit tests) or an unexpected return is trusted rather than rejected; in production the + child always returns a real snapshot dir, where this catches HF handing back an existing partial + snapshot on an offline / timed-out request.""" try: path = Path(payload) except (TypeError, ValueError, OSError): @@ -1031,15 +1060,9 @@ def _snapshot_payload_incomplete( return False except OSError: return False - # require_named_weights so a finished download that handed back a stale snapshot missing an - # explicitly named file -- a base + adapter list (["model.safetensors", - # "adapter_model.safetensors"]) where only the base materialized, or a weight + named - # tokenizer.json -- is still treated as incomplete and retried, not returned with files - # missing. Format / shard variants of one logical weight are grouped, so an "either format" - # list stays satisfied by whichever variant the repo actually ships (no error-forever). - return not _snapshot_is_acceptable( + return not _download_result_usable( path, repo_type = repo_type, allow_patterns = allow_patterns, - ignore_patterns = ignore_patterns, require_named_weights = True, + ignore_patterns = ignore_patterns, ) @@ -1360,20 +1383,18 @@ def snapshot_download_with_xet_fallback( # snapshots/ exist, even one left by a prior interrupted or patterned # download (a config-only snapshot from an AutoConfig fetch, or a partial # shard pull). Validate the EXACT returned revision dir against the request: - # a full model warmup requires its weight files on disk, a patterned / non-model - # request only its referenced files (no dangling symlinks). Complete it in the - # killable child otherwise, so the in-process load never proceeds with missing + # a full model warmup may skip the child only when its canonical weights are + # provably complete (the conservative fast-path gate); a patterned / non-model + # request only needs its referenced files (no dangling symlinks). Complete it in + # the killable child otherwise, so the in-process load never proceeds with missing # files. Scope the check to the returned snapshot, NOT the whole repo: an # unrelated revision mid-download (a stale .incomplete blob or a broken older # snapshot elsewhere in the same repo cache) must not force a needless re-fetch. - if _snapshot_is_acceptable( + if _cache_can_skip_download( Path(cached_dir), repo_type = repo_type, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, - # Pre-download short-circuit: require each explicitly named weight so a stale - # snapshot missing one (base present, adapter not) is completed, not accepted. - require_named_weights = True, ): return cached_dir logger.debug("Cached snapshot for %s is incomplete; downloading.", repo_id) From e0f7b431544945bcbf16d0b611c0e09f2a92fad5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 29 Jun 2026 02:46:58 +0000 Subject: [PATCH 45/82] Fix five review-round findings in the cache-completeness gate and spawn wiring A 10-reviewer pass over the trimmed gate surfaced three over-accept / over-reject correctness holes plus two robustness gaps. All five are addressed without re-growing the classifier or touching the watchdog / spawn / HTTP-retry mechanism. Over-accept (the cardinal-invariant direction -- a false "complete" lets the in-process load fetch a missing weight over un-killable Xet): - _weight_shard_index_complete is now fail-CLOSED. A truncated / unreadable index, a non-dict payload, a missing or non-dict weight_map, or an empty shard set all return False, so a malformed index defers to the watched child instead of reading as a complete shard set. - _canonical_root_weights_complete is ignore-aware. A root weight (or weight-shard index) whose FORMAT the request's ignore filter drops no longer counts: a stale pytorch_model.bin under ignore=['*.bin'] with no safetensors on disk defers to the child, so a use_safetensors load cannot silently fetch the real weight over Xet. The format probe also discards a pytorch_model.bin.index.json (whose .json sidecar name would slip the raw filter). The common safetensors warm under the bare from_pretrained ignore list (which includes *.bin) stays fast-path eligible -- model.safetensors survives the *.bin drop. Over-reject (a false "not usable" loops a good download into a spurious DownloadStallError): - _download_result_usable scopes its no-weight rejection to an UNPATTERNED model request (allow_patterns is None), the only shape that should always yield a weight. A patterned request is trusted: a genuinely weightless result (e.g. allow=['tokenizer*']) is taken as intended, not failed. The config-only stale-snapshot rejection stays in force for the unpatterned warm. Robustness: - HFValidationError is added to the deterministic error set. A malformed repo id / revision fails identically over either transport, so it is surfaced with its real type across the spawn boundary (no pointless HTTP retry, no degrade to a generic RuntimeError). - The UNSLOTH_HF_XET_FORCE_STALL test hook now holds the fake .incomplete blob OPEN for the whole stall. The snapshot watchdog finds it by filename, but the single-file watchdog (watch_new_partials_only) counts only partials the child PID holds open via _child_open_incomplete_blobs -- a closed file there was ignored, so single-file fault injection never tripped. Keeping the fd open lets both modes see it. Tests: add focused regression guards for each finding (malformed/empty/non-dict shard index defers; ignored canonical weight and ignored .bin index do not prove completeness while a safetensors warm under *.bin ignore still does; a weightless tokenizer* result is accepted post-download while an unpatterned config-only warm is still rejected; HFValidationError type is preserved and treated as non-retryable). Full suite 120 passed / 1 skipped; the 40k-layout safety fuzz stays at 0 violations; the real stall->HTTP e2e and the Studio de-dup surface check stay green. --- tests/test_hf_xet_fallback.py | 77 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 67 ++++++++++++++++++++++------- unsloth_zoo/hf_xet_fallback.py | 24 +++++++++-- 3 files changed, 149 insertions(+), 19 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index bb0632730..b7978babe 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2194,3 +2194,80 @@ def test_pre_download_partial_ignore_does_not_skip_config_only(tmp_path): assert xf._cache_can_skip_download( snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.safetensors", "*.bin"]) is False + + +# --------------------------------------------------------------------------- +# Review-round regression guards (10-reviewer findings) +# --------------------------------------------------------------------------- + +def test_gate_rejects_malformed_shard_index(tmp_path): + """Finding 2 (over-accept): a truncated / non-dict / empty weight-shard index must NOT read as + complete. _weight_shard_index_complete is fail-CLOSED so the fast path defers a malformed index + to the watched child rather than skipping it and failing the in-process load on the bad index.""" + snap, blob = _mk_snapshot(tmp_path, "malidx") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text("{not valid json") + assert hcs.snapshot_dir_is_complete(snap) is False + # Empty weight_map proves nothing. + snap2, blob2 = _mk_snapshot(tmp_path, "emptyidx") + (snap2 / "model-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model.safetensors.index.json").write_text(json.dumps({"weight_map": {}})) + assert hcs.snapshot_dir_is_complete(snap2) is False + # weight_map present but not a dict. + snap3, blob3 = _mk_snapshot(tmp_path, "listidx") + (snap3 / "model.safetensors.index.json").write_text(json.dumps({"weight_map": ["a", "b"]})) + assert hcs._weight_shard_index_complete(snap3 / "model.safetensors.index.json") is False + + +def test_gate_ignored_canonical_weight_does_not_prove_complete(tmp_path): + """Finding 3 (over-accept): a stale canonical weight whose FORMAT the request ignores must not + count as proof of completeness. ignore=['*.bin'] with only a pytorch_model.bin on disk (no + safetensors) defers to the child, so a use_safetensors load cannot silently fetch over Xet.""" + snap, blob = _mk_snapshot(tmp_path, "ignbin") + (snap / "config.json").write_text("{}") + (snap / "pytorch_model.bin").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.bin"]) is False + # Without the ignore, the present .bin is what a default load reads -> complete. + assert hcs.snapshot_dir_is_complete(snap) is True + # A .bin shard index is also discarded when *.bin is ignored (its .json sidecar would slip the + # raw name filter, but the format probe catches it). + snap2, blob2 = _mk_snapshot(tmp_path, "ignbinshard") + (snap2 / "pytorch_model-00001-of-00001.bin").symlink_to(blob2) + (snap2 / "pytorch_model.bin.index.json").write_text( + json.dumps({"weight_map": {"a": "pytorch_model-00001-of-00001.bin"}})) + assert hcs.snapshot_dir_is_complete(snap2, ignore_patterns = ["*.bin"]) is False + assert hcs.snapshot_dir_is_complete(snap2) is True + # A safetensors warm survives an *.bin ignore (the common bare from_pretrained case). + snap3, blob3 = _mk_snapshot(tmp_path, "stignbin") + (snap3 / "config.json").write_text("{}") + (snap3 / "model.safetensors").symlink_to(blob3) + assert hcs.snapshot_dir_is_complete(snap3, ignore_patterns = ["*.bin"]) is True + + +def test_post_download_accepts_weightless_patterned_result(tmp_path): + """Finding 1 (over-reject): a genuinely weightless PATTERNED result (e.g. allow=['tokenizer*']) + must be accepted post-download -- the caller scoped it, so 'no weight' is intended, not a stale + config-only snapshot. Rejecting it would loop into a spurious DownloadStallError on a good + download. The no-weight rejection stays in force for an UNPATTERNED model warm.""" + snap, _ = _mk_snapshot(tmp_path, "tokglob") + (snap / "tokenizer.json").write_text("{}") + (snap / "tokenizer_config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer*"], ignore_patterns = None) is True + # An unpatterned model warm with no weight is still rejected (stale config-only snapshot). + cfg, _ = _mk_snapshot(tmp_path, "cfgonly") + (cfg / "config.json").write_text("{}") + assert xf._download_result_usable( + cfg, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_hfvalidationerror_type_preserved_across_spawn(): + """Finding 4: a malformed repo id fails identically over either transport, so HFValidationError is + deterministic (not retried) and its TYPE is reconstructed across the spawn boundary instead of + degrading to a generic RuntimeError.""" + assert "HFValidationError" in xf._DETERMINISTIC_ERROR_NAMES + cls = xf._resolve_exception_class("HFValidationError") + assert cls is not None and issubclass(cls, BaseException) + inst = xf._instantiate_preserving_type(cls, "HFValidationError: bad repo id") + assert type(inst).__name__ == "HFValidationError" + assert xf._is_retryable_download_error(inst) is False diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 13d83dc13..7a4e2464e 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -270,26 +270,34 @@ def _is_weight_shard_index(name: str) -> bool: def _weight_shard_index_complete(index_path: Path) -> bool: - """True if every shard a HF weight index (``model.safetensors.index.json`` / - ``pytorch_model.bin.index.json``) lists is present next to the index. An unreadable - or non-sharded index is treated as satisfied (nothing extra to verify), so this only - ever rejects an index whose shards are demonstrably missing on disk.""" + """True only if every shard a HF weight index (``model.safetensors.index.json`` / + ``pytorch_model.bin.index.json``) lists is present next to the index. + + Fail-CLOSED: an unreadable / truncated index, a non-dict payload, a missing or non-dict + ``weight_map``, or an empty shard set all return False. This function feeds the fast-path + completeness gate, where a malformed index proves nothing -- treating it as complete would let + the in-process load skip the protective child and then fail (or fetch over Xet) on a truncated + index, so the safe direction is to defer such an index to the watched ``snapshot_download`` + child. Only an index whose every listed shard is demonstrably on disk returns True.""" import json try: with open(index_path, "r", encoding = "utf-8") as f: data = json.load(f) except (OSError, ValueError): - return True + return False weight_map = data.get("weight_map") if isinstance(data, dict) else None if not isinstance(weight_map, dict): - return True + return False # weight_map values are filenames relative to the index file's own directory. They come from # arbitrary JSON: a non-string (e.g. list/dict) value is both unhashable -- so it would break # set() -- and invalid for ``base / shard``, so filter to strings BEFORE de-duplicating rather - # than crash (consistent with the fail-open parse handling above). + # than crash. + shards = {s for s in weight_map.values() if isinstance(s, str)} + if not shards: + return False # an empty / all-non-string weight_map cannot prove a complete shard set base = index_path.parent - for shard in {s for s in weight_map.values() if isinstance(s, str)}: + for shard in shards: try: if not (base / shard).exists(): return False @@ -457,11 +465,21 @@ def request_can_include_weights( return any(_pattern_can_select_weight(pat) for pat in allow_patterns) -def _canonical_root_weights_complete(snapshot_dir: Path, entries: list) -> bool: +def _canonical_root_weights_complete( + snapshot_dir: Path, entries: list, ignore_patterns: "Optional[list]" = None +) -> bool: """True iff the snapshot holds a complete canonical ROOT model weight set: a root ``model.safetensors`` / ``pytorch_model.bin`` single file, OR a root weight-shard index whose every listed shard is present. Numbered shard files without a valid index, or weights that live - only in a subfolder, do NOT count -- those are deferred to the watched child.""" + only in a subfolder, do NOT count -- those are deferred to the watched child. + + A root weight (or weight-shard index) whose FORMAT the request's ignore filter drops does NOT + count: a stale ``pytorch_model.bin`` under ``ignore=['*.bin']`` is not proof that the + safetensors weights an in-process load (e.g. ``use_safetensors=True``) will actually read are on + disk, so it must not let the fast path skip the protective child and then hang fetching the real + weight over Xet. The surviving-format check uses a representative weight name per format, so a + ``*.bin`` ignore also discards a ``pytorch_model.bin.index.json`` (whose ``.json`` sidecar name + would otherwise slip past the filter).""" root_files: set = set() root_indices: list = [] for entry in entries: @@ -476,12 +494,28 @@ def _canonical_root_weights_complete(snapshot_dir: Path, entries: list) -> bool: root_indices.append(entry) elif _safe_is_file(entry): root_files.add(entry.name) - # Sharded: a canonical root index whose every listed shard is on disk. + + def _format_kept(weight_name: str) -> bool: + # The weight format an in-process load reads from *weight_name* must survive the request's + # ignore filter; otherwise the file is a stale artifact for an excluded format and proves + # nothing about what the load will fetch. + if not ignore_patterns: + return True + return bool(_filter_paths([weight_name], None, ignore_patterns)) + + # Sharded: a canonical root index whose format is kept and whose every listed shard is on disk. for index_entry in root_indices: - if _weight_shard_index_complete(index_entry): + fmt_probe = ( + "model.safetensors" + if ".safetensors.index." in index_entry.name + else "pytorch_model.bin" + ) + if _format_kept(fmt_probe) and _weight_shard_index_complete(index_entry): return True - # Single-file canonical weight. - return any(name in root_files for name in _CANONICAL_SINGLE_WEIGHTS) + # Single-file canonical weight (the file itself must survive the ignore filter). + return any( + name in root_files and _format_kept(name) for name in _CANONICAL_SINGLE_WEIGHTS + ) def snapshot_dir_is_complete( @@ -528,8 +562,9 @@ def snapshot_dir_is_complete( # 3. A dangling symlink = an interrupted blob (missing or still .incomplete) -> not complete. if snapshot_dir_has_broken_symlinks(snapshot_dir): return False - # 4. Canonical root weights present and complete. - return _canonical_root_weights_complete(snapshot_dir, entries) + # 4. Canonical root weights present and complete (a weight whose format the request ignores + # does not count -- see _canonical_root_weights_complete). + return _canonical_root_weights_complete(snapshot_dir, entries, ignore_patterns) def requested_named_files_present( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index a9e5b1edf..d376900f9 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -472,6 +472,10 @@ def _scrub_in_child(text: str, token: Optional[str]) -> str: "DisabledRepoError", "LocalEntryNotFoundError", "BadRequestError", + # A malformed repo id / revision fails identically over either transport (it never reaches the + # network), so surface it with its real type instead of a generic RuntimeError or a pointless + # HTTP retry. + "HFValidationError", }) # Substrings that mark a transient transport failure (hf_xet / CAS error, timeout, reset, # HTTP 5xx / 429) that disabling Xet and retrying over HTTP may recover. @@ -662,6 +666,7 @@ def _download_child_entry( # Test-only fault injection (never set in production): stall the Xet attempt # so the watchdog + HTTP fallback can be exercised against a real repo. if not disable_xet and os.environ.get("UNSLOTH_HF_XET_FORCE_STALL") == "1": + _stall_fh = None try: from huggingface_hub.constants import HF_HUB_CACHE @@ -673,8 +678,15 @@ def _download_child_entry( repo_dir_name = f"{repo_type or 'model'}s--" + repo_id.replace("/", "--") blobs = os.path.join(cache_root, repo_dir_name, "blobs") os.makedirs(blobs, exist_ok = True) - with open(os.path.join(blobs, "xet-force-stall.incomplete"), "wb") as fh: - fh.write(b"\0" * 4096) + # Hold the fake partial OPEN for the whole stall. The snapshot watchdog finds it by + # filename (has_active_incomplete_blobs), but the single-file watchdog + # (watch_new_partials_only) counts ONLY partials this child PID holds open via + # _child_open_incomplete_blobs -- a closed file there is ignored and the stall never + # trips. Keeping the fd open lets BOTH modes see it. The handle is bound to a local so + # it stays open across the sleep below. + _stall_fh = open(os.path.join(blobs, "xet-force-stall.incomplete"), "wb") + _stall_fh.write(b"\0" * 4096) + _stall_fh.flush() except OSError: pass while True: @@ -1026,7 +1038,12 @@ def _download_result_usable( weight at all (HF handed back a stale config-only snapshot on an offline / timed-out request). LENIENT otherwise -- a finished diffusers / variant / either-format download passes, and a named file simply absent from the repo is not treated as missing -- so a good download is never failed - and re-looped into a ``DownloadStallError``.""" + and re-looped into a ``DownloadStallError``. + + The no-weight rejection is scoped to an UNPATTERNED model request (``allow_patterns is None``), + which is the only shape that should always yield a weight. A patterned request is trusted: the + caller scoped it, so a weightless result is taken as intended (e.g. ``allow_patterns=['tokenizer*']`` + legitimately returns no weight) rather than failed into a spurious ``DownloadStallError``.""" if snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, repo_type = repo_type, @@ -1034,6 +1051,7 @@ def _download_result_usable( return False if ( repo_type in (None, "model") + and allow_patterns is None and request_can_include_weights(allow_patterns, ignore_patterns) and not _has_any_weight(snapshot_dir) ): From b6ec9fdd88580ba9a748692fc7566761e5befb1a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 29 Jun 2026 03:11:39 +0000 Subject: [PATCH 46/82] Address second review round: weight-pattern selector, sharded/variant gate, error types, concurrent cleanup A further bot pass (codex inline) and a 10-reviewer pass surfaced more over-accept / over-reject holes and two robustness gaps. Fixes keep the conservative-gate philosophy and the watchdog / spawn / retry mechanism untouched. Cache-state gate (hf_cache_state.py): - _pattern_can_select_weight is now a probe + fnmatch selector. tokenizer / config / *.json globs read weightless (keeping their offline short-circuit), while every standard weight name and single-char (?) / character-class ([]) globs (model.?afetensors, model.[sp]afetensors, checkpoint-*/model.?afetensors) read weight-bearing. This replaces the old "any trailing wildcard is weight-bearing" rule, which both misclassified tokenizer* as weight-bearing and missed ? / [] globs. - The canonical fast path now accepts only a CANONICAL (non-variant) shard index. A variant-only cache (model.safetensors.index.fp16.json with no canonical index) no longer reads complete: the wrapper takes no variant parameter, so a default load probes the canonical index whose weights are still missing -- skipping the child there would reintroduce the unprotected in-process Xet fetch the gate prevents. - New _has_incomplete_canonical_root_shards: canonical numbered shards at the root with no covering index (an interrupted sharded download) are detected. Variant shards are excluded so a variant-only repo is never force-failed. Acceptance wiring (hf_xet_fallback.py): - The post-download no-weight rejection again keys on request_can_include_weights (not "allow is None"). An explicit weight request that came back config-only (allow=["model.safetensors"]) is correctly rejected and retried, while a genuinely weightless patterned result (allow=["tokenizer*"]) is accepted. This restores the explicit-weight guard a prior fix had dropped, now without the tokenizer* false reject (handled by the selector above). - Post-download also rejects an interrupted canonical root-sharded warm for an unpatterned request, so a loose model-00001-of-00002.safetensors without its index is retried over HTTP rather than handed back. - LocalTokenNotFoundError joins the deterministic error set, and a bare HfHubHTTPError is type-resolvable (via a type-preserve-only set, so a transient 5xx bare HfHubHTTPError still retries by status code). - _default_prepare_for_http no longer unconditionally unlinks every dangling snapshot symlink. It spares a concurrent sibling download's link whose target blob still has a fresh .incomplete partner (common in multi-rank training), mirroring the active-grace guard the .incomplete blob purge already uses; our own stale interrupted link (partner already purged) is still cleared. Deliberately NOT changed (the suggested fixes would regress a load-bearing leniency): - _has_any_weight stays layout-agnostic: filtering it by the ignore patterns would reject a complete diffusers download, whose component weights the common */*.safetensors ignore strips. - requested_named_files_present stays lenient for globs: requiring a glob to prove a match would loop a complete no-match-glob download (e.g. allow=["*.txt"] on a repo with no .txt) into a DownloadStallError. - The wrappers keep their focused signature (local_dir et al. not added): local_dir uses a non-blob layout the watchdog's .incomplete detection cannot see, so a naive passthrough would silently disable stall detection. A full drop-in surface is out of scope for this PR. Tests: add regression guards for the glob selector, the explicit-weight-pattern post reject, the incomplete / complete / variant-only sharded cases, LocalTokenNotFoundError, the bare-HfHubHTTPError type/retry split, and the concurrent-sibling symlink spare. Suite 127 passed / 1 skipped; the 40k-layout safety fuzz stays at 0 violations; the real stall->HTTP e2e and the Studio de-dup surface stay green. --- tests/test_hf_xet_fallback.py | 127 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 100 +++++++++++++++++++++++--- unsloth_zoo/hf_xet_fallback.py | 72 +++++++++++++++---- 3 files changed, 276 insertions(+), 23 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index b7978babe..db5b63f07 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -471,6 +471,33 @@ def test_prepare_for_http_clears_broken_snapshot_symlink(tmp_path): assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, False) +def test_prepare_for_http_spares_concurrent_sibling_active_symlink(tmp_path): + """Round-2 F1: HTTP prep must NOT delete a concurrent sibling download's dangling snapshot + symlink while that sibling is still writing the target blob (a fresh .incomplete partner exists). + Our own stale interrupted link (no .incomplete partner) is still cleared in the same sweep.""" + repo = "ztest/concurrent" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + blobs = repo_dir / "blobs" + snap = repo_dir / "snapshots" / "sha" + blobs.mkdir(parents = True) + snap.mkdir(parents = True) + + # Sibling mid-download: a dangling link to a blob whose .incomplete partner is being written now. + active_partner = blobs / "activehash.incomplete" + active_partner.write_bytes(b"active") + sibling_link = snap / "active.safetensors" + sibling_link.symlink_to(blobs / "activehash") + + # Our own stale interrupted link: target blob has no .incomplete partner. + stale_link = snap / "stale.safetensors" + stale_link.symlink_to(blobs / "stalehash") + + _REAL_DEFAULT_PREPARE("model", repo, cache_dir = str(tmp_path), active_grace = 180) + + assert sibling_link.is_symlink(), "active sibling's dangling symlink must be preserved" + assert not stale_link.is_symlink(), "our own stale dangling symlink must still be cleared" + + def test_snapshot_dir_has_broken_symlinks_unit(tmp_path): """The new per-snapshot primitive flags a dangling link and is clean otherwise.""" snap = tmp_path / "snapshots" / "sha" @@ -2261,6 +2288,43 @@ def test_post_download_accepts_weightless_patterned_result(tmp_path): cfg, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False +def test_gate_rejects_variant_only_shard_index(tmp_path): + """codex :269 (over-accept): a variant-only shard index (model.safetensors.index.fp16.json) must + NOT satisfy the canonical allow=None fast path -- the fallback wrapper takes no variant param, so + a default load probes the canonical index whose weights are absent. Only a canonical + (non-variant) index counts; the variant layout defers to the watched child.""" + snap, blob = _mk_snapshot(tmp_path, "variant") + (snap / "config.json").write_text("{}") + (snap / "model-00001-of-00001.fp16.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.fp16.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00001.fp16.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap) is False + # The canonical index for the same model still fast-paths. + (snap / "model-00001-of-00001.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00001.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_generic_hub_http_error_type_preserved_but_status_drives_retry(): + """codex :499: a deterministic 4xx surfaced as a bare HfHubHTTPError keeps its TYPE across the + spawn boundary (so `except HfHubHTTPError` still works) WITHOUT joining the retry-deterministic + name shortcut -- a transient 5xx bare HfHubHTTPError must still retry over HTTP via its status + code.""" + assert "HfHubHTTPError" not in xf._DETERMINISTIC_ERROR_NAMES # status-driven, not name-driven + cls = xf._resolve_exception_class("HfHubHTTPError") + assert cls is not None and issubclass(cls, BaseException) + + class _Resp: + def __init__(self, code): self.status_code = code + e503 = xf._instantiate_preserving_type(cls, "HfHubHTTPError: 503 service unavailable") + e503.response = _Resp(503) + assert xf._is_retryable_download_error(e503) is True # 5xx still retryable + e403 = xf._instantiate_preserving_type(cls, "HfHubHTTPError: 403 forbidden") + e403.response = _Resp(403) + assert xf._is_retryable_download_error(e403) is False # 4xx deterministic + + def test_hfvalidationerror_type_preserved_across_spawn(): """Finding 4: a malformed repo id fails identically over either transport, so HFValidationError is deterministic (not retried) and its TYPE is reconstructed across the spawn boundary instead of @@ -2271,3 +2335,66 @@ def test_hfvalidationerror_type_preserved_across_spawn(): inst = xf._instantiate_preserving_type(cls, "HFValidationError: bad repo id") assert type(inst).__name__ == "HFValidationError" assert xf._is_retryable_download_error(inst) is False + + +def test_weight_pattern_selector_handles_globs(tmp_path): + """Round-2 F1(round1)/F3/F6: the weight-pattern selector reads tokenizer / config / json globs as + weightless (keeps their offline short-circuit) while classifying every standard weight name and + single-char (?) / class ([]) globs as weight-bearing.""" + weightless = ["tokenizer*", "*.json", "config.json", "tokenizer.model", "*.txt"] + weighty = [ + "model.safetensors", "*.safetensors", "model.?afetensors", "model.[sp]afetensors", + "checkpoint-*/model.?afetensors", "unet/*", "adapter_model*", "consolidated*", "model.gguf", + ] + for pat in weightless: + assert hcs.request_can_include_weights([pat], None) is False, pat + for pat in weighty: + assert hcs.request_can_include_weights([pat], None) is True, pat + + +def test_post_download_rejects_config_only_for_explicit_weight_pattern(tmp_path): + """Round-2 F3: an explicit weight request (allow=['model.safetensors']) that came back with only + config.json is a stale config-only snapshot and must be rejected (retry over HTTP), NOT accepted. + A genuinely weightless patterned request stays accepted (test_post_download_accepts_weightless...).""" + snap, _ = _mk_snapshot(tmp_path, "patcfg") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["model.safetensors"], ignore_patterns = None) is False + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["adapter_model.safetensors"], + ignore_patterns = None) is False + + +def test_post_download_rejects_incomplete_canonical_root_shards(tmp_path): + """Round-2 F2: an interrupted canonical sharded warm (a loose model-00001-of-00002.safetensors + with no index / missing sibling) has a loadable weight file but a default load cannot read it and + would fetch the rest over un-killable Xet, so it is rejected post-download. A complete sharded set + is accepted; a variant-only shard layout is not force-failed (it simply has no canonical shard).""" + snap, blob = _mk_snapshot(tmp_path, "incshard") + (snap / "config.json").write_text("{}") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # Complete the set with its index -> accepted. + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # A variant-only shard layout has no canonical shard -> not force-failed here. + vsnap, vblob = _mk_snapshot(tmp_path, "vshard") + (vsnap / "config.json").write_text("{}") + (vsnap / "model-00001-of-00001.fp16.safetensors").symlink_to(vblob) + assert xf._download_result_usable( + vsnap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_local_token_not_found_error_type_preserved(): + """Round-2 F4: a missing required token fails identically over either transport, so + LocalTokenNotFoundError is deterministic and its type is reconstructed across the spawn boundary.""" + assert "LocalTokenNotFoundError" in xf._DETERMINISTIC_ERROR_NAMES + cls = xf._resolve_exception_class("LocalTokenNotFoundError") + assert cls is not None and issubclass(cls, BaseException) + assert xf._is_retryable_download_error( + xf._instantiate_preserving_type(cls, "LocalTokenNotFoundError: token required")) is False diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 7a4e2464e..39b201ff9 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -34,6 +34,7 @@ from __future__ import annotations import fnmatch +import re import sys from pathlib import Path from typing import Iterator, Optional @@ -269,6 +270,17 @@ def _is_weight_shard_index(name: str) -> bool: return name.endswith(".json") and (".safetensors.index." in name or ".bin.index." in name) +def _is_canonical_weight_shard_index(name: str) -> bool: + """True only for the CANONICAL (non-variant) shard index a default in-process load probes: + ``model.safetensors.index.json`` / ``pytorch_model.bin.index.json`` (any stem). A variant form + such as ``model.safetensors.index.fp16.json`` ends in ``.index.fp16.json`` and is rejected: the + fallback wrapper takes no variant parameter, so a default ``from_pretrained`` reads the canonical + index, and a variant-only cache must NOT satisfy the canonical fast path (its canonical weights + are still missing -- skipping the child there would reintroduce the unprotected in-process fetch + this gate prevents).""" + return name.endswith(".safetensors.index.json") or name.endswith(".bin.index.json") + + def _weight_shard_index_complete(index_path: Path) -> bool: """True only if every shard a HF weight index (``model.safetensors.index.json`` / ``pytorch_model.bin.index.json``) lists is present next to the index. @@ -420,22 +432,58 @@ def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: return True +# Representative weight filenames a glob allow pattern is probed against (via fnmatch). A glob that +# matches one of these can select a weight; one that matches none (``tokenizer*``, ``*.json``) is +# weightless. Covers the canonical / variant / sharded / adapter / diffusers / mistral-consolidated +# and the non-safetensors weight formats so a real weight glob is never under-classified. +_WEIGHT_PATTERN_PROBES = ( + "model.safetensors", + "model.fp16.safetensors", + "model-00001-of-00002.safetensors", + "pytorch_model.bin", + "pytorch_model-00001-of-00002.bin", + "adapter_model.safetensors", + "adapter_model.bin", + "consolidated.safetensors", + "consolidated.00.pth", + "diffusion_pytorch_model.safetensors", + "model.gguf", + "model.pt", + "model.pth", + "model.h5", + "model.msgpack", + "tf_model.h5", + "flax_model.msgpack", +) + + def _pattern_can_select_weight(pattern: "object") -> bool: - """Whether a single allow pattern could select a model weight file. Conservative: a wildcard - basename (``unet/*``, ``adapter_model*``) or a directory pattern (``unet/``) could absorb a - weight, as does a basename ending in a weight suffix (``*.safetensors``, ``model.fp16.safetensors``, - ``model.gguf``). Only a basename ending in a concrete NON-weight extension (``config.json``, - ``*.py``, ``tokenizer.model``, ``additional_chat_templates/*.jinja``) is treated as weightless, - so a tokenizer / config allow list keeps its offline short-circuit while a real (sub)weight glob - is never under-classified.""" + """Whether a single allow pattern could select a model weight file. + + - a non-string (unknown shape) -> conservative True; + - a bare directory pattern (``unet/``) -> True (expands to everything under it, incl. weights); + - a basename ending in a weight suffix (``*.safetensors``, ``model.gguf``) -> True; + - a glob basename (``model.?afetensors``, ``model.[sp]afetensors``, ``*``) -> True iff it matches a + representative weight name in ``_WEIGHT_PATTERN_PROBES`` -- so ``tokenizer*`` / ``*.json`` read + weightless and keep their offline short-circuit, while ``model.?afetensors`` / ``unet/*`` are + weight-bearing; + - a concrete non-weight name (``config.json``, ``tokenizer.model``) -> False. + + A glob is matched on its basename so ``checkpoint-*/model.?afetensors`` is still recognized. Both + directions are bounded: a false weight-bearing only makes the pre-download gate spawn the cheap + child; a false weightless is avoided for every standard weight name by the probe set.""" if not isinstance(pattern, str): return True # unknown shape -> conservative if pattern.endswith("/"): return True # a bare directory pattern expands to everything under it, incl. weights base = pattern.rsplit("/", 1)[-1] - if base.endswith(("*", "?")): - return True # a trailing wildcard could absorb a weight suffix - return base.endswith(_WEIGHT_FILE_SUFFIXES) + if base.endswith(_WEIGHT_FILE_SUFFIXES): + return True # a concrete or wildcard-stem weight suffix + if any(ch in base for ch in _GLOB_CHARS): + # A glob basename selects a weight only if it can actually match a weight filename. This keeps + # tokenizer / config globs weightless while catching single-char (?) and class ([]) globs. + return any(fnmatch.fnmatchcase(probe, base) for probe in _WEIGHT_PATTERN_PROBES) + return False def request_can_include_weights( @@ -489,7 +537,9 @@ def _canonical_root_weights_complete( rel = entry.name if "/" in rel: continue # a bare from_pretrained reads ROOT files only - if _is_weight_shard_index(entry.name): + # Only the CANONICAL (non-variant) index counts here: a default load probes + # model.safetensors.index.json, not a variant like model.safetensors.index.fp16.json. + if _is_canonical_weight_shard_index(entry.name): if _safe_is_file(entry): root_indices.append(entry) elif _safe_is_file(entry): @@ -567,6 +617,34 @@ def snapshot_dir_is_complete( return _canonical_root_weights_complete(snapshot_dir, entries, ignore_patterns) +# A canonical numbered weight shard at the snapshot root: the shard index sits IMMEDIATELY before the +# extension (no variant token), so ``model-00001-of-00002.safetensors`` matches but the variant +# ``model-00001-of-00002.fp16.safetensors`` does NOT. +_CANONICAL_ROOT_SHARD_RE = re.compile( + r"^(?:model|pytorch_model)-\d{5}-of-\d{5}\.(?:safetensors|bin)$" +) + + +def _has_incomplete_canonical_root_shards(snapshot_dir: Path) -> bool: + """True when the snapshot root holds canonical numbered weight shards + (``model-00001-of-00002.safetensors`` / ``pytorch_model-...bin``) but is NOT a complete canonical + model -- the shard index is missing or a listed shard is absent. + + Such a loose-shard layout is a stale / interrupted download: a default in-process load cannot read + bare numbered shards without their index and would fetch the rest over un-killable Xet, so the + post-download acceptance check rejects it and retries over HTTP. Variant shards + (``model-...fp16.safetensors``) are intentionally excluded -- they never satisfy a default load, so + a variant-only repo must not be force-failed here (it simply defers, like any non-canonical warm).""" + try: + names = [entry.name for entry in snapshot_dir.iterdir()] + except OSError: + return False + if not any(_CANONICAL_ROOT_SHARD_RE.match(name) for name in names): + return False + # Canonical shards exist but no complete single-file / indexed canonical set covers them. + return not snapshot_dir_is_complete(snapshot_dir) + + def requested_named_files_present( snapshot_dir: Path, *, diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index d376900f9..90cfda009 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -52,6 +52,7 @@ _is_loadable_weight_file, blob_bytes_present, has_active_incomplete_blobs, + _has_incomplete_canonical_root_shards, hf_cache_root, iter_active_repo_cache_dirs, request_can_include_weights, @@ -173,6 +174,29 @@ def _redact_signed_query(match: "re.Match") -> str: return out +def _broken_link_has_active_partner(link: Path, *, active_grace: float) -> bool: + """True if a dangling snapshot symlink should be SPARED from the HTTP-prep purge because a + concurrent sibling download (a different process pulling the same repo into the same cache, common + in multi-rank training) is still writing the blob it points at. + + The reliable discriminator is a FRESH ``.incomplete`` partner of the link's target blob (mirroring + the active-grace guard the ``.incomplete`` blob purge already uses), NOT the link's own mtime: our + OWN killed child's link is freshly created too, but by this point its ``.incomplete`` has been + static for the full stall timeout and is purged first, so the target has no partner and the link is + correctly cleared -- while a sibling mid-download still has a growing ``.incomplete`` partner, so + its link is spared.""" + try: + target = Path(os.readlink(link)) + if not target.is_absolute(): + target = link.parent / target + incomplete_partner = target.with_name(target.name + INCOMPLETE_SUFFIX) + if incomplete_partner.is_file(): + return time.time() - incomplete_partner.stat().st_mtime < active_grace + except OSError: + return False + return False + + def _default_prepare_for_http( repo_type: str, repo_id: str, @@ -228,6 +252,11 @@ def _default_prepare_for_http( try: for link in snapshot.rglob("*"): if link.is_symlink() and not link.exists(): + # Spare a concurrent sibling's active dangling link (its target blob still + # has a fresh .incomplete partner); only purge our own stale + # interrupted-download links so the HTTP retry reads clean. + if _broken_link_has_active_partner(link, active_grace = active_grace): + continue try: link.unlink() except OSError: @@ -471,12 +500,25 @@ def _scrub_in_child(text: str, token: Optional[str]) -> str: "GatedRepoError", "DisabledRepoError", "LocalEntryNotFoundError", + # A required token that is absent locally fails identically over either transport (it never + # reaches the network), so surface it deterministically with its real type. + "LocalTokenNotFoundError", "BadRequestError", # A malformed repo id / revision fails identically over either transport (it never reaches the # network), so surface it with its real type instead of a generic RuntimeError or a pointless # HTTP retry. "HFValidationError", }) +# Names whose TYPE should be reconstructed across the spawn boundary but which must NOT join the +# retry-deterministic shortcut above. ``HfHubHTTPError`` is the base of both the deterministic 4xx +# (401 / 403 / 404 / 416) and the transient 5xx / 429 errors, so the retry decision for it must stay +# status-code driven (``_is_retryable_download_error`` falls through to the status check). But once an +# error has been classified deterministic and surfaced as ``"HfHubHTTPError: "``, the parent +# should still re-raise the original type so a caller's ``except HfHubHTTPError`` (auth / quota / +# permission handling) keeps working instead of seeing a generic ``RuntimeError``. +_TYPE_PRESERVE_ONLY_NAMES = frozenset({ + "HfHubHTTPError", +}) # Substrings that mark a transient transport failure (hf_xet / CAS error, timeout, reset, # HTTP 5xx / 429) that disabling Xet and retrying over HTTP may recover. _TRANSIENT_ERROR_HINTS = ( @@ -495,7 +537,7 @@ def _resolve_exception_class(type_name: str) -> "Optional[type]": occurs and never hard-depends on a specific huggingface_hub layout.""" if type_name == "OSError": return OSError - if type_name not in _DETERMINISTIC_ERROR_NAMES: + if type_name not in _DETERMINISTIC_ERROR_NAMES and type_name not in _TYPE_PRESERVE_ONLY_NAMES: return None for module_name in ("huggingface_hub.errors", "huggingface_hub.utils"): try: @@ -1040,22 +1082,28 @@ def _download_result_usable( file simply absent from the repo is not treated as missing -- so a good download is never failed and re-looped into a ``DownloadStallError``. - The no-weight rejection is scoped to an UNPATTERNED model request (``allow_patterns is None``), - which is the only shape that should always yield a weight. A patterned request is trusted: the - caller scoped it, so a weightless result is taken as intended (e.g. ``allow_patterns=['tokenizer*']`` - legitimately returns no weight) rather than failed into a spurious ``DownloadStallError``.""" + The no-weight rejection fires whenever the request can include weights (``request_can_include_weights``): + an unpatterned model warm, or an explicit weight request (``allow_patterns=['model.safetensors']``) + that came back with no weight, is a stale config-only snapshot and is retried. A genuinely weightless + request (``allow_patterns=['tokenizer*']`` / ``['*.json']``) reads weightless there, so its valid + no-weight result is accepted rather than failed. + + It also rejects an interrupted CANONICAL sharded warm (loose ``model-00001-of-00002.safetensors`` + without its index or a sibling shard) for an unpatterned request: that layout has a loadable weight + file but a default load still cannot read it and would fetch the rest over un-killable Xet.""" if snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, repo_type = repo_type, ): return False - if ( - repo_type in (None, "model") - and allow_patterns is None - and request_can_include_weights(allow_patterns, ignore_patterns) - and not _has_any_weight(snapshot_dir) - ): - return False + if repo_type in (None, "model"): + if ( + request_can_include_weights(allow_patterns, ignore_patterns) + and not _has_any_weight(snapshot_dir) + ): + return False + if allow_patterns is None and _has_incomplete_canonical_root_shards(snapshot_dir): + return False return True From d55d32052bf47167ebde5c692d9378ff55c9b9c0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 29 Jun 2026 04:04:54 +0000 Subject: [PATCH 47/82] Tighten post-download acceptance and HTTP-prep cleanup for the third review round A third bot pass (codex inline) plus a 10-reviewer pass found more over-accept / over-reject holes in the cache-skip and post-download acceptance paths, and one concurrency bug in the HTTP-prep cleanup. The watchdog / spawn / retry mechanism stays in place; this sharpens what each layer accepts. Post-download acceptance (hf_xet_fallback._download_result_usable): - Every EXACT-named requested file must be present, grouped by weight equivalence. The either-format pair ["pytorch_model.bin", "model.safetensors"] needs only one, but a base weight AND an adapter_model.safetensors must each be present, and an exact non-weight request like ["tokenizer.json"] must find its file. A stale snapshot that merely shares one weight no longer satisfies a request for a different one (e.g. an adapter-only request matched by a stale base). - For an UNPATTERNED model warm the weight must be ROOT-readable (or, for a diffusers pipeline gated by model_index.json, a component-subfolder weight). A stale training-checkpoint-only snapshot (weights only under checkpoint-*/) is rejected -- a default from_pretrained ignores those subfolders and would fetch the missing root weights over un-killable Xet. - For a PATTERNED weight request the present weight must fall WITHIN the requested scope (the allow / ignore filter), so a stale out-of-scope weight does not pass it. - A metadata-directory allow pattern (tokenizer/, processor/, ...) now reads weightless, so a complete tokenizer-only download is accepted instead of being looped into a DownloadStallError; component and checkpoint directory patterns stay conservatively weight-bearing. Pre-download cache skip (hf_xet_fallback._cache_can_skip_download): - A non-model (dataset / space) request, or a weightless model request, may skip the watched child only when it names EXACT files whose subset is intact. An allow=None or glob dataset request cannot be proven complete from local files alone, so it defers to the child for the authoritative manifest compare + resume instead of returning a partial snapshot. HTTP-prep cleanup (hf_xet_fallback._default_prepare_for_http): - The purge of stale *.incomplete blobs and dangling snapshot symlinks is now scoped to the partials the stalled child actually OWNED (captured from its open fds before it is killed, with a since-spawn baseline fallback). A concurrent same-repo sibling download writing a DIFFERENT blob -- common in multi-rank training -- is never deleted, even if its partial has aged past the stall timeout. When ownership cannot be determined the coarser mtime guard is used, as before. Deliberately NOT changed: the warm / offline fast path still serves a cached snapshot for a mutable revision (main / a branch / a tag) without a remote freshness check. Revalidating mutable refs would add a network round-trip to every load and regress the offline and warm-cache short-circuit the design preserves; commit-pinned reproducibility and offline use are favored over always pulling the newest main. Pin a revision (or clear the cache) to force a refresh. Flagging this as a policy choice rather than a silent change. Tests: add regression guards for the grouped exact-file validation (adapter-vs-base, either-format), the checkpoint-only and metadata-directory cases, the weightless / dataset exact-subset validation, the dataset unpatterned/glob no-skip, and the owned-scoped HTTP-prep cleanup (sibling blob + link spared, None-ownership falls back to mtime). Suite 133 passed / 1 skipped; the 40k-layout safety fuzz stays at 0 violations; the real stall->HTTP e2e and the Studio de-dup surface stay green. --- tests/test_hf_xet_fallback.py | 132 ++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 20 ++- unsloth_zoo/hf_xet_fallback.py | 222 +++++++++++++++++++++++++++++---- 3 files changed, 350 insertions(+), 24 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index db5b63f07..53092ca27 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2398,3 +2398,135 @@ def test_local_token_not_found_error_type_preserved(): assert cls is not None and issubclass(cls, BaseException) assert xf._is_retryable_download_error( xf._instantiate_preserving_type(cls, "LocalTokenNotFoundError: token required")) is False + + +def test_metadata_directory_pattern_is_weightless(tmp_path): + """Round-3 A: a trailing-slash metadata directory pattern (allow=['tokenizer/']) reads weightless + so a complete tokenizer-only download is accepted, not looped into a DownloadStallError. A + component / checkpoint directory pattern stays conservatively weight-bearing.""" + assert hcs.request_can_include_weights(["tokenizer/"], None) is False + assert hcs.request_can_include_weights(["processor/"], None) is False + assert hcs.request_can_include_weights(["unet/"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/"], None) is True + snap, _ = _mk_snapshot(tmp_path, "tokdir") + (snap / "tokenizer").mkdir() + (snap / "tokenizer" / "tokenizer.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer/"], ignore_patterns = None) is True + + +def test_post_download_rejects_checkpoint_only_root_model(tmp_path): + """Round-3 B (over-accept): a stale snapshot whose only weight is under checkpoint-7/ is rejected + for an unpatterned root warm -- a default from_pretrained ignores checkpoint-*/ and would fetch the + missing root weights over un-killable Xet. The same checkpoint is accepted when explicitly scoped.""" + snap, blob = _mk_snapshot(tmp_path, "ckonly") + (snap / "config.json").write_text("{}") + (snap / "checkpoint-7").mkdir() + (snap / "checkpoint-7" / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["checkpoint-7/*"], ignore_patterns = None) is True + # A diffusers pipeline's subfolder weights still count (model_index.json gates that). + dsnap, dblob = _mk_snapshot(tmp_path, "diff") + (dsnap / "model_index.json").write_text("{}") + (dsnap / "unet").mkdir() + (dsnap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(dblob) + assert xf._download_result_usable( + dsnap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_validates_weightless_named_subset(tmp_path): + """Round-3 C: an exact weightless request (allow=['tokenizer.json'], or a dataset file) that came + back as a stale config-only snapshot missing the named file is rejected and retried. A glob allow + list stays lenient (cannot be turned into an exact manifest).""" + snap, _ = _mk_snapshot(tmp_path, "noname") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer.json"], ignore_patterns = None) is False + assert xf._download_result_usable( + snap, repo_type = "dataset", allow_patterns = ["data.parquet"], ignore_patterns = None) is False + # Present named file -> accepted; a glob stays lenient. + (snap / "tokenizer.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer.json"], ignore_patterns = None) is True + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.json"], ignore_patterns = None) is True + + +def test_post_download_rejects_missing_exact_weight_request(tmp_path): + """Round-3 F2: an exact weight request whose file is missing is rejected even when a different + weight is present -- allow=['adapter_model.safetensors'] is NOT satisfied by a stale base + model.safetensors, and ['model.safetensors','adapter_model.safetensors'] needs both. The classic + either-format ['model.safetensors','pytorch_model.bin'] pair stays satisfied by one (equivalence).""" + base, blob = _mk_snapshot(tmp_path, "baseonly") + (base / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + base, repo_type = "model", allow_patterns = ["adapter_model.safetensors"], + ignore_patterns = None) is False + assert xf._download_result_usable( + base, repo_type = "model", + allow_patterns = ["model.safetensors", "adapter_model.safetensors"], ignore_patterns = None) is False + assert xf._download_result_usable( + base, repo_type = "model", + allow_patterns = ["model.safetensors", "pytorch_model.bin"], ignore_patterns = None) is True + # Both present -> accepted. + (base / "adapter_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + base, repo_type = "model", + allow_patterns = ["model.safetensors", "adapter_model.safetensors"], ignore_patterns = None) is True + + +def test_dataset_unpatterned_or_glob_partial_does_not_skip_child(tmp_path): + """Round-3 F3: a dataset/space snapshot whose completeness cannot be proven from local files + (allow=None or a glob) must defer to the watched child -- a partial cache must not be returned as + complete. An intact exact-named subset still short-circuits.""" + snap, _ = _mk_snapshot(tmp_path, "dspart") + (snap / "README.md").write_text("partial") + assert xf._cache_can_skip_download( + snap, repo_type = "dataset", allow_patterns = None, ignore_patterns = None) is False + assert xf._cache_can_skip_download( + snap, repo_type = "dataset", allow_patterns = ["*.parquet"], ignore_patterns = None) is False + assert xf._cache_can_skip_download( + snap, repo_type = "dataset", allow_patterns = ["README.md"], ignore_patterns = None) is True + + +def test_http_prep_scopes_blob_cleanup_to_owned_partials(tmp_path): + """Round-3 F1: HTTP prep must purge only the stalled child's OWN partials, never a concurrent + same-repo sibling's blob (multi-rank). With ownership known, a sibling's aged partial and its + dangling link are spared; with ownership unknown (None), the coarser mtime guard purges both.""" + repo = "ztest/concurrent-blobs" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + blobs = repo_dir / "blobs" + snap = repo_dir / "snapshots" / "sha" + blobs.mkdir(parents = True) + snap.mkdir(parents = True) + old = time.time() - 600 + + def _seed(): + owned = blobs / "ownedhash.incomplete" + sibling = blobs / "siblinghash.incomplete" + owned.write_bytes(b"o") + sibling.write_bytes(b"s") + os.utime(owned, (old, old)) + os.utime(sibling, (old, old)) + for name in list(snap.iterdir()): + name.unlink() + (snap / "our.safetensors").symlink_to(blobs / "ownedhash") + (snap / "sib.safetensors").symlink_to(blobs / "siblinghash") + return owned, sibling + + owned, sibling = _seed() + _REAL_DEFAULT_PREPARE( + "model", repo, cache_dir = str(tmp_path), active_grace = 180, + owned_incomplete_blobs = {"ownedhash.incomplete"}) + assert not owned.exists(), "our own stalled partial must be purged" + assert sibling.exists(), "a concurrent sibling's partial must be spared" + assert not (snap / "our.safetensors").is_symlink() + assert (snap / "sib.safetensors").is_symlink(), "sibling's dangling link must be spared" + + # No ownership info -> coarse mtime guard purges both aged partials (prior behavior). + owned, sibling = _seed() + _REAL_DEFAULT_PREPARE( + "model", repo, cache_dir = str(tmp_path), active_grace = 180, owned_incomplete_blobs = None) + assert not owned.exists() and not sibling.exists() diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 39b201ff9..2cd77f4c6 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -456,6 +456,21 @@ def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: "flax_model.msgpack", ) +# Snapshot subdirectories that hold only metadata / config (never a loadable model weight), so a +# trailing-slash directory pattern scoped to one of them (``allow_patterns=['tokenizer/']``) is +# weightless. Any OTHER directory pattern stays conservatively weight-bearing: a component dir +# (``unet/``, ``vae/``) or a training-checkpoint dir (``checkpoint-10/``) can hold a weight, so the +# fast path must not skip the child on it. +_NON_WEIGHT_DIRS = frozenset({ + "tokenizer", + "processor", + "preprocessor", + "feature_extractor", + "image_processor", + "video_processor", + "scheduler", +}) + def _pattern_can_select_weight(pattern: "object") -> bool: """Whether a single allow pattern could select a model weight file. @@ -475,7 +490,10 @@ def _pattern_can_select_weight(pattern: "object") -> bool: if not isinstance(pattern, str): return True # unknown shape -> conservative if pattern.endswith("/"): - return True # a bare directory pattern expands to everything under it, incl. weights + # A bare directory pattern expands to everything under it. A known metadata dir holds no + # weight (so it stays weightless and keeps its offline short-circuit); any other dir could. + dir_name = pattern.rstrip("/").rsplit("/", 1)[-1].lower() + return dir_name not in _NON_WEIGHT_DIRS base = pattern.rsplit("/", 1)[-1] if base.endswith(_WEIGHT_FILE_SUFFIXES): return True # a concrete or wildcard-stem weight suffix diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 90cfda009..7e786027d 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -49,10 +49,13 @@ from unsloth_zoo.hf_cache_state import ( INCOMPLETE_SUFFIX, + _as_pattern_list, + _filter_paths, + _has_glob, + _has_incomplete_canonical_root_shards, _is_loadable_weight_file, blob_bytes_present, has_active_incomplete_blobs, - _has_incomplete_canonical_root_shards, hf_cache_root, iter_active_repo_cache_dirs, request_can_include_weights, @@ -197,12 +200,22 @@ def _broken_link_has_active_partner(link: Path, *, active_grace: float) -> bool: return False +def _link_incomplete_partner_name(link: Path) -> Optional[str]: + """The ``.incomplete`` basename for a dangling snapshot symlink's target blob, or None.""" + try: + target = Path(os.readlink(link)) + return target.name + INCOMPLETE_SUFFIX + except OSError: + return None + + def _default_prepare_for_http( repo_type: str, repo_id: str, *, cache_dir: Optional[str] = None, active_grace: float = DEFAULT_STALL_TIMEOUT, + owned_incomplete_blobs: Optional[set] = None, ) -> None: """Generic 'make the partial safe for an HTTP resume': delete the repo's active ``*.incomplete`` blobs (an HTTP resume over a sparse Xet/hf_transfer partial @@ -213,6 +226,12 @@ def _default_prepare_for_http( ``iter_active_repo_cache_dirs`` is case-collision safe, so this destructive purge only touches an exact-case (or single unambiguous) repo cache dir. + + When *owned_incomplete_blobs* is given (the ``.incomplete`` basenames the stalled child actually + held open, captured before it was killed), the purge is SCOPED to those blobs: a concurrent + same-repo sibling download (common in multi-rank training) writing a DIFFERENT blob is never + touched, even if its partial has aged past *active_grace*. When it is None (ownership could not be + determined), the coarser ``active_grace`` mtime guard alone is used, as before. """ try: for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): @@ -220,6 +239,10 @@ def _default_prepare_for_http( if blobs_dir.is_dir(): for blob in blobs_dir.iterdir(): if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): + # Scope to the stalled child's own partials when known: never delete a + # sibling's blob, even an aged one. + if owned_incomplete_blobs is not None and blob.name not in owned_incomplete_blobs: + continue try: # Do not unlink a partial another concurrent download is # still actively writing: on POSIX that lets the sibling keep @@ -252,6 +275,12 @@ def _default_prepare_for_http( try: for link in snapshot.rglob("*"): if link.is_symlink() and not link.exists(): + # Scope to our own partials when known: a link to a sibling's blob is left + # alone (it is the sibling's snapshot reference, not our stale state). + if owned_incomplete_blobs is not None and ( + _link_incomplete_partner_name(link) not in owned_incomplete_blobs + ): + continue # Spare a concurrent sibling's active dangling link (its target blob still # has a fresh .incomplete partner); only purge our own stale # interrupted-download links so the HTTP retry reads clean. @@ -958,6 +987,18 @@ def _run_download_attempt( break except queue.Empty: pass + # Capture the partials THIS child owns BEFORE killing it, so the HTTP-prep purge can + # scope its blob/symlink cleanup to them and never delete a concurrent sibling's + # partial. Prefer the precise per-pid open-fd set; fall back to the partials that + # appeared since this child spawned (kind=="file" tracks a baseline) when the child + # cannot be inspected. None -> prep keeps its coarser mtime-only guard. + owned = _child_open_incomplete_blobs(proc.pid) if proc.pid else None + if owned is None and baseline_partials is not None: + current = set( + _active_incomplete_blob_sizes(repo_type, repo_id, params.get("cache_dir")) + ) + owned = current - baseline_partials + params["_owned_incomplete_blobs"] = owned _terminate_process_group(proc, grace_period) return ("stall", None) try: @@ -1048,6 +1089,119 @@ def _has_any_weight(snapshot_dir: Path) -> bool: return False +def _root_has_loadable_weight(snapshot_dir: Path) -> bool: + """True if a loadable weight sits at the snapshot ROOT (where a default ``from_pretrained`` reads + it). Unlike ``_has_any_weight`` this ignores subfolders, so a stale training-checkpoint-only + snapshot (weights only under ``checkpoint-7/``) is not mistaken for a usable root model.""" + try: + for entry in snapshot_dir.iterdir(): + if _is_loadable_weight_file(entry.name): + try: + if entry.is_file(): + return True + except OSError: + continue + except OSError: + return False + return False + + +def _root_model_has_weight(snapshot_dir: Path) -> bool: + """Whether an UNPATTERNED model warm holds a weight a default load will actually read: a ROOT + weight, or -- for a diffusers pipeline (root ``model_index.json``) -- a component-subfolder weight. + + A bare ``from_pretrained`` reads root weights and ignores arbitrary subfolders (``checkpoint-*/`` ...), + so counting any subtree weight (as ``_has_any_weight`` does) would accept a stale checkpoint-only + snapshot and then fetch the missing root weights over un-killable Xet. Diffusers is the one layout + whose weights legitimately live in subfolders, and its ``model_index.json`` marker gates that.""" + try: + is_diffusers = (snapshot_dir / "model_index.json").is_file() + except OSError: + is_diffusers = False + if is_diffusers: + return _has_any_weight(snapshot_dir) + return _root_has_loadable_weight(snapshot_dir) + + +# Exact weight filenames that are interchangeable: a request naming several of the same logical +# weight (the classic ``["pytorch_model.bin", "model.safetensors"]`` either-format pair) is satisfied +# by ANY one of them, while distinct logical weights (a base ``model.safetensors`` AND an +# ``adapter_model.safetensors``) must each be present. +_EQUIVALENT_EXACT_WEIGHT_NAMES = { + "model.safetensors": "root_model", + "pytorch_model.bin": "root_model", + "adapter_model.safetensors": "adapter_model", + "adapter_model.bin": "adapter_model", +} + + +def _requested_exact_files_present_grouped( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, +) -> bool: + """True unless an EXACT-named requested file is missing. A request that names several + interchangeable weights (``["pytorch_model.bin", "model.safetensors"]``) is satisfied by any one + of them; distinct logical files (a base weight AND an adapter, or a tokenizer file) must each be + present. A request with ANY glob, or no allow list, is a best-effort warm and is trivially + satisfied here -- the weight-presence checks below cover those.""" + allow = _as_pattern_list(allow_patterns) + ignore = _as_pattern_list(ignore_patterns) + if not allow or any(not isinstance(p, str) or _has_glob(p) for p in allow): + return True + requested = _filter_paths(allow, None, ignore) + if not requested: + return True # the ignore filter dropped every named file -> nothing to require + try: + present = { + entry.relative_to(snapshot_dir).as_posix() + for entry in snapshot_dir.rglob("*") + if entry.is_file() + } + except OSError: + return True # cannot enumerate -> do not reject on an unreadable dir + groups: "dict[tuple[str, str], list[str]]" = {} + for rel in requested: + parent, base = rel.rsplit("/", 1) if "/" in rel else ("", rel) + logical = _EQUIVALENT_EXACT_WEIGHT_NAMES.get(base, base) + groups.setdefault((parent, logical), []).append(rel) + return all( + any(candidate in present for candidate in candidates) for candidates in groups.values() + ) + + +def _has_selected_weight( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, +) -> bool: + """True if at least one loadable weight the request actually SELECTS is present. Unlike + ``_has_any_weight`` this applies the allow / ignore filter, so a patterned request + (``["*.safetensors"]``, ``["unet/*"]``) is not satisfied by an out-of-scope weight (a stale + ``.bin`` left behind, a checkpoint subfolder the request did not ask for).""" + weights: list = [] + try: + for entry in snapshot_dir.rglob("*"): + if not _is_loadable_weight_file(entry.name): + continue + try: + if entry.is_file(): + weights.append(entry.relative_to(snapshot_dir).as_posix()) + except (OSError, ValueError): + continue + except OSError: + return False + return bool(_filter_paths(weights, allow_patterns, ignore_patterns)) + + +def _patterns_are_exact_names(patterns: Any) -> bool: + """True only for a non-empty allow list of EXACT filenames (no ``None``, no glob, no trailing-slash + directory pattern). Only such a request can be proven complete from local files alone; ``None`` or a + glob needs the Hub manifest, so it must defer to the watched child.""" + patterns = _as_pattern_list(patterns) + if patterns is None: + return False + if not patterns: + return True # selects nothing -> trivially satisfied, nothing to fetch + return all(isinstance(p, str) and not _has_glob(p) for p in patterns) + + def _cache_can_skip_download( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, ) -> bool: @@ -1058,12 +1212,20 @@ def _cache_can_skip_download( (``snapshot_dir_is_complete``) may skip the child; anything uncertain (diffusers, variants, non-trivial patterns, sharded-without-index) returns False -> spawn the child. A false True here would let the in-process load fetch a missing weight over un-killable Xet (the hang). A weightless - / non-model request has no weight to hang on, so an intact requested subset is enough -- this - preserves the offline short-circuit for a tokenizer-only / dataset warm.""" + model request (a tokenizer / config / metadata-dir allow list) or a non-model (dataset / space) + request has no weight to hang on, but its completeness is only locally provable when it names + EXACT files: an unpatterned or glob request cannot be proven complete without the Hub manifest, so + it defers to the watched child rather than hand back a partial cache. An exact-named subset that is + intact still short-circuits (preserving the offline tokenizer-only / named-file warm).""" if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): return snapshot_dir_is_complete( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, ) + # Weightless model / non-model request: skip only when it names exact files whose subset is intact. + # A None / glob request (e.g. a whole-dataset ``allow_patterns=None``) cannot be proven complete + # from local files alone, so defer to the child for the authoritative manifest compare + resume. + if not _patterns_are_exact_names(allow_patterns): + return False return _intact_subset( snapshot_dir, repo_type = repo_type, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, @@ -1078,31 +1240,41 @@ def _download_result_usable( unless there is POSITIVE evidence of a silent-Xet partial: a dangling REQUESTED symlink (a blob that is missing or still ``.incomplete``), or a weight-bearing model warm that came back with NO weight at all (HF handed back a stale config-only snapshot on an offline / timed-out request). - LENIENT otherwise -- a finished diffusers / variant / either-format download passes, and a named - file simply absent from the repo is not treated as missing -- so a good download is never failed - and re-looped into a ``DownloadStallError``. - - The no-weight rejection fires whenever the request can include weights (``request_can_include_weights``): - an unpatterned model warm, or an explicit weight request (``allow_patterns=['model.safetensors']``) - that came back with no weight, is a stale config-only snapshot and is retried. A genuinely weightless - request (``allow_patterns=['tokenizer*']`` / ``['*.json']``) reads weightless there, so its valid - no-weight result is accepted rather than failed. - - It also rejects an interrupted CANONICAL sharded warm (loose ``model-00001-of-00002.safetensors`` - without its index or a sibling shard) for an unpatterned request: that layout has a loadable weight - file but a default load still cannot read it and would fetch the rest over un-killable Xet.""" + LENIENT otherwise -- a finished diffusers / variant / either-format download passes, and an + OPTIONAL file simply absent from the repo is not treated as missing -- so a good download is never + failed and re-looped into a ``DownloadStallError``. + + Positive-breakage checks: + - Any dangling REQUESTED symlink (a missing / still-``.incomplete`` blob). + - Every EXACT-named requested file present (grouped by weight equivalence, so the either-format + ``["pytorch_model.bin", "model.safetensors"]`` pair needs only one, but a base weight AND an + ``adapter_model.safetensors``, or a ``["tokenizer.json"]`` config request, must each be present). + A glob allow list cannot be turned into an exact manifest, so it stays lenient there. + - A weight-bearing MODEL request that came back with no usable weight. For an UNPATTERNED warm the + weight must be ROOT-readable (or a diffusers component) -- a stale ``checkpoint-7/``-only snapshot + does not count, since a default load ignores it -- and an interrupted CANONICAL sharded warm + (loose ``model-00001-of-00002.safetensors`` with no index) is rejected. A patterned weight request + must have a weight WITHIN its requested scope (not a stale out-of-scope ``.bin`` / checkpoint).""" if snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, repo_type = repo_type, ): return False - if repo_type in (None, "model"): - if ( - request_can_include_weights(allow_patterns, ignore_patterns) - and not _has_any_weight(snapshot_dir) + if not _requested_exact_files_present_grouped( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ): + return False + if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): + if allow_patterns is None: + # Default root load: a root (or diffusers-component) weight, sharded set complete. + if not _root_model_has_weight(snapshot_dir): + return False + if _has_incomplete_canonical_root_shards(snapshot_dir): + return False + elif not _has_selected_weight( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns ): - return False - if allow_patterns is None and _has_incomplete_canonical_root_shards(snapshot_dir): + # Patterned weight request: a weight WITHIN the requested scope must be present. return False return True @@ -1166,10 +1338,14 @@ def _download_with_xet_fallback( # over a sparse Xet/hf_transfer partial silently corrupts the blob. # The generic purge is cache_dir-aware; an injected (Studio) hook owns # its own cache accounting and keeps the (repo_type, repo_id) signature. + # The previous attempt's stall recorded the partials its child owned (if it could). + # Scope the cleanup to them so a concurrent same-repo sibling's partial is never purged. + owned_incomplete = params.pop("_owned_incomplete_blobs", None) try: if prepare_for_http_fn is None: _default_prepare_for_http( - repo_type, repo_id, cache_dir = cache_dir, active_grace = stall_timeout + repo_type, repo_id, cache_dir = cache_dir, active_grace = stall_timeout, + owned_incomplete_blobs = owned_incomplete, ) else: prepare_for_http_fn(repo_type, repo_id) From 6fe5acb4be29c97feda5272e63b7d5d278de78ed Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 29 Jun 2026 04:51:06 +0000 Subject: [PATCH 48/82] Trim comments to a succinct, load-bearing core Comment-only pass over the shared fallback. No code, control flow, or test behavior changes (AST-verified: only comments / docstrings / whitespace differ). - Collapse the verbose multi-line blocks in hf_cache_state.py and hf_xet_fallback.py to short explanations, keeping the load-bearing rationale (the pre/post asymmetry, the watchdog ownership scoping, the equivalence grouping, the un-killable-Xet gotchas) and dropping restatements of obvious code. - Tighten the public wrapper and helper docstrings. - Drop the review-round prefixes ("Round-2 F1", "Round-3 A", ...) from the test docstrings so they describe what each test verifies, not the review history. Suite: 133 passed / 1 skipped; ruff clean. --- tests/test_hf_xet_fallback.py | 72 ++-- unsloth_zoo/hf_cache_state.py | 399 +++++++------------ unsloth_zoo/hf_xet_fallback.py | 677 ++++++++++++--------------------- 3 files changed, 405 insertions(+), 743 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 53092ca27..aff3a3c29 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -40,12 +40,10 @@ def _load(name: str, filename: str): # A package placeholder so ``from unsloth_zoo.hf_cache_state import ...`` inside hf_xet_fallback -# resolves to the file we load below, not the installed package. RESTORE sys.modules afterwards: -# leaving the placeholder (and the two submodule entries _load installs) in sys.modules would shadow -# the REAL unsloth_zoo for the rest of the pytest process -- its __init__ never runs -- so a later -# test importing unsloth_zoo (e.g. unsloth_zoo.FORCE_FLOAT32) would fail. The two loaded modules keep -# their own bound references (their intra-package import resolved during exec), so they work after -# the placeholder is removed (Codex #829). +# resolves to the file we load below, not the installed package. RESTORE sys.modules afterwards: a +# leftover placeholder would shadow the REAL unsloth_zoo (its __init__ never runs) and fail a later +# test that imports it. The two loaded modules keep their own bound references, so they work after +# the placeholder is removed. _saved_modules = { name: sys.modules.get(name) for name in ("unsloth_zoo", "unsloth_zoo.hf_cache_state", "unsloth_zoo.hf_xet_fallback") @@ -472,9 +470,9 @@ def test_prepare_for_http_clears_broken_snapshot_symlink(tmp_path): def test_prepare_for_http_spares_concurrent_sibling_active_symlink(tmp_path): - """Round-2 F1: HTTP prep must NOT delete a concurrent sibling download's dangling snapshot - symlink while that sibling is still writing the target blob (a fresh .incomplete partner exists). - Our own stale interrupted link (no .incomplete partner) is still cleared in the same sweep.""" + """HTTP prep must NOT delete a concurrent sibling's dangling snapshot symlink while that sibling is + still writing the target blob (a fresh .incomplete partner exists). Our own stale interrupted link + (no .incomplete partner) is still cleared in the same sweep.""" repo = "ztest/concurrent" repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" blobs = repo_dir / "blobs" @@ -2338,9 +2336,8 @@ def test_hfvalidationerror_type_preserved_across_spawn(): def test_weight_pattern_selector_handles_globs(tmp_path): - """Round-2 F1(round1)/F3/F6: the weight-pattern selector reads tokenizer / config / json globs as - weightless (keeps their offline short-circuit) while classifying every standard weight name and - single-char (?) / class ([]) globs as weight-bearing.""" + """The weight-pattern selector reads tokenizer / config / json globs as weightless (keeps their + offline short-circuit) but classifies standard weight names and ? / [] globs as weight-bearing.""" weightless = ["tokenizer*", "*.json", "config.json", "tokenizer.model", "*.txt"] weighty = [ "model.safetensors", "*.safetensors", "model.?afetensors", "model.[sp]afetensors", @@ -2353,9 +2350,8 @@ def test_weight_pattern_selector_handles_globs(tmp_path): def test_post_download_rejects_config_only_for_explicit_weight_pattern(tmp_path): - """Round-2 F3: an explicit weight request (allow=['model.safetensors']) that came back with only - config.json is a stale config-only snapshot and must be rejected (retry over HTTP), NOT accepted. - A genuinely weightless patterned request stays accepted (test_post_download_accepts_weightless...).""" + """An explicit weight request (allow=['model.safetensors']) returning only config.json is a stale + config-only snapshot: reject and retry over HTTP. A weightless patterned request stays accepted.""" snap, _ = _mk_snapshot(tmp_path, "patcfg") (snap / "config.json").write_text("{}") assert xf._download_result_usable( @@ -2366,10 +2362,9 @@ def test_post_download_rejects_config_only_for_explicit_weight_pattern(tmp_path) def test_post_download_rejects_incomplete_canonical_root_shards(tmp_path): - """Round-2 F2: an interrupted canonical sharded warm (a loose model-00001-of-00002.safetensors - with no index / missing sibling) has a loadable weight file but a default load cannot read it and - would fetch the rest over un-killable Xet, so it is rejected post-download. A complete sharded set - is accepted; a variant-only shard layout is not force-failed (it simply has no canonical shard).""" + """An interrupted canonical sharded warm (loose model-00001-of-00002.safetensors, no index) has a + loadable file but a default load cannot read it and would fetch the rest over un-killable Xet, so + it is rejected. A complete sharded set is accepted; a variant-only shard layout is not force-failed.""" snap, blob = _mk_snapshot(tmp_path, "incshard") (snap / "config.json").write_text("{}") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) @@ -2391,8 +2386,8 @@ def test_post_download_rejects_incomplete_canonical_root_shards(tmp_path): def test_local_token_not_found_error_type_preserved(): - """Round-2 F4: a missing required token fails identically over either transport, so - LocalTokenNotFoundError is deterministic and its type is reconstructed across the spawn boundary.""" + """A missing required token fails identically over either transport, so LocalTokenNotFoundError is + deterministic and its type is reconstructed across the spawn boundary.""" assert "LocalTokenNotFoundError" in xf._DETERMINISTIC_ERROR_NAMES cls = xf._resolve_exception_class("LocalTokenNotFoundError") assert cls is not None and issubclass(cls, BaseException) @@ -2401,9 +2396,8 @@ def test_local_token_not_found_error_type_preserved(): def test_metadata_directory_pattern_is_weightless(tmp_path): - """Round-3 A: a trailing-slash metadata directory pattern (allow=['tokenizer/']) reads weightless - so a complete tokenizer-only download is accepted, not looped into a DownloadStallError. A - component / checkpoint directory pattern stays conservatively weight-bearing.""" + """A trailing-slash metadata dir pattern (allow=['tokenizer/']) reads weightless, so a complete + tokenizer-only download is accepted. Component / checkpoint dir patterns stay weight-bearing.""" assert hcs.request_can_include_weights(["tokenizer/"], None) is False assert hcs.request_can_include_weights(["processor/"], None) is False assert hcs.request_can_include_weights(["unet/"], None) is True @@ -2416,9 +2410,9 @@ def test_metadata_directory_pattern_is_weightless(tmp_path): def test_post_download_rejects_checkpoint_only_root_model(tmp_path): - """Round-3 B (over-accept): a stale snapshot whose only weight is under checkpoint-7/ is rejected - for an unpatterned root warm -- a default from_pretrained ignores checkpoint-*/ and would fetch the - missing root weights over un-killable Xet. The same checkpoint is accepted when explicitly scoped.""" + """A stale snapshot whose only weight is under checkpoint-7/ is rejected for an unpatterned root + warm -- a default from_pretrained ignores checkpoint-*/ and would fetch the missing root weights + over un-killable Xet. The same checkpoint is accepted when explicitly scoped.""" snap, blob = _mk_snapshot(tmp_path, "ckonly") (snap / "config.json").write_text("{}") (snap / "checkpoint-7").mkdir() @@ -2437,9 +2431,8 @@ def test_post_download_rejects_checkpoint_only_root_model(tmp_path): def test_post_download_validates_weightless_named_subset(tmp_path): - """Round-3 C: an exact weightless request (allow=['tokenizer.json'], or a dataset file) that came - back as a stale config-only snapshot missing the named file is rejected and retried. A glob allow - list stays lenient (cannot be turned into an exact manifest).""" + """An exact weightless request (allow=['tokenizer.json'], or a dataset file) returning a stale + snapshot missing the named file is rejected and retried. A glob allow list stays lenient.""" snap, _ = _mk_snapshot(tmp_path, "noname") (snap / "config.json").write_text("{}") assert xf._download_result_usable( @@ -2455,10 +2448,10 @@ def test_post_download_validates_weightless_named_subset(tmp_path): def test_post_download_rejects_missing_exact_weight_request(tmp_path): - """Round-3 F2: an exact weight request whose file is missing is rejected even when a different - weight is present -- allow=['adapter_model.safetensors'] is NOT satisfied by a stale base - model.safetensors, and ['model.safetensors','adapter_model.safetensors'] needs both. The classic - either-format ['model.safetensors','pytorch_model.bin'] pair stays satisfied by one (equivalence).""" + """An exact weight request whose file is missing is rejected even when a different weight is present: + allow=['adapter_model.safetensors'] is NOT satisfied by a stale base model.safetensors, and + ['model.safetensors','adapter_model.safetensors'] needs both. The either-format + ['model.safetensors','pytorch_model.bin'] pair stays satisfied by one (equivalence).""" base, blob = _mk_snapshot(tmp_path, "baseonly") (base / "model.safetensors").symlink_to(blob) assert xf._download_result_usable( @@ -2478,9 +2471,8 @@ def test_post_download_rejects_missing_exact_weight_request(tmp_path): def test_dataset_unpatterned_or_glob_partial_does_not_skip_child(tmp_path): - """Round-3 F3: a dataset/space snapshot whose completeness cannot be proven from local files - (allow=None or a glob) must defer to the watched child -- a partial cache must not be returned as - complete. An intact exact-named subset still short-circuits.""" + """A dataset/space snapshot whose completeness cannot be proven from local files (allow=None or a + glob) must defer to the watched child. An intact exact-named subset still short-circuits.""" snap, _ = _mk_snapshot(tmp_path, "dspart") (snap / "README.md").write_text("partial") assert xf._cache_can_skip_download( @@ -2492,9 +2484,9 @@ def test_dataset_unpatterned_or_glob_partial_does_not_skip_child(tmp_path): def test_http_prep_scopes_blob_cleanup_to_owned_partials(tmp_path): - """Round-3 F1: HTTP prep must purge only the stalled child's OWN partials, never a concurrent - same-repo sibling's blob (multi-rank). With ownership known, a sibling's aged partial and its - dangling link are spared; with ownership unknown (None), the coarser mtime guard purges both.""" + """HTTP prep must purge only the stalled child's OWN partials, never a concurrent same-repo + sibling's blob (multi-rank). With ownership known, a sibling's aged partial and dangling link are + spared; with ownership unknown (None), the coarser mtime guard purges both.""" repo = "ztest/concurrent-blobs" repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" blobs = repo_dir / "blobs" diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 2cd77f4c6..3ef491ae6 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -10,25 +10,18 @@ """Sparse-aware introspection of the active Hugging Face hub cache. -These helpers answer two questions for a repo's blobs under ``HF_HUB_CACHE``: -how many bytes are actually on disk (sparse-aware, so a partially written Xet / -``hf_transfer`` ``.incomplete`` is not mistaken for full-size progress) and -whether an ``.incomplete`` partial is present. The no-progress download watchdog -is built on exactly these two signals. - -The completeness check here is intentionally a CONSERVATIVE fast-path gate, not an -authoritative snapshot verifier. It returns "complete" only for the unambiguous -canonical model-cache layouts whose local evidence proves an in-process load will -not fetch a weight. Everything else (diffusers pipelines, weight variants, -non-trivial allow/ignore patterns, datasets, any layout needing inference) returns -"not complete" so the caller runs the authoritative Hugging Face download/resume in -the watched child. Returning a false "complete" is the only dangerous error (it can -send an in-process load to fetch a missing weight over un-killable Xet); returning a -false "not complete" only spawns the cheap watched child, so the gate errs that way. - -Only the single active cache root (``huggingface_hub.constants.HF_HUB_CACHE``) is -scanned here; multi-root / legacy-cache enumeration and transport-marker logic -are download-manager concerns that live in the consumer, not in this module. +These helpers report, for a repo's blobs under ``HF_HUB_CACHE``, how many bytes are actually on disk +(sparse-aware, so a partial Xet / ``hf_transfer`` ``.incomplete`` is not read as full progress) and +whether an ``.incomplete`` partial is present -- the two signals the no-progress watchdog runs on. + +``snapshot_dir_is_complete`` is a CONSERVATIVE fast-path gate, not an authoritative verifier: it +returns "complete" only for unambiguous canonical model layouts, and defers everything else +(diffusers, variants, patterns, datasets) to the watched ``snapshot_download`` child. A false +"complete" is the only dangerous error (an in-process load could then fetch a missing weight over +un-killable Xet); a false "not complete" only spawns the cheap child, so the gate errs that way. + +Only the active ``HF_HUB_CACHE`` root is scanned; multi-root / transport-marker logic is a +download-manager concern that lives in the consumer. """ from __future__ import annotations @@ -60,9 +53,7 @@ def _safe_is_dir(path: Path) -> bool: - """``Path.is_dir()`` returning False instead of raising when the path or a - parent is unreadable (e.g. a restricted ``~/.cache/huggingface/hub``), so - cache enumeration skips that root rather than erroring.""" + """``Path.is_dir()`` that returns False instead of raising on an unreadable path.""" try: return path.is_dir() except OSError: @@ -70,8 +61,7 @@ def _safe_is_dir(path: Path) -> bool: def _safe_is_file(path: Path) -> bool: - """``Path.is_file()`` (follows symlinks) returning False instead of raising on an - unreadable path or a dangling link, so snapshot enumeration never errors out.""" + """``Path.is_file()`` that returns False instead of raising on an unreadable / dangling path.""" try: return path.is_file() except OSError: @@ -79,18 +69,11 @@ def _safe_is_file(path: Path) -> bool: def hf_cache_root(*, create: bool = False, cache_dir: "Optional[str | Path]" = None) -> Optional[Path]: - """The hub cache root to scan, or None if unavailable. - - When *cache_dir* is given (a caller-supplied ``snapshot_download`` cache), it - is used verbatim; otherwise the active ``HF_HUB_CACHE`` is read lazily so any - redirect applied at import time (see - ``unsloth_zoo.hf_cache.redirect_hf_cache_if_readonly``) is honored. - """ + """The hub cache root to scan, or None if unavailable. A given *cache_dir* is used verbatim; + otherwise ``HF_HUB_CACHE`` is read lazily so an import-time redirect is honored.""" if cache_dir is not None: - # Match huggingface_hub, which expands ~ before writing; scanning the - # literal path would otherwise miss a partial under e.g. ~/hf-cache. - # Path.expanduser() raises RuntimeError when no home can be resolved (a restricted - # container with HOME unset); fall back to the literal path rather than crash the probe. + # Match huggingface_hub, which expands ~ before writing. expanduser() raises if no home can + # be resolved (HOME unset in a container); fall back to the literal path rather than crash. try: root = Path(cache_dir).expanduser() except (RuntimeError, OSError): @@ -115,9 +98,7 @@ def target_dir_name(repo_type: Optional[str], repo_id: str) -> str: def repo_cache_dir_name(repo_type: Optional[str], repo_id: str) -> str: - # Hugging Face treats repo_type=None as the default "model"; mirror that here - # so a caller forwarding repo_type=None still resolves models-- (not - # Nones--, which would make the cache-state probe miss real partials). + # repo_type=None is HF's default "model"; mirror that so None resolves models--, not Nones--. repo_type = repo_type or "model" return f"{repo_type}s--{repo_id.replace('/', '--')}" @@ -133,17 +114,14 @@ def _blob_dir_is_partial(blobs_dir: Path) -> bool: def blob_bytes_present(path: Path) -> int: - """Sparse-aware on-disk size: XET / ``hf_transfer`` ``.incomplete`` partials - report a full ``st_size`` while only some blocks are allocated, so prefer - ``st_blocks``, falling back to ``st_size`` where it is unreported (Windows, - some network filesystems).""" + """Sparse-aware on-disk size: a Xet / ``hf_transfer`` ``.incomplete`` reports full ``st_size`` + while only some blocks are allocated, so prefer ``st_blocks``, falling back to ``st_size`` where + it is unreported (Windows, some network filesystems).""" st = path.stat() blocks = getattr(st, "st_blocks", None) if blocks is not None: - # st_blocks is reported (POSIX): trust it even when 0. A freshly truncated - # sparse .incomplete reports st_size == full but 0 allocated blocks, and - # must count as 0 bytes present, not full size (a > 0 guard would fall - # through to st_size and read an empty partial as complete). + # Trust st_blocks even when 0: a truncated sparse .incomplete reports full st_size but 0 + # blocks and must read as 0 bytes present (a > 0 guard would fall through to st_size). return min(blocks * 512, st.st_size) if sys.platform == "win32": allocated = _windows_allocated_size(path) @@ -179,11 +157,7 @@ def _windows_allocated_size(path: Path) -> Optional[int]: def latest_snapshot_dir(repo_dir: Path) -> Optional[Path]: - """Newest immediate child of ``repo_dir/snapshots`` by mtime, or None. - - mtime is the signal huggingface_hub's from_pretrained resolves to, so this - points at whatever snapshot most recently landed on disk. - """ + """Newest child of ``repo_dir/snapshots`` by mtime (the signal from_pretrained resolves to), or None.""" snapshots_dir = repo_dir / "snapshots" try: if not snapshots_dir.is_dir(): @@ -197,10 +171,8 @@ def latest_snapshot_dir(repo_dir: Path) -> Optional[Path]: def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: - """True if *snapshot_dir* contains a dangling symlink -- a file the snapshot - references whose blob is missing or still an ``.incomplete`` partial, i.e. an - interrupted download. Used to validate one specific (caller-requested) - revision, not just the newest one on disk.""" + """True if *snapshot_dir* holds a dangling symlink (a referenced blob that is missing or still + ``.incomplete``) -- an interrupted download. Validates one requested revision, not just the newest.""" try: for entry in snapshot_dir.rglob("*"): if entry.is_symlink() and not entry.exists(): @@ -211,11 +183,9 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: # --------------------------------------------------------------------------- -# Weight-file recognition (small helpers the conservative completeness gate needs) +# Weight-file recognition # --------------------------------------------------------------------------- -# Model weight file extensions. A snapshot with none of these is config/tokenizer -# only (e.g. a prior AutoConfig fetch), so it is not a warm cache for a weight load. _WEIGHT_FILE_SUFFIXES = ( ".safetensors", ".bin", @@ -229,9 +199,8 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: ".pdparams", ) -# Trainer / optimizer state files carry weight suffixes (.bin / .pt / .pth) but are NOT -# loadable model weights. A checkpoint dir or a patterned pull can leave only these behind, -# so they must not count as a model weight on disk. +# Trainer / optimizer state carries weight suffixes (.bin / .pt / .pth) but is NOT a loadable weight, +# so a cache holding only these is not a warm model cache. _NON_WEIGHT_BASENAMES = frozenset({ "training_args.bin", "optimizer.bin", @@ -242,15 +211,13 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: "rng_state.pt", "rng_state.pth", }) -# Distributed trainer runs shard the RNG state as rng_state_0.pth, rng_state_1.pth, ... +# Distributed runs shard the RNG state as rng_state_0.pth, rng_state_1.pth, ... _NON_WEIGHT_BASENAME_PREFIXES = ("rng_state_",) def _is_loadable_weight_file(name: str) -> bool: - """True if *name* is a loadable model-weight file: a recognized weight suffix that is - not a known trainer / optimizer state artifact (training_args.bin, optimizer.pt, - scheduler.pt, rng_state.pth, ...). Those share weight suffixes but are not model - weights, so a cache holding only them is not a warm model cache.""" + """True if *name* is a loadable model weight: a weight suffix that is not a trainer / optimizer + state artifact (training_args.bin, optimizer.pt, rng_state.pth, ...).""" if not name.endswith(_WEIGHT_FILE_SUFFIXES): return False lowered = name.lower() @@ -262,35 +229,25 @@ def _is_loadable_weight_file(name: str) -> bool: def _is_weight_shard_index(name: str) -> bool: - """True if *name* is a weight-shard index sidecar: the canonical - ``model.safetensors.index.json`` / ``pytorch_model.bin.index.json`` AND the variant form - ``model.safetensors.index.fp16.json`` (transformers' ``_add_variant`` inserts the variant token - before the trailing ``.json``). A plain ``*.safetensors.index.json`` suffix test misses the - variant form, leaving its listed shards unvalidated.""" + """True for a weight-shard index sidecar, canonical or variant (``model.safetensors.index.json`` + and ``model.safetensors.index.fp16.json``); a plain suffix test would miss the variant form.""" return name.endswith(".json") and (".safetensors.index." in name or ".bin.index." in name) def _is_canonical_weight_shard_index(name: str) -> bool: - """True only for the CANONICAL (non-variant) shard index a default in-process load probes: - ``model.safetensors.index.json`` / ``pytorch_model.bin.index.json`` (any stem). A variant form - such as ``model.safetensors.index.fp16.json`` ends in ``.index.fp16.json`` and is rejected: the - fallback wrapper takes no variant parameter, so a default ``from_pretrained`` reads the canonical - index, and a variant-only cache must NOT satisfy the canonical fast path (its canonical weights - are still missing -- skipping the child there would reintroduce the unprotected in-process fetch - this gate prevents).""" + """True only for the CANONICAL (non-variant) index a default load probes + (``model.safetensors.index.json`` / ``pytorch_model.bin.index.json``). A variant + (``...index.fp16.json``) is rejected: the wrapper takes no variant param, so a variant-only cache + must not satisfy the canonical fast path (its canonical weights are still missing).""" return name.endswith(".safetensors.index.json") or name.endswith(".bin.index.json") def _weight_shard_index_complete(index_path: Path) -> bool: - """True only if every shard a HF weight index (``model.safetensors.index.json`` / - ``pytorch_model.bin.index.json``) lists is present next to the index. - - Fail-CLOSED: an unreadable / truncated index, a non-dict payload, a missing or non-dict - ``weight_map``, or an empty shard set all return False. This function feeds the fast-path - completeness gate, where a malformed index proves nothing -- treating it as complete would let - the in-process load skip the protective child and then fail (or fetch over Xet) on a truncated - index, so the safe direction is to defer such an index to the watched ``snapshot_download`` - child. Only an index whose every listed shard is demonstrably on disk returns True.""" + """True only if every shard a HF weight index lists is present next to it. + + Fail-CLOSED: an unreadable / truncated index, a non-dict payload or ``weight_map``, or an empty + shard set return False, so a malformed index defers to the watched child rather than letting the + in-process load skip it and then fail (or fetch over Xet).""" import json try: @@ -301,13 +258,10 @@ def _weight_shard_index_complete(index_path: Path) -> bool: weight_map = data.get("weight_map") if isinstance(data, dict) else None if not isinstance(weight_map, dict): return False - # weight_map values are filenames relative to the index file's own directory. They come from - # arbitrary JSON: a non-string (e.g. list/dict) value is both unhashable -- so it would break - # set() -- and invalid for ``base / shard``, so filter to strings BEFORE de-duplicating rather - # than crash. + # A non-string value is unhashable (breaks set()) and invalid for ``base / shard``; filter first. shards = {s for s in weight_map.values() if isinstance(s, str)} if not shards: - return False # an empty / all-non-string weight_map cannot prove a complete shard set + return False base = index_path.parent for shard in shards: try: @@ -319,24 +273,21 @@ def _weight_shard_index_complete(index_path: Path) -> bool: # --------------------------------------------------------------------------- -# Pattern helpers (kept small: normalization + glob detection + HF filtering) +# Pattern helpers (normalization + glob detection + HF filtering) # --------------------------------------------------------------------------- _GLOB_CHARS = ("*", "?", "[") def _has_glob(text: str) -> bool: - # A trailing-slash directory pattern ("unet/", "checkpoint-10/") is NOT an exact filename: - # Hugging Face's filter_repo_objects expands it to match everything under that directory (as - # if "unet/*"). Treat it as a wildcard so the strict exact-name checks do not look for a - # literal "unet/" entry and wrongly reject a fully cached directory / component download. + # A trailing-slash dir pattern ("unet/") is not an exact filename: HF expands it like "unet/*", + # so treat it as a wildcard rather than look for a literal "unet/" entry. return text.endswith("/") or any(ch in text for ch in _GLOB_CHARS) def _as_pattern_list(patterns: "Optional[object]") -> "Optional[list]": - """Normalize an allow / ignore pattern argument to a list. Hugging Face accepts a bare - ``str`` as well as a list, and iterating the ``str`` form would walk it character by - character (so ``"checkpoint-10/*"`` would never match), misclassifying the request.""" + """Normalize an allow / ignore argument to a list. HF accepts a bare ``str``; iterating it would + walk it character by character ("checkpoint-10/*" would never match).""" if patterns is None: return None if isinstance(patterns, str): @@ -349,9 +300,9 @@ def _filter_paths( allow_patterns: "Optional[list]" = None, ignore_patterns: "Optional[list]" = None, ) -> list: - """Filter repo-relative *paths* by Hugging Face allow / ignore patterns, mirroring how - ``snapshot_download`` selects files. On any failure, treat all paths as selected so a - snapshot that does hold weights is never rejected for an unevaluable filter.""" + """Filter repo-relative *paths* by HF allow / ignore patterns (as ``snapshot_download`` selects). + Fails OPEN (returns all paths) so a snapshot that does hold weights is never rejected on an + unevaluable filter.""" try: from huggingface_hub.utils import filter_repo_objects @@ -365,10 +316,8 @@ def _filter_paths( def _broken_symlink_rel_paths(snapshot_dir: Path) -> list: - """Repo-relative posix paths of every dangling symlink in *snapshot_dir* -- a referenced file - whose blob is missing or still an ``.incomplete`` partial (an interrupted download). Empty when - none. Lets the broken-symlink check scope the interrupted-download signal to the files a request - actually selects, rather than rejecting the whole snapshot for a dangle outside the request.""" + """Repo-relative posix paths of every dangling symlink in *snapshot_dir* (empty when none), so the + interrupted-download signal can be scoped to the files a request actually selects.""" out: list = [] try: for entry in snapshot_dir.rglob("*"): @@ -392,14 +341,9 @@ def snapshot_has_requested_broken_symlinks( ignore_patterns: "Optional[object]" = None, repo_type: "Optional[str]" = "model", ) -> bool: - """True iff a dangling symlink in *snapshot_dir* is for a file the request actually SELECTS. - - A dangling symlink marks an interrupted download, but for a scoped request only one for a - requested file should reject the snapshot: a dangling root ``model.safetensors`` left by an - earlier interrupted pull must not fail a weightless ``allow_patterns=["config.json"]`` request - whose config is on disk. The allow / ignore filter mirrors ``snapshot_download`` selection, so a - dangle for an excluded file does not reject the cache. (``repo_type`` is accepted for signature - compatibility; the scoping is now purely the allow/ignore filter.)""" + """True iff a dangling symlink in *snapshot_dir* is for a file the request actually SELECTS, so a + dangling root ``model.safetensors`` does not fail a weightless ``allow=["config.json"]`` request + whose config is on disk. (*repo_type* is kept for signature compatibility.)""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) broken = _broken_symlink_rel_paths(snapshot_dir) @@ -412,19 +356,14 @@ def snapshot_has_requested_broken_symlinks( # The conservative fast-path completeness gate # --------------------------------------------------------------------------- -# Canonical root weight filenames an in-process model load reads. Used to prove a warm cache (the -# file or its shard index is present). +# Canonical root weight filenames a default load reads (the single file or its shard index proves warm). _CANONICAL_SINGLE_WEIGHTS = ("model.safetensors", "pytorch_model.bin") def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: - """True iff the ignore set provably excludes EVERY weight format: for each weight suffix there is - a pattern matching a representative filename with that suffix. Only then is an ignore-only request - weightless. A partial strip -- only some suffixes, or only the canonical ``model.safetensors`` / - ``pytorch_model.bin`` names while a variant (``model.fp16.safetensors``) or an other-format weight - (``model.gguf``, a ``*.pt`` checkpoint) survives -- is NOT weightless, so the request reads as - weight-bearing (conservative: never under-classify a request that could still pull a weight, which - would let the fast path skip the protective child on a config-only cache and hang on Xet).""" + """True iff the ignore set provably excludes EVERY weight format (a probe of each suffix matches a + pattern). A partial strip is NOT weightless -- a surviving variant / .gguf / .pt weight could + still be pulled, so the request stays weight-bearing (conservative).""" for suffix in _WEIGHT_FILE_SUFFIXES: probe = "weight" + suffix if not any(isinstance(p, str) and fnmatch.fnmatchcase(probe, p) for p in ignore_patterns): @@ -432,10 +371,9 @@ def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: return True -# Representative weight filenames a glob allow pattern is probed against (via fnmatch). A glob that -# matches one of these can select a weight; one that matches none (``tokenizer*``, ``*.json``) is -# weightless. Covers the canonical / variant / sharded / adapter / diffusers / mistral-consolidated -# and the non-safetensors weight formats so a real weight glob is never under-classified. +# Representative weight names a glob allow pattern is probed against (via fnmatch): a glob matching one +# can select a weight; one matching none (``tokenizer*``, ``*.json``) is weightless. Covers canonical / +# variant / sharded / adapter / diffusers / consolidated and the non-safetensors formats. _WEIGHT_PATTERN_PROBES = ( "model.safetensors", "model.fp16.safetensors", @@ -456,11 +394,8 @@ def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: "flax_model.msgpack", ) -# Snapshot subdirectories that hold only metadata / config (never a loadable model weight), so a -# trailing-slash directory pattern scoped to one of them (``allow_patterns=['tokenizer/']``) is -# weightless. Any OTHER directory pattern stays conservatively weight-bearing: a component dir -# (``unet/``, ``vae/``) or a training-checkpoint dir (``checkpoint-10/``) can hold a weight, so the -# fast path must not skip the child on it. +# Subdirs that hold only metadata / config, so a ``dir/`` pattern scoped to one is weightless. Any +# other dir pattern stays weight-bearing (a component dir or a checkpoint dir can hold a weight). _NON_WEIGHT_DIRS = frozenset({ "tokenizer", "processor", @@ -473,33 +408,20 @@ def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: def _pattern_can_select_weight(pattern: "object") -> bool: - """Whether a single allow pattern could select a model weight file. - - - a non-string (unknown shape) -> conservative True; - - a bare directory pattern (``unet/``) -> True (expands to everything under it, incl. weights); - - a basename ending in a weight suffix (``*.safetensors``, ``model.gguf``) -> True; - - a glob basename (``model.?afetensors``, ``model.[sp]afetensors``, ``*``) -> True iff it matches a - representative weight name in ``_WEIGHT_PATTERN_PROBES`` -- so ``tokenizer*`` / ``*.json`` read - weightless and keep their offline short-circuit, while ``model.?afetensors`` / ``unet/*`` are - weight-bearing; - - a concrete non-weight name (``config.json``, ``tokenizer.model``) -> False. - - A glob is matched on its basename so ``checkpoint-*/model.?afetensors`` is still recognized. Both - directions are bounded: a false weight-bearing only makes the pre-download gate spawn the cheap - child; a false weightless is avoided for every standard weight name by the probe set.""" + """Whether a single allow pattern could select a weight. A weight-suffix basename or a non-metadata + directory pattern is weight-bearing; a glob basename is weight-bearing only if it matches a + ``_WEIGHT_PATTERN_PROBES`` name (so ``tokenizer*`` / ``*.json`` stay weightless while + ``model.?afetensors`` / ``unet/*`` do not); a concrete non-weight name is weightless. A false + weight-bearing only spawns the cheap child; the probe set avoids a false weightless on real weights.""" if not isinstance(pattern, str): - return True # unknown shape -> conservative + return True if pattern.endswith("/"): - # A bare directory pattern expands to everything under it. A known metadata dir holds no - # weight (so it stays weightless and keeps its offline short-circuit); any other dir could. dir_name = pattern.rstrip("/").rsplit("/", 1)[-1].lower() return dir_name not in _NON_WEIGHT_DIRS base = pattern.rsplit("/", 1)[-1] if base.endswith(_WEIGHT_FILE_SUFFIXES): - return True # a concrete or wildcard-stem weight suffix + return True if any(ch in base for ch in _GLOB_CHARS): - # A glob basename selects a weight only if it can actually match a weight filename. This keeps - # tokenizer / config globs weightless while catching single-char (?) and class ([]) globs. return any(fnmatch.fnmatchcase(probe, base) for probe in _WEIGHT_PATTERN_PROBES) return False @@ -507,45 +429,31 @@ def _pattern_can_select_weight(pattern: "object") -> bool: def request_can_include_weights( allow_patterns: "Optional[object]" = None, ignore_patterns: "Optional[object]" = None ) -> bool: - """Whether a request restricted by *allow_patterns* / *ignore_patterns* can still include a model - weight. Used to pick the weight-requiring vs weightless branch of the acceptance check. - - Conservative by design: when uncertain it returns True (treat the request as weight-bearing), so - the acceptance check requires a weight and never short-circuits a config-only cache for a real - weight load. It returns False only when the request is clearly weightless (a tokenizer / config - allow list that matches no weight name, or an ignore list that drops every weight format), which - preserves the offline short-circuit for a genuine tokenizer-only warm.""" + """Whether a request restricted by *allow_patterns* / *ignore_patterns* can still include a weight. + Conservative: True when uncertain, so the acceptance check requires a weight; False only for a + clearly weightless request (a tokenizer / config allow list, or an ignore list dropping every + weight format), which preserves the offline short-circuit for a tokenizer-only warm.""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) if allow_patterns is None and ignore_patterns is None: return True if allow_patterns is None: - # Ignore-only request: weight-bearing unless the ignore list strips every weight format. return not _ignore_strips_all_weights(ignore_patterns or []) if not allow_patterns: - # allow_patterns=[] selects nothing -> no weight (HF filter selects no objects). - return False - # An allow list includes weights iff SOME pattern could select a weight (wildcard basename, - # weight-suffix basename, or a bare directory pattern). A list of only concrete non-weight names - # (a tokenizer / config warm) is weightless and keeps its offline short-circuit. + return False # allow=[] selects nothing return any(_pattern_can_select_weight(pat) for pat in allow_patterns) def _canonical_root_weights_complete( snapshot_dir: Path, entries: list, ignore_patterns: "Optional[list]" = None ) -> bool: - """True iff the snapshot holds a complete canonical ROOT model weight set: a root - ``model.safetensors`` / ``pytorch_model.bin`` single file, OR a root weight-shard index whose - every listed shard is present. Numbered shard files without a valid index, or weights that live - only in a subfolder, do NOT count -- those are deferred to the watched child. - - A root weight (or weight-shard index) whose FORMAT the request's ignore filter drops does NOT - count: a stale ``pytorch_model.bin`` under ``ignore=['*.bin']`` is not proof that the - safetensors weights an in-process load (e.g. ``use_safetensors=True``) will actually read are on - disk, so it must not let the fast path skip the protective child and then hang fetching the real - weight over Xet. The surviving-format check uses a representative weight name per format, so a - ``*.bin`` ignore also discards a ``pytorch_model.bin.index.json`` (whose ``.json`` sidecar name - would otherwise slip past the filter).""" + """True iff the snapshot holds a complete canonical ROOT weight set: a root + ``model.safetensors`` / ``pytorch_model.bin``, OR a root shard index whose every shard is present. + Numbered shards without an index, or subfolder-only weights, do NOT count. + + A weight whose FORMAT the ignore filter drops does not count (a stale ``pytorch_model.bin`` under + ``ignore=['*.bin']`` is not proof the requested safetensors are on disk). The format probe also + discards a ``pytorch_model.bin.index.json`` whose ``.json`` name would slip the raw filter.""" root_files: set = set() root_indices: list = [] for entry in entries: @@ -554,9 +462,7 @@ def _canonical_root_weights_complete( except ValueError: rel = entry.name if "/" in rel: - continue # a bare from_pretrained reads ROOT files only - # Only the CANONICAL (non-variant) index counts here: a default load probes - # model.safetensors.index.json, not a variant like model.safetensors.index.fp16.json. + continue # ROOT files only if _is_canonical_weight_shard_index(entry.name): if _safe_is_file(entry): root_indices.append(entry) @@ -564,14 +470,12 @@ def _canonical_root_weights_complete( root_files.add(entry.name) def _format_kept(weight_name: str) -> bool: - # The weight format an in-process load reads from *weight_name* must survive the request's - # ignore filter; otherwise the file is a stale artifact for an excluded format and proves - # nothing about what the load will fetch. + # The format a load reads from *weight_name* must survive the ignore filter, else the file is + # a stale artifact for an excluded format and proves nothing. if not ignore_patterns: return True return bool(_filter_paths([weight_name], None, ignore_patterns)) - # Sharded: a canonical root index whose format is kept and whose every listed shard is on disk. for index_entry in root_indices: fmt_probe = ( "model.safetensors" @@ -580,7 +484,6 @@ def _format_kept(weight_name: str) -> bool: ) if _format_kept(fmt_probe) and _weight_shard_index_complete(index_entry): return True - # Single-file canonical weight (the file itself must survive the ignore filter). return any( name in root_files and _format_kept(name) for name in _CANONICAL_SINGLE_WEIGHTS ) @@ -593,73 +496,49 @@ def snapshot_dir_is_complete( ignore_patterns: "Optional[object]" = None, require_named_weights: bool = False, ) -> bool: - """Conservative fast-path gate: True only when *snapshot_dir* is an unambiguously complete - canonical ROOT model cache, so an in-process load will not fetch any weight. - - This is intentionally NOT an authoritative snapshot verifier. It returns True only for: - - an UNPATTERNED request (allow_patterns is None; ignore_patterns are fine), - - that is not a diffusers pipeline (no root ``model_index.json``), - - with no dangling symlink (interrupted blob), - - whose canonical root weights are present (single file, or a shard index with every shard). - Every other layout -- variants, diffusers, datasets, any allow pattern, sharded weights without - an index -- returns False, deferring to the watched ``snapshot_download`` child (the authoritative - manifest compare + resume). A false True risks a silent un-killable Xet fetch during the in-process - load; a false False only spawns the cheap child. ``require_named_weights`` is accepted for signature - compatibility (a named-weight request is non-trivially patterned and so is never fast-pathed here). - - ``ignore_patterns`` need no eligibility gate: the canonical-weight presence check below verifies - what the in-process load actually reads (root ``model.safetensors`` / ``pytorch_model.bin`` or a - complete shard index) is on disk, so an ignore that dropped some weight format (the common - ``*.onnx`` / ``*.gguf`` / ``*.pt`` / ``*.bin`` prefetch ignores, or the subdir ``*/*.safetensors`` - drops) cannot make an incomplete cache read complete -- the surviving canonical weight is what is - checked. This keeps the common warm ``from_pretrained`` cache fast-path eligible.""" + """Conservative fast-path gate: True only for an unambiguously complete canonical ROOT model cache, + so an in-process load will not fetch a weight. True requires: an UNPATTERNED request + (``allow_patterns is None``), not a diffusers pipeline (no root ``model_index.json``), no dangling + symlink, and canonical root weights present. Everything else defers to the watched child. A false + True risks a silent Xet fetch; a false False only spawns the cheap child. *require_named_weights* + is accepted for signature compatibility (a named-weight request is patterned, so never fast-pathed). + + *ignore_patterns* need no eligibility gate: the canonical-weight check below is what the load reads, + so an ignore that dropped some format (the common ``*.bin`` / subdir prefetch ignores) cannot make + an incomplete cache read complete -- keeping the common warm ``from_pretrained`` cache eligible.""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) - # 1. Only an UNPATTERNED request is eligible. Any allow list scopes the on-disk set to a subset - # whose relationship to the in-process load is not locally provable -> defer to the child. if allow_patterns is not None: - return False + return False # any allow list scopes the on-disk set unprovably -> defer try: entries = list(snapshot_dir.rglob("*")) except OSError: return False - # 2. A diffusers pipeline (root model_index.json) needs component-completeness reasoning we do - # not fast-path -> defer to the child. if _safe_is_file(snapshot_dir / "model_index.json"): - return False - # 3. A dangling symlink = an interrupted blob (missing or still .incomplete) -> not complete. + return False # diffusers needs component reasoning we do not fast-path if snapshot_dir_has_broken_symlinks(snapshot_dir): - return False - # 4. Canonical root weights present and complete (a weight whose format the request ignores - # does not count -- see _canonical_root_weights_complete). + return False # interrupted blob return _canonical_root_weights_complete(snapshot_dir, entries, ignore_patterns) -# A canonical numbered weight shard at the snapshot root: the shard index sits IMMEDIATELY before the -# extension (no variant token), so ``model-00001-of-00002.safetensors`` matches but the variant -# ``model-00001-of-00002.fp16.safetensors`` does NOT. +# A canonical numbered root shard: the index sits IMMEDIATELY before the extension (no variant token), +# so ``model-00001-of-00002.safetensors`` matches but ``model-00001-of-00002.fp16.safetensors`` does not. _CANONICAL_ROOT_SHARD_RE = re.compile( r"^(?:model|pytorch_model)-\d{5}-of-\d{5}\.(?:safetensors|bin)$" ) def _has_incomplete_canonical_root_shards(snapshot_dir: Path) -> bool: - """True when the snapshot root holds canonical numbered weight shards - (``model-00001-of-00002.safetensors`` / ``pytorch_model-...bin``) but is NOT a complete canonical - model -- the shard index is missing or a listed shard is absent. - - Such a loose-shard layout is a stale / interrupted download: a default in-process load cannot read - bare numbered shards without their index and would fetch the rest over un-killable Xet, so the - post-download acceptance check rejects it and retries over HTTP. Variant shards - (``model-...fp16.safetensors``) are intentionally excluded -- they never satisfy a default load, so - a variant-only repo must not be force-failed here (it simply defers, like any non-canonical warm).""" + """True when the root holds canonical numbered shards but is NOT a complete canonical model (index + missing or a shard absent) -- a stale interrupted download a default load cannot read, so the + post-download check rejects it and retries over HTTP. Variant shards are excluded, so a + variant-only repo is not force-failed here.""" try: names = [entry.name for entry in snapshot_dir.iterdir()] except OSError: return False if not any(_CANONICAL_ROOT_SHARD_RE.match(name) for name in names): return False - # Canonical shards exist but no complete single-file / indexed canonical set covers them. return not snapshot_dir_is_complete(snapshot_dir) @@ -669,19 +548,10 @@ def requested_named_files_present( allow_patterns: "Optional[object]" = None, ignore_patterns: "Optional[object]" = None, ) -> bool: - """For a request that names EXACT files (every ``allow_patterns`` entry is glob-free), True only - when each named file the ignore filter keeps is on disk. - - ``snapshot_download(local_files_only=True)`` returns a snapshot dir whenever the revision folder - exists -- even a config-only one left by a prior ``AutoConfig`` fetch -- so for a weightless - request like ``allow_patterns=["tokenizer.json"]`` a dangling-symlink check alone would accept a - cache that does not actually contain the requested file. This makes that request require its - named file before the snapshot is treated as warm. - - A request with ANY glob, or with no ``allow_patterns``, is a best-effort "warm what matches" and - cannot be turned into an exact manifest (an optional ``vocab.txt`` the repo may simply lack would - wrongly fail it), so it is trivially satisfied here -- the weight-bearing requests are gated by - ``snapshot_dir_is_complete`` instead.""" + """For a request naming EXACT files (every entry glob-free), True only when each named file the + ignore filter keeps is on disk -- ``snapshot_download(local_files_only=True)`` returns the revision + dir even when config-only, so a ``["tokenizer.json"]`` request needs its file present. A request + with ANY glob, or no allow list, is trivially satisfied (it cannot be turned into an exact manifest).""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) if not allow_patterns or any(_has_glob(p) for p in allow_patterns): @@ -689,7 +559,7 @@ def requested_named_files_present( try: entries = list(snapshot_dir.rglob("*")) except OSError: - return True # cannot enumerate -> do not reject on an unreadable dir + return True present = set() for entry in entries: if _safe_is_file(entry): @@ -698,8 +568,7 @@ def requested_named_files_present( except ValueError: present.add(entry.name) for pat in allow_patterns: - # A named file the ignore filter drops is not actually requested. _filter_paths fails OPEN - # (returns all on error), so an unevaluable filter keeps the strict presence check. + # A named file the ignore filter drops is not actually requested. if ignore_patterns and not _filter_paths([pat], None, ignore_patterns): continue if pat not in present: @@ -723,10 +592,8 @@ def _iter_snapshot_dirs(repo_dir: Path) -> Iterator[Path]: def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: - # Check every snapshot, not just the newest by mtime: a caller may request an - # older revision whose snapshot is broken while a more recent one is clean, so - # a latest-only check would report the repo healthy and let the interrupted - # revision load with missing files. + # Check every snapshot, not just the newest: a requested older revision may be broken while a + # newer one is clean, and a latest-only check would report the repo healthy. return any( snapshot_dir_has_broken_symlinks(snapshot) for snapshot in _iter_snapshot_dirs(repo_dir) @@ -734,17 +601,13 @@ def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: str) -> list: - """Cache dirs that can be safely attributed to this exact repo id. - - The cache dir name is case-folded by the Hub, so a case-insensitive match is - needed for compatibility, but a bare case-insensitive match is unsafe: on a - case-sensitive filesystem ``models--Org--Repo`` and ``models--org--repo`` are - distinct repos. Prefer an exact-case match; otherwise accept a single folded - match ONLY when the filesystem is case-insensitive (so the folded dir really is - the same entry); on a 2+ way collision attribute to neither, so a stale partial - in one repo cannot be charged to the other (which would let the watchdog kill an - unrelated active download or HTTP-prep purge the wrong repo). - """ + """Cache dirs safely attributable to this exact repo id. + + The Hub case-folds the dir name, so a case-insensitive match is needed, but on a case-sensitive + filesystem ``models--Org--Repo`` and ``models--org--repo`` are distinct repos. Prefer an + exact-case match; otherwise accept a single folded match ONLY when the filesystem is + case-insensitive (the exact-case name resolves to it); on a 2+ way collision attribute to neither, + so a stale partial in one repo cannot make the watchdog kill / purge the other.""" target = repo_cache_dir_name(repo_type, repo_id) folded_target = target.lower() try: @@ -755,11 +618,8 @@ def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: st if exact: return exact if len(entries) == 1: - # A single folded-but-not-exact match. Attribute it to this repo only when - # the filesystem is case-insensitive: looking up the exact-case name then - # resolves to that same directory. On a case-sensitive filesystem the - # exact-case path does not exist, so the folded dir is a DIFFERENT repo and - # must not be charged here. + # Attribute a single folded-but-not-exact match only on a case-insensitive filesystem, where + # the exact-case path resolves to the same dir; on a case-sensitive fs it is a DIFFERENT repo. try: if (root / target).exists(): return entries @@ -771,11 +631,8 @@ def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: st def iter_active_repo_cache_dirs( repo_type: Optional[str], repo_id: str, *, cache_dir: "Optional[str | Path]" = None ) -> Iterator[Path]: - """Yield the repo's cache dir(s) under *cache_dir* (or the active ``HF_HUB_CACHE``). - - Case-collision safe (see ``_case_safe_repo_cache_dirs``), so both the read / - watchdog path and the destructive HTTP-prep path share one attribution rule. - """ + """Yield the repo's cache dir(s) under *cache_dir* (or the active ``HF_HUB_CACHE``). Case-collision + safe, so the read / watchdog path and the destructive HTTP-prep path share one attribution rule.""" root = hf_cache_root(cache_dir = cache_dir) if root is None: return diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 7e786027d..a4cdb9f45 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -9,26 +9,19 @@ """Xet-primary HF downloads with an automatic HTTP fallback on a no-progress stall. -Xet (``hf_xet``) is the fast default but can hang with no progress and no -exception, and a blocked native thread cannot be killed. Keep Xet primary; fall -back to plain HTTP only when the parent observes a stall. ``HF_HUB_DISABLE_XET`` -is read at import time, so the fallback runs in a fresh ``spawn`` child (not a -thread) that sets the env before importing ``huggingface_hub``. Cached files -short-circuit with no child; deterministic errors (401/403/404/disk-full) and -cancellation propagate without a fallback. - -``hf_hub_download_with_xet_fallback`` downloads a single file; the new -``snapshot_download_with_xet_fallback`` does a whole repo (the entrypoint -Unsloth's ``from_pretrained`` uses to warm the cache in a killable child before -the in-process load). Studio-specific cache/secret/process helpers are used -best-effort (imported only if present) or injected, so the same code runs both -inside Unsloth Studio and in Unsloth itself. - -Like the rest of ``unsloth_zoo``, this module is imported with ``unsloth`` -installed; the package ``__init__`` runs its device init on first import. The -download spawn child does not need that and sets ``UNSLOTH_ZOO_DISABLE_GPU_INIT=1`` -before it imports the package, which selects ``unsloth_zoo``'s lightweight import -path (no torch/transformers), keeping each child fast. +Xet (``hf_xet``) is the fast default but can hang with no progress, no exception, and a native thread +that cannot be killed. Keep Xet primary and fall back to plain HTTP only when the parent observes a +stall. ``HF_HUB_DISABLE_XET`` is read at import time, so the fallback runs in a fresh ``spawn`` child +(not a thread) that sets the env before importing ``huggingface_hub``. Cached files short-circuit with +no child; deterministic errors (401/403/404/disk-full) and cancellation propagate without a fallback. + +``hf_hub_download_with_xet_fallback`` does a single file; ``snapshot_download_with_xet_fallback`` does +a whole repo (the entrypoint Unsloth's ``from_pretrained`` uses to warm the cache in a killable child +before the in-process load). Studio cache / secret / process helpers are used best-effort (imported +only if present) or injected, so one body runs both inside Studio and in Unsloth. + +The spawn child sets ``UNSLOTH_ZOO_DISABLE_GPU_INIT=1`` before importing the package, selecting +``unsloth_zoo``'s lightweight import path (no torch / transformers) so each child stays fast. """ from __future__ import annotations @@ -104,10 +97,8 @@ def _is_true(value: Optional[str]) -> bool: def _safe_status(callback: Optional[Callable[[str], None]], message: str) -> None: - """Invoke a caller status/heartbeat callback without letting it kill the - daemon watchdog thread. A disconnected Studio client can make on_status raise; - if that propagated, stall detection for a genuinely hung child would stop and - the HTTP retry would never fire.""" + """Invoke a status / heartbeat callback without letting it kill the daemon watchdog thread: a + disconnected Studio client can make on_status raise, which would stop stall detection.""" if callback is None: return try: @@ -117,10 +108,8 @@ def _safe_status(callback: Optional[Callable[[str], None]], message: str) -> Non class DownloadStallError(RuntimeError): - """Raised when no download progress is observed for too long. - - Canonical home; Studio's orchestrator re-imports it so all paths share one type. - """ + """Raised when no download progress is observed for too long. Canonical home; Studio re-imports it + so all paths share one type.""" def is_hf_xet_available() -> bool: @@ -132,11 +121,8 @@ def is_hf_xet_available() -> bool: def xet_force_disabled() -> bool: - """Whether the user has asked us to skip Xet up front (force HTTP). - - Honors the Unsloth knobs ``UNSLOTH_DISABLE_XET`` / ``UNSLOTH_STABLE_DOWNLOADS`` - and Hugging Face's own ``HF_HUB_DISABLE_XET``. - """ + """Whether the user asked to skip Xet up front (force HTTP), via ``UNSLOTH_DISABLE_XET`` / + ``UNSLOTH_STABLE_DOWNLOADS`` or HF's own ``HF_HUB_DISABLE_XET``.""" return ( _is_true(os.environ.get("UNSLOTH_DISABLE_XET")) or _is_true(os.environ.get("UNSLOTH_STABLE_DOWNLOADS")) @@ -178,16 +164,11 @@ def _redact_signed_query(match: "re.Match") -> str: def _broken_link_has_active_partner(link: Path, *, active_grace: float) -> bool: - """True if a dangling snapshot symlink should be SPARED from the HTTP-prep purge because a - concurrent sibling download (a different process pulling the same repo into the same cache, common - in multi-rank training) is still writing the blob it points at. - - The reliable discriminator is a FRESH ``.incomplete`` partner of the link's target blob (mirroring - the active-grace guard the ``.incomplete`` blob purge already uses), NOT the link's own mtime: our - OWN killed child's link is freshly created too, but by this point its ``.incomplete`` has been - static for the full stall timeout and is purged first, so the target has no partner and the link is - correctly cleared -- while a sibling mid-download still has a growing ``.incomplete`` partner, so - its link is spared.""" + """True if a dangling snapshot symlink should be SPARED because a concurrent sibling download is + still writing the blob it points at. The discriminator is a FRESH ``.incomplete`` partner of the + target blob, NOT the link's own mtime: our own killed child's ``.incomplete`` was static for the + full stall timeout and is purged first (no partner -> link cleared), while a sibling mid-download + still has a growing partner (link spared).""" try: target = Path(os.readlink(link)) if not target.is_absolute(): @@ -217,21 +198,15 @@ def _default_prepare_for_http( active_grace: float = DEFAULT_STALL_TIMEOUT, owned_incomplete_blobs: Optional[set] = None, ) -> None: - """Generic 'make the partial safe for an HTTP resume': delete the repo's active - ``*.incomplete`` blobs (an HTTP resume over a sparse Xet/hf_transfer partial - silently corrupts the blob) and any broken snapshot symlinks the incomplete - detector counts as active (else the HTTP retry inherits stale 'incomplete' - state and trips the watchdog again). Studio injects its marker-aware version - instead. - - ``iter_active_repo_cache_dirs`` is case-collision safe, so this destructive - purge only touches an exact-case (or single unambiguous) repo cache dir. - - When *owned_incomplete_blobs* is given (the ``.incomplete`` basenames the stalled child actually - held open, captured before it was killed), the purge is SCOPED to those blobs: a concurrent - same-repo sibling download (common in multi-rank training) writing a DIFFERENT blob is never - touched, even if its partial has aged past *active_grace*. When it is None (ownership could not be - determined), the coarser ``active_grace`` mtime guard alone is used, as before. + """Make the partial safe for an HTTP resume: delete the repo's active ``*.incomplete`` blobs (an + HTTP resume over a sparse Xet / hf_transfer partial silently corrupts the blob) and the broken + snapshot symlinks the detector counts as active (else the retry inherits stale state and re-trips). + Studio injects its marker-aware version instead. ``iter_active_repo_cache_dirs`` is case-collision + safe, so this destructive purge only touches an unambiguous repo cache dir. + + When *owned_incomplete_blobs* is given (the basenames the stalled child held open, captured before + it was killed), the purge is SCOPED to them, so a concurrent same-repo sibling writing a DIFFERENT + blob is never touched even if its partial aged past *active_grace*. None -> coarser mtime guard only. """ try: for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): @@ -239,33 +214,21 @@ def _default_prepare_for_http( if blobs_dir.is_dir(): for blob in blobs_dir.iterdir(): if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): - # Scope to the stalled child's own partials when known: never delete a - # sibling's blob, even an aged one. + # Scope to our own partials when known: never delete a sibling's blob. if owned_incomplete_blobs is not None and blob.name not in owned_incomplete_blobs: continue try: - # Do not unlink a partial another concurrent download is - # still actively writing: on POSIX that lets the sibling keep - # writing to an unlinked path and then fail when the Hub moves - # its temp file into place. Spare any partial written within - # active_grace (the stall timeout in use): the watchdog only - # declares a download stalled after that long with no growth, - # so a slower sibling that simply has not written recently is - # not stalled and must be left alone. Our own killed partial - # has been static for the full stall timeout, so it is purged. + # Spare a partial written within active_grace: a slower sibling that just + # has not written recently is not stalled. Our own killed partial has been + # static for the full stall timeout, so it is purged. if time.time() - blob.stat().st_mtime < active_grace: continue blob.unlink() except OSError: - # A locked / permission-denied blob (common on Windows) - # must not abort cleanup of the rest of the partials. - continue - # repo_cache_dir_has_incomplete_blobs() also flags a broken snapshot - # symlink as active incomplete state; clear those too so the detector - # reads clean after prep. Sweep EVERY snapshot, not just the newest: - # the broken-symlink detector now inspects all of them, so a stale - # dangling link under an older revision would otherwise keep the repo - # marked incomplete after prep and re-trip the watchdog. + continue # a locked / permission-denied blob must not abort the rest + # A broken snapshot symlink also reads as active incomplete state; clear those too. Sweep + # EVERY snapshot (the detector inspects all), else a dangling link under an older revision + # keeps the repo marked incomplete and re-trips the watchdog. snapshots_dir = entry / "snapshots" try: snapshot_dirs = [s for s in snapshots_dir.iterdir() if s.is_dir()] @@ -275,15 +238,12 @@ def _default_prepare_for_http( try: for link in snapshot.rglob("*"): if link.is_symlink() and not link.exists(): - # Scope to our own partials when known: a link to a sibling's blob is left - # alone (it is the sibling's snapshot reference, not our stale state). + # Scope to our own partials when known; a link to a sibling's blob is theirs. if owned_incomplete_blobs is not None and ( _link_incomplete_partner_name(link) not in owned_incomplete_blobs ): continue - # Spare a concurrent sibling's active dangling link (its target blob still - # has a fresh .incomplete partner); only purge our own stale - # interrupted-download links so the HTTP retry reads clean. + # Spare a sibling's active link (target blob still has a fresh .incomplete). if _broken_link_has_active_partner(link, active_grace = active_grace): continue try: @@ -299,13 +259,9 @@ def _default_prepare_for_http( def _active_incomplete_blob_sizes( repo_type: Optional[str], repo_id: str, cache_dir: Optional[str] = None ) -> dict[str, int]: - """Map ``{blob_filename: bytes_present}`` for the repo's ``*.incomplete`` partials. - - Sparse-aware (st_blocks based). The single-file watchdog uses this to follow only the - partials its own child created, so a concurrent sibling download of a different file in - the same repo (its partial already present when this download began) cannot mask this - file's stall by contributing its own progress. - """ + """Map ``{blob_filename: bytes_present}`` (sparse-aware) for the repo's ``*.incomplete`` partials. + The single-file watchdog uses it to follow only its own child's partials, so a concurrent sibling + download of a different file cannot mask this file's stall with its own progress.""" sizes: dict[str, int] = {} try: for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): @@ -324,17 +280,11 @@ def _active_incomplete_blob_sizes( def _child_open_incomplete_blobs(pid: int) -> Optional[set]: - """Basenames of the ``*.incomplete`` blob files the download child *pid* currently has - open. - - This pinpoints exactly the partial THIS child is writing -- including a resumed prior - partial that reuses the same blob-hash filename (which Hugging Face does on a retry), so - a hung resume is still detected -- without confusing it for a concurrent sibling - download's partial (held open by a different pid). Returns ``None`` when it cannot be - determined (no ``psutil`` and no ``/proc``, or the process is gone), so the caller falls - back to a coarser measure; an empty set means the child is running but not yet writing a - partial (connect / metadata phase). - """ + """Basenames of the ``*.incomplete`` blobs the download child *pid* currently has open -- exactly + the partial THIS child is writing (incl. a resumed partial that reuses a prior blob-hash name), + not a sibling's (held by a different pid). ``None`` when undeterminable (no ``psutil`` / ``/proc``, + or the process is gone) -> caller uses a coarser measure; an empty set means the child is not yet + writing a partial (connect / metadata phase).""" # Cross-platform (Linux / macOS / Windows) when psutil is available. try: import psutil # type: ignore @@ -369,13 +319,9 @@ def get_hf_download_state( repo_type: Optional[str] = "model", cache_dir: Optional[str] = None, ) -> Optional[tuple[int, bool]]: - """Return ``(total_on_disk_bytes, has_incomplete)`` for the HF cache being written. - - Scans *cache_dir* when the download targets a caller-supplied cache, else the - active ``HF_HUB_CACHE``. Sparse-aware (st_blocks based) so a sparse Xet/ - ``hf_transfer`` ``.incomplete`` is not mistaken for full-size progress. ``None`` - means the state could not be measured, so callers skip stall logic for that tick. - """ + """Return ``(total_on_disk_bytes, has_incomplete)`` for the HF cache being written (sparse-aware, + so a partial Xet / ``hf_transfer`` blob is not read as full progress). Scans *cache_dir* or the + active ``HF_HUB_CACHE``. ``None`` -> unmeasurable, so callers skip stall logic this tick.""" try: if hf_cache_root(cache_dir = cache_dir) is None: return (0, False) @@ -383,8 +329,7 @@ def get_hf_download_state( total = 0 has_incomplete = False for repo_id in repo_ids or []: - # Skip local paths: HF IDs never start with / . ~, contain "\", or a - # drive-letter ":" (e.g. C:/models or C:\models on Windows). + # Skip local paths: HF IDs never start with / . ~, contain "\", or a drive-letter ":". if ( not repo_id or repo_id.startswith(("/", ".", "~")) @@ -424,22 +369,16 @@ def start_watchdog( baseline_incomplete_blobs: Optional[set] = None, child_pid: Optional[int] = None, ) -> threading.Event: - """Start a daemon thread that fires ``on_stall(message)`` exactly once iff a - ``*.incomplete`` is present AND the on-disk size is unchanged for - *stall_timeout* seconds. The timer resets while no ``*.incomplete`` exists, so - post-download init is never misread as a stall. Scans *cache_dir* when the - download targets a caller-supplied cache, else the active ``HF_HUB_CACHE``. - Returns a stop event the caller sets when the download phase ends. - - When *watch_new_partials_only* is set (single-file downloads), progress is measured only - over the child's own partial, so a concurrent sibling download of a different file in the - same repo cannot reset the stall timer with its progress (which would keep a hung child - alive forever). The child's partial is identified, in order of preference, by the - ``*.incomplete`` blobs the *child_pid* process actually has open (precise across a - resumed download that reuses a prior blob-hash filename), else by the partials that did - NOT already exist in *baseline_incomplete_blobs* (captured before the child started). - Snapshot downloads keep the repo-wide measurement (every blob is part of the one pull). - """ + """Start a daemon thread that fires ``on_stall(message)`` exactly once iff a ``*.incomplete`` is + present AND the on-disk size is unchanged for *stall_timeout* seconds. The timer resets while no + ``*.incomplete`` exists, so post-download init is not misread as a stall. Returns a stop event the + caller sets when the download phase ends. + + With *watch_new_partials_only* (single-file), progress is measured only over the child's own + partial, so a concurrent sibling pull of a different file cannot reset the timer and keep a hung + child alive. The child's partial is identified by the blobs *child_pid* has open (precise across a + resumed download), else by the partials not in *baseline_incomplete_blobs* (captured pre-spawn). + Snapshots keep the repo-wide measurement (every blob is part of the one pull).""" stop = threading.Event() transport = "https" if xet_disabled else "xet" fired = False @@ -451,17 +390,12 @@ def _measure() -> Optional[tuple[int, bool]]: sizes = _active_incomplete_blob_sizes(repo_type, single_repo_id, cache_dir) open_names = _child_open_incomplete_blobs(child_pid) if child_pid else None if open_names is not None: - # Precise: only the partials this child process holds open (handles a resumed - # partial that reuses a baseline blob-hash name, and excludes siblings). hf_xet - # writes in-process and holds the .incomplete fd continuously, so an EMPTY set - # here means the child owns no partial YET (the connect / metadata phase), NOT - # that a helper process owns one -- it must own nothing this tick, so a stalled - # sibling's post-baseline partial cannot be misattributed and kill a connecting - # child. + # Only the partials this child holds open (handles a resumed partial reusing a baseline + # name, excludes siblings). hf_xet holds the .incomplete fd continuously, so an EMPTY + # set means the child owns no partial YET (connect / metadata phase), not a sibling's. owned = {name: n for name, n in sizes.items() if name in open_names} else: - # Cannot inspect the child (no psutil / no /proc): best-effort fall back to - # following only newly-created partials (not in the pre-spawn baseline). + # No psutil / /proc: fall back to following only newly-created (post-baseline) partials. owned = {name: n for name, n in sizes.items() if name not in baseline} return (sum(owned.values()), len(owned) > 0) return get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) @@ -477,9 +411,8 @@ def _beat() -> None: now = time.monotonic() if state is None: - # Unmeasurable this tick (transient FS error): treat as progress - # so a long unmeasurable gap cannot trip a false stall the instant - # the state becomes readable again. + # Unmeasurable this tick (transient FS error): treat as progress so the gap cannot + # trip a false stall once the state becomes readable again. last_change = now _safe_status(on_heartbeat, f"Downloading ({transport} transport)...") continue @@ -529,22 +462,15 @@ def _scrub_in_child(text: str, token: Optional[str]) -> str: "GatedRepoError", "DisabledRepoError", "LocalEntryNotFoundError", - # A required token that is absent locally fails identically over either transport (it never - # reaches the network), so surface it deterministically with its real type. - "LocalTokenNotFoundError", + "LocalTokenNotFoundError", # a missing required token fails identically over either transport "BadRequestError", - # A malformed repo id / revision fails identically over either transport (it never reaches the - # network), so surface it with its real type instead of a generic RuntimeError or a pointless - # HTTP retry. - "HFValidationError", + "HFValidationError", # a malformed repo id / revision never reaches the network }) -# Names whose TYPE should be reconstructed across the spawn boundary but which must NOT join the -# retry-deterministic shortcut above. ``HfHubHTTPError`` is the base of both the deterministic 4xx -# (401 / 403 / 404 / 416) and the transient 5xx / 429 errors, so the retry decision for it must stay -# status-code driven (``_is_retryable_download_error`` falls through to the status check). But once an -# error has been classified deterministic and surfaced as ``"HfHubHTTPError: "``, the parent -# should still re-raise the original type so a caller's ``except HfHubHTTPError`` (auth / quota / -# permission handling) keeps working instead of seeing a generic ``RuntimeError``. +# Names whose TYPE is reconstructed across the spawn boundary but which must NOT join the +# retry-deterministic set above: ``HfHubHTTPError`` is the base of both deterministic 4xx and transient +# 5xx / 429, so its retry decision stays status-code driven. Once classified deterministic and surfaced +# as ``"HfHubHTTPError: "``, the parent still re-raises the real type so ``except HfHubHTTPError`` +# keeps working instead of seeing ``RuntimeError``. _TYPE_PRESERVE_ONLY_NAMES = frozenset({ "HfHubHTTPError", }) @@ -580,13 +506,10 @@ def _resolve_exception_class(type_name: str) -> "Optional[type]": def _instantiate_preserving_type(exc_cls: type, message: str) -> "Optional[BaseException]": - """Build an *exc_cls* instance carrying *message*, robust to a finicky constructor. Hub error - classes (``RepositoryNotFoundError`` ...) subclass ``HfHubHTTPError``, whose ``response`` arg is - keyword-only -- and required on some huggingface_hub versions -- so a plain ``exc_cls(message)`` - can raise ``TypeError``. Try the normal constructors first (best fidelity: they default - ``response`` / ``server_message``), then BYPASS ``__init__`` via ``__new__`` so the TYPE and the - message survive even when no constructor accepts a lone string. Returns None only if even - ``__new__`` fails, so the caller can fall back to ``RuntimeError``.""" + """Build an *exc_cls* instance carrying *message*, robust to a finicky constructor: Hub error + classes subclass ``HfHubHTTPError`` whose ``response`` arg is keyword-only (required on some + versions), so ``exc_cls(message)`` can raise ``TypeError``. Try the normal constructors first, then + BYPASS ``__init__`` via ``__new__`` so the TYPE and message survive. None only if ``__new__`` fails.""" for build in ( lambda: exc_cls(message), lambda: exc_cls(message, response = None), @@ -617,12 +540,10 @@ def _parse_errno(message: str) -> "Optional[int]": def _raise_child_error(message: str) -> None: - """Re-raise a deterministic child download error, preserving its original exception TYPE when it - is a known Hub / OS error, so callers that catch ``RepositoryNotFoundError`` / ``GatedRepoError`` - / ``OSError`` (auth prompts, offline handling, disk cleanup) still see those types across the - spawn-process boundary. The child reports the failure as ``": "``, so the - type is reconstructed from that prefix; anything unrecognized -- or a class that cannot be - instantiated at all -- falls back to ``RuntimeError`` (the prior behavior).""" + """Re-raise a deterministic child error preserving its original TYPE when it is a known Hub / OS + error, so callers catching ``RepositoryNotFoundError`` / ``GatedRepoError`` / ``OSError`` still + match across the spawn boundary. The child reports ``": "``; an unrecognized or + uninstantiable class falls back to ``RuntimeError``.""" type_name = message.split(":", 1)[0].strip() if ":" in message else "" exc_cls = _resolve_exception_class(type_name) if exc_cls is None: @@ -704,14 +625,10 @@ def _download_child_entry( disable_xet: bool, result_queue: Any, ) -> None: - """Spawn-child entrypoint: download and report the result. - - Top-level and picklable. Sets the Xet env BEFORE importing huggingface_hub, - forms its own process group so the parent can kill the whole transfer, and - never logs the token or signed URLs. - """ - # Die with the parent on Linux when running under Studio (best-effort; the - # module is absent standalone, in which case there is nothing to bind to). + """Spawn-child entrypoint (top-level + picklable): set the Xet env BEFORE importing + huggingface_hub, form its own process group so the parent can kill the whole transfer, and never + log the token or signed URLs.""" + # Die with the parent on Linux under Studio (best-effort; the module is absent standalone). try: from utils.process_lifetime import bind_current_process_to_parent_lifetime # type: ignore @@ -734,27 +651,22 @@ def _download_child_entry( repo_id = params["repo_id"] - # Test-only fault injection (never set in production): stall the Xet attempt - # so the watchdog + HTTP fallback can be exercised against a real repo. + # Test-only fault injection (never set in production): stall the Xet attempt so the watchdog + + # HTTP fallback can be exercised against a real repo. if not disable_xet and os.environ.get("UNSLOTH_HF_XET_FORCE_STALL") == "1": _stall_fh = None try: from huggingface_hub.constants import HF_HUB_CACHE - # Write the fake partial under the SAME cache the watchdog scans - # (params["cache_dir"] when the caller set one, else HF_HUB_CACHE) and - # under the repo_type-correct dir name, so has_active_incomplete_blobs - # sees it and the stall/HTTP fallback actually fires in tests. + # Write the fake partial under the cache the watchdog scans, under the repo_type-correct + # dir, so the stall / HTTP fallback fires in tests. cache_root = params.get("cache_dir") or HF_HUB_CACHE repo_dir_name = f"{repo_type or 'model'}s--" + repo_id.replace("/", "--") blobs = os.path.join(cache_root, repo_dir_name, "blobs") os.makedirs(blobs, exist_ok = True) - # Hold the fake partial OPEN for the whole stall. The snapshot watchdog finds it by - # filename (has_active_incomplete_blobs), but the single-file watchdog - # (watch_new_partials_only) counts ONLY partials this child PID holds open via - # _child_open_incomplete_blobs -- a closed file there is ignored and the stall never - # trips. Keeping the fd open lets BOTH modes see it. The handle is bound to a local so - # it stays open across the sleep below. + # Hold the partial OPEN for the whole stall: the snapshot watchdog finds it by filename, but + # the single-file watchdog counts only partials this PID holds open (a closed file is + # ignored). The handle is bound to a local so it stays open across the sleep. _stall_fh = open(os.path.join(blobs, "xet-force-stall.incomplete"), "wb") _stall_fh.write(b"\0" * 4096) _stall_fh.flush() @@ -767,9 +679,8 @@ def _download_child_entry( path = _child_download(kind = kind, params = params, token = token, repo_type = repo_type) result_queue.put({"ok": True, "path": path}) except BaseException as e: # noqa: BLE001 - report every failure to the parent - # Classify here, where the exception object (status code, errno, type) is intact, so the - # parent can retry a transient Xet transport failure over HTTP and still surface a - # deterministic Hub error without a pointless second attempt. + # Classify here, where the exception object (status, errno, type) is intact, so the parent can + # retry a transient failure over HTTP yet surface a deterministic error without a second attempt. result_queue.put({ "ok": False, "error": _scrub_in_child(f"{type(e).__name__}: {e}", token), @@ -778,12 +689,9 @@ def _download_child_entry( def _terminate_process_group(proc: "mp.process.BaseProcess", grace_period: float) -> None: - """Kill *proc* and its whole process group (Xet may spawn helper procs). - - The child calls ``os.setsid()`` so its pgid equals its pid; signal via - ``os.killpg(pid, ...)`` -- NOT ``getpgid``, which before the child becomes a - group leader resolves to OUR group. SIGTERM, then SIGKILL after *grace_period*. - """ + """Kill *proc* and its whole process group (Xet may spawn helpers). The child ``os.setsid()``s so + its pgid equals its pid; signal via ``os.killpg(pid, ...)`` -- NOT ``getpgid``, which before the + child is a group leader resolves to OUR group. SIGTERM, then SIGKILL after *grace_period*.""" pid = proc.pid def _signal_group(sig: int) -> None: @@ -801,11 +709,10 @@ def _signal_group(sig: int) -> None: _signal_group(getattr(signal, "SIGTERM", signal.SIGINT)) proc.join(timeout = grace_period) - # Post-grace SIGKILL only while the leader is still alive, so its pid (== pgid after setsid) is - # a live target. Once proc.join() reaps a leader that exited on SIGTERM, that pid is free and a - # busy host can recycle it into an unrelated setsid'd group within the grace window -- a - # killpg(pid) would then signal the WRONG group. hf_xet 1.5.x writes in-process and spawns no - # helper procs, so a reaped leader leaves nothing in the group to clean up. + # SIGKILL only while the leader is alive, so its pid (== pgid after setsid) is a live target. Once + # join() reaps a leader that exited on SIGTERM, that pid is free and a busy host could recycle it + # into an unrelated group -- killpg(pid) would then signal the WRONG group. hf_xet 1.5.x spawns no + # helpers, so a reaped leader leaves nothing to clean up. if proc.is_alive(): _signal_group(getattr(signal, "SIGKILL", signal.SIGTERM)) proc.join(timeout = 5.0) @@ -825,18 +732,12 @@ def _run_download_attempt( grace_period: float, on_status: Optional[Callable[[str], None]], ) -> tuple[str, Optional[str]]: - """Run one download in a spawn child supervised by the no-progress watchdog. - - Returns ``("ok", path)``, ``("stall", None)``, ``("cancelled", None)``, - ``("crashed", message)`` (process-level crash, no captured exception), - ``("retryable_error", message)`` (a transient Xet transport failure worth an HTTP retry), - or ``("error", message)`` (a deterministic Hub error). This is the seam tests monkeypatch - to avoid spawning. - """ - # A single-file download scopes its stall detection to its own child's partials. - # Capture the partials already on disk for this repo BEFORE spawning, so the watchdog - # can ignore a concurrent sibling's in-flight partial (a different file in the same - # repo) and only follow the blob(s) this child newly writes. Snapshots stay repo-wide. + """Run one download in a spawn child supervised by the no-progress watchdog. Returns ``("ok", + path)``, ``("stall", None)``, ``("cancelled", None)``, ``("crashed", message)`` (process crash, no + captured exception), ``("retryable_error", message)`` (transient, worth an HTTP retry), or + ``("error", message)`` (deterministic Hub error). The seam tests monkeypatch to avoid spawning.""" + # Single-file: capture the partials on disk BEFORE spawning so the watchdog ignores a sibling's + # in-flight partial and follows only the blob(s) this child writes. Snapshots stay repo-wide. baseline_partials: Optional[set] = None if kind == "file": baseline_partials = set( @@ -855,66 +756,50 @@ def _run_download_attempt( ), daemon = True, ) - # Set the transport env in THIS process around the spawn so the child inherits - # it from creation. HF reads HF_HUB_DISABLE_XET into constants at import time, - # and a spawn child re-imports the (heavy) unsloth_zoo package -- importing - # huggingface_hub -- before the child body runs, so a child-side os.environ - # assignment would land too late. The child still sets it too, defensively. + # Set the transport env in THIS process around the spawn so the child inherits it from creation: + # HF reads HF_HUB_DISABLE_XET into a constant at import time, and the child re-imports + # huggingface_hub before its body runs, so a child-side assignment would land too late. The child + # still sets it defensively. child_env = { "HF_HUB_DISABLE_PROGRESS_BARS": "1", - # The download child is a fresh spawn interpreter that only needs - # huggingface_hub; tell unsloth_zoo's __init__ to skip its heavy torch / - # transformers / device init in that process (the parent keeps full init). + # Tell unsloth_zoo's __init__ to skip its heavy torch / transformers / device init in the child. "UNSLOTH_ZOO_DISABLE_GPU_INIT": "1", } if disable_xet: child_env["HF_HUB_DISABLE_XET"] = "1" child_env["HF_HUB_ENABLE_HF_TRANSFER"] = "0" with _SPAWN_ENV_LOCK: - # Cache huggingface_hub's transport constants in the PARENT from the REAL environment NOW, - # before the child-only env (HF_HUB_DISABLE_XET=1) is briefly set below. Hub reads - # HF_HUB_DISABLE_XET into a module constant at import time; without this, a concurrent thread - # doing its FIRST `import huggingface_hub` inside the spawn window could cache the child-only - # disabled-Xet value in the parent and silently route later in-process downloads over HTTP. - # Once imported it is a no-op, so a concurrent import in the window then re-reads nothing. + # Cache Hub's transport constants in the PARENT from the REAL env NOW, before the child-only + # HF_HUB_DISABLE_XET=1 is briefly set below: a concurrent thread's FIRST `import huggingface_hub` + # in the spawn window would otherwise cache the disabled-Xet value and route later in-process + # downloads over HTTP. Once imported this is a no-op. try: import huggingface_hub.constants # noqa: F401 except Exception: pass saved_env = {k: os.environ.get(k) for k in child_env} - # multiprocessing 'spawn' reconstructs __main__ in the child from - # __main__.__file__. If that is a pseudo-path ('', a notebook) the - # child fails to start; if it is a real but UNGUARDED caller script the - # child re-imports it as __mp_main__ and re-runs the top-level - # from_pretrained/download, hitting the "start a new process before - # bootstrapping" error -> the parent then sees the child exit without a - # result. In every case we only need the child to unpickle and run - # _download_child_entry, so point __main__ at THIS importable, side-effect - # -free module for the spawn (and restore it after). The child imports us - # as __mp_main__ instead of re-executing the caller's script. + # 'spawn' reconstructs __main__ from __main__.__file__. A pseudo-path ('', a notebook) + # fails to start; a real but UNGUARDED caller script gets re-imported as __mp_main__, re-running + # the top-level from_pretrained and hitting the "start a process before bootstrapping" error -> + # the parent sees the child exit without a result. We only need the child to run + # _download_child_entry, so point __main__ at THIS side-effect-free module for the spawn. main_module = sys.modules.get("__main__") saved_main_file = _UNSET saved_main_spec = _UNSET if main_module is not None: saved_main_file = getattr(main_module, "__file__", _UNSET) main_module.__file__ = __file__ - # When the caller was launched as a module (python -m pkg), spawn's - # preparation prefers __main__.__spec__.name over __file__ and re-imports - # the user's module BY NAME -> re-runs its top-level from_pretrained in - # the child and hits the bootstrapping error. Clearing __spec__ forces - # the path branch, which uses the __file__ we just repointed at this - # side-effect-free helper module. + # Launched as `python -m pkg`: spawn prefers __spec__.name and re-imports the module BY + # NAME (re-running its top-level code). Clearing __spec__ forces the __file__ path branch. saved_main_spec = getattr(main_module, "__spec__", _UNSET) main_module.__spec__ = None try: os.environ.update(child_env) proc.start() except BaseException: - # proc.start() can raise (e.g. OSError "can't start new process" under fd / - # thread exhaustion). The result_queue's OS pipe fds were allocated above, but - # the lifecycle try/finally that closes them is only entered AFTER a successful - # start, so on a failed spawn that cleanup never runs and the fds leak. Close - # the queue here so a failed spawn is deterministic rather than fd-leaking. + # proc.start() can raise (OSError "can't start new process" under fd / thread exhaustion). + # The result_queue's pipe fds were allocated above but the lifecycle try/finally that + # closes them runs only after a successful start, so close the queue here to avoid an fd leak. try: result_queue.cancel_join_thread() result_queue.close() @@ -953,9 +838,8 @@ def _run_download_attempt( pass stalled = threading.Event() - # start_watchdog creates and starts a thread; if that raises (e.g. "can't start new thread" - # under thread/FD exhaustion), the child already started above must STILL be terminated. So it - # runs inside the try whose finally reaps the child; stop_watchdog stays None until it succeeds. + # If start_watchdog raises ("can't start new thread"), the already-started child must STILL be + # reaped, so it runs inside the try whose finally reaps it; stop_watchdog stays None until it works. stop_watchdog = None result: Optional[dict] = None try: @@ -977,21 +861,17 @@ def _run_download_attempt( _terminate_process_group(proc, grace_period) return ("cancelled", None) if stalled.is_set(): - # Prefer a result the child enqueued in the same ~interval window the watchdog - # fired in over a late stall, so a download that just succeeded is not killed and - # needlessly retried over HTTP. A spawn Queue has a child-side feeder thread, so a - # result put microseconds earlier is not yet readable by get_nowait(); use a short - # timeout (matching the process-exit drain below) to let the pipe flush. + # Prefer a result the child enqueued in the same window the watchdog fired in, so a + # download that just succeeded is not killed. The Queue's feeder thread may not have + # flushed a microseconds-earlier put, so use a short timeout, not get_nowait(). try: result = result_queue.get(timeout = 1.0) break except queue.Empty: pass - # Capture the partials THIS child owns BEFORE killing it, so the HTTP-prep purge can - # scope its blob/symlink cleanup to them and never delete a concurrent sibling's - # partial. Prefer the precise per-pid open-fd set; fall back to the partials that - # appeared since this child spawned (kind=="file" tracks a baseline) when the child - # cannot be inspected. None -> prep keeps its coarser mtime-only guard. + # Capture the partials THIS child owns BEFORE killing it, so HTTP prep can scope its + # purge to them. Prefer the per-pid open-fd set; fall back to post-baseline partials + # when the child can't be inspected. None -> prep keeps its coarser mtime guard. owned = _child_open_incomplete_blobs(proc.pid) if proc.pid else None if owned is None and baseline_partials is not None: current = set( @@ -1007,10 +887,8 @@ def _run_download_attempt( except queue.Empty: continue else: - # Process exited; drain any result it enqueued. Use a short timeout, - # not get_nowait(): the child can exit microseconds before its queue - # feeder flushes the pipe, and a bare get_nowait() would then spuriously - # report "exited without a result" on an otherwise successful download. + # Process exited; drain any result it enqueued. Short timeout, not get_nowait(): the child + # can exit just before its feeder flushes the pipe, which would spuriously look resultless. try: result = result_queue.get(timeout = 1.0) except queue.Empty: @@ -1019,16 +897,12 @@ def _run_download_attempt( if stop_watchdog is not None: stop_watchdog.set() proc.join(timeout = grace_period) - # Any exit from the loop -- normal completion, cancel/stall, or an - # unexpected exception (e.g. KeyboardInterrupt) -- must not leak the child. - # If it is still alive after the grace join, kill its whole process group. - # _terminate_process_group is idempotent, so a redundant call after the - # cancel/stall branch already terminated it is a harmless no-op. + # Any loop exit (completion, cancel/stall, KeyboardInterrupt) must not leak the child. + # _terminate_process_group is idempotent, so a redundant call here is a harmless no-op. if proc.is_alive(): _terminate_process_group(proc, grace_period) - # Release the queue's pipe fds deterministically rather than waiting for GC (which is - # fragile when the child was killed mid-put). The result, if any, is already extracted, - # and a killed child has nothing more to flush, so cancel the feeder join before close. + # Release the queue's pipe fds deterministically rather than waiting for GC. The result is + # already extracted and a killed child has nothing to flush, so cancel the feeder before close. try: result_queue.cancel_join_thread() result_queue.close() @@ -1036,10 +910,8 @@ def _run_download_attempt( pass if result is None: - # The child exited without enqueuing a result: a process-level crash (e.g. a native - # hf_xet abort / segfault), NOT a captured Hub exception. No deterministic error was - # observed, so the other transport may still succeed -- report it as "crashed" so the - # caller can retry over HTTP rather than surfacing a hard error. + # The child exited without a result: a process-level crash (a native hf_xet abort / segfault), + # not a captured exception, so the other transport may still succeed -- report "crashed". return ( "crashed", f"download process for '{repo_id}' exited " @@ -1057,10 +929,9 @@ def _run_download_attempt( def _intact_subset( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, ) -> bool: - """No interrupted-download evidence for the files the request SELECTS: no dangling requested - symlink, and every EXACT-named requested file present. Used for a weightless / non-model request - (a dataset, a tokenizer-only allow list) and as the breakage check for a finished download. A - dangling EXCLUDED weight from an earlier interrupted pull does not reject a complete subset.""" + """No interrupted-download evidence for the SELECTED files: no dangling requested symlink, and + every EXACT-named requested file present. A dangling EXCLUDED weight does not reject a complete + subset.""" return ( not snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, @@ -1073,9 +944,8 @@ def _intact_subset( def _has_any_weight(snapshot_dir: Path) -> bool: - """True if the snapshot holds at least one loadable model weight anywhere (root or a component - subfolder). Lenient on purpose: it only distinguishes a real model warm from the config-only - stale snapshot HF can hand back on an offline / timed-out request, without classifying layout.""" + """True if the snapshot holds at least one loadable weight anywhere (root or subfolder). Lenient: + it only tells a real model warm from a config-only stale snapshot, without classifying layout.""" try: for entry in snapshot_dir.rglob("*"): if _is_loadable_weight_file(entry.name): @@ -1090,9 +960,8 @@ def _has_any_weight(snapshot_dir: Path) -> bool: def _root_has_loadable_weight(snapshot_dir: Path) -> bool: - """True if a loadable weight sits at the snapshot ROOT (where a default ``from_pretrained`` reads - it). Unlike ``_has_any_weight`` this ignores subfolders, so a stale training-checkpoint-only - snapshot (weights only under ``checkpoint-7/``) is not mistaken for a usable root model.""" + """True if a loadable weight sits at the snapshot ROOT (where a default load reads it). Ignores + subfolders, so a stale ``checkpoint-7/``-only snapshot is not mistaken for a usable root model.""" try: for entry in snapshot_dir.iterdir(): if _is_loadable_weight_file(entry.name): @@ -1107,13 +976,10 @@ def _root_has_loadable_weight(snapshot_dir: Path) -> bool: def _root_model_has_weight(snapshot_dir: Path) -> bool: - """Whether an UNPATTERNED model warm holds a weight a default load will actually read: a ROOT - weight, or -- for a diffusers pipeline (root ``model_index.json``) -- a component-subfolder weight. - - A bare ``from_pretrained`` reads root weights and ignores arbitrary subfolders (``checkpoint-*/`` ...), - so counting any subtree weight (as ``_has_any_weight`` does) would accept a stale checkpoint-only - snapshot and then fetch the missing root weights over un-killable Xet. Diffusers is the one layout - whose weights legitimately live in subfolders, and its ``model_index.json`` marker gates that.""" + """Whether an UNPATTERNED model warm holds a weight a default load reads: a ROOT weight, or -- for a + diffusers pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting any subtree + weight (as ``_has_any_weight`` does) would accept a stale checkpoint-only snapshot and then fetch + the root weights over un-killable Xet; diffusers is the one layout whose weights live in subfolders.""" try: is_diffusers = (snapshot_dir / "model_index.json").is_file() except OSError: @@ -1123,10 +989,8 @@ def _root_model_has_weight(snapshot_dir: Path) -> bool: return _root_has_loadable_weight(snapshot_dir) -# Exact weight filenames that are interchangeable: a request naming several of the same logical -# weight (the classic ``["pytorch_model.bin", "model.safetensors"]`` either-format pair) is satisfied -# by ANY one of them, while distinct logical weights (a base ``model.safetensors`` AND an -# ``adapter_model.safetensors``) must each be present. +# Interchangeable exact weight names: the either-format ``["pytorch_model.bin", "model.safetensors"]`` +# pair is satisfied by ANY one, while distinct logical weights (base AND adapter) must each be present. _EQUIVALENT_EXACT_WEIGHT_NAMES = { "model.safetensors": "root_model", "pytorch_model.bin": "root_model", @@ -1138,11 +1002,9 @@ def _root_model_has_weight(snapshot_dir: Path) -> bool: def _requested_exact_files_present_grouped( snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, ) -> bool: - """True unless an EXACT-named requested file is missing. A request that names several - interchangeable weights (``["pytorch_model.bin", "model.safetensors"]``) is satisfied by any one - of them; distinct logical files (a base weight AND an adapter, or a tokenizer file) must each be - present. A request with ANY glob, or no allow list, is a best-effort warm and is trivially - satisfied here -- the weight-presence checks below cover those.""" + """True unless an EXACT-named requested file is missing. Interchangeable weights + (``["pytorch_model.bin", "model.safetensors"]``) need any one; distinct logical files (base AND + adapter, a tokenizer file) each. A glob / unpatterned request is trivially satisfied here.""" allow = _as_pattern_list(allow_patterns) ignore = _as_pattern_list(ignore_patterns) if not allow or any(not isinstance(p, str) or _has_glob(p) for p in allow): @@ -1171,10 +1033,9 @@ def _requested_exact_files_present_grouped( def _has_selected_weight( snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, ) -> bool: - """True if at least one loadable weight the request actually SELECTS is present. Unlike - ``_has_any_weight`` this applies the allow / ignore filter, so a patterned request - (``["*.safetensors"]``, ``["unet/*"]``) is not satisfied by an out-of-scope weight (a stale - ``.bin`` left behind, a checkpoint subfolder the request did not ask for).""" + """True if a loadable weight the request SELECTS is present. Applies the allow / ignore filter (vs + ``_has_any_weight``), so a patterned request is not satisfied by an out-of-scope weight (a stale + ``.bin``, an unrequested checkpoint subfolder).""" weights: list = [] try: for entry in snapshot_dir.rglob("*"): @@ -1191,9 +1052,8 @@ def _has_selected_weight( def _patterns_are_exact_names(patterns: Any) -> bool: - """True only for a non-empty allow list of EXACT filenames (no ``None``, no glob, no trailing-slash - directory pattern). Only such a request can be proven complete from local files alone; ``None`` or a - glob needs the Hub manifest, so it must defer to the watched child.""" + """True only for a non-empty allow list of EXACT filenames (no ``None`` / glob / trailing-slash + dir). Only such a request is locally provable complete; ``None`` / a glob needs the Hub manifest.""" patterns = _as_pattern_list(patterns) if patterns is None: return False @@ -1205,25 +1065,21 @@ def _patterns_are_exact_names(patterns: Any) -> bool: def _cache_can_skip_download( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, ) -> bool: - """PRE-download: whether a locally cached snapshot is complete enough that the in-process load - will not fetch anything, so the protective child can be skipped. - - STRICT for a weight-bearing model request: only the conservative canonical fast-path - (``snapshot_dir_is_complete``) may skip the child; anything uncertain (diffusers, variants, - non-trivial patterns, sharded-without-index) returns False -> spawn the child. A false True here - would let the in-process load fetch a missing weight over un-killable Xet (the hang). A weightless - model request (a tokenizer / config / metadata-dir allow list) or a non-model (dataset / space) - request has no weight to hang on, but its completeness is only locally provable when it names - EXACT files: an unpatterned or glob request cannot be proven complete without the Hub manifest, so - it defers to the watched child rather than hand back a partial cache. An exact-named subset that is - intact still short-circuits (preserving the offline tokenizer-only / named-file warm).""" + """PRE-download: whether a cached snapshot is complete enough to skip the protective child. + + STRICT for a weight-bearing model request: only the conservative canonical gate + (``snapshot_dir_is_complete``) skips; anything uncertain (diffusers, variants, patterns, + sharded-without-index) spawns the child. A false True would let the load fetch a missing weight over + un-killable Xet (the hang). A weightless model or non-model (dataset) request has no weight to hang + on, but is locally provable complete only when it names EXACT files -- an unpatterned / glob request + defers to the child rather than hand back a partial cache. An intact exact-named subset still + short-circuits (offline tokenizer-only / named-file warm).""" if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): return snapshot_dir_is_complete( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, ) - # Weightless model / non-model request: skip only when it names exact files whose subset is intact. - # A None / glob request (e.g. a whole-dataset ``allow_patterns=None``) cannot be proven complete - # from local files alone, so defer to the child for the authoritative manifest compare + resume. + # Weightless / non-model: skip only for an intact exact-named subset. A None / glob request cannot + # be proven complete from local files, so defer to the child for the manifest compare + resume. if not _patterns_are_exact_names(allow_patterns): return False return _intact_subset( @@ -1235,26 +1091,18 @@ def _cache_can_skip_download( def _download_result_usable( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, ) -> bool: - """POST-download: whether the child's ``snapshot_download`` result is usable, or should be retried - over HTTP. snapshot_download already did the authoritative manifest compare + resume, so accept - unless there is POSITIVE evidence of a silent-Xet partial: a dangling REQUESTED symlink (a blob - that is missing or still ``.incomplete``), or a weight-bearing model warm that came back with NO - weight at all (HF handed back a stale config-only snapshot on an offline / timed-out request). - LENIENT otherwise -- a finished diffusers / variant / either-format download passes, and an - OPTIONAL file simply absent from the repo is not treated as missing -- so a good download is never - failed and re-looped into a ``DownloadStallError``. - - Positive-breakage checks: - - Any dangling REQUESTED symlink (a missing / still-``.incomplete`` blob). - - Every EXACT-named requested file present (grouped by weight equivalence, so the either-format - ``["pytorch_model.bin", "model.safetensors"]`` pair needs only one, but a base weight AND an - ``adapter_model.safetensors``, or a ``["tokenizer.json"]`` config request, must each be present). - A glob allow list cannot be turned into an exact manifest, so it stays lenient there. - - A weight-bearing MODEL request that came back with no usable weight. For an UNPATTERNED warm the - weight must be ROOT-readable (or a diffusers component) -- a stale ``checkpoint-7/``-only snapshot - does not count, since a default load ignores it -- and an interrupted CANONICAL sharded warm - (loose ``model-00001-of-00002.safetensors`` with no index) is rejected. A patterned weight request - must have a weight WITHIN its requested scope (not a stale out-of-scope ``.bin`` / checkpoint).""" + """POST-download: whether the child's result is usable, or should be retried over HTTP. + snapshot_download already did the authoritative manifest compare, so accept unless there is + POSITIVE breakage evidence; LENIENT otherwise (a finished diffusers / variant / either-format + download passes, an optional missing file is not treated as broken) so a good download is never + looped into a ``DownloadStallError``. Breakage checks: + + - A dangling REQUESTED symlink (a missing / still-``.incomplete`` blob). + - A missing EXACT-named requested file (grouped by weight equivalence: the either-format pair needs + one; base AND adapter, or a ``["tokenizer.json"]`` request, each). Globs stay lenient. + - A weight-bearing MODEL request with no usable weight. UNPATTERNED -> the weight must be + ROOT-readable (or a diffusers component; a stale ``checkpoint-7/``-only snapshot does not count) + and a loose canonical-sharded warm (no index) is rejected. Patterned -> a weight WITHIN scope.""" if snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, repo_type = repo_type, @@ -1282,17 +1130,13 @@ def _download_result_usable( def _snapshot_payload_incomplete( payload: Any, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any ) -> bool: - """True when a snapshot download returned a real directory that is not usable for the request - (see ``_download_result_usable``). Guarded to an existing directory so a mocked / non-path - payload (unit tests) or an unexpected return is trusted rather than rejected; in production the - child always returns a real snapshot dir, where this catches HF handing back an existing partial - snapshot on an offline / timed-out request.""" + """True when a snapshot download returned a real directory not usable for the request (see + ``_download_result_usable``). Guarded to an existing dir, so a mocked / non-path payload (tests) is + trusted rather than rejected; in production the child always returns a real snapshot dir.""" try: path = Path(payload) except (TypeError, ValueError, OSError): - # Non-path payload (unit-test sentinel) or, on Windows, a path with invalid characters - # (ValueError / OSError): trust it rather than reject -- production always returns a real dir. - return False + return False # non-path payload (test sentinel) or invalid path -> trust it try: if not path.is_dir(): return False @@ -1334,12 +1178,10 @@ def _download_with_xet_fallback( for attempt in range(2): if disable_xet: - # Purge a non-HTTP partial before resuming over HTTP: an HTTP resume - # over a sparse Xet/hf_transfer partial silently corrupts the blob. - # The generic purge is cache_dir-aware; an injected (Studio) hook owns - # its own cache accounting and keeps the (repo_type, repo_id) signature. - # The previous attempt's stall recorded the partials its child owned (if it could). - # Scope the cleanup to them so a concurrent same-repo sibling's partial is never purged. + # Purge a non-HTTP partial first: an HTTP resume over a sparse Xet/hf_transfer partial + # silently corrupts the blob. Scope the purge to the partials the stalled child owned, so + # a concurrent same-repo sibling's partial is spared. An injected (Studio) hook owns its + # own cache accounting, so it keeps the plain (repo_type, repo_id) signature. owned_incomplete = params.pop("_owned_incomplete_blobs", None) try: if prepare_for_http_fn is None: @@ -1351,10 +1193,8 @@ def _download_with_xet_fallback( prepare_for_http_fn(repo_type, repo_id) except Exception as e: logger.debug("prepare_for_http failed for %s: %s", repo_id, e) - # If an unsafe partial could not be cleared (e.g. a locked file or a - # permission error), an HTTP resume over a sparse Xet/hf_transfer - # partial would silently corrupt the blob. Force a clean re-download - # for this HTTP attempt instead of resuming over it. + # An unsafe partial that could not be cleared (locked file, permission error) would + # corrupt the blob on an HTTP resume: force a clean re-download instead. if has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir): logger.warning( "Unsafe partial for '%s' could not be cleared; forcing a clean " @@ -1383,12 +1223,10 @@ def _download_with_xet_fallback( allow_patterns = params.get("allow_patterns"), ignore_patterns = params.get("ignore_patterns"), ): - # HF can return an existing, incomplete snapshot dir on an offline or - # timed-out request instead of fetching the missing files. Never hand an - # incomplete snapshot to the in-process load: retry over HTTP, and if it - # still comes back incomplete, fail loudly rather than silently loading a - # broken cache. (A patterned / non-model request is judged by its own - # requested subset, so this never rejects a valid weightless snapshot.) + # HF can hand back an existing incomplete snapshot dir (offline / timed-out request) + # instead of fetching the missing files. Never load that in-process: retry over HTTP, + # then fail loudly rather than load a broken cache. (Patterned / non-model requests are + # judged by their own subset, so a valid weightless snapshot is not rejected.) if not disable_xet: logger.warning( "Download for '%s' returned an incomplete snapshot -- " @@ -1405,17 +1243,13 @@ def _download_with_xet_fallback( if kind_result == "cancelled": raise RuntimeError("Cancelled") if kind_result == "error": - # Deterministic failure (a captured Hub exception: auth, not-found, gated, disk - # full): the other transport would fail identically, so do not retry. Re-raise - # preserving the original exception type (RepositoryNotFoundError / GatedRepoError / - # OSError ...) where known, so callers' typed except clauses still match across the - # spawn boundary; unknown errors fall back to RuntimeError. + # Deterministic failure (auth / not-found / gated / disk-full): the other transport fails + # identically, so do not retry. _raise_child_error preserves the original exception type + # across the spawn boundary so callers' typed except clauses still match. _raise_child_error(payload) if kind_result == "retryable_error": - # A transient transport failure (hf_xet CAS timeout, 5xx, connection reset) rather - # than a deterministic Hub error: disabling Xet and retrying over HTTP may recover, - # so try the other transport once before surfacing it (mirrors the crash / stall - # paths). If HTTP also failed, there is no other transport left -- raise. + # Transient transport failure (hf_xet CAS timeout, 5xx, reset): HTTP may recover, so retry + # once before surfacing it; if HTTP also failed there is no transport left -> raise. if not disable_xet: logger.warning( "Download for '%s' hit a transient Xet transport error -- retrying " @@ -1426,8 +1260,7 @@ def _download_with_xet_fallback( continue raise RuntimeError(payload) if kind_result == "crashed": - # A process-level crash with no captured exception: HTTP may still succeed, so - # retry over it once before surfacing a hard error (mirrors the stall path). + # Process-level crash with no captured exception: HTTP may still succeed, so retry once. if not disable_xet: logger.warning( "Download process for '%s' crashed without a result -- " @@ -1442,9 +1275,8 @@ def _download_with_xet_fallback( logger.warning( "Download stalled for '%s' -- retrying with HF_HUB_DISABLE_XET=1", label ) - # _safe_status: a raising status hook (e.g. a disconnected client) must - # not abort the retry before disable_xet is set, turning a recoverable - # stall into a failed download. + # _safe_status: a raising status hook (disconnected client) must not abort the retry + # before disable_xet is set, turning a recoverable stall into a failed download. _safe_status(on_status, f"{label}: Xet stalled, retrying over HTTP") disable_xet = True continue @@ -1477,29 +1309,22 @@ def hf_hub_download_with_xet_fallback( ) -> str: """Download a single file with Xet primary and HTTP as a stall-only fallback. - Returns the local cache path. Raises ``RuntimeError("Cancelled")`` if - *cancel_event* is set, re-raises a deterministic child error unchanged (no - fallback), and raises ``DownloadStallError`` only if BOTH transports stall. - ``force_download=True`` re-fetches even if cached (skips the cache short-circuit). - ``local_files_only=True`` resolves from cache in-process and never spawns a - network child (matching Hugging Face offline semantics). ``subfolder`` is - forwarded to ``hf_hub_download`` for files stored under a repo subdirectory. + Returns the local cache path. Raises ``RuntimeError("Cancelled")`` if *cancel_event* is set, + re-raises a deterministic child error unchanged (no fallback), and raises ``DownloadStallError`` + only if BOTH transports stall. ``force_download=True`` re-fetches even if cached; + ``local_files_only=True`` resolves from cache in-process with no child (HF offline semantics); + ``subfolder`` is forwarded to ``hf_hub_download``. """ repo_type = repo_type or "model" # HF treats None as the default model repo. - # Expand ~ as huggingface_hub does before writing, so the cache probe below and - # the child both resolve to the same on-disk location (else a warm ~/hf-cache - # is missed and we spawn a child for an already-cached file). Path-like cache - # dirs are normalized too, since HF accepts pathlib.Path. + # Expand ~ (and normalize Path) as huggingface_hub does, so the probe and the child resolve to + # the same on-disk location (else a warm cache is missed and we spawn a child for a cached file). if isinstance(cache_dir, (str, os.PathLike)): cache_dir = os.path.expanduser(os.fspath(cache_dir)) - # Honor an already-set cancellation before any cache probe or network work. The offline and - # warm-cache short-circuits below return without reaching _download_with_xet_fallback (which - # holds the only other cancel check), so a request cancelled before this point must not - # resolve and hand back a cached file. + # Honor an already-set cancellation before any probe: the short-circuits below return without + # reaching _download_with_xet_fallback (which holds the only other cancel check). if cancel_event is not None and cancel_event.is_set(): raise RuntimeError("Cancelled") - # Offline: resolve purely from the local cache, never reaching the network. HF - # raises LocalEntryNotFoundError if it is not cached; let that propagate. + # Offline: resolve purely from cache. HF raises LocalEntryNotFoundError if uncached; let it propagate. if local_files_only: from huggingface_hub import hf_hub_download @@ -1513,9 +1338,8 @@ def hf_hub_download_with_xet_fallback( cache_dir = cache_dir, local_files_only = True, ) - # Finalized blob already cached: return it with no child and no network - # (skipped when force_download re-fetches unconditionally). The cache stores a - # subfolder file under "/", which is what the probe wants. + # Finalized blob already cached: return it with no child and no network (skipped under + # force_download). The cache stores a subfolder file under "/". if not force_download: try: from huggingface_hub import try_to_load_from_cache @@ -1570,30 +1394,24 @@ def snapshot_download_with_xet_fallback( on_status: Optional[Callable[[str], None]] = None, prepare_for_http_fn: Optional[Callable[[str, str], None]] = None, ) -> str: - """Download a whole repo snapshot with Xet primary and HTTP as a stall-only - fallback, returning the local snapshot dir. - - Used by Unsloth's ``from_pretrained`` to warm the cache in a killable child - BEFORE the in-process model load (which then hits a warm cache and cannot - hang on a native Xet thread). A fully cached repo short-circuits in-process - via ``local_files_only`` with no child and no network. ``force_download=True`` - re-fetches in the killable child even if cached (skips that short-circuit). - ``local_files_only=True`` resolves from cache in-process and never spawns a - network child (matching Hugging Face offline semantics). + """Download a whole repo snapshot with Xet primary and HTTP as a stall-only fallback, returning + the local snapshot dir. + + Used by Unsloth's ``from_pretrained`` to warm the cache in a killable child BEFORE the in-process + model load (which then hits a warm cache and cannot hang on a native Xet thread). A fully cached + repo short-circuits in-process via ``local_files_only`` with no child. ``force_download=True`` + re-fetches in the killable child even if cached; ``local_files_only=True`` resolves from cache + in-process with no child (HF offline semantics). """ repo_type = repo_type or "model" # HF treats None as the default model repo. - # Expand ~ as huggingface_hub does before writing, so the probe and the child - # resolve to the same on-disk cache location. + # Expand ~ as huggingface_hub does, so the probe and the child resolve to the same cache location. if isinstance(cache_dir, (str, os.PathLike)): cache_dir = os.path.expanduser(os.fspath(cache_dir)) - # Honor an already-set cancellation before any cache probe or network work. The offline and - # warm-cache short-circuits below return without reaching _download_with_xet_fallback (which - # holds the only other cancel check), so a request cancelled before this point must not - # resolve and hand back a snapshot. + # Honor an already-set cancellation before any probe: the short-circuits below return without + # reaching _download_with_xet_fallback (which holds the only other cancel check). if cancel_event is not None and cancel_event.is_set(): raise RuntimeError("Cancelled") - # Offline: resolve purely from the local cache, never reaching the network. HF - # raises if the snapshot is not cached; let that propagate. + # Offline: resolve purely from cache. HF raises if uncached; let it propagate. if local_files_only: from huggingface_hub import snapshot_download @@ -1606,8 +1424,8 @@ def snapshot_download_with_xet_fallback( ignore_patterns = ignore_patterns, local_files_only = True, ) - # Fast path: everything already on disk -> resolve in-process (no Xet, no - # hang). Skipped when force_download re-fetches unconditionally. + # Fast path: everything already on disk -> resolve in-process (no Xet, no hang). Skipped under + # force_download. if not force_download: try: from huggingface_hub import snapshot_download @@ -1621,17 +1439,12 @@ def snapshot_download_with_xet_fallback( ignore_patterns = ignore_patterns, local_files_only = True, ) - # local_files_only returns a snapshot dir whenever refs/ and - # snapshots/ exist, even one left by a prior interrupted or patterned - # download (a config-only snapshot from an AutoConfig fetch, or a partial - # shard pull). Validate the EXACT returned revision dir against the request: - # a full model warmup may skip the child only when its canonical weights are - # provably complete (the conservative fast-path gate); a patterned / non-model - # request only needs its referenced files (no dangling symlinks). Complete it in - # the killable child otherwise, so the in-process load never proceeds with missing - # files. Scope the check to the returned snapshot, NOT the whole repo: an - # unrelated revision mid-download (a stale .incomplete blob or a broken older - # snapshot elsewhere in the same repo cache) must not force a needless re-fetch. + # local_files_only returns a snapshot dir whenever refs/ + snapshots/ exist, + # even one left by a prior interrupted or patterned download (config-only, partial shards). + # Validate the EXACT returned revision dir: a full model warmup skips the child only when + # its canonical weights are provably complete; a patterned / non-model request only needs + # its referenced files. Scope to this snapshot, NOT the whole repo, so an unrelated + # revision mid-download elsewhere in the repo cache does not force a needless re-fetch. if _cache_can_skip_download( Path(cached_dir), repo_type = repo_type, From 7db059af7183d7d51379f2d08dce035c81362f4b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 29 Jun 2026 05:39:27 +0000 Subject: [PATCH 49/82] Stop misclassifying two weightless request shapes as weight-bearing A 10-reviewer pass plus a parallel safety/concurrency/correctness audit found the acceptance path solid (the pre/post split, the watchdog ownership scoping, error-type preservation, cancellation, token scrubbing, and the process-group teardown are all intact). Two real false-positive classifications remained in request_can_include_weights / _pattern_can_select_weight, both producing a misleading DownloadStallError on a COMPLETE weightless download: - A metadata-directory GLOB (allow=["tokenizer/*"], "processor/*.json") was treated as weight-bearing because the "*" basename matched a weight probe, even though the trailing-slash form ("tokenizer/") already reads weightless. It now inherits its directory's weightlessness, so a complete tokenizer-only download is accepted instead of looped into a DownloadStallError. - A root-reachable allow that the ignore filter strips of every weight (allow=["*"] + ignore=[every weight suffix]) was classified weight-bearing because the ignore set was disregarded whenever allow_patterns was present. It now applies HF's allow-then-ignore semantics to the weight probes. A subdir-scoped allow (unet/*, checkpoint-*/*) stays weight-bearing, and unsloth's weights_at_root warm (allow=None + ignore=["*/*.safetensors", ...]) is unchanged. The shared classifier gates both the pre-download skip and the post-download accept, so the fix lands consistently. A genuinely unfiltered model warm (allow=None, no usable root weight) is still rejected -- the hang-avoidance invariant is unchanged. Deliberately NOT changed (reproduced, then rejected because the suggested fix regresses a real case rather than fixing a hang): - Post-download accepting a complete variant-only / single-format root snapshot (model.fp16.safetensors). Forcing the strict canonical gate here would fail a valid variant="fp16" warm; the variant=None mismatch it targets surfaces as a clean OSError, not an un-killable hang. The pre-strict / post-lenient asymmetry is intentional. - Gating the post-download weight check on "ignore_patterns is None". That would disable the root-weight check for the common weights_at_root warm (which always sets an ignore list) and reintroduce the silent Xet hang. - SIGKILLing the child process group after the leader is already reaped. The leader's pid is freed by join(), so a busy host could recycle it into an unrelated group; SIGTERM already hit the whole group, and hf_xet 1.5.x spawns no helpers. Tests: add regression guards for the metadata-dir glob and the allow+ignore weightless cases (135 passed / 1 skipped). The 40k-layout safety fuzz stays at 0 violations; the Studio de-dup surface stays 7/7; ruff clean. --- tests/test_hf_xet_fallback.py | 36 +++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 20 ++++++++++++++++--- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index aff3a3c29..fadcb4324 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2409,6 +2409,42 @@ def test_metadata_directory_pattern_is_weightless(tmp_path): snap, repo_type = "model", allow_patterns = ["tokenizer/"], ignore_patterns = None) is True +def test_metadata_directory_glob_is_weightless(tmp_path): + """A metadata-dir GLOB (allow=['tokenizer/*'], 'processor/*.json') reads weightless like its + trailing-slash form, so a complete tokenizer-only download is accepted instead of looped into a + DownloadStallError. A component / checkpoint dir glob stays weight-bearing.""" + assert hcs.request_can_include_weights(["tokenizer/*"], None) is False + assert hcs.request_can_include_weights(["tokenizer/*.json"], None) is False + assert hcs.request_can_include_weights(["processor/*"], None) is False + assert hcs.request_can_include_weights(["unet/*"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/*"], None) is True + snap, _ = _mk_snapshot(tmp_path, "tokglob") + (snap / "tokenizer").mkdir() + (snap / "tokenizer" / "tokenizer.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer/*"], ignore_patterns = None) is True + + +def test_allow_star_with_all_weights_ignored_is_weightless(tmp_path): + """A root-reachable allow that the ignore filter strips of every weight (allow=['*'] + + ignore=[every weight suffix]) reads weightless, so a complete config-only download is accepted, not + looped into a DownloadStallError. A subdir-scoped allow stays weight-bearing, and an allow whose + weights survive the ignore stays weight-bearing.""" + all_weight_ignores = [ + "*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", + "*.ckpt", "*.onnx", "*.msgpack", "*.h5", "*.pdparams", + ] + assert hcs.request_can_include_weights(["*"], all_weight_ignores) is False + assert hcs.request_can_include_weights(["*"], None) is True + assert hcs.request_can_include_weights(["unet/*"], all_weight_ignores) is True + assert hcs.request_can_include_weights(["*.safetensors"], ["*.bin"]) is True + snap, _ = _mk_snapshot(tmp_path, "cfgonly") + (snap / "config.json").write_text("{}") + (snap / "tokenizer.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*"], ignore_patterns = all_weight_ignores) is True + + def test_post_download_rejects_checkpoint_only_root_model(tmp_path): """A stale snapshot whose only weight is under checkpoint-7/ is rejected for an unpatterned root warm -- a default from_pretrained ignores checkpoint-*/ and would fetch the missing root weights diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 3ef491ae6..2e190bc20 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -418,6 +418,12 @@ def _pattern_can_select_weight(pattern: "object") -> bool: if pattern.endswith("/"): dir_name = pattern.rstrip("/").rsplit("/", 1)[-1].lower() return dir_name not in _NON_WEIGHT_DIRS + # A pattern scoped under a metadata dir ("tokenizer/*", "processor/*.json") is weightless like the + # "tokenizer/" form, instead of letting a "*" basename match a weight probe. + if "/" in pattern: + parent = pattern.rsplit("/", 1)[0].rstrip("/").rsplit("/", 1)[-1].lower() + if parent in _NON_WEIGHT_DIRS: + return False base = pattern.rsplit("/", 1)[-1] if base.endswith(_WEIGHT_FILE_SUFFIXES): return True @@ -431,8 +437,8 @@ def request_can_include_weights( ) -> bool: """Whether a request restricted by *allow_patterns* / *ignore_patterns* can still include a weight. Conservative: True when uncertain, so the acceptance check requires a weight; False only for a - clearly weightless request (a tokenizer / config allow list, or an ignore list dropping every - weight format), which preserves the offline short-circuit for a tokenizer-only warm.""" + clearly weightless request (a tokenizer / config allow list, an ignore list dropping every weight + format, or an allow + ignore pair that strips them all), preserving the tokenizer-only short-circuit.""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) if allow_patterns is None and ignore_patterns is None: @@ -441,7 +447,15 @@ def request_can_include_weights( return not _ignore_strips_all_weights(ignore_patterns or []) if not allow_patterns: return False # allow=[] selects nothing - return any(_pattern_can_select_weight(pat) for pat in allow_patterns) + if not any(_pattern_can_select_weight(pat) for pat in allow_patterns): + return False + # A root-reachable allow (no required subdir) can still be left weightless by the ignore filter + # (allow=["*"] + ignore=[every weight suffix]). Apply HF's allow-then-ignore semantics to the weight + # probes; a subdir-scoped allow stays weight-bearing (its required dir is absent from the root probes). + if ignore_patterns and all(isinstance(p, str) and "/" not in p for p in allow_patterns): + if not _filter_paths(list(_WEIGHT_PATTERN_PROBES), allow_patterns, ignore_patterns): + return False + return True def _canonical_root_weights_complete( From b4ed391aab6cbce845001e1d1fe4f3fd02a1ada1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 29 Jun 2026 06:40:56 +0000 Subject: [PATCH 50/82] Only canonical model/pytorch_model shard indexes gate the fast path A correctness audit found a real PRE-download over-skip that could reintroduce the un-killable Xet hang the PR exists to prevent. _is_canonical_weight_shard_index matched ANY "*.safetensors.index.json" / "*.bin.index.json", so an adapter_model.safetensors.index.json counted as a canonical base index. A base+adapter repo whose cache held only a sharded adapter (no base weights) therefore passed snapshot_dir_is_complete -> _cache_can_skip_download returned True -> the protective child was skipped -> a default from_pretrained base load fetched the missing base weights over un-killable Xet and hung. Tighten it to the two exact canonical index names (model.safetensors.index.json, pytorch_model.bin.index.json), matching its own docstring. A sharded-adapter-only or variant-only cache now correctly defers to the watched child. Canonical single and canonical sharded (with index) caches still fast-path; the POST side already excludes adapter shards via _CANONICAL_ROOT_SHARD_RE, so no POST change is needed. Reproduced before/after; add a regression test (adapter-only root cache must defer). Suite 136 passed / 1 skipped; the 40k-layout safety fuzz stays at 0 violations; the Studio de-dup surface stays 7/7; ruff clean. A 10-reviewer pass plus two more audit forks otherwise found the PR solid (the warm/offline short-circuit, the 2-attempt retry's incomplete-snapshot guard, the watchdog, error-type preservation, cancellation, and the round-4 classifier change are all confirmed sound). Two low-consensus multi-rank concerns were reproduced and deliberately not changed: the injected (Studio) HTTP-prep hook scoping is a Studio-internal follow-up (the zoo default path is already owned-scoped), and scoping the snapshot watchdog to per-child partials would risk missing a real stall on a resumed shard for only a needless HTTP downgrade. --- tests/test_hf_xet_fallback.py | 20 ++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 9 +++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index fadcb4324..1733e4018 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2072,6 +2072,26 @@ def test_gate_fast_paths_canonical_sharded_with_index(tmp_path): assert hcs.snapshot_dir_is_complete(snap2) is False +def test_gate_rejects_sharded_adapter_only_root_cache(tmp_path): + """A complete sharded ADAPTER at the root (adapter_model.safetensors.index.json + its shards) is + NOT a canonical base model: only model/pytorch_model index names gate the fast path. A base+adapter + repo whose cache holds only the adapter must defer to the child, else a default from_pretrained + base load fetches the missing base weights over un-killable Xet.""" + assert hcs._is_canonical_weight_shard_index("adapter_model.safetensors.index.json") is False + assert hcs._is_canonical_weight_shard_index("model.safetensors.index.json") is True + assert hcs._is_canonical_weight_shard_index("pytorch_model.bin.index.json") is True + snap, blob = _mk_snapshot(tmp_path, "adapteronly") + (snap / "config.json").write_text("{}") + (snap / "adapter_model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "adapter_model-00002-of-00002.safetensors").symlink_to(blob) + (snap / "adapter_model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "adapter_model-00001-of-00002.safetensors", + "b": "adapter_model-00002-of-00002.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap) is False + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + def test_gate_rejects_config_only(tmp_path): snap, _ = _mk_snapshot(tmp_path, "cfg") (snap / "config.json").write_text("{}") diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 2e190bc20..07aa9c6da 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -236,10 +236,11 @@ def _is_weight_shard_index(name: str) -> bool: def _is_canonical_weight_shard_index(name: str) -> bool: """True only for the CANONICAL (non-variant) index a default load probes - (``model.safetensors.index.json`` / ``pytorch_model.bin.index.json``). A variant - (``...index.fp16.json``) is rejected: the wrapper takes no variant param, so a variant-only cache - must not satisfy the canonical fast path (its canonical weights are still missing).""" - return name.endswith(".safetensors.index.json") or name.endswith(".bin.index.json") + (``model.safetensors.index.json`` / ``pytorch_model.bin.index.json``). Exact names only: an + ``adapter_model.safetensors.index.json`` (or a variant ``...index.fp16.json``) is rejected, so a + sharded-adapter-only / variant-only cache does not satisfy the canonical fast path (its base + canonical weights are still missing -> the load would fetch them over un-killable Xet).""" + return name in ("model.safetensors.index.json", "pytorch_model.bin.index.json") def _weight_shard_index_complete(index_path: Path) -> bool: From 2ffeb7f42576da9a6193a2326997323b4554c1a3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 30 Jun 2026 08:40:54 +0000 Subject: [PATCH 51/82] Clarify get_hf_download_state return-value docstring A missing or empty HF cache returns (0, False), not None; None is returned only on a probe exception. The docstring conflated the two. Comment-only; no behavior change (AST-verified). --- unsloth_zoo/hf_xet_fallback.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index a4cdb9f45..a521e7720 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -321,7 +321,8 @@ def get_hf_download_state( ) -> Optional[tuple[int, bool]]: """Return ``(total_on_disk_bytes, has_incomplete)`` for the HF cache being written (sparse-aware, so a partial Xet / ``hf_transfer`` blob is not read as full progress). Scans *cache_dir* or the - active ``HF_HUB_CACHE``. ``None`` -> unmeasurable, so callers skip stall logic this tick.""" + active ``HF_HUB_CACHE``. A missing / empty cache reads as ``(0, False)``; ``None`` is returned only + on a probe exception (unmeasurable -> callers skip stall logic this tick).""" try: if hf_cache_root(cache_dir = cache_dir) is None: return (0, False) From 8b58fbe5126e16774485cf241ce433b32672f4d1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 00:51:40 +0000 Subject: [PATCH 52/82] Defer the pre-download cache skip on a variant load A variant load (variant="fp16") reads variant-named weights (model..safetensors) that the canonical completeness gate does not check. A cache holding only the non-variant canonical weight read as complete, so the protective child was skipped and the in-process load fetched the missing variant weight over un-killable Xet: the exact hang this fallback exists to prevent. Plumb variant through snapshot_download_with_xet_fallback into _cache_can_skip_download and defer (spawn the child) when it is set, matching the conservative posture already used for diffusers and non-trivial patterns. The child warms every root weight including the variant, and the post-download check is untouched, so a complete download is never looped into a stall error. Also raise the never-stall watchdog tests' stall_timeout to 0.5s so a GIL-contended CI runner cannot trip a false stall, and add a test for the transient-unmeasurable watchdog tick (get_hf_download_state returning None is treated as progress). --- tests/test_hf_xet_fallback.py | 57 +++++++++++++++++++++++++++++----- unsloth_zoo/hf_xet_fallback.py | 12 ++++++- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 1733e4018..1c60fa00e 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -128,7 +128,7 @@ def _grow(): calls: list[str] = [] stop = xf.start_watchdog( - repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3 + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5 ) try: time.sleep(1.0) # well past stall_timeout, but bytes keep growing @@ -144,7 +144,7 @@ def test_no_incomplete_never_stalls(hf_cache): calls: list[str] = [] stop = xf.start_watchdog( - repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3 + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5 ) try: time.sleep(0.8) @@ -153,6 +153,32 @@ def test_no_incomplete_never_stalls(hf_cache): stop.set() +def test_transient_unmeasurable_tick_is_progress(hf_cache, monkeypatch): + """A tick whose cache state is momentarily unmeasurable (get_hf_download_state -> None on a + transient FS error) is treated as progress, so a run of None ticks cannot trip a false stall. + Once the state is readable again and confirms a frozen .incomplete, the real stall still fires -- + the None-handling must not permanently mask a genuine stall.""" + seq = {"n": 0} + frozen = (2048, True) # constant size + active .incomplete: would stall if measured every tick + + def fake_state(*args, **kwargs): + seq["n"] += 1 + return None if seq["n"] <= 8 else frozen # first ~8 ticks unmeasurable, then measurable+frozen + + monkeypatch.setattr(xf, "get_hf_download_state", fake_state) + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + ) + try: + time.sleep(0.3) # within the unmeasurable window: no false stall despite no measured progress + assert calls == [], "watchdog fired during a transient-unmeasurable window" + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), "stall never fired after state recovered" + finally: + stop.set() + + def test_stall_fires_at_most_once(hf_cache): blobs = _blobs_dir(hf_cache) (blobs / "frozen.incomplete").write_bytes(b"\0" * 2048) @@ -229,7 +255,7 @@ def _grow(): calls: list[str] = [] stop = xf.start_watchdog( # default: repo-wide (watch_new_partials_only = False) - repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5, ) try: time.sleep(1.0) # well past stall_timeout, but repo-wide bytes keep growing @@ -248,7 +274,7 @@ def test_file_watchdog_ignores_baseline_only_partials(hf_cache): calls: list[str] = [] stop = xf.start_watchdog( - repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.2, + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5, watch_new_partials_only = True, baseline_incomplete_blobs = {"sibling.incomplete"}, ) try: @@ -355,7 +381,7 @@ def test_file_watchdog_empty_open_set_ignores_sibling(hf_cache, monkeypatch): calls: list[str] = [] stop = xf.start_watchdog( - repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.2, + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5, watch_new_partials_only = True, baseline_incomplete_blobs = set(), child_pid = 4242, # non-None so the precise child-open path is taken ) @@ -2176,6 +2202,20 @@ def test_pre_download_skips_complete_model(tmp_path): snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True +def test_pre_download_defers_variant_on_canonical_cache(tmp_path): + """A variant load (variant="fp16") reads model..safetensors, which the canonical gate + does not check. A cache holding only the non-variant canonical weight must NOT fast-path when a + variant is requested -- else the in-process load fetches the missing variant over un-killable Xet. + Same cache, no variant, still fast-paths (the child is only spawned when actually needed).""" + snap, blob = _mk_snapshot(tmp_path, "var") + (snap / "model.safetensors").symlink_to(blob) + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + def test_pre_download_does_not_skip_diffusers_but_post_accepts(tmp_path): """The pre/post asymmetry: a diffusers warm is NOT fast-pathed (spawn the child), but the same complete diffusers result IS accepted post-download (it has component weights), so a good @@ -2308,9 +2348,10 @@ def test_post_download_accepts_weightless_patterned_result(tmp_path): def test_gate_rejects_variant_only_shard_index(tmp_path): """codex :269 (over-accept): a variant-only shard index (model.safetensors.index.fp16.json) must - NOT satisfy the canonical allow=None fast path -- the fallback wrapper takes no variant param, so - a default load probes the canonical index whose weights are absent. Only a canonical - (non-variant) index counts; the variant layout defers to the watched child.""" + NOT satisfy the canonical allow=None fast path -- snapshot_dir_is_complete is variant-blind (a + default load probes the canonical index whose weights are absent). Only a canonical (non-variant) + index counts here; a variant REQUEST is deferred one level up in _cache_can_skip_download (see + test_pre_download_defers_variant_on_canonical_cache).""" snap, blob = _mk_snapshot(tmp_path, "variant") (snap / "config.json").write_text("{}") (snap / "model-00001-of-00001.fp16.safetensors").symlink_to(blob) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index a521e7720..60f5bb085 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1065,6 +1065,7 @@ def _patterns_are_exact_names(patterns: Any) -> bool: def _cache_can_skip_download( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, + variant: Optional[str] = None, ) -> bool: """PRE-download: whether a cached snapshot is complete enough to skip the protective child. @@ -1076,6 +1077,12 @@ def _cache_can_skip_download( defers to the child rather than hand back a partial cache. An intact exact-named subset still short-circuits (offline tokenizer-only / named-file warm).""" if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): + # A variant load reads variant-named weights (model..safetensors) that the canonical + # gate does not check: a cache holding only the canonical weight reads as complete, so the + # in-process load would fetch the variant over un-killable Xet. Defer to the child (it warms + # the variant too). + if variant: + return False return snapshot_dir_is_complete( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, ) @@ -1388,6 +1395,7 @@ def snapshot_download_with_xet_fallback( ignore_patterns: Optional[Any] = None, force_download: bool = False, local_files_only: bool = False, + variant: Optional[str] = None, cancel_event: Optional[threading.Event] = None, stall_timeout: float = DEFAULT_STALL_TIMEOUT, interval: float = DEFAULT_HEARTBEAT_INTERVAL, @@ -1402,7 +1410,8 @@ def snapshot_download_with_xet_fallback( model load (which then hits a warm cache and cannot hang on a native Xet thread). A fully cached repo short-circuits in-process via ``local_files_only`` with no child. ``force_download=True`` re-fetches in the killable child even if cached; ``local_files_only=True`` resolves from cache - in-process with no child (HF offline semantics). + in-process with no child (HF offline semantics). ``variant`` (e.g. "fp16") forces the child even + on a warm canonical cache, since the canonical gate cannot prove the variant-named weights present. """ repo_type = repo_type or "model" # HF treats None as the default model repo. # Expand ~ as huggingface_hub does, so the probe and the child resolve to the same cache location. @@ -1451,6 +1460,7 @@ def snapshot_download_with_xet_fallback( repo_type = repo_type, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + variant = variant, ): return cached_dir logger.debug("Cached snapshot for %s is incomplete; downloading.", repo_id) From 98c8ad18c26151becccc074684985d7907ac9dbc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 03:29:21 +0000 Subject: [PATCH 53/82] Preserve builtin OSError subclasses across the spawn boundary _resolve_exception_class special-cased the exact name "OSError" but not its builtin subclasses, so a deterministic PermissionError (an unwritable custom cache), FileNotFoundError, or similar was reconstructed in the parent as a generic RuntimeError, breaking callers that catch OSError / PermissionError for cache cleanup or fallback. Map any builtin OSError subclass back to its real type, consistent with the existing OSError and HfHubHTTPError type preservation. The retry decision is unchanged: these stay deterministic (not retried). --- tests/test_hf_xet_fallback.py | 16 ++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 8 ++++++++ 2 files changed, 24 insertions(+) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 1c60fa00e..e8d916dcf 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2396,6 +2396,22 @@ def test_hfvalidationerror_type_preserved_across_spawn(): assert xf._is_retryable_download_error(inst) is False +def test_oserror_subclass_type_preserved_across_spawn(): + """A deterministic builtin OSError subclass (PermissionError from an unwritable cache, + FileNotFoundError, ...) keeps its TYPE across the spawn boundary so a caller's `except OSError` / + `except PermissionError` still fires instead of seeing a generic RuntimeError. Non-OSError builtins + are not spuriously resolved (they fall through to the Hub-name lookup / None).""" + for name in ("PermissionError", "FileNotFoundError", "IsADirectoryError", "NotADirectoryError"): + cls = xf._resolve_exception_class(name) + assert cls is not None and issubclass(cls, OSError) and cls.__name__ == name + # A deterministic PermissionError is reconstructed as a real PermissionError and not retried. + perm = xf._instantiate_preserving_type(xf._resolve_exception_class("PermissionError"), "denied") + assert isinstance(perm, PermissionError) + assert xf._is_retryable_download_error(perm) is False + # An unrelated builtin (not an OSError subclass, not a Hub error name) is not resolved here. + assert xf._resolve_exception_class("ValueError") is None + + def test_weight_pattern_selector_handles_globs(tmp_path): """The weight-pattern selector reads tokenizer / config / json globs as weightless (keeps their offline short-circuit) but classifies standard weight names and ? / [] globs as weight-bearing.""" diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 60f5bb085..3b481daef 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -26,6 +26,7 @@ from __future__ import annotations +import builtins import errno import importlib.util import multiprocessing as mp @@ -493,6 +494,13 @@ def _resolve_exception_class(type_name: str) -> "Optional[type]": occurs and never hard-depends on a specific huggingface_hub layout.""" if type_name == "OSError": return OSError + # Preserve builtin OSError subclasses (PermissionError, FileNotFoundError, ...): these are + # deterministic filesystem failures (e.g. an unwritable custom cache) the child cannot retry away, + # so a caller's `except OSError` / `except PermissionError` must still fire rather than see the + # generic RuntimeError the resolver would otherwise fall through to. + builtin_cls = getattr(builtins, type_name, None) + if isinstance(builtin_cls, type) and issubclass(builtin_cls, OSError): + return builtin_cls if type_name not in _DETERMINISTIC_ERROR_NAMES and type_name not in _TYPE_PRESERVE_ONLY_NAMES: return None for module_name in ("huggingface_hub.errors", "huggingface_hub.utils"): From b1628a24a4ea29bf6cd548ff6e8b3d13f2c20d7e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 04:05:41 +0000 Subject: [PATCH 54/82] Tighten the post-download completeness check against offline-fallback partials On a transient connection error during the child's metadata call, snapshot_download silently returns an existing (stale / partial) cache instead of fetching, and the child reports it as success. The post-download check is the only guard, so it now validates the weight the load will actually read: - Apply the request's ignore filter to the root-weight check, so a partial holding only the format the load will NOT read (an ignored pytorch_model.bin when safetensors was requested) does not count as a usable weight. - Thread the variant through the post-download check: a variant load whose partial kept only the canonical model.safetensors, not model..safetensors, is retried rather than accepted (else the in-process load fetches the variant over un-killable Xet). - Apply the canonical shard-completeness check to patterned (globbed) weight requests too, not just the unpatterned root path, so a partial with a shard index but a missing shard is retried. A complete download always passes every check (verified by the e2e recovery sim and the 40k-layout safety fuzz), so a good result is never looped into a stall error. The diffusers all-components case is intentionally left lenient (determining the required components locally would risk failing a complete pipeline). --- tests/test_hf_xet_fallback.py | 53 ++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 112 ++++++++++++++++++++++----------- 2 files changed, 130 insertions(+), 35 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index e8d916dcf..ed45906d3 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2240,6 +2240,59 @@ def test_post_download_rejects_config_only_model(tmp_path): snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False +def test_post_download_rejects_ignored_only_format(tmp_path): + """snapshot_download silently returns a stale cache on a transient metadata error. A safetensors + load (ignore=['*.bin']) whose returned partial kept only the ignored pytorch_model.bin -- not the + requested model.safetensors -- must be rejected (the weight check applies the ignore filter) and + retried over HTTP, not loaded in-process (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "ignfmt") + (snap / "pytorch_model.bin").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is False + # The requested safetensors present -> accepted. + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is True + + +def test_post_download_rejects_canonical_only_for_variant(tmp_path): + """A variant load (variant='fp16') whose returned partial kept only the canonical model.safetensors + -- not model.fp16.safetensors -- must be rejected and retried, else the in-process load fetches the + missing variant over un-killable Xet (Codex #829). A present variant weight (single or sharded + infix) is accepted.""" + snap, blob = _mk_snapshot(tmp_path, "varpost") + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + (snap / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + snap2, blob2 = _mk_snapshot(tmp_path, "varshard") + (snap2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_rejects_incomplete_sharded_glob(tmp_path): + """A globbed weight request (allow=['*.safetensors']) whose returned partial has a canonical shard + index but is missing a shard must be rejected -- globs get the same shard-completeness check as the + unpatterned root path -- so the load does not finish the missing shard over Xet (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "shardglob") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None) is False + # The missing shard present -> complete -> accepted. + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None) is True + + def test_post_download_accepts_dataset_without_weight(tmp_path): snap, blob = _mk_snapshot(tmp_path, "ds") (snap / "data.parquet").symlink_to(blob) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 3b481daef..d968e720e 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -968,34 +968,57 @@ def _has_any_weight(snapshot_dir: Path) -> bool: return False -def _root_has_loadable_weight(snapshot_dir: Path) -> bool: - """True if a loadable weight sits at the snapshot ROOT (where a default load reads it). Ignores - subfolders, so a stale ``checkpoint-7/``-only snapshot is not mistaken for a usable root model.""" - try: - for entry in snapshot_dir.iterdir(): - if _is_loadable_weight_file(entry.name): - try: - if entry.is_file(): - return True - except OSError: - continue - except OSError: - return False - return False - - -def _root_model_has_weight(snapshot_dir: Path) -> bool: +def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: """Whether an UNPATTERNED model warm holds a weight a default load reads: a ROOT weight, or -- for a diffusers pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting any subtree weight (as ``_has_any_weight`` does) would accept a stale checkpoint-only snapshot and then fetch - the root weights over un-killable Xet; diffusers is the one layout whose weights live in subfolders.""" + the root weights over un-killable Xet; diffusers is the one layout whose weights live in subfolders. + The request's ignore filter is applied to the ROOT weights, so an offline-fallback partial holding + only the format the load will NOT read (an ignored ``*.bin`` when safetensors was requested) does not + count as a usable weight -- the incomplete result is retried over HTTP instead of loaded in-process.""" try: is_diffusers = (snapshot_dir / "model_index.json").is_file() except OSError: is_diffusers = False if is_diffusers: return _has_any_weight(snapshot_dir) - return _root_has_loadable_weight(snapshot_dir) + rels: list = [] + try: + for entry in snapshot_dir.iterdir(): + if not _is_loadable_weight_file(entry.name): + continue + try: + if entry.is_file(): + rels.append(entry.name) + except OSError: + continue + except OSError: + return False + return bool(_filter_paths(rels, None, ignore_patterns)) + + +def _root_has_variant_weight(snapshot_dir: Path, variant: str) -> bool: + """True if a ROOT weight carrying the requested *variant* token is present. transformers inserts the + variant before the extension (a ``..`` infix: ``model.fp16.safetensors``) or before a shard + suffix (a ``.-`` infix: ``model.fp16-00001-of-00002.safetensors``), so an offline-fallback + partial that kept only the canonical weight does not satisfy a variant request.""" + infix_dot = f".{variant}." + infix_dash = f".{variant}-" + try: + for entry in snapshot_dir.iterdir(): + name = entry.name + if not _is_loadable_weight_file(name): + continue + if infix_dot not in name and infix_dash not in name: + continue + try: + if entry.is_file(): + return True + except OSError: + continue + except OSError: + return False + return False # Interchangeable exact weight names: the either-format ``["pytorch_model.bin", "model.safetensors"]`` @@ -1106,19 +1129,23 @@ def _cache_can_skip_download( def _download_result_usable( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, + variant: Optional[str] = None, ) -> bool: """POST-download: whether the child's result is usable, or should be retried over HTTP. snapshot_download already did the authoritative manifest compare, so accept unless there is - POSITIVE breakage evidence; LENIENT otherwise (a finished diffusers / variant / either-format - download passes, an optional missing file is not treated as broken) so a good download is never - looped into a ``DownloadStallError``. Breakage checks: + POSITIVE breakage evidence; LENIENT otherwise (a finished diffusers / either-format download passes, + an optional missing file is not treated as broken) so a good download is never looped into a + ``DownloadStallError``. A transient connection error during the child's metadata call makes + ``snapshot_download`` silently return an existing (stale / partial) cache instead of fetching, so + the checks below apply the request's filters to the weight the load will actually read. Breakage: - A dangling REQUESTED symlink (a missing / still-``.incomplete`` blob). - A missing EXACT-named requested file (grouped by weight equivalence: the either-format pair needs one; base AND adapter, or a ``["tokenizer.json"]`` request, each). Globs stay lenient. - - A weight-bearing MODEL request with no usable weight. UNPATTERNED -> the weight must be - ROOT-readable (or a diffusers component; a stale ``checkpoint-7/``-only snapshot does not count) - and a loose canonical-sharded warm (no index) is rejected. Patterned -> a weight WITHIN scope.""" + - A weight-bearing MODEL request with no usable weight. A variant load needs a variant-named ROOT + weight (the canonical weight a partial kept does not satisfy it). UNPATTERNED -> a ROOT-readable + weight the load reads (ignore filter applied; a stale ``checkpoint-7/``-only snapshot does not + count) with a complete canonical shard set. Patterned -> a weight WITHIN scope, shard set complete.""" if snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, repo_type = repo_type, @@ -1129,22 +1156,34 @@ def _download_result_usable( ): return False if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): - if allow_patterns is None: - # Default root load: a root (or diffusers-component) weight, sharded set complete. - if not _root_model_has_weight(snapshot_dir): + if allow_patterns is None and variant: + # Variant root load: a partial that kept only the canonical weight would leave the load to + # fetch the requested variant over un-killable Xet -> require a variant-named root weight. + if not _root_has_variant_weight(snapshot_dir, variant): + return False + elif allow_patterns is None: + # Default root load: a root (or diffusers-component) weight the load reads (ignore filter + # applied), sharded set complete. + if not _root_model_has_weight(snapshot_dir, ignore_patterns = ignore_patterns): + return False + if _has_incomplete_canonical_root_shards(snapshot_dir): + return False + else: + # Patterned weight request: a selected weight must be present AND a selected canonical shard + # set must be complete (a lone ``model-00001-of-0000N`` without its index / remaining shards + # is a partial the in-process load would finish over Xet). + if not _has_selected_weight( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ): return False if _has_incomplete_canonical_root_shards(snapshot_dir): return False - elif not _has_selected_weight( - snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns - ): - # Patterned weight request: a weight WITHIN the requested scope must be present. - return False return True def _snapshot_payload_incomplete( - payload: Any, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any + payload: Any, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, + variant: Optional[str] = None, ) -> bool: """True when a snapshot download returned a real directory not usable for the request (see ``_download_result_usable``). Guarded to an existing dir, so a mocked / non-path payload (tests) is @@ -1160,7 +1199,7 @@ def _snapshot_payload_incomplete( return False return not _download_result_usable( path, repo_type = repo_type, allow_patterns = allow_patterns, - ignore_patterns = ignore_patterns, + ignore_patterns = ignore_patterns, variant = variant, ) @@ -1178,6 +1217,7 @@ def _download_with_xet_fallback( grace_period: float, on_status: Optional[Callable[[str], None]], prepare_for_http_fn: Optional[Callable[[str, str], None]], + variant: Optional[str] = None, ) -> str: """Shared 2-attempt loop: Xet primary, HTTP on a stall. Returns the local path.""" if cancel_event is not None and cancel_event.is_set(): @@ -1238,6 +1278,7 @@ def _download_with_xet_fallback( repo_type = repo_type, allow_patterns = params.get("allow_patterns"), ignore_patterns = params.get("ignore_patterns"), + variant = variant, ): # HF can hand back an existing incomplete snapshot dir (offline / timed-out request) # instead of fetching the missing files. Never load that in-process: retry over HTTP, @@ -1495,4 +1536,5 @@ def snapshot_download_with_xet_fallback( grace_period = grace_period, on_status = on_status, prepare_for_http_fn = prepare_for_http_fn, + variant = variant, ) From b8cc019fe50dbcf85c9976c571e99ce02c8baf08 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 05:38:57 +0000 Subject: [PATCH 55/82] Honor the variant check for patterned snapshot warms The post-download variant guard only fired when allow_patterns was None, so a patterned variant load (a subfolder= + variant= request, whose warm passes allow_patterns=['/*', ...]) fell into the patterned branch where a stale cache holding only the canonical model.safetensors satisfied _has_selected_weight. On a transient connection error the child's snapshot_download can silently return that stale cache as success, so the in-process from_pretrained(variant=...) then fetches model..safetensors over un-killable Xet. Thread the variant into the patterned branch with a scope-aware check (_has_selected_variant_weight): a SELECTED weight must also carry the variant infix, so a partial that kept only the canonical weight in scope is retried over HTTP. A complete variant download (the in-scope model..safetensors, single or sharded) still passes, so a good result is never looped into a stall error. This extends the existing allow_patterns-None variant guard to the patterned case with the same risk profile. --- tests/test_hf_xet_fallback.py | 35 +++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 48 +++++++++++++++++++++++++++++++--- 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index ed45906d3..6721c02a4 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2276,6 +2276,41 @@ def test_post_download_rejects_canonical_only_for_variant(tmp_path): variant = "fp16") is True +def test_post_download_rejects_patterned_canonical_only_for_variant(tmp_path): + """A PATTERNED variant load (subfolder= + variant=, so allow=['weights/*']) whose returned partial + kept only the canonical weight in scope must be rejected -- the variant check applies to the + patterned branch too, not only allow=None (Codex #829). A present in-scope variant weight is + accepted, and a complete variant download is never false-rejected.""" + snap, blob = _mk_snapshot(tmp_path, "subvar") + sub = snap / "weights" + sub.mkdir() + (sub / "model.safetensors").symlink_to(blob) # canonical only, no variant + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is False + # The in-scope variant weight present -> complete -> accepted (no false-reject). + (sub / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is True + # A sharded in-scope variant weight (dash infix) is likewise accepted. + snap2, blob2 = _mk_snapshot(tmp_path, "subvarshard") + sub2 = snap2 / "weights" + sub2.mkdir() + (sub2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is True + # An out-of-scope variant weight does NOT satisfy an in-scope variant request. + snap3, blob3 = _mk_snapshot(tmp_path, "subvaroos") + (snap3 / "model.fp16.safetensors").symlink_to(blob3) # at root, but request scopes to weights/ + (snap3 / "weights").mkdir() + (snap3 / "weights" / "model.safetensors").symlink_to(blob3) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is False + + def test_post_download_rejects_incomplete_sharded_glob(tmp_path): """A globbed weight request (allow=['*.safetensors']) whose returned partial has a canonical shard index but is missing a shard must be rejected -- globs get the same shard-completeness check as the diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index d968e720e..1943d927a 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1083,6 +1083,34 @@ def _has_selected_weight( return bool(_filter_paths(weights, allow_patterns, ignore_patterns)) +def _has_selected_variant_weight( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: str, +) -> bool: + """True if a SELECTED loadable weight carrying the *variant* token is present. Combines the request's + allow / ignore scope (as ``_has_selected_weight``) with the variant infix check (as + ``_root_has_variant_weight``): a patterned variant load (e.g. ``subfolder=`` + ``variant=``) whose + offline-fallback partial kept only the canonical weight in scope is retried over HTTP rather than + loaded, else the in-process load fetches ``model..safetensors`` over un-killable Xet.""" + infix_dot = f".{variant}." + infix_dash = f".{variant}-" + weights: list = [] + try: + for entry in snapshot_dir.rglob("*"): + name = entry.name + if not _is_loadable_weight_file(name): + continue + if infix_dot not in name and infix_dash not in name: + continue + try: + if entry.is_file(): + weights.append(entry.relative_to(snapshot_dir).as_posix()) + except (OSError, ValueError): + continue + except OSError: + return False + return bool(_filter_paths(weights, allow_patterns, ignore_patterns)) + + def _patterns_are_exact_names(patterns: Any) -> bool: """True only for a non-empty allow list of EXACT filenames (no ``None`` / glob / trailing-slash dir). Only such a request is locally provable complete; ``None`` / a glob needs the Hub manifest.""" @@ -1142,10 +1170,11 @@ def _download_result_usable( - A dangling REQUESTED symlink (a missing / still-``.incomplete`` blob). - A missing EXACT-named requested file (grouped by weight equivalence: the either-format pair needs one; base AND adapter, or a ``["tokenizer.json"]`` request, each). Globs stay lenient. - - A weight-bearing MODEL request with no usable weight. A variant load needs a variant-named ROOT - weight (the canonical weight a partial kept does not satisfy it). UNPATTERNED -> a ROOT-readable - weight the load reads (ignore filter applied; a stale ``checkpoint-7/``-only snapshot does not - count) with a complete canonical shard set. Patterned -> a weight WITHIN scope, shard set complete.""" + - A weight-bearing MODEL request with no usable weight. A variant load needs a variant-named weight + (the canonical weight a partial kept does not satisfy it): a ROOT one when UNPATTERNED, else one + WITHIN scope. UNPATTERNED non-variant -> a ROOT-readable weight the load reads (ignore filter + applied; a stale ``checkpoint-7/``-only snapshot does not count) with a complete canonical shard + set. Patterned non-variant -> a weight WITHIN scope, shard set complete.""" if snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, repo_type = repo_type, @@ -1168,6 +1197,17 @@ def _download_result_usable( return False if _has_incomplete_canonical_root_shards(snapshot_dir): return False + elif variant: + # Patterned variant load (e.g. subfolder= + variant=): a selected weight is not enough -- + # a partial that kept only the canonical weight in scope would leave the load to fetch the + # requested variant over un-killable Xet. Require a selected weight carrying the variant. + if not _has_selected_variant_weight( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return False + if _has_incomplete_canonical_root_shards(snapshot_dir): + return False else: # Patterned weight request: a selected weight must be present AND a selected canonical shard # set must be complete (a lone ``model-00001-of-0000N`` without its index / remaining shards From 18b35d61320a5b215aa023cb7ebf1e7ec9017820 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 05:53:23 +0000 Subject: [PATCH 56/82] Preserve errno across the spawn boundary for OSError subclasses _raise_child_error only preserved errno for exact OSError, so once the reconstruction map began resolving OSError subclasses (PermissionError, FileNotFoundError, hf_hub's LocalEntryNotFoundError), their errno was dropped and a caller's except OSError cleanup that branches on exc.errno stopped matching. Reconstruct every deterministic error through the robust _instantiate_preserving_type path and, for any OSError result missing an errno, parse it from the message and set it as an attribute. The attribute set works for every subclass, including one whose __init__ rejects the (errno, strerror) form, and keeps the message from being double-prefixed with [Errno N]. Also harmonize the SIGKILL fallback sentinel in _terminate_process_group: the force-kill call site passed getattr(signal, "SIGKILL", signal.SIGTERM) while the branch inside _signal_group compares against getattr(signal, "SIGKILL", -9), so on Windows (no signal.SIGKILL) the two defaults diverged and the second attempt re-took the terminate() branch. Functionally moot on Windows (multiprocessing maps proc.kill() to proc.terminate() to TerminateProcess, a hard kill either way), but the call site now matches the check. --- tests/test_hf_xet_fallback.py | 33 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 26 ++++++++++++++++++-------- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 6721c02a4..c6f726f85 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1964,6 +1964,39 @@ def test_oserror_errno_preserved(monkeypatch): assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" +def test_oserror_subclass_errno_preserved(monkeypatch): + """An OSError SUBCLASS (PermissionError from an unwritable cache) keeps BOTH its type AND its errno + across the spawn boundary, so a caller branching on exc.errno still matches (Gemini #829). Errno is + set as an attribute, so it survives even for a subclass whose constructor rejects (errno, strerror); + the message is not double-prefixed with the errno.""" + import errno as _errno + + fake = _install(monkeypatch, [("error", "PermissionError: [Errno 13] Permission denied")]) + with pytest.raises(PermissionError) as excinfo: + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert excinfo.value.errno == _errno.EACCES + assert "[Errno 13] [Errno 13]" not in str(excinfo.value) + assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" + + +def test_raise_child_error_sets_errno_via_attribute(): + """_raise_child_error preserves errno on an OSError subclass even when its __init__ takes a single + positional (like hf_hub's LocalEntryNotFoundError), where the (errno, strerror) constructor Gemini + proposed would raise TypeError.""" + class _SingleArgOSError(OSError): + def __init__(self, message): # rejects the two-arg (errno, strerror) form + super().__init__(message) + + orig = xf._resolve_exception_class + try: + xf._resolve_exception_class = lambda name: _SingleArgOSError if name == "LocalEntryNotFoundError" else orig(name) + with pytest.raises(_SingleArgOSError) as excinfo: + xf._raise_child_error("LocalEntryNotFoundError: [Errno 2] No such file or directory") + assert excinfo.value.errno == 2 + finally: + xf._resolve_exception_class = orig + + # --------------------------------------------------------------------------- # # Spawn-safety regressions: failed-spawn queue cleanup + disable-Xet env-race lock. # --------------------------------------------------------------------------- # diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 1943d927a..61e85278b 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -557,16 +557,22 @@ def _raise_child_error(message: str) -> None: exc_cls = _resolve_exception_class(type_name) if exc_cls is None: raise RuntimeError(message) - if exc_cls is OSError: - # Preserve errno (ENOSPC / EDQUOT ...) so a caller's `except OSError` cleanup can still - # branch on exc.errno; OSError(message) alone would leave errno = None. - errno_val = _parse_errno(message) - if errno_val is not None: - raise OSError(errno_val, message) - raise OSError(message) exc = _instantiate_preserving_type(exc_cls, message) if exc is None: raise RuntimeError(message) + if isinstance(exc, OSError) and getattr(exc, "errno", None) is None: + # Preserve errno (ENOSPC / EDQUOT ...) across the spawn boundary so a caller's `except OSError` + # cleanup can still branch on exc.errno -- for EVERY OSError subclass (PermissionError, + # FileNotFoundError, hf_hub's LocalEntryNotFoundError ...), not just exact OSError. Set it as an + # attribute rather than via the (errno, strerror) constructor: a subclass with a single-arg + # __init__ (LocalEntryNotFoundError) rejects the two-arg form, and this keeps the message clean + # (no doubled "[Errno N]" prefix). + errno_val = _parse_errno(message) + if errno_val is not None: + try: + exc.errno = errno_val + except Exception: + pass raise exc @@ -723,7 +729,11 @@ def _signal_group(sig: int) -> None: # into an unrelated group -- killpg(pid) would then signal the WRONG group. hf_xet 1.5.x spawns no # helpers, so a reaped leader leaves nothing to clean up. if proc.is_alive(): - _signal_group(getattr(signal, "SIGKILL", signal.SIGTERM)) + # Match _signal_group's own SIGKILL sentinel (-9) so the force-kill branch (proc.kill()) is + # taken on Windows, where signal.SIGKILL is undefined. Functionally moot there (multiprocessing + # maps proc.kill() == proc.terminate() == TerminateProcess, a hard kill either way), but keeps + # the call site and the check consistent. + _signal_group(getattr(signal, "SIGKILL", -9)) proc.join(timeout = 5.0) From 8722146e4113f645a00458d3d2c91edb1cf0bd0f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 06:27:17 +0000 Subject: [PATCH 57/82] Scope the post-download shard-completeness gate to what the load reads A fresh review found the canonical-root-shard gate in the post-download check was request-agnostic, format-blind, and variant-blind, causing three defects: - POST false-reject (High): a genuinely complete patterned download (adapter / gguf / subfolder) was force-failed into a DownloadStallError when an unrelated partial canonical base shard set (a leftover from a prior interrupted base pull) was co-resident at root. The gate scanned the whole root regardless of what the request selected. Now the patterned branch applies it only when the request actually selects canonical root shards (a globbed weight request), and the variant-patterned branch drops it entirely (the load reads variant weights, and any co-resident canonical shard set is out of scope). - POST false-accept (Low): an unpatterned load that ignores safetensors (so it reads .bin) accepted a snapshot whose .bin shard set was incomplete but whose safetensors set was complete, because the gate called snapshot_dir_is_complete with no ignore filter, so the complete safetensors masked the incomplete .bin the load reads. The gate now threads the request's ignore filter through. - POST false-accept (Low): an unpatterned variant load accepted an incomplete variant shard set (a variant index present with a listed shard missing). Add a positive-evidence-only variant-shard check: a single-file variant or a complete variant shard set is never rejected, so a complete variant download is not false-rejected. Also harden two items the review surfaced: - Restrict the errno-preservation in _raise_child_error to BUILTIN OSError types. An HF HTTP error (HfHubHTTPError / RepositoryNotFoundError ...) subclasses OSError via requests -> IOError, so isinstance(exc, OSError) matched it; a bracketed [Errno N] in such a message would have been mistaken for a real OS errno. Genuine OS errors (disk full, permission, file-not-found) are builtins and keep their errno. - Reject an absolute or parent-escaping shard path in a weight index before the base / shard existence probe, so a malformed or crafted index cannot make a file outside the snapshot read as a present shard. --- tests/test_hf_xet_fallback.py | 113 ++++++++++++++++++++++++++++++--- unsloth_zoo/hf_cache_state.py | 41 ++++++++++-- unsloth_zoo/hf_xet_fallback.py | 68 ++++++++++++++------ 3 files changed, 189 insertions(+), 33 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index c6f726f85..649ec6d0f 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1979,20 +1979,30 @@ def test_oserror_subclass_errno_preserved(monkeypatch): assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" -def test_raise_child_error_sets_errno_via_attribute(): - """_raise_child_error preserves errno on an OSError subclass even when its __init__ takes a single - positional (like hf_hub's LocalEntryNotFoundError), where the (errno, strerror) constructor Gemini - proposed would raise TypeError.""" - class _SingleArgOSError(OSError): - def __init__(self, message): # rejects the two-arg (errno, strerror) form +def test_raise_child_error_errno_only_for_builtin_oserror(): + """errno is preserved only for a BUILTIN OSError type (a real OS errno), set via attribute so it + survives a builtin whose __init__ rejects the (errno, strerror) form. A NON-builtin OSError subclass + -- an HF HTTP error subclasses OSError via requests -> IOError -- with a bracketed [Errno N] in its + message must NOT get a spurious errno (#829 re-review).""" + # Builtin OSError subclass -> errno preserved. + with pytest.raises(FileNotFoundError) as excinfo: + xf._raise_child_error("FileNotFoundError: [Errno 2] No such file or directory") + assert excinfo.value.errno == 2 + + # A non-builtin OSError subclass (simulating HfHubHTTPError) whose message merely contains a + # bracketed [Errno N] must NOT have it mistaken for a real OS errno. + class _FakeHubHTTPError(OSError): + def __init__(self, message): # single-arg, like hf_hub's error types super().__init__(message) orig = xf._resolve_exception_class try: - xf._resolve_exception_class = lambda name: _SingleArgOSError if name == "LocalEntryNotFoundError" else orig(name) - with pytest.raises(_SingleArgOSError) as excinfo: - xf._raise_child_error("LocalEntryNotFoundError: [Errno 2] No such file or directory") - assert excinfo.value.errno == 2 + xf._resolve_exception_class = ( + lambda name: _FakeHubHTTPError if name == "HfHubHTTPError" else orig(name) + ) + with pytest.raises(_FakeHubHTTPError) as excinfo2: + xf._raise_child_error("HfHubHTTPError: 500 Server Error [Errno 111] for url https://x") + assert excinfo2.value.errno is None finally: xf._resolve_exception_class = orig @@ -2361,6 +2371,89 @@ def test_post_download_rejects_incomplete_sharded_glob(tmp_path): snap, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None) is True +def test_post_download_accepts_patterned_with_coresident_partial_canonical_shards(tmp_path): + """A COMPLETE patterned download (adapter / gguf / subfolder) whose selected weight the load reads is + present must be ACCEPTED even when an UNRELATED partial canonical base shard set is co-resident at + root (a leftover from a prior interrupted base pull). The canonical-shard gate is request-agnostic; + scoping it to requests that actually select canonical root shards avoids failing a working download + into a DownloadStallError (#829 re-review).""" + def _partial_base_shards(snap, blob): + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) # shard 1 present + (snap / "model-00002-of-00002.safetensors").symlink_to(snap / "MISSING") # dangling shard 2 + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + # Adapter request completes; co-resident partial base shards must not force-reject it. + snap, blob = _mk_snapshot(tmp_path, "adapter_coresident") + (snap / "adapter_model.safetensors").symlink_to(blob) + (snap / "adapter_config.json").write_text("{}") + _partial_base_shards(snap, blob) + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["adapter_model.safetensors", "adapter_config.json", "*.json"], + ignore_patterns = None) is True + # gguf request completes; same co-resident partial base shards. + snap2, blob2 = _mk_snapshot(tmp_path, "gguf_coresident") + (snap2 / "model.Q4_K_M.gguf").symlink_to(blob2) + (snap2 / "config.json").write_text("{}") + _partial_base_shards(snap2, blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", + allow_patterns = ["model.Q4_K_M.gguf", "config.json", "*.json"], + ignore_patterns = None) is True + # A globbed weight request that DOES select canonical root shards still gets the completeness gate. + snap3, blob3 = _mk_snapshot(tmp_path, "glob_still_gated") + _partial_base_shards(snap3, blob3) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None) is False + + +def test_post_download_rejects_incomplete_ignored_format_shards(tmp_path): + """An unpatterned load that ignores safetensors (so it reads .bin) whose returned partial has a + COMPLETE safetensors shard set but an INCOMPLETE .bin set must be rejected -- the shard gate applies + the ignore filter, so the complete safetensors does not mask the incomplete .bin the load actually + reads (else the in-process load finishes the missing .bin shard over Xet) (#829 re-review).""" + snap, blob = _mk_snapshot(tmp_path, "ignored_format_shards") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + (snap / "pytorch_model-00001-of-00002.bin").symlink_to(blob) # bin shard 1 only, no index/shard 2 + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, + ignore_patterns = ["*.safetensors"]) is False + # Ignoring the .bin instead (load reads the complete safetensors) -> accepted. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is True + + +def test_post_download_rejects_incomplete_variant_shards(tmp_path): + """An unpatterned variant load whose returned partial has a variant shard INDEX but is missing a + listed variant shard must be rejected, else the in-process load finishes the missing variant shard + over Xet (#829 re-review). Positive-evidence only: a COMPLETE variant shard set and a SINGLE-FILE + variant are both accepted (a complete variant download is never false-rejected).""" + snap, blob = _mk_snapshot(tmp_path, "variant_incomplete") + (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) # shard 1; shard 2 absent + (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # The missing variant shard present -> complete set -> accepted (no false-reject). + (snap / "model.fp16-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # A single-file variant (no index) is accepted. + snap2, blob2 = _mk_snapshot(tmp_path, "variant_single") + (snap2 / "model.fp16.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + def test_post_download_accepts_dataset_without_weight(tmp_path): snap, blob = _mk_snapshot(tmp_path, "ds") (snap / "data.parquet").symlink_to(blob) diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 07aa9c6da..778967c3c 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -265,6 +265,11 @@ def _weight_shard_index_complete(index_path: Path) -> bool: return False base = index_path.parent for shard in shards: + # A well-formed HF index lists a relative shard basename. Reject an absolute / parent-escaping + # value (a malformed or crafted index) rather than let ``base / shard`` resolve to an unrelated + # existing file OUTSIDE the snapshot and read as "present". + if shard.startswith(("/", "\\")) or ".." in shard.replace("\\", "/").split("/"): + return False try: if not (base / shard).exists(): return False @@ -543,18 +548,44 @@ def snapshot_dir_is_complete( ) -def _has_incomplete_canonical_root_shards(snapshot_dir: Path) -> bool: +def _has_incomplete_canonical_root_shards( + snapshot_dir: Path, *, ignore_patterns: "Optional[object]" = None +) -> bool: """True when the root holds canonical numbered shards but is NOT a complete canonical model (index - missing or a shard absent) -- a stale interrupted download a default load cannot read, so the - post-download check rejects it and retries over HTTP. Variant shards are excluded, so a - variant-only repo is not force-failed here.""" + missing or a shard absent) for the format the request READS -- a stale interrupted download a + default load cannot read, so the post-download check rejects it and retries over HTTP. The request's + ignore filter is applied, so a complete safetensors set does not mask an incomplete ``.bin`` set the + load reads under ``ignore=['*.safetensors']``. Variant shards are excluded (their names carry a + ``.-`` infix), so a variant-only repo is not force-failed here.""" try: names = [entry.name for entry in snapshot_dir.iterdir()] except OSError: return False if not any(_CANONICAL_ROOT_SHARD_RE.match(name) for name in names): return False - return not snapshot_dir_is_complete(snapshot_dir) + return not snapshot_dir_is_complete(snapshot_dir, ignore_patterns = ignore_patterns) + + +def _has_incomplete_variant_root_shards(snapshot_dir: Path, variant: str) -> bool: + """True when the root holds a VARIANT weight shard index whose set is incomplete (a listed shard + missing). Positive-evidence ONLY: a single-file variant (no index) or a complete variant shard set + returns False, so a complete or single-file variant download is never rejected. transformers writes + the variant index with the variant token before ``.json`` (``model.safetensors.index..json`` + / ``pytorch_model.bin.index..json``), so it carries both the shard-index marker and the + ``..`` infix.""" + infix = f".{variant}." + try: + entries = list(snapshot_dir.iterdir()) + except OSError: + return False + for entry in entries: + name = entry.name + # _is_weight_shard_index matches canonical AND variant indices; the infix restricts to the + # requested variant (the canonical index has no ".." token). + if infix in name and _is_weight_shard_index(name): + if _safe_is_file(entry) and not _weight_shard_index_complete(entry): + return True + return False def requested_named_files_present( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 61e85278b..d1e949973 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -47,6 +47,7 @@ _filter_paths, _has_glob, _has_incomplete_canonical_root_shards, + _has_incomplete_variant_root_shards, _is_loadable_weight_file, blob_bytes_present, has_active_incomplete_blobs, @@ -548,6 +549,17 @@ def _parse_errno(message: str) -> "Optional[int]": return None +def _is_builtin_oserror(exc: BaseException) -> bool: + """True iff *exc*'s type is a BUILTIN ``OSError`` (or subclass): a genuine OS-level error whose + ``[Errno N]`` is a real errno. Excludes HF/requests HTTP errors, which subclass ``OSError`` via + ``requests -> IOError`` yet carry no OS errno, so a bracketed ``[Errno N]`` in their message is not + mistaken for one.""" + if not isinstance(exc, OSError): + return False + builtin = getattr(builtins, type(exc).__name__, None) + return isinstance(builtin, type) and issubclass(builtin, OSError) and isinstance(exc, builtin) + + def _raise_child_error(message: str) -> None: """Re-raise a deterministic child error preserving its original TYPE when it is a known Hub / OS error, so callers catching ``RepositoryNotFoundError`` / ``GatedRepoError`` / ``OSError`` still @@ -560,13 +572,15 @@ def _raise_child_error(message: str) -> None: exc = _instantiate_preserving_type(exc_cls, message) if exc is None: raise RuntimeError(message) - if isinstance(exc, OSError) and getattr(exc, "errno", None) is None: + if _is_builtin_oserror(exc) and getattr(exc, "errno", None) is None: # Preserve errno (ENOSPC / EDQUOT ...) across the spawn boundary so a caller's `except OSError` - # cleanup can still branch on exc.errno -- for EVERY OSError subclass (PermissionError, - # FileNotFoundError, hf_hub's LocalEntryNotFoundError ...), not just exact OSError. Set it as an - # attribute rather than via the (errno, strerror) constructor: a subclass with a single-arg - # __init__ (LocalEntryNotFoundError) rejects the two-arg form, and this keeps the message clean - # (no doubled "[Errno N]" prefix). + # cleanup can still branch on exc.errno -- for EVERY builtin OSError subclass (PermissionError, + # FileNotFoundError, ...), not just exact OSError. Restricted to BUILTIN OSError types: an HF + # HTTP error (HfHubHTTPError / RepositoryNotFoundError ...) is ALSO an OSError subclass (via + # requests -> IOError), and a bracketed "[Errno N]" in its message must not be mistaken for a + # real OS errno. Set it as an attribute rather than via the (errno, strerror) constructor: a + # subclass with a single-arg __init__ (hf_hub's LocalEntryNotFoundError) rejects the two-arg + # form, and this keeps the message clean (no doubled "[Errno N]" prefix). errno_val = _parse_errno(message) if errno_val is not None: try: @@ -1132,6 +1146,16 @@ def _patterns_are_exact_names(patterns: Any) -> bool: return all(isinstance(p, str) and not _has_glob(p) for p in patterns) +def _request_selects_canonical_root_shards(allow_patterns: Any, ignore_patterns: Any) -> bool: + """Whether the request's allow / ignore filter keeps a canonical ROOT shard name. When False, an + incomplete canonical root shard set is OUT of the request's scope -- a co-resident leftover from a + prior interrupted base pull that a patterned load (adapter / gguf / subfolder) never reads -- so the + canonical-shard-completeness gate must NOT reject on it, else a genuinely complete patterned download + is failed into a DownloadStallError.""" + probes = ["model-00001-of-00002.safetensors", "pytorch_model-00001-of-00002.bin"] + return bool(_filter_paths(probes, allow_patterns, ignore_patterns)) + + def _cache_can_skip_download( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str] = None, @@ -1197,36 +1221,44 @@ def _download_result_usable( if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): if allow_patterns is None and variant: # Variant root load: a partial that kept only the canonical weight would leave the load to - # fetch the requested variant over un-killable Xet -> require a variant-named root weight. + # fetch the requested variant over un-killable Xet -> require a variant-named root weight, + # and reject an incomplete variant shard set (index present, a listed variant shard missing). if not _root_has_variant_weight(snapshot_dir, variant): return False + if _has_incomplete_variant_root_shards(snapshot_dir, variant): + return False elif allow_patterns is None: # Default root load: a root (or diffusers-component) weight the load reads (ignore filter - # applied), sharded set complete. + # applied), with the canonical shard set complete for the format the load READS (ignore + # filter applied, so a complete safetensors set does not mask an incomplete ``.bin`` the load + # reads under ignore=['*.safetensors']). if not _root_model_has_weight(snapshot_dir, ignore_patterns = ignore_patterns): return False - if _has_incomplete_canonical_root_shards(snapshot_dir): + if _has_incomplete_canonical_root_shards(snapshot_dir, ignore_patterns = ignore_patterns): return False elif variant: - # Patterned variant load (e.g. subfolder= + variant=): a selected weight is not enough -- - # a partial that kept only the canonical weight in scope would leave the load to fetch the - # requested variant over un-killable Xet. Require a selected weight carrying the variant. + # Patterned variant load (e.g. subfolder= + variant=): require a SELECTED weight carrying the + # variant -- a partial that kept only the canonical weight in scope would leave the load to + # fetch the requested variant over un-killable Xet. The canonical-root-shard gate does NOT + # apply here: the load reads variant weights, and any co-resident canonical root shard set is + # out of scope (checking it would false-reject a complete variant download). if not _has_selected_variant_weight( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, variant = variant, ): return False - if _has_incomplete_canonical_root_shards(snapshot_dir): - return False else: - # Patterned weight request: a selected weight must be present AND a selected canonical shard - # set must be complete (a lone ``model-00001-of-0000N`` without its index / remaining shards - # is a partial the in-process load would finish over Xet). + # Patterned weight request: a selected weight must be present AND -- only when the request + # SELECTS canonical root shards (a globbed weight request, not an adapter / gguf / subfolder + # request whose co-resident canonical shards it never reads) -- that shard set must be + # complete (a lone ``model-00001-of-0000N`` without its index / remaining shards is a partial + # the in-process load would finish over Xet). if not _has_selected_weight( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns ): return False - if _has_incomplete_canonical_root_shards(snapshot_dir): + if _request_selects_canonical_root_shards(allow_patterns, ignore_patterns) \ + and _has_incomplete_canonical_root_shards(snapshot_dir): return False return True From e24641dec2f5220951001bee2cbf3e279d777fec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 06:39:06 +0000 Subject: [PATCH 58/82] Tighten the variant-shard and exact-subset post-download checks Two follow-ups to the post-download completeness gate: - A variant load accepted an incomplete SHARDED variant even when its index was absent: _has_incomplete_variant_root_shards only fired when a variant shard index was present with a listed shard missing, so a lone model.-00001-of-000NN with no index (and no remaining shards) passed. It now fires whenever a variant weight SHARD is present without a COMPLETE variant shard index, so a partial that never fetched the index is retried rather than handed to the in-process load, which would fetch the missing variant shards over un-killable Xet. Still positive-evidence only: a single-file variant or a complete variant shard set is never rejected. - An EXACT-named shard request (allow_patterns=["model-00001-of-00002.safetensors"]) was force-rejected once its sibling shard / index was absent, even though the caller asked for precisely that one file and it was present, so the wrapper could retry and finally raise DownloadStallError on a satisfied request. The canonical-shard-completeness gate now applies only to GLOBBED weight warms where a later load expects the whole sharded checkpoint, not to an exact-named subset (whose presence is already verified by the exact-files check). --- tests/test_hf_xet_fallback.py | 31 +++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 30 ++++++++++++++++++------------ unsloth_zoo/hf_xet_fallback.py | 14 ++++++++------ 3 files changed, 57 insertions(+), 18 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 649ec6d0f..cddd9f1fa 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2312,8 +2312,14 @@ def test_post_download_rejects_canonical_only_for_variant(tmp_path): assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is True + # A COMPLETE sharded variant set (shards + index) is accepted; an incomplete one is retried + # (covered by test_post_download_rejects_incomplete_variant_shards). snap2, blob2 = _mk_snapshot(tmp_path, "varshard") (snap2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model.fp16-00002-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) assert xf._download_result_usable( snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is True @@ -2441,6 +2447,12 @@ def test_post_download_rejects_incomplete_variant_shards(tmp_path): assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is False + # A lone variant shard with NO index (a partial that never fetched the index) is also incomplete. + snap_noidx, blob_ni = _mk_snapshot(tmp_path, "variant_no_index") + (snap_noidx / "model.fp16-00001-of-00002.safetensors").symlink_to(blob_ni) + assert xf._download_result_usable( + snap_noidx, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False # The missing variant shard present -> complete set -> accepted (no false-reject). (snap / "model.fp16-00002-of-00002.safetensors").symlink_to(blob) assert xf._download_result_usable( @@ -2454,6 +2466,25 @@ def test_post_download_rejects_incomplete_variant_shards(tmp_path): variant = "fp16") is True +def test_post_download_accepts_exact_named_shard_subset(tmp_path): + """A caller naming an EXACT shard file (allow=['model-00001-of-00002.safetensors']) asked for + precisely that file; once it is present the result is accepted, even though its sibling shard / index + is absent -- the whole-checkpoint completeness gate applies only to GLOBBED weight warms, not an + exact-named subset (else a satisfied request is failed into a DownloadStallError) (#829 re-review). + A named shard that is ABSENT is still rejected by the exact-files check.""" + snap, blob = _mk_snapshot(tmp_path, "exact_shard_present") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["model-00001-of-00002.safetensors"], ignore_patterns = None) is True + # The exact-named shard absent -> rejected (nothing to load). + snap2, _ = _mk_snapshot(tmp_path, "exact_shard_absent") + (snap2 / "config.json").write_text("{}") + assert xf._download_result_usable( + snap2, repo_type = "model", + allow_patterns = ["model-00001-of-00002.safetensors"], ignore_patterns = None) is False + + def test_post_download_accepts_dataset_without_weight(tmp_path): snap, blob = _mk_snapshot(tmp_path, "ds") (snap / "data.parquet").symlink_to(blob) diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 778967c3c..5166789b0 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -567,25 +567,31 @@ def _has_incomplete_canonical_root_shards( def _has_incomplete_variant_root_shards(snapshot_dir: Path, variant: str) -> bool: - """True when the root holds a VARIANT weight shard index whose set is incomplete (a listed shard - missing). Positive-evidence ONLY: a single-file variant (no index) or a complete variant shard set - returns False, so a complete or single-file variant download is never rejected. transformers writes - the variant index with the variant token before ``.json`` (``model.safetensors.index..json`` - / ``pytorch_model.bin.index..json``), so it carries both the shard-index marker and the - ``..`` infix.""" - infix = f".{variant}." + """True when the root holds a VARIANT weight SHARD (a ``.-NNNNN-of-...`` file) that is NOT + backed by a COMPLETE variant shard index -- the index is missing, or one of its listed shards is + absent. Positive-evidence ONLY: a single-file variant (no shard files) or a complete variant shard + set returns False, so a complete or single-file variant download is never rejected. transformers + writes a sharded variant weight with a ``.-`` infix and its index as + ``model.safetensors.index..json`` (a ``..`` infix before ``.json``).""" + dot_infix = f".{variant}." # the variant shard index: model.safetensors.index..json + dash_infix = f".{variant}-" # a sharded variant weight: model.-00001-of-00002.safetensors try: entries = list(snapshot_dir.iterdir()) except OSError: return False + has_variant_shard = False + has_complete_variant_index = False for entry in entries: name = entry.name - # _is_weight_shard_index matches canonical AND variant indices; the infix restricts to the + # _is_weight_shard_index matches canonical AND variant indices; the dot infix restricts to the # requested variant (the canonical index has no ".." token). - if infix in name and _is_weight_shard_index(name): - if _safe_is_file(entry) and not _weight_shard_index_complete(entry): - return True - return False + if dot_infix in name and _is_weight_shard_index(name): + if _safe_is_file(entry) and _weight_shard_index_complete(entry): + has_complete_variant_index = True + elif dash_infix in name and _is_loadable_weight_file(name): + if _safe_is_file(entry): + has_variant_shard = True + return has_variant_shard and not has_complete_variant_index def requested_named_files_present( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index d1e949973..eb210093d 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1248,16 +1248,18 @@ def _download_result_usable( ): return False else: - # Patterned weight request: a selected weight must be present AND -- only when the request - # SELECTS canonical root shards (a globbed weight request, not an adapter / gguf / subfolder - # request whose co-resident canonical shards it never reads) -- that shard set must be - # complete (a lone ``model-00001-of-0000N`` without its index / remaining shards is a partial - # the in-process load would finish over Xet). + # Patterned weight request: a selected weight must be present AND -- only for a GLOBBED + # request that SELECTS canonical root shards (so a later load expects the whole sharded + # checkpoint), not an adapter / gguf / subfolder request whose co-resident canonical shards + # it never reads, and not an EXACT-named request that asked for precisely those files -- the + # canonical shard set must be complete (a lone ``model-00001-of-0000N`` without its index / + # remaining shards is a partial the in-process load would finish over Xet). if not _has_selected_weight( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns ): return False - if _request_selects_canonical_root_shards(allow_patterns, ignore_patterns) \ + if not _patterns_are_exact_names(allow_patterns) \ + and _request_selects_canonical_root_shards(allow_patterns, ignore_patterns) \ and _has_incomplete_canonical_root_shards(snapshot_dir): return False return True From 3cf8a063516c0a37cecdea8bc941c8513c04b269 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 07:38:15 +0000 Subject: [PATCH 59/82] Defer an incomplete preferred index, and complete patterned-variant shards Three follow-ups to the cache-completeness checks: - A cache with an incomplete safetensors index (a listed shard missing) plus a complete pytorch_model.bin was read as complete, because a complete single weight was accepted after the broken safetensors index failed its own check. But a default transformers load probes the safetensors index BEFORE the bin, so it would fetch the missing safetensors shards over un-killable Xet. Treat a present-but-incomplete preferred (safetensors) index as breakage and defer to the watched child, unless safetensors is explicitly ignored (then the load reads the bin). - The patterned variant branch of the post-download check required only one selected variant weight and skipped the variant shard-completeness check, so a globbed variant warm (allow=['*.safetensors'] with variant=) accepted a lone root variant shard with no index / missing shards. Apply the same _has_incomplete_variant_root_shards check there; a root-only check never false-rejects a complete or subfolder-scoped variant download. - A weight shard index mapping a tensor to a non-string value (for example null) was read as complete once the remaining string-mapped shards were present, by silently filtering the bad entry. transformers cannot load such an index, so fail closed and defer to the child instead. --- tests/test_hf_xet_fallback.py | 51 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 29 +++++++++++++------ unsloth_zoo/hf_xet_fallback.py | 15 ++++++---- 3 files changed, 81 insertions(+), 14 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index cddd9f1fa..0911b2f1e 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2141,6 +2141,37 @@ def test_gate_fast_paths_canonical_sharded_with_index(tmp_path): assert hcs.snapshot_dir_is_complete(snap2) is False +def test_shard_index_with_non_string_value_is_incomplete(tmp_path): + """A malformed shard index mapping a tensor to a non-string value (e.g. null) is NOT complete even + when the remaining string-mapped shard is present -- transformers cannot load it, so fail closed and + defer to the watched child rather than silently dropping the bad entry (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "badindex") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", "b": None}})) + assert hcs._weight_shard_index_complete(snap / "model.safetensors.index.json") is False + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_gate_defers_incomplete_preferred_index_masked_by_complete_bin(tmp_path): + """A present-but-incomplete safetensors index must not be masked by a complete pytorch_model.bin: + transformers probes the safetensors index BEFORE the bin, so the load would fetch the missing + safetensors shards over un-killable Xet. The gate defers unless safetensors is explicitly ignored + (then the load reads the bin) (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "prefidx") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) # ST shard 2 absent -> incomplete index + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + (snap / "pytorch_model.bin").symlink_to(blob) # complete bin co-resident + assert hcs.snapshot_dir_is_complete(snap) is False # load prefers the incomplete safetensors + # safetensors explicitly ignored -> the load reads the complete bin -> eligible. + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.safetensors"]) is True + # A COMPLETE safetensors index alongside the bin is eligible. + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + + def test_gate_rejects_sharded_adapter_only_root_cache(tmp_path): """A complete sharded ADAPTER at the root (adapter_model.safetensors.index.json + its shards) is NOT a canonical base model: only model/pytorch_model index names gate the fast path. A base+adapter @@ -2485,6 +2516,26 @@ def test_post_download_accepts_exact_named_shard_subset(tmp_path): allow_patterns = ["model-00001-of-00002.safetensors"], ignore_patterns = None) is False +def test_post_download_rejects_patterned_incomplete_variant_shards(tmp_path): + """A GLOBBED variant request (allow=['*.safetensors'] + variant='fp16') whose partial kept only a + lone root variant shard without its index / remaining shards must be rejected too -- the + variant-shard completeness check applies to the patterned variant branch, not only allow=None (Codex + #829). A complete root variant shard set in scope is accepted (no false-reject).""" + snap, blob = _mk_snapshot(tmp_path, "pat_var_incomplete") + (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) # lone shard, no index + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None, + variant = "fp16") is False + # Complete variant shard set -> accepted. + (snap / "model.fp16-00002-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None, + variant = "fp16") is True + + def test_post_download_accepts_dataset_without_weight(tmp_path): snap, blob = _mk_snapshot(tmp_path, "ds") (snap / "data.parquet").symlink_to(blob) diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 5166789b0..82127965d 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -259,10 +259,13 @@ def _weight_shard_index_complete(index_path: Path) -> bool: weight_map = data.get("weight_map") if isinstance(data, dict) else None if not isinstance(weight_map, dict): return False - # A non-string value is unhashable (breaks set()) and invalid for ``base / shard``; filter first. - shards = {s for s in weight_map.values() if isinstance(s, str)} - if not shards: + values = list(weight_map.values()) + # A non-string shard value is a malformed index transformers cannot load; fail CLOSED (defer to the + # watched child) rather than silently dropping the bad entry and reading the remaining shards as a + # complete set. + if not values or not all(isinstance(s, str) for s in values): return False + shards = set(values) base = index_path.parent for shard in shards: # A well-formed HF index lists a relative shard basename. Reject an absolute / parent-escaping @@ -496,14 +499,22 @@ def _format_kept(weight_name: str) -> bool: return True return bool(_filter_paths([weight_name], None, ignore_patterns)) + incomplete_preferred_index = False for index_entry in root_indices: - fmt_probe = ( - "model.safetensors" - if ".safetensors.index." in index_entry.name - else "pytorch_model.bin" - ) - if _format_kept(fmt_probe) and _weight_shard_index_complete(index_entry): + is_safetensors = ".safetensors.index." in index_entry.name + fmt_probe = "model.safetensors" if is_safetensors else "pytorch_model.bin" + if not _format_kept(fmt_probe): + continue # this format is ignored -> the load will not read it + if _weight_shard_index_complete(index_entry): return True + if is_safetensors: + # transformers probes the safetensors index BEFORE the .bin, so a present-but-incomplete + # safetensors index means the load prefers (and fetches) safetensors -- a complete .bin must + # NOT mask it. Treat it as breakage (defer to the watched child) unless safetensors is + # explicitly ignored (handled by _format_kept above). + incomplete_preferred_index = True + if incomplete_preferred_index: + return False return any( name in root_files and _format_kept(name) for name in _CANONICAL_SINGLE_WEIGHTS ) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index eb210093d..e49b53362 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1237,16 +1237,21 @@ def _download_result_usable( if _has_incomplete_canonical_root_shards(snapshot_dir, ignore_patterns = ignore_patterns): return False elif variant: - # Patterned variant load (e.g. subfolder= + variant=): require a SELECTED weight carrying the - # variant -- a partial that kept only the canonical weight in scope would leave the load to - # fetch the requested variant over un-killable Xet. The canonical-root-shard gate does NOT - # apply here: the load reads variant weights, and any co-resident canonical root shard set is - # out of scope (checking it would false-reject a complete variant download). + # Patterned variant load (e.g. allow=['*.safetensors'] or subfolder= with variant=): require + # a SELECTED weight carrying the variant -- a partial that kept only the canonical weight in + # scope would leave the load to fetch the requested variant over un-killable Xet. Also reject + # an incomplete ROOT variant shard set (a lone shard with no index / a missing shard), same as + # the unpatterned variant branch. The CANONICAL-root-shard gate does not apply (the load reads + # variant weights; a co-resident canonical shard set is out of scope). A variant shard set in + # a subfolder is not root-checked, but a root-only check never false-rejects a complete + # download. if not _has_selected_variant_weight( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, variant = variant, ): return False + if _has_incomplete_variant_root_shards(snapshot_dir, variant): + return False else: # Patterned weight request: a selected weight must be present AND -- only for a GLOBBED # request that SELECTS canonical root shards (so a later load expects the whole sharded From bf243870d777f36db09fc43d13577ca67978c675 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 08:09:31 +0000 Subject: [PATCH 60/82] Apply ignore filters and request scope consistently in the acceptance check Five consistency fixes to the post-download acceptance check, each closing a false-accept (silent Xet hang) or a false-reject (DownloadStallError on a working download): - Diffusers component weights are now filtered through the request's ignore list, so an unpatterned diffusers warm that ignores a format is not satisfied by a component weight in that ignored format (a unet/*.bin under ignore=['*.bin']). - The canonical-shard completeness gate now treats an index-only partial (a canonical shard index present with none of its shards yet) as incomplete, even when a complete pytorch_model.bin is co-resident: transformers probes the safetensors index before the bin, so the load would fetch the missing shards. - The patterned-request canonical-shard check now applies the ignore filter, so allow=['*'] + ignore=['*.safetensors'] (which selects the complete .bin) is not rejected because a stale incomplete safetensors shard set is co-resident. - The root variant-shard check is now scoped to requests that select a root variant weight: a subfolder variant request (allow=['unet/*']) with a complete in-scope weight is no longer rejected by an unrelated stale root variant shard. - The unpatterned root variant weight probe now applies the ignore list, so a partial holding only model.fp16.bin under ignore=['*.bin'] does not satisfy a safetensors variant request. --- tests/test_hf_xet_fallback.py | 83 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 9 +++- unsloth_zoo/hf_xet_fallback.py | 76 +++++++++++++++++++++---------- 3 files changed, 144 insertions(+), 24 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 0911b2f1e..1b730aa1e 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2536,6 +2536,89 @@ def test_post_download_rejects_patterned_incomplete_variant_shards(tmp_path): variant = "fp16") is True +def test_post_download_applies_ignore_to_diffusers_components(tmp_path): + """An unpatterned diffusers warm that ignores a format must not be satisfied by a component weight in + that ignored format: only unet/*.bin present under ignore=['*.bin'] (safetensors requested) is + rejected, else the load fetches the missing safetensors components over Xet (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "diff_ignore") + (snap / "model_index.json").write_text("{}") + (snap / "unet").mkdir() + (snap / "unet" / "diffusion_pytorch_model.bin").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is False + # The safetensors component present -> usable. + (snap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is True + + +def test_post_download_rejects_index_only_sharded_masked_by_bin(tmp_path): + """A safetensors index present with NONE of its shards (an index-only partial), co-resident with a + complete pytorch_model.bin, must be rejected: transformers probes the safetensors index before the + bin, so the load would fetch the missing shards over Xet (Codex #829). The shard-completeness gate + fires on a present index even before any shard file exists.""" + snap, blob = _mk_snapshot(tmp_path, "index_only") + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + (snap / "pytorch_model.bin").symlink_to(blob) # complete bin, no ST shards at all + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # safetensors explicitly ignored -> load reads the complete bin -> usable. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.safetensors"]) is True + + +def test_post_download_patterned_shard_check_honors_ignore(tmp_path): + """A patterned request that ignores safetensors (allow=['*'], ignore=['*.safetensors']) selects the + complete .bin; a co-resident incomplete safetensors shard set must NOT force-reject it -- the + patterned shard-completeness check applies the ignore filter, so a satisfied request is not failed + into a DownloadStallError (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "pat_ignore") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) # incomplete ST (shard 2 absent) + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + (snap / "pytorch_model.bin").symlink_to(blob) # complete bin, the selected format + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*"], ignore_patterns = ["*.safetensors"]) is True + + +def test_post_download_variant_root_shard_check_scoped_to_selection(tmp_path): + """A subfolder variant request (allow=['unet/*'] + variant) whose selected weight is complete must be + accepted even when a stale ROOT variant shard (out of scope) is co-resident -- the root variant-shard + check applies only when the request selects a root variant weight (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "var_scope") + (snap / "unet").mkdir() + (snap / "unet" / "model.fp16.safetensors").symlink_to(blob) # complete in-scope variant + (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) # stale ROOT variant shard, oos + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None, + variant = "fp16") is True + # A GLOBBED root variant request DOES get the root variant-shard check. + snap2, blob2 = _mk_snapshot(tmp_path, "var_scope_glob") + (snap2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) # lone root variant shard + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_root_variant_weight_honors_ignore(tmp_path): + """An unpatterned variant load that ignores .bin must not be satisfied by a variant .bin: only + model.fp16.bin present under ignore=['*.bin'] is rejected, else the load fetches + model.fp16.safetensors over Xet (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "var_ignore") + (snap / "model.fp16.bin").symlink_to(blob) # only the ignored-format variant + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"], + variant = "fp16") is False + # The safetensors variant present -> usable. + (snap / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"], + variant = "fp16") is True + + def test_post_download_accepts_dataset_without_weight(tmp_path): snap, blob = _mk_snapshot(tmp_path, "ds") (snap / "data.parquet").symlink_to(blob) diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 82127965d..69bc93622 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -572,7 +572,14 @@ def _has_incomplete_canonical_root_shards( names = [entry.name for entry in snapshot_dir.iterdir()] except OSError: return False - if not any(_CANONICAL_ROOT_SHARD_RE.match(name) for name in names): + # Canonical shard evidence = a numbered shard FILE, or a canonical shard INDEX. An index-only + # partial (index present, no shards yet) is still an incomplete sharded checkpoint the load would + # finish over Xet, so it must be caught here even before any shard file exists. + has_shard_evidence = ( + any(_CANONICAL_ROOT_SHARD_RE.match(name) for name in names) + or any(_is_canonical_weight_shard_index(name) for name in names) + ) + if not has_shard_evidence: return False return not snapshot_dir_is_complete(snapshot_dir, ignore_patterns = ignore_patterns) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index e49b53362..4b24d102a 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -976,20 +976,25 @@ def _intact_subset( ) -def _has_any_weight(snapshot_dir: Path) -> bool: - """True if the snapshot holds at least one loadable weight anywhere (root or subfolder). Lenient: - it only tells a real model warm from a config-only stale snapshot, without classifying layout.""" +def _has_any_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: + """True if the snapshot holds at least one loadable weight anywhere (root or subfolder) that the + request's ignore filter keeps. Lenient: it only tells a real model warm from a config-only stale + snapshot, without classifying layout. The ignore filter matters for diffusers, whose component + weights live in subfolders -- a partial holding only the ignored format (``unet/*.bin`` under + ``ignore=['*.bin']``) is not a usable weight for a safetensors load.""" + rels: list = [] try: for entry in snapshot_dir.rglob("*"): - if _is_loadable_weight_file(entry.name): - try: - if entry.is_file(): - return True - except OSError: - continue + if not _is_loadable_weight_file(entry.name): + continue + try: + if entry.is_file(): + rels.append(entry.relative_to(snapshot_dir).as_posix()) + except (OSError, ValueError): + continue except OSError: return False - return False + return bool(_filter_paths(rels, None, ignore_patterns)) def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: @@ -1005,7 +1010,7 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - except OSError: is_diffusers = False if is_diffusers: - return _has_any_weight(snapshot_dir) + return _has_any_weight(snapshot_dir, ignore_patterns = ignore_patterns) rels: list = [] try: for entry in snapshot_dir.iterdir(): @@ -1021,13 +1026,18 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - return bool(_filter_paths(rels, None, ignore_patterns)) -def _root_has_variant_weight(snapshot_dir: Path, variant: str) -> bool: - """True if a ROOT weight carrying the requested *variant* token is present. transformers inserts the - variant before the extension (a ``..`` infix: ``model.fp16.safetensors``) or before a shard - suffix (a ``.-`` infix: ``model.fp16-00001-of-00002.safetensors``), so an offline-fallback - partial that kept only the canonical weight does not satisfy a variant request.""" +def _root_has_variant_weight( + snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None +) -> bool: + """True if a ROOT weight carrying the requested *variant* token, and kept by the ignore filter, is + present. transformers inserts the variant before the extension (a ``..`` infix: + ``model.fp16.safetensors``) or before a shard suffix (a ``.-`` infix: + ``model.fp16-00001-of-00002.safetensors``), so an offline-fallback partial that kept only the + canonical weight does not satisfy a variant request. The ignore filter is applied so a partial + holding only the ignored format (``model.fp16.bin`` under ``ignore=['*.bin']``) does not count.""" infix_dot = f".{variant}." infix_dash = f".{variant}-" + rels: list = [] try: for entry in snapshot_dir.iterdir(): name = entry.name @@ -1037,12 +1047,12 @@ def _root_has_variant_weight(snapshot_dir: Path, variant: str) -> bool: continue try: if entry.is_file(): - return True + rels.append(name) except OSError: continue except OSError: return False - return False + return bool(_filter_paths(rels, None, ignore_patterns)) # Interchangeable exact weight names: the either-format ``["pytorch_model.bin", "model.safetensors"]`` @@ -1156,6 +1166,20 @@ def _request_selects_canonical_root_shards(allow_patterns: Any, ignore_patterns: return bool(_filter_paths(probes, allow_patterns, ignore_patterns)) +def _request_selects_root_variant_weight( + allow_patterns: Any, ignore_patterns: Any, variant: str, +) -> bool: + """Whether the request's allow / ignore filter keeps a ROOT variant weight name. When False, a stale + incomplete root variant shard set is OUT of the request's scope (e.g. a subfolder request + ``allow=['unet/*']`` whose variant weights live under ``unet/``), so the ROOT variant-shard gate must + not reject on it, else a complete in-scope variant download is failed.""" + probes = [ + f"model.{variant}.safetensors", f"model.{variant}-00001-of-00002.safetensors", + f"pytorch_model.{variant}.bin", f"pytorch_model.{variant}-00001-of-00002.bin", + ] + return bool(_filter_paths(probes, allow_patterns, ignore_patterns)) + + def _cache_can_skip_download( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str] = None, @@ -1223,7 +1247,9 @@ def _download_result_usable( # Variant root load: a partial that kept only the canonical weight would leave the load to # fetch the requested variant over un-killable Xet -> require a variant-named root weight, # and reject an incomplete variant shard set (index present, a listed variant shard missing). - if not _root_has_variant_weight(snapshot_dir, variant): + if not _root_has_variant_weight( + snapshot_dir, variant, ignore_patterns = ignore_patterns + ): return False if _has_incomplete_variant_root_shards(snapshot_dir, variant): return False @@ -1243,14 +1269,17 @@ def _download_result_usable( # an incomplete ROOT variant shard set (a lone shard with no index / a missing shard), same as # the unpatterned variant branch. The CANONICAL-root-shard gate does not apply (the load reads # variant weights; a co-resident canonical shard set is out of scope). A variant shard set in - # a subfolder is not root-checked, but a root-only check never false-rejects a complete - # download. + # a subfolder is not root-checked. The ROOT variant-shard check applies only when the request + # SELECTS a root variant weight (a globbed allow like ['*.safetensors']); for a subfolder + # request (allow=['unet/*']) a stale root variant shard is out of scope and must not + # false-reject a complete in-scope download. if not _has_selected_variant_weight( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, variant = variant, ): return False - if _has_incomplete_variant_root_shards(snapshot_dir, variant): + if _request_selects_root_variant_weight(allow_patterns, ignore_patterns, variant) \ + and _has_incomplete_variant_root_shards(snapshot_dir, variant): return False else: # Patterned weight request: a selected weight must be present AND -- only for a GLOBBED @@ -1265,7 +1294,8 @@ def _download_result_usable( return False if not _patterns_are_exact_names(allow_patterns) \ and _request_selects_canonical_root_shards(allow_patterns, ignore_patterns) \ - and _has_incomplete_canonical_root_shards(snapshot_dir): + and _has_incomplete_canonical_root_shards( + snapshot_dir, ignore_patterns = ignore_patterns): return False return True From 428b5c5dd9e68d1eb15bd3af12d25487ea5f1d9d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jul 2026 08:17:45 +0000 Subject: [PATCH 61/82] Funnel the weight acceptance check through one helper with two invariants The post-download weight-bearing model check had grown to four parallel branches (unpatterned/patterned x variant/plain), each independently responsible for applying the request's ignore filter and matching the request's scope. Recent rounds fixed several branches that had forgotten one or the other (an ignored format counted as a usable weight; a co-resident out-of-scope shard set false-rejected a complete download). Refactor the four branches into one entry point, _selected_readable_weight_complete, built from two orthogonal, uniformly-applied invariants: - _has_readable_weight (presence): the weight the load will READ is on disk, with the ignore filter always applied and the scope matched to the request (root variant / selected variant / root-or-diffusers / selected). - _readable_shard_set_incomplete (completeness): an in-scope shard set the load reads is incomplete, always scoped to what the request selects and ignore-aware, so a co-resident out-of-scope or ignored-format partial cannot false-reject. _download_result_usable's weight check is now a single delegation. Behavior is unchanged: the full test suite, the layout safety fuzz (0 invariant violations across 40k layouts), and the end-to-end recovery sim all pass identically. The two invariants now live in one place each, so a future gap is closed once rather than per branch. --- tests/test_hf_xet_fallback.py | 24 ++++++ unsloth_zoo/hf_xet_fallback.py | 153 ++++++++++++++++++++------------- 2 files changed, 118 insertions(+), 59 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 1b730aa1e..9f3cf1097 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2619,6 +2619,30 @@ def test_post_download_root_variant_weight_honors_ignore(tmp_path): variant = "fp16") is True +def test_selected_readable_weight_complete_entry_point(tmp_path): + """The weight-bearing acceptance check funnels through one helper enforcing two invariants: + (A) a readable weight is present (ignore + scope applied), (B) its in-scope shard set is complete. + Directly exercise the entry point for a present+complete, an absent, and an incomplete-shard case.""" + # Present + complete single weight -> True. + snap, blob = _mk_snapshot(tmp_path, "srwc_ok") + (snap / "model.safetensors").symlink_to(blob) + assert xf._selected_readable_weight_complete( + snap, allow_patterns = None, ignore_patterns = None, variant = None) is True + # Invariant A fails: only an ignored-format weight present -> False. + snap2, blob2 = _mk_snapshot(tmp_path, "srwc_ignored") + (snap2 / "pytorch_model.bin").symlink_to(blob2) + assert xf._selected_readable_weight_complete( + snap2, allow_patterns = None, ignore_patterns = ["*.bin"], variant = None) is False + # Invariant B fails: readable weight present but its shard set incomplete -> False. + snap3, blob3 = _mk_snapshot(tmp_path, "srwc_incomplete") + (snap3 / "model-00001-of-00002.safetensors").symlink_to(blob3) + (snap3 / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._selected_readable_weight_complete( + snap3, allow_patterns = None, ignore_patterns = None, variant = None) is False + + def test_post_download_accepts_dataset_without_weight(tmp_path): snap, blob = _mk_snapshot(tmp_path, "ds") (snap / "data.parquet").symlink_to(blob) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 4b24d102a..7f5de0730 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1213,6 +1213,90 @@ def _cache_can_skip_download( ) +def _has_readable_weight( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str], +) -> bool: + """Invariant A (presence): a weight the in-process load will READ is present on disk, with the + request's ignore filter ALWAYS applied and the scope matched to the request: + + - variant + UNPATTERNED -> a ROOT variant weight (``model..*``); + - variant + PATTERNED -> a SELECTED variant weight (within the allow scope); + - plain + UNPATTERNED -> a ROOT (or diffusers-component) weight, NOT a stray subfolder checkpoint; + - plain + PATTERNED -> a SELECTED weight (within the allow scope). + + A partial that kept only the ignored format (an ``*.bin`` under ``ignore=['*.bin']``) does not count, + so the incomplete result is retried over HTTP rather than loaded in-process.""" + if variant: + if allow_patterns is None: + return _root_has_variant_weight(snapshot_dir, variant, ignore_patterns = ignore_patterns) + return _has_selected_variant_weight( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ) + if allow_patterns is None: + return _root_model_has_weight(snapshot_dir, ignore_patterns = ignore_patterns) + return _has_selected_weight( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) + + +def _readable_shard_set_incomplete( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str], +) -> bool: + """Invariant B (shard completeness): an IN-SCOPE shard set the load reads is incomplete (an index + present with a shard missing, or a lone numbered shard without its index) and must be retried. The + check is ALWAYS scoped to what the request selects, so a co-resident stale shard set the load never + reads (a leftover root checkpoint under a subfolder/adapter/gguf request) does not false-reject a + complete download: + + - variant: the ROOT variant-shard check applies for an UNPATTERNED request, or a PATTERNED request + that selects a ROOT variant weight (a globbed ``['*.safetensors']``); a subfolder-scoped variant + request does not root-check. + - plain: the canonical-root-shard check applies for an UNPATTERNED request, or a GLOBBED request that + selects canonical root shards; an exact-named subset or an out-of-scope request does not. + + The ignore filter is threaded through so completeness is judged for the FORMAT the load reads (a + complete safetensors set does not mask an incomplete ``.bin`` under ``ignore=['*.safetensors']``).""" + if variant: + if allow_patterns is None or _request_selects_root_variant_weight( + allow_patterns, ignore_patterns, variant + ): + return _has_incomplete_variant_root_shards(snapshot_dir, variant) + return False + if allow_patterns is None: + return _has_incomplete_canonical_root_shards( + snapshot_dir, ignore_patterns = ignore_patterns + ) + if not _patterns_are_exact_names(allow_patterns) and _request_selects_canonical_root_shards( + allow_patterns, ignore_patterns + ): + return _has_incomplete_canonical_root_shards( + snapshot_dir, ignore_patterns = ignore_patterns + ) + return False + + +def _selected_readable_weight_complete( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str], +) -> bool: + """Single entry point for the weight-bearing MODEL acceptance check: the weight the in-process load + will READ is present (Invariant A) AND its in-scope shard set is complete (Invariant B). Both + invariants apply the request's ignore filter and match its scope uniformly, so a co-resident + out-of-scope / ignored-format partial neither masks an incomplete readable weight (a silent Xet hang) + nor false-rejects a complete download (a spurious ``DownloadStallError``).""" + if not _has_readable_weight( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return False + if _readable_shard_set_incomplete( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return False + return True + + def _download_result_usable( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str] = None, @@ -1228,11 +1312,11 @@ def _download_result_usable( - A dangling REQUESTED symlink (a missing / still-``.incomplete`` blob). - A missing EXACT-named requested file (grouped by weight equivalence: the either-format pair needs one; base AND adapter, or a ``["tokenizer.json"]`` request, each). Globs stay lenient. - - A weight-bearing MODEL request with no usable weight. A variant load needs a variant-named weight - (the canonical weight a partial kept does not satisfy it): a ROOT one when UNPATTERNED, else one - WITHIN scope. UNPATTERNED non-variant -> a ROOT-readable weight the load reads (ignore filter - applied; a stale ``checkpoint-7/``-only snapshot does not count) with a complete canonical shard - set. Patterned non-variant -> a weight WITHIN scope, shard set complete.""" + - A weight-bearing MODEL request whose READABLE weight is absent or incomplete. Delegated to + ``_selected_readable_weight_complete``, which applies the request's ignore filter and scope + uniformly: the weight the load reads (variant vs canonical, root vs in-scope) must be present, and + its in-scope shard set complete. A co-resident out-of-scope / ignored-format partial neither masks + an incomplete readable weight nor false-rejects a complete download.""" if snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, repo_type = repo_type, @@ -1243,60 +1327,11 @@ def _download_result_usable( ): return False if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): - if allow_patterns is None and variant: - # Variant root load: a partial that kept only the canonical weight would leave the load to - # fetch the requested variant over un-killable Xet -> require a variant-named root weight, - # and reject an incomplete variant shard set (index present, a listed variant shard missing). - if not _root_has_variant_weight( - snapshot_dir, variant, ignore_patterns = ignore_patterns - ): - return False - if _has_incomplete_variant_root_shards(snapshot_dir, variant): - return False - elif allow_patterns is None: - # Default root load: a root (or diffusers-component) weight the load reads (ignore filter - # applied), with the canonical shard set complete for the format the load READS (ignore - # filter applied, so a complete safetensors set does not mask an incomplete ``.bin`` the load - # reads under ignore=['*.safetensors']). - if not _root_model_has_weight(snapshot_dir, ignore_patterns = ignore_patterns): - return False - if _has_incomplete_canonical_root_shards(snapshot_dir, ignore_patterns = ignore_patterns): - return False - elif variant: - # Patterned variant load (e.g. allow=['*.safetensors'] or subfolder= with variant=): require - # a SELECTED weight carrying the variant -- a partial that kept only the canonical weight in - # scope would leave the load to fetch the requested variant over un-killable Xet. Also reject - # an incomplete ROOT variant shard set (a lone shard with no index / a missing shard), same as - # the unpatterned variant branch. The CANONICAL-root-shard gate does not apply (the load reads - # variant weights; a co-resident canonical shard set is out of scope). A variant shard set in - # a subfolder is not root-checked. The ROOT variant-shard check applies only when the request - # SELECTS a root variant weight (a globbed allow like ['*.safetensors']); for a subfolder - # request (allow=['unet/*']) a stale root variant shard is out of scope and must not - # false-reject a complete in-scope download. - if not _has_selected_variant_weight( - snapshot_dir, allow_patterns = allow_patterns, - ignore_patterns = ignore_patterns, variant = variant, - ): - return False - if _request_selects_root_variant_weight(allow_patterns, ignore_patterns, variant) \ - and _has_incomplete_variant_root_shards(snapshot_dir, variant): - return False - else: - # Patterned weight request: a selected weight must be present AND -- only for a GLOBBED - # request that SELECTS canonical root shards (so a later load expects the whole sharded - # checkpoint), not an adapter / gguf / subfolder request whose co-resident canonical shards - # it never reads, and not an EXACT-named request that asked for precisely those files -- the - # canonical shard set must be complete (a lone ``model-00001-of-0000N`` without its index / - # remaining shards is a partial the in-process load would finish over Xet). - if not _has_selected_weight( - snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns - ): - return False - if not _patterns_are_exact_names(allow_patterns) \ - and _request_selects_canonical_root_shards(allow_patterns, ignore_patterns) \ - and _has_incomplete_canonical_root_shards( - snapshot_dir, ignore_patterns = ignore_patterns): - return False + if not _selected_readable_weight_complete( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return False return True From 1b877ee6ac675f4dc203c7f2b0e2464f9bd9f2f6 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 1 Jul 2026 09:02:22 +0000 Subject: [PATCH 62/82] Complete variant and non-root shard-set checks in the post-download gate The post-download completeness check judged the ROOT variant weight and the canonical root shards, but three cases it reads were not fully covered, so a partial the child returned on a transient connection error could either be force-rejected (a good download looped into DownloadStallError) or accepted while incomplete (the in-process load then finishes the missing shard over un-killable Xet). - _has_incomplete_variant_root_shards now threads the request ignore filter, so a stale variant shard in an ignored format (a leftover model.fp16-00001-of-... .bin under ignore=['*.bin']) no longer force-rejects a complete variant safetensors the load actually reads. - It also fires on a present-but-incomplete variant safetensors INDEX even when no variant shard file exists yet (an index-only partial), and treats the variant safetensors as read before the variant bin: a stale variant safetensors index co-resident with a complete variant bin is breakage, because transformers probes the safetensors index first and would fetch the missing variant shards. This mirrors the canonical index-only / precedence handling. - A new _selected_shard_index_incomplete covers the shard indices the root-model checks do not: a sharded PEFT adapter (adapter_model.safetensors.index.json under allow=['adapter_model*']) and a component subfolder. It is scoped to the request (variant token, allow/ignore on the listed shards) and applies safetensors-before-bin precedence per directory, so a complete download is never false-rejected while an incomplete selected set is caught. Adds regression tests for the variant ignore-filtered set, the variant index-only partial masked by a complete bin, the incomplete sharded adapter, and the incomplete component subfolder, plus the complete-set acceptance for each. --- tests/test_hf_xet_fallback.py | 74 +++++++++++++ unsloth_zoo/hf_cache_state.py | 185 ++++++++++++++++++++++++++++++--- unsloth_zoo/hf_xet_fallback.py | 30 ++++-- 3 files changed, 266 insertions(+), 23 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 9f3cf1097..4313b94d8 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2619,6 +2619,80 @@ def test_post_download_root_variant_weight_honors_ignore(tmp_path): variant = "fp16") is True +def test_post_download_variant_shard_check_honors_ignore(tmp_path): + """A variant load that ignores .bin must judge the variant shard set for the READ format only: a + complete model.fp16.safetensors co-resident with a stale IGNORED model.fp16-00001-of-00002.bin shard + (no index) is accepted, not force-rejected -- the variant shard-completeness check applies the ignore + filter, so a satisfied variant download is not failed into a DownloadStallError (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "var_shard_ignore") + (snap / "model.fp16.safetensors").symlink_to(blob) # complete, the read format + (snap / "model.fp16-00001-of-00002.bin").symlink_to(blob) # stale ignored .bin shard, no index + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"], + variant = "fp16") is True + # Without the ignore the lone .bin variant shard IS an incomplete set (no index) -> rejected. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True # the complete safetensors variant is preferred and read + + +def test_post_download_rejects_variant_index_only_masked_by_bin(tmp_path): + """A VARIANT safetensors index present with NONE of its shards (an index-only partial), co-resident + with a complete variant pytorch_model.fp16.bin, must be rejected: transformers probes the variant + safetensors index before the variant bin, so the load would fetch the missing variant shards over Xet + (the variant analog of the canonical index-only case) (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "var_index_only") + (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + (snap / "pytorch_model.fp16.bin").symlink_to(blob) # complete bin, no ST variant shards at all + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # The variant safetensors explicitly ignored -> load reads the complete variant bin -> usable. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.safetensors"], + variant = "fp16") is True + + +def test_post_download_rejects_incomplete_sharded_adapter(tmp_path): + """A PEFT adapter load (allow=['adapter_config.json', 'adapter_model*']) whose partial kept a sharded + adapter INDEX but is missing a listed adapter shard must be rejected, else the in-process load + finishes the missing adapter shard over Xet. The canonical/variant ROOT-model checks do not cover a + non-model 'adapter_model' index, so the selected-index check catches it (Codex #829). A complete + adapter shard set in scope is accepted (no false-reject).""" + snap, blob = _mk_snapshot(tmp_path, "adapter_incomplete") + (snap / "adapter_config.json").write_text("{}") + (snap / "adapter_model-00001-of-00002.safetensors").symlink_to(blob) # shard 1; shard 2 absent + (snap / "adapter_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "adapter_model-00001-of-00002.safetensors", + "b": "adapter_model-00002-of-00002.safetensors"}})) + allow = ["adapter_config.json", "adapter_model*"] + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = allow, ignore_patterns = None) is False + # The missing adapter shard present -> complete set -> accepted (no false-reject). + (snap / "adapter_model-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = allow, ignore_patterns = None) is True + + +def test_post_download_rejects_incomplete_component_subfolder_shards(tmp_path): + """A subfolder-scoped request (allow=['unet/*']) whose selected component has a shard INDEX missing a + listed shard must be rejected -- the selected-index check covers component subfolders the root-model + checks do not (Codex #829). A complete component shard set in scope is accepted.""" + snap, blob = _mk_snapshot(tmp_path, "component_incomplete") + (snap / "unet").mkdir() + (snap / "unet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "unet" / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is False + (snap / "unet" / "diffusion_pytorch_model-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is True + + def test_selected_readable_weight_complete_entry_point(tmp_path): """The weight-bearing acceptance check funnels through one helper enforcing two invariants: (A) a readable weight is present (ignore + scope applied), (B) its in-scope shard set is complete. diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 69bc93622..acbbdd599 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -584,32 +584,185 @@ def _has_incomplete_canonical_root_shards( return not snapshot_dir_is_complete(snapshot_dir, ignore_patterns = ignore_patterns) -def _has_incomplete_variant_root_shards(snapshot_dir: Path, variant: str) -> bool: - """True when the root holds a VARIANT weight SHARD (a ``.-NNNNN-of-...`` file) that is NOT - backed by a COMPLETE variant shard index -- the index is missing, or one of its listed shards is - absent. Positive-evidence ONLY: a single-file variant (no shard files) or a complete variant shard - set returns False, so a complete or single-file variant download is never rejected. transformers - writes a sharded variant weight with a ``.-`` infix and its index as - ``model.safetensors.index..json`` (a ``..`` infix before ``.json``).""" - dot_infix = f".{variant}." # the variant shard index: model.safetensors.index..json - dash_infix = f".{variant}-" # a sharded variant weight: model.-00001-of-00002.safetensors +def _has_incomplete_variant_root_shards( + snapshot_dir: Path, variant: str, *, ignore_patterns: "Optional[object]" = None +) -> bool: + """True when the ROOT variant weight the load READS is an incomplete sharded set. transformers writes + a sharded variant weight with a ``.-`` shard infix and its index as + ``model.safetensors.index..json`` (a ``..`` infix before ``.json``); a single-file + variant is ``model..safetensors``. Incomplete means: a present variant shard INDEX whose + listed shards are not all present (an index-only partial with no shard files counts), OR variant shard + FILES with no complete index. + + The request's ignore filter is applied so a variant weight in an ignored format is not the read + format, and safetensors is treated as read BEFORE bin (transformers' probe order): a present-but- + incomplete variant safetensors index is breakage even with a complete variant bin. Positive-evidence: + a single-file variant or a complete variant shard set returns False, so a complete or single-file + variant download is never rejected.""" + dot_infix = f".{variant}." # variant index (model.safetensors.index..json) or single file + dash_infix = f".{variant}-" # a sharded variant weight (model.-00001-of-00002.safetensors) + ignore_patterns = _as_pattern_list(ignore_patterns) + + def _format_kept(weight_name: str) -> bool: + # The format a load reads from *weight_name* must survive the ignore filter, else the file is a + # stale artifact for an excluded format the load does not read. + if not ignore_patterns: + return True + return bool(_filter_paths([weight_name], None, ignore_patterns)) + try: entries = list(snapshot_dir.iterdir()) except OSError: return False - has_variant_shard = False - has_complete_variant_index = False + st_index_incomplete = None # tri-state: None absent, else present & (in)complete + bin_index_incomplete = None + has_st_shard = has_bin_shard = False + has_single_st = False for entry in entries: name = entry.name # _is_weight_shard_index matches canonical AND variant indices; the dot infix restricts to the # requested variant (the canonical index has no ".." token). if dot_infix in name and _is_weight_shard_index(name): - if _safe_is_file(entry) and _weight_shard_index_complete(entry): - has_complete_variant_index = True + is_safetensors = ".safetensors.index." in name + fmt_probe = ( + f"model.{variant}.safetensors" if is_safetensors else f"pytorch_model.{variant}.bin" + ) + if not _format_kept(fmt_probe): + continue # this format is ignored -> the load does not read it + incomplete = not (_safe_is_file(entry) and _weight_shard_index_complete(entry)) + if is_safetensors: + st_index_incomplete = incomplete + else: + bin_index_incomplete = incomplete elif dash_infix in name and _is_loadable_weight_file(name): - if _safe_is_file(entry): - has_variant_shard = True - return has_variant_shard and not has_complete_variant_index + if _safe_is_file(entry) and _format_kept(name): + if name.endswith(".safetensors"): + has_st_shard = True + else: + has_bin_shard = True + elif dot_infix in name and _is_loadable_weight_file(name): + # a single-file variant weight; only a safetensors single-file matters for precedence (a + # single-file bin variant is complete and handled by the fall-through ``return False``). + if name.endswith(".safetensors") and _safe_is_file(entry) and _format_kept(name): + has_single_st = True + # transformers reads safetensors before bin: judge the safetensors variant first, and fall to bin + # only when no safetensors variant is present in any form. + if st_index_incomplete is not None: + return st_index_incomplete + if has_st_shard: + return True # variant safetensors shard files with no index -> incomplete + if has_single_st: + return False # a complete single-file safetensors variant + if bin_index_incomplete is not None: + return bin_index_incomplete + if has_bin_shard: + return True # variant bin shard files with no index -> incomplete + return False + + +_VARIANT_SHARD_INDEX_RE = re.compile(r"\.(?:safetensors|bin)\.index\.([^.]+)\.json$") + +# The ROOT canonical / variant MODEL shard index (owned by the canonical / variant root checks): +# model.safetensors.index.json, pytorch_model.bin.index.json, and their variant forms. +_ROOT_MODEL_SHARD_INDEX_RE = re.compile( + r"^(?:model\.safetensors|pytorch_model\.bin)\.index(?:\.[^.]+)?\.json$" +) + + +def _index_variant_token(name: str) -> "Optional[str]": + """The variant token of a weight-shard INDEX basename, or None for the canonical (non-variant) form. + ``model.safetensors.index.json`` -> None; ``model.safetensors.index.fp16.json`` -> ``"fp16"``. Lets + the selected-index check read only the indices a load reads (a variant load reads variant indices, a + plain load reads canonical ones).""" + if name.endswith(".safetensors.index.json") or name.endswith(".bin.index.json"): + return None + m = _VARIANT_SHARD_INDEX_RE.search(name) + return m.group(1) if m else None + + +def _index_shard_rel_paths(index_path: Path, dir_rel: str) -> "Optional[list]": + """The snapshot-relative posix paths of the shards a weight index lists, or None if the index is + unreadable / malformed -- mirrors the fail-CLOSED rules of ``_weight_shard_index_complete`` (a + non-dict payload or ``weight_map``, an empty shard set, or a non-string / absolute / parent-escaping + shard value all return None). *dir_rel* is the index's snapshot-relative directory ("" at root), so a + listed basename is joined back to a full repo-relative path for the request filter.""" + import json + + try: + with open(index_path, "r", encoding = "utf-8") as f: + data = json.load(f) + except (OSError, ValueError): + return None + weight_map = data.get("weight_map") if isinstance(data, dict) else None + if not isinstance(weight_map, dict): + return None + values = list(weight_map.values()) + if not values or not all(isinstance(s, str) for s in values): + return None + prefix = f"{dir_rel}/" if dir_rel else "" + out: list = [] + for shard in set(values): + if shard.startswith(("/", "\\")) or ".." in shard.replace("\\", "/").split("/"): + return None + out.append(f"{prefix}{shard}") + return out + + +def _selected_shard_index_incomplete( + snapshot_dir: Path, *, allow_patterns: "Optional[object]", ignore_patterns: "Optional[object]", + variant: "Optional[str]", +) -> bool: + """True when a weight-shard INDEX the in-process load READS -- a sharded ADAPTER or a component + SUBFOLDER set that the canonical / variant ROOT-model checks do not cover -- lists a shard that is + absent (or the index is malformed). Scoped to the request so a complete download is never + false-rejected: + + - variant: a variant load reads only variant indices (token == variant); a plain load reads only + canonical (token is None) indices. + - allow / ignore: an index is read only when its listed shards survive the request filter. + - precedence: within a directory transformers reads safetensors before bin, so when both a + safetensors and a bin index are selected only the safetensors set's completeness is required. + + The ROOT canonical / variant MODEL index is skipped -- ``_has_incomplete_canonical_root_shards`` / + ``_has_incomplete_variant_root_shards`` own it (with their own precedence handling).""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + want_variant = variant or None + try: + entries = list(snapshot_dir.rglob("*")) + except OSError: + return False + per_dir: dict = {} # dir_rel -> {"safetensors": [shard_rels, ...], "bin": [...]} + for entry in entries: + name = entry.name + if not _is_weight_shard_index(name) or not _safe_is_file(entry): + continue + if _index_variant_token(name) != want_variant: + continue # a wrong-variant index the load does not read + try: + rel = entry.relative_to(snapshot_dir).as_posix() + except ValueError: + continue + dir_rel = rel.rsplit("/", 1)[0] if "/" in rel else "" + if dir_rel == "" and _ROOT_MODEL_SHARD_INDEX_RE.match(name): + continue # the ROOT model index -- owned by the canonical / variant root checks + shard_rels = _index_shard_rel_paths(entry, dir_rel) + if shard_rels is None: + return True # a malformed / non-string index -> defer to the watched child + if not _filter_paths(shard_rels, allow_patterns, ignore_patterns): + continue # the load does not read this set (out of scope / ignored format) + fmt = "safetensors" if ".safetensors.index." in name else "bin" + per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) + for by_fmt in per_dir.values(): + # safetensors read before bin: require only the preferred format present in this directory. + for shard_rels in by_fmt.get("safetensors") or by_fmt.get("bin") or []: + for shard in shard_rels: + try: + if not (snapshot_dir / shard).exists(): + return True + except OSError: + return True + return False def requested_named_files_present( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 7f5de0730..2e8c554c5 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -49,6 +49,7 @@ _has_incomplete_canonical_root_shards, _has_incomplete_variant_root_shards, _is_loadable_weight_file, + _selected_shard_index_incomplete, blob_bytes_present, has_active_incomplete_blobs, hf_cache_root, @@ -1254,6 +1255,10 @@ def _readable_shard_set_incomplete( request does not root-check. - plain: the canonical-root-shard check applies for an UNPATTERNED request, or a GLOBBED request that selects canonical root shards; an exact-named subset or an out-of-scope request does not. + - non-root: a PATTERNED request additionally checks any SELECTED shard index the root-model checks do + not cover (a sharded adapter under ``['adapter_model*']``, a component subfolder) via + ``_selected_shard_index_incomplete``. An UNPATTERNED request reads only the root model weight, so it + does not; an exact-named subset defers to the exact-file presence check. The ignore filter is threaded through so completeness is judged for the FORMAT the load reads (a complete safetensors set does not mask an incomplete ``.bin`` under ``ignore=['*.safetensors']``).""" @@ -1261,19 +1266,30 @@ def _readable_shard_set_incomplete( if allow_patterns is None or _request_selects_root_variant_weight( allow_patterns, ignore_patterns, variant ): - return _has_incomplete_variant_root_shards(snapshot_dir, variant) + if _has_incomplete_variant_root_shards( + snapshot_dir, variant, ignore_patterns = ignore_patterns + ): + return True + if allow_patterns is not None and _selected_shard_index_incomplete( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return True return False if allow_patterns is None: return _has_incomplete_canonical_root_shards( snapshot_dir, ignore_patterns = ignore_patterns ) - if not _patterns_are_exact_names(allow_patterns) and _request_selects_canonical_root_shards( - allow_patterns, ignore_patterns + if _patterns_are_exact_names(allow_patterns): + return False # an exact-named subset defers to the exact-file presence check + if _request_selects_canonical_root_shards(allow_patterns, ignore_patterns) and ( + _has_incomplete_canonical_root_shards(snapshot_dir, ignore_patterns = ignore_patterns) ): - return _has_incomplete_canonical_root_shards( - snapshot_dir, ignore_patterns = ignore_patterns - ) - return False + return True + return _selected_shard_index_incomplete( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = None, + ) def _selected_readable_weight_complete( From 2757211a564cdd222d9d3841364ff288e05b37eb Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 1 Jul 2026 09:45:49 +0000 Subject: [PATCH 63/82] Tighten default-load weight presence and either-format / weightless scoping Four post-download acceptance gaps a stale partial from the child could slip through, each guarded now for the layout the in-process load actually reads: - A DEFAULT (no-variant) load reads the canonical model.safetensors / pytorch_model.bin, not a variant-named model.fp16.safetensors. The root presence check accepted any loadable weight suffix, so a variant-only cache passed and the load then fetched the absent canonical weight over Xet. _root_model_has_weight now excludes a variant of a canonical base (model..safetensors and its sharded form); canonical names and non-canonical bases (consolidated.*, tf_model.h5) are unaffected. - A diffusers pipeline read any loadable weight anywhere as proof of a component, so a stale partial holding only a training-checkpoint subtree (checkpoint-N/) or a root adapter passed while the unet/vae/text-encoder weights the pipeline reads were missing. A new _has_diffusers_component_weight counts only weights in a component SUBFOLDER, excluding root-level files and checkpoint-N/ subtrees. It stays lenient on WHICH components are required (they can be optional). - An exact request listing both variant formats (["model.fp16.safetensors", "pytorch_model.fp16.bin"]) is an alternative over the repo, like the canonical either-format pair, but the equivalence grouping keyed only the canonical basenames, so a repo publishing only the safetensors variant was failed into a DownloadStallError. _exact_weight_logical now groups either-format alternatives of a shared variant token too. - request_can_include_weights treated every subdir-scoped allow as weight-bearing because the allow contained "/", so a metadata-only subdir warm (allow=["unet/*"] + ignore=[every weight suffix]) was required to hold a weight and a complete weightless subset was rejected. It now probes weight names under each subdir allow, so a genuinely weightless request accepts; a subdir allow whose weight suffixes survive the ignore stays weight-bearing. Adds regression tests for the variant-only default load, the diffusers checkpoint-only partial, the variant either-format alternatives, and the subdir-scoped weightless warm. --- tests/test_hf_xet_fallback.py | 81 ++++++++++++++++++++++++-- unsloth_zoo/hf_cache_state.py | 19 ++++-- unsloth_zoo/hf_xet_fallback.py | 102 ++++++++++++++++++++++++++------- 3 files changed, 172 insertions(+), 30 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 4313b94d8..fb86a3356 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2980,17 +2980,21 @@ def test_metadata_directory_glob_is_weightless(tmp_path): def test_allow_star_with_all_weights_ignored_is_weightless(tmp_path): - """A root-reachable allow that the ignore filter strips of every weight (allow=['*'] + - ignore=[every weight suffix]) reads weightless, so a complete config-only download is accepted, not - looped into a DownloadStallError. A subdir-scoped allow stays weight-bearing, and an allow whose - weights survive the ignore stays weight-bearing.""" + """An allow that the ignore filter strips of every weight reads weightless, so a complete config-only + download is accepted, not looped into a DownloadStallError. This holds for a ROOT allow (allow=['*']) + AND a subdir-scoped allow (allow=['unet/*']) -- a subdir warm that ignores every weight suffix selects + only that subdir's metadata (Codex #829). A subdir allow whose weight suffixes SURVIVE the ignore + stays weight-bearing, as does a root allow whose weights survive.""" all_weight_ignores = [ "*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", "*.ckpt", "*.onnx", "*.msgpack", "*.h5", "*.pdparams", ] assert hcs.request_can_include_weights(["*"], all_weight_ignores) is False assert hcs.request_can_include_weights(["*"], None) is True - assert hcs.request_can_include_weights(["unet/*"], all_weight_ignores) is True + # A subdir allow that ignores every weight suffix is weightless too (only unet/ metadata selected)... + assert hcs.request_can_include_weights(["unet/*"], all_weight_ignores) is False + # ...but one whose weight suffixes survive the ignore stays weight-bearing. + assert hcs.request_can_include_weights(["unet/*"], ["*.bin"]) is True assert hcs.request_can_include_weights(["*.safetensors"], ["*.bin"]) is True snap, _ = _mk_snapshot(tmp_path, "cfgonly") (snap / "config.json").write_text("{}") @@ -3018,6 +3022,73 @@ def test_post_download_rejects_checkpoint_only_root_model(tmp_path): (dsnap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(dblob) assert xf._download_result_usable( dsnap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # A diffusers snapshot whose ONLY weight is under checkpoint-N/ (a training artifact, not a pipeline + # component) is rejected: DiffusionPipeline reads component subfolders (unet/vae/...), so the load + # would fetch the missing components over Xet (Codex #829). + dck, dckb = _mk_snapshot(tmp_path, "diff_ckpt") + (dck / "model_index.json").write_text("{}") + (dck / "checkpoint-7").mkdir() + (dck / "checkpoint-7" / "diffusion_pytorch_model.safetensors").symlink_to(dckb) + assert xf._download_result_usable( + dck, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_rejects_variant_only_root_for_default_load(tmp_path): + """A DEFAULT (no-variant) load reads the canonical model.safetensors / pytorch_model.bin, NOT a + variant-named model.fp16.safetensors. A stale snapshot holding only the variant weight must be + rejected, else the in-process default load fetches the absent canonical weight over un-killable Xet + (Codex #829). A single-file or sharded variant name is excluded; canonical names still pass.""" + snap, blob = _mk_snapshot(tmp_path, "var_only") + (snap / "model.fp16.safetensors").symlink_to(blob) # variant-named only + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # A sharded variant name is likewise not a default weight. + snap_sh, blob_sh = _mk_snapshot(tmp_path, "var_only_sharded") + (snap_sh / "model.fp16-00001-of-00002.safetensors").symlink_to(blob_sh) + (snap_sh / "config.json").write_text("{}") + assert xf._download_result_usable( + snap_sh, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # The canonical weight present -> accepted (no false-reject), even beside the variant. + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # A variant load (variant='fp16') DOES read the variant weight -> accepted. + assert xf._download_result_usable( + snap_sh, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False # sharded variant with no index is still an incomplete set + snap_v, blob_v = _mk_snapshot(tmp_path, "var_single") + (snap_v / "model.fp16.safetensors").symlink_to(blob_v) + (snap_v / "config.json").write_text("{}") + assert xf._download_result_usable( + snap_v, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_variant_either_format_exact_alternatives(tmp_path): + """An exact request listing both variant formats (allow=['model.fp16.safetensors', + 'pytorch_model.fp16.bin']) is an ALTERNATIVE over the repo, like the canonical either-format pair: + a repo publishing only the safetensors variant is complete and must not be failed into a + DownloadStallError (Codex #829). A distinct-variant / base+adapter request still requires each.""" + snap, blob = _mk_snapshot(tmp_path, "var_either") + (snap / "model.fp16.safetensors").symlink_to(blob) # only the safetensors variant present + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["model.fp16.safetensors", "pytorch_model.fp16.bin"], + ignore_patterns = None, variant = "fp16") is True + # The canonical either-format pair keeps working (regression). + snap_c, blob_c = _mk_snapshot(tmp_path, "canon_either") + (snap_c / "pytorch_model.bin").symlink_to(blob_c) + assert xf._download_result_usable( + snap_c, repo_type = "model", + allow_patterns = ["model.safetensors", "pytorch_model.bin"], ignore_patterns = None) is True + # Base AND adapter are distinct groups: the adapter present but base absent -> rejected. + snap_d, blob_d = _mk_snapshot(tmp_path, "base_and_adapter") + (snap_d / "adapter_model.safetensors").symlink_to(blob_d) + assert xf._download_result_usable( + snap_d, repo_type = "model", + allow_patterns = ["model.safetensors", "adapter_model.safetensors"], + ignore_patterns = None) is False def test_post_download_validates_weightless_named_subset(tmp_path): diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index acbbdd599..04e4db448 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -458,11 +458,20 @@ def request_can_include_weights( return False # allow=[] selects nothing if not any(_pattern_can_select_weight(pat) for pat in allow_patterns): return False - # A root-reachable allow (no required subdir) can still be left weightless by the ignore filter - # (allow=["*"] + ignore=[every weight suffix]). Apply HF's allow-then-ignore semantics to the weight - # probes; a subdir-scoped allow stays weight-bearing (its required dir is absent from the root probes). - if ignore_patterns and all(isinstance(p, str) and "/" not in p for p in allow_patterns): - if not _filter_paths(list(_WEIGHT_PATTERN_PROBES), allow_patterns, ignore_patterns): + # An allow that can reach a weight can still be left weightless by the ignore filter: allow=["*"] + + # ignore=[every weight suffix], OR a subdir warm allow=["unet/*"] that ignores every weight suffix to + # fetch only that subdir's metadata / configs. Apply HF's allow-then-ignore semantics to representative + # weight probes at the ROOT and UNDER each subdir-scoped allow, so a genuinely weightless request is not + # required to hold a weight (which would false-reject a complete metadata-only subset after both + # transports). A subdir allow that keeps its weight suffixes still matches a subdir probe and stays + # weight-bearing. + if ignore_patterns: + probes = list(_WEIGHT_PATTERN_PROBES) + for pat in allow_patterns: + if isinstance(pat, str) and "/" in pat: + head = pat.rsplit("/", 1)[0] + probes.extend(f"{head}/{name}" for name in _WEIGHT_PATTERN_PROBES) + if not _filter_paths(probes, allow_patterns, ignore_patterns): return False return True diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 2e8c554c5..27aebb1db 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -998,28 +998,77 @@ def _has_any_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: return bool(_filter_paths(rels, None, ignore_patterns)) +# A VARIANT of a canonical root weight: the variant token sits between the base and the extension / +# shard suffix (model.fp16.safetensors, pytorch_model.fp16-00001-of-00002.bin). A DEFAULT (no-variant) +# load reads the canonical model.safetensors / pytorch_model.bin, NOT these, so a variant-only cache +# must not satisfy a default load's presence check (else the load fetches the absent canonical weight +# over un-killable Xet). Canonical names (model.safetensors, model-00001-of-00002.safetensors -- a dash, +# not a dotted token) and non-canonical bases (consolidated.*, tf_model.h5) are deliberately NOT matched. +_CANONICAL_BASE_VARIANT_RE = re.compile( + r"^(?:model|pytorch_model)\.[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" +) + +# A training-checkpoint subdir (checkpoint-500/, checkpoint_7/): its weights are never read as diffusers +# pipeline COMPONENTS, so they must not mask missing unet/vae/text-encoder weights. +_CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") + + +def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: + """True if a diffusers pipeline COMPONENT weight (a loadable weight in a component SUBFOLDER: unet/, + vae/, text_encoder/, ...) that the ignore filter keeps is present. Excludes ROOT-level weights (an + adapter / merged file a ``DiffusionPipeline`` does not read as a component) and training-checkpoint + subtrees (checkpoint-N/), so a stale partial holding only those does not mask the missing component + weights the pipeline reads -- which the in-process load would then fetch over un-killable Xet. Stays + lenient on WHICH components are required (a pipeline's components can be optional): it only tells a + real component warm from a checkpoint-only / config-only stale snapshot.""" + rels: list = [] + try: + for entry in snapshot_dir.rglob("*"): + if not _is_loadable_weight_file(entry.name): + continue + try: + if not entry.is_file(): + continue + rel = entry.relative_to(snapshot_dir).as_posix() + except (OSError, ValueError): + continue + parts = rel.split("/") + if len(parts) < 2: + continue # a ROOT-level weight is not a pipeline component + if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): + continue # under a training-checkpoint subtree, not a component + rels.append(rel) + except OSError: + return False + return bool(_filter_paths(rels, None, ignore_patterns)) + + def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: - """Whether an UNPATTERNED model warm holds a weight a default load reads: a ROOT weight, or -- for a - diffusers pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting any subtree - weight (as ``_has_any_weight`` does) would accept a stale checkpoint-only snapshot and then fetch - the root weights over un-killable Xet; diffusers is the one layout whose weights live in subfolders. - The request's ignore filter is applied to the ROOT weights, so an offline-fallback partial holding - only the format the load will NOT read (an ignored ``*.bin`` when safetensors was requested) does not - count as a usable weight -- the incomplete result is retried over HTTP instead of loaded in-process.""" + """Whether an UNPATTERNED model warm holds a weight a default load reads: a canonical ROOT weight, or + -- for a diffusers pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting any + subtree weight (as ``_has_any_weight`` does) would accept a stale checkpoint-only snapshot and then + fetch the root weights over un-killable Xet; diffusers is the one layout whose weights live in + subfolders. A VARIANT-named root weight (``model.fp16.safetensors``) is excluded: a default load does + not read it, so a variant-only cache is retried over HTTP rather than loaded. The request's ignore + filter is applied to the ROOT weights, so an offline-fallback partial holding only the format the load + will NOT read (an ignored ``*.bin`` when safetensors was requested) does not count as a usable weight.""" try: is_diffusers = (snapshot_dir / "model_index.json").is_file() except OSError: is_diffusers = False if is_diffusers: - return _has_any_weight(snapshot_dir, ignore_patterns = ignore_patterns) + return _has_diffusers_component_weight(snapshot_dir, ignore_patterns = ignore_patterns) rels: list = [] try: for entry in snapshot_dir.iterdir(): - if not _is_loadable_weight_file(entry.name): + name = entry.name + if not _is_loadable_weight_file(name): continue + if _CANONICAL_BASE_VARIANT_RE.match(name): + continue # a variant of a canonical weight is not read by a default (no-variant) load try: if entry.is_file(): - rels.append(entry.name) + rels.append(name) except OSError: continue except OSError: @@ -1056,14 +1105,27 @@ def _root_has_variant_weight( return bool(_filter_paths(rels, None, ignore_patterns)) -# Interchangeable exact weight names: the either-format ``["pytorch_model.bin", "model.safetensors"]`` -# pair is satisfied by ANY one, while distinct logical weights (base AND adapter) must each be present. -_EQUIVALENT_EXACT_WEIGHT_NAMES = { - "model.safetensors": "root_model", - "pytorch_model.bin": "root_model", - "adapter_model.safetensors": "adapter_model", - "adapter_model.bin": "adapter_model", -} +# Interchangeable exact weight names collapse to one equivalence group: the either-format pair +# ``["pytorch_model.bin", "model.safetensors"]`` is satisfied by ANY one -- and so is the variant pair +# ``["model.fp16.safetensors", "pytorch_model.fp16.bin"]`` (HF allow patterns are ALTERNATIVES over the +# repo, so a repo publishing only one format is complete). Distinct logical weights (base AND adapter, a +# different variant token) stay separate groups (each required). +_EITHER_FORMAT_WEIGHT_RE = re.compile( + r"^(model|pytorch_model|adapter_model)(?:\.([^.]+))?\.(?:safetensors|bin)$" +) + + +def _exact_weight_logical(base: str) -> Any: + """Equivalence key for an EXACT-named weight so the either-format alternatives share a group: + ``model.safetensors`` / ``pytorch_model.bin`` -> ``("root_model", None)``; the same variant token in + both formats shares ``("root_model", "")``; ``adapter_model.*`` -> ``("adapter_model", ...)``. + A non-weight (or sharded) name maps to itself, so each distinct file is still required.""" + m = _EITHER_FORMAT_WEIGHT_RE.match(base) + if m is None: + return base + stem, variant = m.group(1), m.group(2) + logical = "adapter_model" if stem == "adapter_model" else "root_model" + return (logical, variant) def _requested_exact_files_present_grouped( @@ -1087,10 +1149,10 @@ def _requested_exact_files_present_grouped( } except OSError: return True # cannot enumerate -> do not reject on an unreadable dir - groups: "dict[tuple[str, str], list[str]]" = {} + groups: "dict[tuple[str, Any], list[str]]" = {} for rel in requested: parent, base = rel.rsplit("/", 1) if "/" in rel else ("", rel) - logical = _EQUIVALENT_EXACT_WEIGHT_NAMES.get(base, base) + logical = _exact_weight_logical(base) groups.setdefault((parent, logical), []).append(rel) return all( any(candidate in present for candidate in candidates) for candidates in groups.values() From b8edaf3685eb930f3886e778b3b8dac2bcd0cdbd Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 1 Jul 2026 10:28:45 +0000 Subject: [PATCH 64/82] Scope default and variant root-weight checks to model, not adapter Two post-download acceptance gaps where a PEFT adapter was mistaken for the default model weight: - A DEFAULT (unpatterned) model warm read any loadable root file as proof the base weight was present, so a stale snapshot holding only adapter_model.* was accepted while the base model.safetensors / pytorch_model.bin was still missing; the in-process base load then fetched it over un-killable Xet. _root_model_has_weight now skips adapter_* (as it already skips a variant of a canonical weight). An adapter-scoped load (allow=['adapter_model*']) is unaffected -- it goes through the selected-weight path, not this default check. - The root variant-shard completeness check treated any shard index carrying the variant token as the model's variant index, so a complete model.fp16.safetensors co-resident with a stale, incomplete adapter_model.safetensors.index.fp16.json was reported unusable and looped into a DownloadStallError. Restrict the check to the root model / pytorch_model variant index and weight names (_ROOT_MODEL_SHARD_INDEX_RE / a new _ROOT_MODEL_VARIANT_WEIGHT_RE), so an adapter or other non-model variant set the default load never reads cannot force-fail a complete model variant. Adds regression tests for the adapter-only default warm (and the unaffected adapter-scoped load) and the variant root check ignoring an adapter index (while still catching an incomplete root model variant index). --- tests/test_hf_xet_fallback.py | 46 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 26 +++++++++++++------ unsloth_zoo/hf_xet_fallback.py | 12 ++++++--- 3 files changed, 72 insertions(+), 12 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index fb86a3356..2987905b9 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -3033,6 +3033,52 @@ def test_post_download_rejects_checkpoint_only_root_model(tmp_path): dck, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False +def test_post_download_rejects_adapter_only_for_default_load(tmp_path): + """A DEFAULT (unpatterned) model warm reads the base model.safetensors / pytorch_model.bin, not a PEFT + adapter. A stale snapshot holding only adapter_model.safetensors must be rejected, else the in-process + base load fetches the absent base weight over un-killable Xet (Codex #829). An adapter-scoped request + (allow=['adapter_model*']) is unaffected: it reads the adapter and still accepts it.""" + snap, blob = _mk_snapshot(tmp_path, "adapter_only_default") + (snap / "adapter_model.safetensors").symlink_to(blob) + (snap / "adapter_config.json").write_text("{}") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # An ADAPTER load (patterned) reads the adapter and accepts it (no regression to the PEFT path). + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is True + # The base weight present -> the default warm accepts (no false-reject), even beside the adapter. + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_variant_root_check_ignores_adapter_index(tmp_path): + """An unpatterned variant load reads the ROOT model variant, not a PEFT adapter. A complete + model.fp16.safetensors co-resident with a STALE, incomplete adapter_model.safetensors.index.fp16.json + must still be accepted -- the root variant-shard check is restricted to model / pytorch_model variant + names, so the adapter index's incompleteness does not force a spurious DownloadStallError (Codex #829). + A genuinely incomplete ROOT model variant index is still rejected.""" + snap, blob = _mk_snapshot(tmp_path, "var_adapter_idx") + (snap / "model.fp16.safetensors").symlink_to(blob) # complete root model variant (the read weight) + (snap / "adapter_model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "adapter_model.fp16-00001-of-00002.safetensors", + "b": "adapter_model.fp16-00002-of-00002.safetensors"}})) # stale, shards absent + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # An incomplete ROOT model variant index is still caught (the restriction did not disable the check). + snap2, blob2 = _mk_snapshot(tmp_path, "var_root_idx_incomplete") + (snap2 / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + (snap2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) # shard 2 absent + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + def test_post_download_rejects_variant_only_root_for_default_load(tmp_path): """A DEFAULT (no-variant) load reads the canonical model.safetensors / pytorch_model.bin, NOT a variant-named model.fp16.safetensors. A stale snapshot holding only the variant weight must be diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 04e4db448..bb7ff91d9 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -607,7 +607,9 @@ def _has_incomplete_variant_root_shards( format, and safetensors is treated as read BEFORE bin (transformers' probe order): a present-but- incomplete variant safetensors index is breakage even with a complete variant bin. Positive-evidence: a single-file variant or a complete variant shard set returns False, so a complete or single-file - variant download is never rejected.""" + variant download is never rejected. Only the ROOT ``model`` / ``pytorch_model`` variant weight is + considered: a co-resident stale ``adapter_model`` variant index / shard set (which a default variant + model load does not read) must not force-fail a complete model variant.""" dot_infix = f".{variant}." # variant index (model.safetensors.index..json) or single file dash_infix = f".{variant}-" # a sharded variant weight (model.-00001-of-00002.safetensors) ignore_patterns = _as_pattern_list(ignore_patterns) @@ -629,9 +631,10 @@ def _format_kept(weight_name: str) -> bool: has_single_st = False for entry in entries: name = entry.name - # _is_weight_shard_index matches canonical AND variant indices; the dot infix restricts to the - # requested variant (the canonical index has no ".." token). - if dot_infix in name and _is_weight_shard_index(name): + # Restrict to the ROOT model index (model.safetensors.index..json / + # pytorch_model.bin.index..json); an adapter_model / other non-model variant index the + # default load does not read is skipped so its incompleteness cannot force-fail the model variant. + if dot_infix in name and _ROOT_MODEL_SHARD_INDEX_RE.match(name): is_safetensors = ".safetensors.index." in name fmt_probe = ( f"model.{variant}.safetensors" if is_safetensors else f"pytorch_model.{variant}.bin" @@ -643,15 +646,15 @@ def _format_kept(weight_name: str) -> bool: st_index_incomplete = incomplete else: bin_index_incomplete = incomplete - elif dash_infix in name and _is_loadable_weight_file(name): + elif dash_infix in name and _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name): if _safe_is_file(entry) and _format_kept(name): if name.endswith(".safetensors"): has_st_shard = True else: has_bin_shard = True - elif dot_infix in name and _is_loadable_weight_file(name): - # a single-file variant weight; only a safetensors single-file matters for precedence (a - # single-file bin variant is complete and handled by the fall-through ``return False``). + elif dot_infix in name and _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name): + # a single-file ROOT model variant weight; only a safetensors single-file matters for + # precedence (a single-file bin variant is complete and handled by the fall-through). if name.endswith(".safetensors") and _safe_is_file(entry) and _format_kept(name): has_single_st = True # transformers reads safetensors before bin: judge the safetensors variant first, and fall to bin @@ -677,6 +680,13 @@ def _format_kept(weight_name: str) -> bool: r"^(?:model\.safetensors|pytorch_model\.bin)\.index(?:\.[^.]+)?\.json$" ) +# A ROOT model VARIANT weight (single or sharded): the variant token sits between the model / pytorch_model +# base and the extension / shard suffix (model.fp16.safetensors, pytorch_model.fp16-00001-of-00002.bin). +# Excludes a PEFT adapter (adapter_model..*) the default variant model load does not read. +_ROOT_MODEL_VARIANT_WEIGHT_RE = re.compile( + r"^(?:model|pytorch_model)\.[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" +) + def _index_variant_token(name: str) -> "Optional[str]": """The variant token of a weight-shard INDEX basename, or None for the canonical (non-variant) form. diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 27aebb1db..f3867eabd 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1048,10 +1048,12 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - -- for a diffusers pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting any subtree weight (as ``_has_any_weight`` does) would accept a stale checkpoint-only snapshot and then fetch the root weights over un-killable Xet; diffusers is the one layout whose weights live in - subfolders. A VARIANT-named root weight (``model.fp16.safetensors``) is excluded: a default load does - not read it, so a variant-only cache is retried over HTTP rather than loaded. The request's ignore - filter is applied to the ROOT weights, so an offline-fallback partial holding only the format the load - will NOT read (an ignored ``*.bin`` when safetensors was requested) does not count as a usable weight.""" + subfolders. A VARIANT-named root weight (``model.fp16.safetensors``) and a PEFT adapter + (``adapter_model.*``) are excluded: a default base-model load reads neither, so a cache holding only + those is retried over HTTP rather than loaded (its base ``model.safetensors`` / ``pytorch_model.bin`` + is still missing). The request's ignore filter is applied to the ROOT weights, so an offline-fallback + partial holding only the format the load will NOT read (an ignored ``*.bin`` under a safetensors + request) does not count as a usable weight.""" try: is_diffusers = (snapshot_dir / "model_index.json").is_file() except OSError: @@ -1064,6 +1066,8 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - name = entry.name if not _is_loadable_weight_file(name): continue + if name.startswith("adapter_"): + continue # a PEFT adapter (adapter_model.*) is not read by a default base-model load if _CANONICAL_BASE_VARIANT_RE.match(name): continue # a variant of a canonical weight is not read by a default (no-variant) load try: From 074b808ef6d996050cc2ef4817b6a27bde2e67ac Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 1 Jul 2026 11:17:12 +0000 Subject: [PATCH 65/82] Match post-download acceptance to what a default load reads (format, adapter, diffusers) Four post-download acceptance gaps where the child's snapshot_download result was judged against the wrong set of weights, so a stale / partial cache was either accepted (a later in-process load then fetches the missing weight over un-killable Xet) or a complete download was rejected into a spurious DownloadStallError. - A DEFAULT (unpatterned) warm accepted any loadable weight suffix as proof the weight was present, so a stale cache holding only model.Q4_K_M.gguf passed while the base model.safetensors / pytorch_model.bin was still missing. A default transformers / diffusers from_pretrained reads only safetensors / bin (a GGUF is read only for a GGUF-specific request), so the default root / component / variant presence checks now use _is_default_load_weight_file (safetensors / bin only). - An unpatterned VARIANT warm accepted a root filename carrying the variant infix, so a stale adapter-only snapshot (adapter_model.fp16.safetensors) passed even though the base model.fp16.safetensors was absent. The root variant check now skips adapter_* (as the plain root check already did), so an adapter the default base-model variant load never reads cannot mask the missing base variant. - A diffusers pipeline VARIANT warm's weights are component-scoped (unet/diffusion_pytorch_model.fp16.safetensors), not root model..* files, so the root-only variant presence check reported a complete diffusers variant download as incomplete and looped it into a DownloadStallError. The unpatterned variant presence check is now diffusers-aware (_root_model_has_variant_weight -> _has_diffusers_component_variant_weight), mirroring the plain path. - An unpatterned diffusers warm reads component subfolders, but the shard-set completeness check only looked at canonical ROOT shards, so a component with a shard INDEX listing a missing shard (unet/....index.json) was accepted and the pipeline load then fetched the missing shard over Xet. The plain and variant unpatterned diffusers branches now also run _diffusers_component_shards_incomplete, which flags a component index that lists an absent shard (scoped to component subfolders, skipping the root index and any training-checkpoint subtree, ignore filter applied, safetensors read before bin) so a complete pipeline still passes. Adds regression tests for each: a gguf-only default warm, an adapter-variant-only variant warm, a complete diffusers variant accepted (and a non-variant-only pipeline rejected for a variant load), and an incomplete diffusers component shard index (plain and variant) rejected then accepted once complete. --- tests/test_hf_xet_fallback.py | 91 ++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 60 +++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 116 +++++++++++++++++++++++++++++---- 3 files changed, 254 insertions(+), 13 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 2987905b9..a3abf9e02 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2693,6 +2693,97 @@ def test_post_download_rejects_incomplete_component_subfolder_shards(tmp_path): snap, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is True +def test_post_download_rejects_gguf_only_default_load(tmp_path): + """A DEFAULT (unpatterned) transformers warm reads model.safetensors / pytorch_model.bin, not a GGUF + file (only a GGUF-specific request does). A stale snapshot holding only model.Q4_K_M.gguf must be + rejected, else the in-process default load fetches the absent safetensors / bin over un-killable Xet + (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "gguf_only") + (snap / "model.Q4_K_M.gguf").symlink_to(blob) + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # The safetensors weight present -> the default warm accepts (no false-reject), even beside the gguf. + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_rejects_adapter_variant_for_default_variant_load(tmp_path): + """An unpatterned variant warm reads the ROOT model variant (model.fp16.safetensors), not a PEFT + adapter variant. A stale snapshot holding only adapter_model.fp16.safetensors must be rejected, else + the in-process base-model variant load fetches the absent model.fp16.safetensors over un-killable Xet + (Codex #829). The base model variant present -> accepted (no false-reject).""" + snap, blob = _mk_snapshot(tmp_path, "adapter_variant_only") + (snap / "adapter_model.fp16.safetensors").symlink_to(blob) + (snap / "adapter_config.json").write_text("{}") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + (snap / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_accepts_complete_diffusers_variant(tmp_path): + """A diffusers pipeline variant warm's weights are COMPONENT-scoped (unet/....fp16.safetensors), not + root model..* files. A complete diffusers variant download must be accepted -- the root-only + variant presence check would false-reject it into a spurious DownloadStallError (Codex #829). A + pipeline holding only the NON-variant component weight does not satisfy a variant load.""" + snap, blob = _mk_snapshot(tmp_path, "diffusers_variant") + (snap / "model_index.json").write_text("{}") + (snap / "unet").mkdir() + (snap / "unet" / "config.json").write_text("{}") + (snap / "unet" / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + snap2, blob2 = _mk_snapshot(tmp_path, "diffusers_variant_missing") + (snap2 / "model_index.json").write_text("{}") + (snap2 / "unet").mkdir() + (snap2 / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob2) # non-variant only + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_rejects_incomplete_diffusers_component_shards_unpatterned(tmp_path): + """An UNPATTERNED diffusers pipeline warm reads component subfolders (unet/, vae/, ...). A component + shard INDEX listing a shard that is absent -- which the canonical ROOT-shard check does not cover -- + must be rejected, else the in-process pipeline load fetches the missing shard over un-killable Xet + (Codex #829). Both the plain and the variant component index are covered; a complete set is accepted.""" + snap, blob = _mk_snapshot(tmp_path, "diffusers_comp_incomplete") + (snap / "model_index.json").write_text("{}") + (snap / "unet").mkdir() + (snap / "unet" / "config.json").write_text("{}") + (snap / "unet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "unet" / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + (snap / "unet" / "diffusion_pytorch_model-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # The same for a VARIANT component index (variant='fp16', unpatterned). + snapv, blobv = _mk_snapshot(tmp_path, "diffusers_comp_variant_incomplete") + (snapv / "model_index.json").write_text("{}") + (snapv / "unet").mkdir() + (snapv / "unet" / "diffusion_pytorch_model.fp16-00001-of-00002.safetensors").symlink_to(blobv) + (snapv / "unet" / "diffusion_pytorch_model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model.fp16-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snapv, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + (snapv / "unet" / "diffusion_pytorch_model.fp16-00002-of-00002.safetensors").symlink_to(blobv) + assert xf._download_result_usable( + snapv, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + def test_selected_readable_weight_complete_entry_point(tmp_path): """The weight-bearing acceptance check funnels through one helper enforcing two invariants: (A) a readable weight is present (ignore + scope applied), (B) its in-scope shard set is complete. diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index bb7ff91d9..ff21c7ed0 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -784,6 +784,66 @@ def _selected_shard_index_incomplete( return False +# A training-checkpoint subdir (checkpoint-500/, checkpoint_7/): its weights are never read as diffusers +# pipeline COMPONENTS, so an incomplete shard index under it must not force-fail a complete pipeline. +_CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") + + +def _diffusers_component_shards_incomplete( + snapshot_dir: Path, *, variant: "Optional[str]" = None, + ignore_patterns: "Optional[object]" = None, +) -> bool: + """True when a diffusers pipeline COMPONENT subfolder (unet/, vae/, text_encoder/, ...) holds a + weight-shard INDEX of the read variant that lists a shard that is absent (or the index is malformed) + -- an interrupted component pull the in-process pipeline load would finish over un-killable Xet. + + Scoped so a complete pipeline is never false-rejected: a ROOT index (owned by the canonical / variant + root-model checks) and a training-checkpoint subtree (checkpoint-N/, never read as a pipeline + component) are skipped, and the request's ignore filter selects the read format. Per directory, + safetensors is read before bin, so only the preferred format's set must be complete. A plain load + reads canonical component indices (token None); a variant load reads variant ones. Positive-evidence: + a single-file component or a complete component shard set is not flagged, so a complete download passes.""" + want_variant = variant or None + ignore_patterns = _as_pattern_list(ignore_patterns) + try: + entries = list(snapshot_dir.rglob("*")) + except OSError: + return False + per_dir: dict = {} + for entry in entries: + name = entry.name + if not _is_weight_shard_index(name) or not _safe_is_file(entry): + continue + if _index_variant_token(name) != want_variant: + continue # a wrong-variant index the load does not read + try: + rel = entry.relative_to(snapshot_dir).as_posix() + except ValueError: + continue + parts = rel.split("/") + if len(parts) < 2: + continue # a ROOT index -- owned by the canonical / variant root-model checks + if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): + continue # a training-checkpoint subtree, not a pipeline component + dir_rel = rel.rsplit("/", 1)[0] + shard_rels = _index_shard_rel_paths(entry, dir_rel) + if shard_rels is None: + return True # a malformed / non-string index -> defer to the watched child + if not _filter_paths(shard_rels, None, ignore_patterns): + continue # the load does not read this set (ignored format) + fmt = "safetensors" if ".safetensors.index." in name else "bin" + per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) + for by_fmt in per_dir.values(): + for shard_rels in by_fmt.get("safetensors") or by_fmt.get("bin") or []: + for shard in shard_rels: + try: + if not (snapshot_dir / shard).exists(): + return True + except OSError: + return True + return False + + def requested_named_files_present( snapshot_dir: Path, *, diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index f3867eabd..cd0d84e78 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -44,6 +44,7 @@ from unsloth_zoo.hf_cache_state import ( INCOMPLETE_SUFFIX, _as_pattern_list, + _diffusers_component_shards_incomplete, _filter_paths, _has_glob, _has_incomplete_canonical_root_shards, @@ -998,6 +999,15 @@ def _has_any_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: return bool(_filter_paths(rels, None, ignore_patterns)) +def _is_default_load_weight_file(name: str) -> bool: + """A weight in a format a DEFAULT ``from_pretrained`` reads: safetensors or bin only. Excludes gguf / + pt / pth / onnx / msgpack / ... -- a default (non-format-specific) transformers / diffusers load does + not read those, so a stale cache holding only e.g. ``model.Q4_K_M.gguf`` does not satisfy the load, + which would then fetch the missing ``model.safetensors`` / ``pytorch_model.bin`` over un-killable Xet. + Trainer / optimizer state (``optimizer.bin``, ...) is excluded by ``_is_loadable_weight_file``.""" + return _is_loadable_weight_file(name) and name.endswith((".safetensors", ".bin")) + + # A VARIANT of a canonical root weight: the variant token sits between the base and the extension / # shard suffix (model.fp16.safetensors, pytorch_model.fp16-00001-of-00002.bin). A DEFAULT (no-variant) # load reads the canonical model.safetensors / pytorch_model.bin, NOT these, so a variant-only cache @@ -1024,7 +1034,7 @@ def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any rels: list = [] try: for entry in snapshot_dir.rglob("*"): - if not _is_loadable_weight_file(entry.name): + if not _is_default_load_weight_file(entry.name): continue try: if not entry.is_file(): @@ -1064,8 +1074,8 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - try: for entry in snapshot_dir.iterdir(): name = entry.name - if not _is_loadable_weight_file(name): - continue + if not _is_default_load_weight_file(name): + continue # a default load reads safetensors / bin, not gguf / pt / ... if name.startswith("adapter_"): continue # a PEFT adapter (adapter_model.*) is not read by a default base-model load if _CANONICAL_BASE_VARIANT_RE.match(name): @@ -1095,8 +1105,10 @@ def _root_has_variant_weight( try: for entry in snapshot_dir.iterdir(): name = entry.name - if not _is_loadable_weight_file(name): - continue + if not _is_default_load_weight_file(name): + continue # a default variant load reads safetensors / bin, not gguf / pt / ... + if name.startswith("adapter_"): + continue # a PEFT adapter variant is not read by a default base-model variant load if infix_dot not in name and infix_dash not in name: continue try: @@ -1109,6 +1121,60 @@ def _root_has_variant_weight( return bool(_filter_paths(rels, None, ignore_patterns)) +def _has_diffusers_component_variant_weight( + snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None +) -> bool: + """Variant analog of ``_has_diffusers_component_weight``: True if a diffusers pipeline COMPONENT + subfolder (unet/, vae/, text_encoder/, ...) holds a weight carrying the requested *variant* token + (``unet/diffusion_pytorch_model.fp16.safetensors``). A variant pipeline warm's weights are + component-scoped, not root ``model..*`` files, so a root-only variant check would + false-reject a complete diffusers variant download into a ``DownloadStallError``. Excludes ROOT-level + and training-checkpoint weights (as the plain component check does) and reads only safetensors / bin.""" + infix_dot = f".{variant}." + infix_dash = f".{variant}-" + rels: list = [] + try: + for entry in snapshot_dir.rglob("*"): + name = entry.name + if not _is_default_load_weight_file(name): + continue + if infix_dot not in name and infix_dash not in name: + continue + try: + if not entry.is_file(): + continue + rel = entry.relative_to(snapshot_dir).as_posix() + except (OSError, ValueError): + continue + parts = rel.split("/") + if len(parts) < 2: + continue # a ROOT-level variant weight is not a pipeline component + if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): + continue # under a training-checkpoint subtree, not a component + rels.append(rel) + except OSError: + return False + return bool(_filter_paths(rels, None, ignore_patterns)) + + +def _root_model_has_variant_weight( + snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None +) -> bool: + """Whether an UNPATTERNED variant warm holds a variant weight a default load reads: a ROOT variant + weight, or -- for a diffusers pipeline (root ``model_index.json``) -- a component-subfolder variant + weight. Variant analog of ``_root_model_has_weight``: a diffusers variant's weights live in component + subfolders, not root ``model..*`` files, so the root-only check would false-reject them.""" + try: + is_diffusers = (snapshot_dir / "model_index.json").is_file() + except OSError: + is_diffusers = False + if is_diffusers: + return _has_diffusers_component_variant_weight( + snapshot_dir, variant, ignore_patterns = ignore_patterns + ) + return _root_has_variant_weight(snapshot_dir, variant, ignore_patterns = ignore_patterns) + + # Interchangeable exact weight names collapse to one equivalence group: the either-format pair # ``["pytorch_model.bin", "model.safetensors"]`` is satisfied by ANY one -- and so is the variant pair # ``["model.fp16.safetensors", "pytorch_model.fp16.bin"]`` (HF allow patterns are ALTERNATIVES over the @@ -1295,7 +1361,9 @@ def _has_readable_weight( so the incomplete result is retried over HTTP rather than loaded in-process.""" if variant: if allow_patterns is None: - return _root_has_variant_weight(snapshot_dir, variant, ignore_patterns = ignore_patterns) + return _root_model_has_variant_weight( + snapshot_dir, variant, ignore_patterns = ignore_patterns + ) return _has_selected_variant_weight( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, variant = variant, @@ -1323,11 +1391,18 @@ def _readable_shard_set_incomplete( selects canonical root shards; an exact-named subset or an out-of-scope request does not. - non-root: a PATTERNED request additionally checks any SELECTED shard index the root-model checks do not cover (a sharded adapter under ``['adapter_model*']``, a component subfolder) via - ``_selected_shard_index_incomplete``. An UNPATTERNED request reads only the root model weight, so it - does not; an exact-named subset defers to the exact-file presence check. + ``_selected_shard_index_incomplete``; an exact-named subset defers to the exact-file presence check. + - diffusers: an UNPATTERNED plain / variant warm of a pipeline (root ``model_index.json``) reads + COMPONENT subfolders (unet/, vae/, ...), so a component shard index missing a shard -- which the + root-model checks do not cover -- is caught via ``_diffusers_component_shards_incomplete``. A + non-diffusers unpatterned request reads only the root model weight, so it does not sub-check. The ignore filter is threaded through so completeness is judged for the FORMAT the load reads (a complete safetensors set does not mask an incomplete ``.bin`` under ``ignore=['*.safetensors']``).""" + try: + is_diffusers = (snapshot_dir / "model_index.json").is_file() + except OSError: + is_diffusers = False if variant: if allow_patterns is None or _request_selects_root_variant_weight( allow_patterns, ignore_patterns, variant @@ -1336,16 +1411,31 @@ def _readable_shard_set_incomplete( snapshot_dir, variant, ignore_patterns = ignore_patterns ): return True - if allow_patterns is not None and _selected_shard_index_incomplete( - snapshot_dir, allow_patterns = allow_patterns, - ignore_patterns = ignore_patterns, variant = variant, + if allow_patterns is not None: + if _selected_shard_index_incomplete( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return True + elif is_diffusers and _diffusers_component_shards_incomplete( + snapshot_dir, variant = variant, ignore_patterns = ignore_patterns ): + # an UNPATTERNED variant diffusers warm: a component subfolder's variant shard index is + # incomplete (the root variant check above only covers root model. shards). return True return False if allow_patterns is None: - return _has_incomplete_canonical_root_shards( + if _has_incomplete_canonical_root_shards( snapshot_dir, ignore_patterns = ignore_patterns - ) + ): + return True + if is_diffusers and _diffusers_component_shards_incomplete( + snapshot_dir, variant = None, ignore_patterns = ignore_patterns + ): + # an UNPATTERNED plain diffusers warm reads component subfolders (unet/, vae/, ...); a + # component shard index missing a shard is not covered by the canonical ROOT-shard check. + return True + return False if _patterns_are_exact_names(allow_patterns): return False # an exact-named subset defers to the exact-file presence check if _request_selects_canonical_root_shards(allow_patterns, ignore_patterns) and ( From b5e5baa76e1def9554665d4633648043b17def5e Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 1 Jul 2026 12:00:46 +0000 Subject: [PATCH 66/82] Align canonical acceptance with transformers precedence; scope diffusers checks to declared components Three post-download acceptance corrections so the check matches what a default load actually reads: - Probe the single safetensors weight BEFORE the safetensors index. transformers' local weight-file precedence is single model.safetensors, then the safetensors index, then single pytorch_model.bin, then the bin index (verified against modeling_utils.py). A complete single model.safetensors co-resident with a STALE incomplete model.safetensors.index.json was reported incomplete and looped into a DownloadStallError, even though the load reads the single file. _canonical_root_ weights_complete now mirrors that order exactly, so a complete single weight is never masked by a stale index (and an incomplete PREFERRED safetensors index is still breakage a complete .bin must not mask). - Count only the CANONICAL root model weight for a default warm. The root presence check accepted any root safetensors / bin name, so a stale cache holding only a non-canonical root weight (consolidated.safetensors) passed while the base model.safetensors / pytorch_model.bin was missing, and the default load would then fetch it over un-killable Xet. _root_model_has_weight now matches only model.safetensors / pytorch_model.bin (single or numbered shard) via _CANONICAL_ROOT_MODEL_WEIGHT_RE; an adapter, a variant, a gguf, and a consolidated.* are excluded (a default from_pretrained probes only the canonical names). - Scope the diffusers component shard check to model_index.json-declared components. It treated every non-checkpoint subfolder as a component, so a complete pipeline co-resident with a stale UNDECLARED subtree (a leftover controlnet/ with an incomplete shard index the DiffusionPipeline load never reads) was force-failed. _diffusers_component_shards_incomplete now skips subfolders the pipeline does not declare; a malformed / empty model_index.json fails OPEN (every subfolder checked) so hang protection is preserved. Adds regression tests for each, plus a control that an incomplete DECLARED component is still rejected. Verified against the safety-invariant fuzz (0 false-accepts). --- tests/test_hf_xet_fallback.py | 79 ++++++++++++++++++++++++++++++++-- unsloth_zoo/hf_cache_state.py | 71 ++++++++++++++++++++---------- unsloth_zoo/hf_xet_fallback.py | 47 ++++++++++---------- 3 files changed, 148 insertions(+), 49 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index a3abf9e02..a48fc050e 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2784,6 +2784,77 @@ def test_post_download_rejects_incomplete_diffusers_component_shards_unpatterned variant = "fp16") is True +def test_post_download_single_safetensors_beats_stale_index(tmp_path): + """transformers probes single model.safetensors BEFORE model.safetensors.index.json, so a complete + single weight co-resident with a STALE incomplete index is usable and must not be looped into a + DownloadStallError (Codex #829). A stale index with NO single weight is still breakage.""" + snap, blob = _mk_snapshot(tmp_path, "single_beats_index") + (snap / "config.json").write_text("{}") + (snap / "model.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) # shards absent + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + assert hcs.snapshot_dir_is_complete(snap) is True # the PRE gate agrees (offline warm short-circuit) + # No single weight, only the stale index -> the sharded-safetensors load would fetch missing shards. + snap2, _ = _mk_snapshot(tmp_path, "stale_index_only") + (snap2 / "config.json").write_text("{}") + (snap2 / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_rejects_noncanonical_root_weight_for_default_load(tmp_path): + """A DEFAULT load probes only the canonical model.safetensors / pytorch_model.bin (single or numbered + shard). A stale cache holding only a NON-canonical root weight (consolidated.safetensors) must be + rejected, else the default load fetches the absent canonical weight over un-killable Xet (Codex #829). + The canonical weight present -> accepted.""" + snap, blob = _mk_snapshot(tmp_path, "noncanonical") + (snap / "config.json").write_text("{}") + (snap / "consolidated.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_diffusers_component_check_scoped_to_declared_components(tmp_path): + """The component shard check is scoped to the components model_index.json declares. A complete + pipeline (declared unet+vae present) co-resident with a STALE UNDECLARED subtree (a leftover + controlnet/ with an incomplete shard index the DiffusionPipeline load never reads) must still be + accepted (Codex #829); an incomplete DECLARED component is still rejected (hang protection kept).""" + snap, blob = _mk_snapshot(tmp_path, "declared_scope") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], "vae": ["diffusers", "AutoencoderKL"]})) + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "config.json").write_text("{}") + (snap / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob) + (snap / "controlnet").mkdir() # UNDECLARED leftover + (snap / "controlnet" / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) + (snap / "controlnet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # An incomplete DECLARED component (unet index missing a shard) is still caught. + snap2, blob2 = _mk_snapshot(tmp_path, "declared_incomplete") + (snap2 / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "UNet2DConditionModel"]})) + (snap2 / "unet").mkdir() + (snap2 / "unet" / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) + (snap2 / "unet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + def test_selected_readable_weight_complete_entry_point(tmp_path): """The weight-bearing acceptance check funnels through one helper enforcing two invariants: (A) a readable weight is present (ignore + scope applied), (B) its in-scope shard set is complete. @@ -3009,7 +3080,8 @@ def test_post_download_rejects_config_only_for_explicit_weight_pattern(tmp_path) def test_post_download_rejects_incomplete_canonical_root_shards(tmp_path): """An interrupted canonical sharded warm (loose model-00001-of-00002.safetensors, no index) has a loadable file but a default load cannot read it and would fetch the rest over un-killable Xet, so - it is rejected. A complete sharded set is accepted; a variant-only shard layout is not force-failed.""" + it is rejected. A complete sharded set is accepted; a variant-only shard layout does not satisfy a + default (no-variant) load, which reads only canonical names.""" snap, blob = _mk_snapshot(tmp_path, "incshard") (snap / "config.json").write_text("{}") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) @@ -3022,12 +3094,13 @@ def test_post_download_rejects_incomplete_canonical_root_shards(tmp_path): "b": "model-00002-of-00002.safetensors"}})) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True - # A variant-only shard layout has no canonical shard -> not force-failed here. + # A variant-named shard is NOT a canonical weight a default load reads, so a variant-only cache is + # rejected (the default load would fetch the absent canonical model.safetensors over Xet). vsnap, vblob = _mk_snapshot(tmp_path, "vshard") (vsnap / "config.json").write_text("{}") (vsnap / "model-00001-of-00001.fp16.safetensors").symlink_to(vblob) assert xf._download_result_usable( - vsnap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + vsnap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False def test_local_token_not_found_error_type_preserved(): diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index ff21c7ed0..4a172085c 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -508,25 +508,23 @@ def _format_kept(weight_name: str) -> bool: return True return bool(_filter_paths([weight_name], None, ignore_patterns)) - incomplete_preferred_index = False - for index_entry in root_indices: - is_safetensors = ".safetensors.index." in index_entry.name - fmt_probe = "model.safetensors" if is_safetensors else "pytorch_model.bin" - if not _format_kept(fmt_probe): - continue # this format is ignored -> the load will not read it - if _weight_shard_index_complete(index_entry): - return True - if is_safetensors: - # transformers probes the safetensors index BEFORE the .bin, so a present-but-incomplete - # safetensors index means the load prefers (and fetches) safetensors -- a complete .bin must - # NOT mask it. Treat it as breakage (defer to the watched child) unless safetensors is - # explicitly ignored (handled by _format_kept above). - incomplete_preferred_index = True - if incomplete_preferred_index: - return False - return any( - name in root_files and _format_kept(name) for name in _CANONICAL_SINGLE_WEIGHTS - ) + st_index = next((e for e in root_indices if ".safetensors.index." in e.name), None) + bin_index = next((e for e in root_indices if ".bin.index." in e.name), None) + # transformers' local weight-file precedence, mirrored exactly: a single model.safetensors is probed + # BEFORE the safetensors index, safetensors before the .bin single, and the .bin single before the + # .bin index. So a complete single weight is never masked by a co-resident stale index, and an + # incomplete PREFERRED (safetensors) index is breakage a complete .bin must not mask (transformers + # takes the safetensors-index branch and does not fall back to .bin). A format the ignore filter + # drops is skipped so the next format the load actually reads is judged. + if "model.safetensors" in root_files and _format_kept("model.safetensors"): + return True + if st_index is not None and _format_kept("model.safetensors"): + return _weight_shard_index_complete(st_index) + if "pytorch_model.bin" in root_files and _format_kept("pytorch_model.bin"): + return True + if bin_index is not None and _format_kept("pytorch_model.bin"): + return _weight_shard_index_complete(bin_index) + return False def snapshot_dir_is_complete( @@ -789,6 +787,31 @@ def _selected_shard_index_incomplete( _CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") +def _diffusers_declared_components(snapshot_dir: Path) -> "Optional[set]": + """The component subfolder names a diffusers ``model_index.json`` declares (top-level keys mapping to + a ``[library, class]`` list; ``_``-prefixed metadata keys excluded). None when the file is absent / + unreadable / malformed, so the caller falls back to treating every subfolder as a component (fail + OPEN, preserving hang protection). Scopes the component shard check to what the pipeline actually + reads, so a co-resident stale UNDECLARED subtree (a leftover adapter / controlnet dir the + ``DiffusionPipeline`` load never reads) cannot force-fail a complete pipeline download.""" + import json + + try: + with open(snapshot_dir / "model_index.json", "r", encoding = "utf-8") as f: + data = json.load(f) + except (OSError, ValueError): + return None + if not isinstance(data, dict): + return None + components = { + key for key, value in data.items() + if not key.startswith("_") and isinstance(value, (list, tuple)) + } + # A real pipeline always declares components; an empty / all-metadata model_index.json is degenerate + # or malformed -> fail OPEN (None) so the caller checks every subfolder, preserving hang protection. + return components or None + + def _diffusers_component_shards_incomplete( snapshot_dir: Path, *, variant: "Optional[str]" = None, ignore_patterns: "Optional[object]" = None, @@ -797,14 +820,16 @@ def _diffusers_component_shards_incomplete( weight-shard INDEX of the read variant that lists a shard that is absent (or the index is malformed) -- an interrupted component pull the in-process pipeline load would finish over un-killable Xet. - Scoped so a complete pipeline is never false-rejected: a ROOT index (owned by the canonical / variant - root-model checks) and a training-checkpoint subtree (checkpoint-N/, never read as a pipeline - component) are skipped, and the request's ignore filter selects the read format. Per directory, + Scoped so a complete pipeline is never false-rejected: the check is limited to the components + ``model_index.json`` declares (a stale UNDECLARED subtree the pipeline load never reads is skipped), + a ROOT index (owned by the canonical / variant root-model checks) and a training-checkpoint subtree + (checkpoint-N/) are skipped, and the request's ignore filter selects the read format. Per directory, safetensors is read before bin, so only the preferred format's set must be complete. A plain load reads canonical component indices (token None); a variant load reads variant ones. Positive-evidence: a single-file component or a complete component shard set is not flagged, so a complete download passes.""" want_variant = variant or None ignore_patterns = _as_pattern_list(ignore_patterns) + declared = _diffusers_declared_components(snapshot_dir) try: entries = list(snapshot_dir.rglob("*")) except OSError: @@ -823,6 +848,8 @@ def _diffusers_component_shards_incomplete( parts = rel.split("/") if len(parts) < 2: continue # a ROOT index -- owned by the canonical / variant root-model checks + if declared is not None and parts[0] not in declared: + continue # an UNDECLARED subtree the DiffusionPipeline load does not read if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): continue # a training-checkpoint subtree, not a pipeline component dir_rel = rel.rsplit("/", 1)[0] diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index cd0d84e78..043f59393 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1008,14 +1008,14 @@ def _is_default_load_weight_file(name: str) -> bool: return _is_loadable_weight_file(name) and name.endswith((".safetensors", ".bin")) -# A VARIANT of a canonical root weight: the variant token sits between the base and the extension / -# shard suffix (model.fp16.safetensors, pytorch_model.fp16-00001-of-00002.bin). A DEFAULT (no-variant) -# load reads the canonical model.safetensors / pytorch_model.bin, NOT these, so a variant-only cache -# must not satisfy a default load's presence check (else the load fetches the absent canonical weight -# over un-killable Xet). Canonical names (model.safetensors, model-00001-of-00002.safetensors -- a dash, -# not a dotted token) and non-canonical bases (consolidated.*, tf_model.h5) are deliberately NOT matched. -_CANONICAL_BASE_VARIANT_RE = re.compile( - r"^(?:model|pytorch_model)\.[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" +# The CANONICAL root model weight a DEFAULT (no-variant) load reads: model.safetensors / +# pytorch_model.bin as a single file, or a numbered shard (model-00001-of-00002.safetensors -- a dash, +# not a dotted variant token). A PEFT adapter (adapter_model.*), a variant (model.fp16.safetensors), a +# gguf, and a non-canonical root weight (consolidated.safetensors, tf_model.h5) are NOT matched: a +# default from_pretrained probes only these canonical names, so a cache holding only something else does +# not satisfy the load, which would then fetch the missing canonical weight over un-killable Xet. +_CANONICAL_ROOT_MODEL_WEIGHT_RE = re.compile( + r"^(?:model|pytorch_model)(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" ) # A training-checkpoint subdir (checkpoint-500/, checkpoint_7/): its weights are never read as diffusers @@ -1054,16 +1054,18 @@ def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: - """Whether an UNPATTERNED model warm holds a weight a default load reads: a canonical ROOT weight, or - -- for a diffusers pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting any - subtree weight (as ``_has_any_weight`` does) would accept a stale checkpoint-only snapshot and then - fetch the root weights over un-killable Xet; diffusers is the one layout whose weights live in - subfolders. A VARIANT-named root weight (``model.fp16.safetensors``) and a PEFT adapter - (``adapter_model.*``) are excluded: a default base-model load reads neither, so a cache holding only - those is retried over HTTP rather than loaded (its base ``model.safetensors`` / ``pytorch_model.bin`` - is still missing). The request's ignore filter is applied to the ROOT weights, so an offline-fallback - partial holding only the format the load will NOT read (an ignored ``*.bin`` under a safetensors - request) does not count as a usable weight.""" + """Whether an UNPATTERNED model warm holds a weight a default load reads: a CANONICAL ROOT weight + (``model.safetensors`` / ``pytorch_model.bin``, single or numbered shard), or -- for a diffusers + pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting any subtree weight (as + ``_has_any_weight`` does) would accept a stale checkpoint-only snapshot and then fetch the root + weights over un-killable Xet; diffusers is the one layout whose weights live in subfolders. Only the + canonical names are counted (``_CANONICAL_ROOT_MODEL_WEIGHT_RE``): a VARIANT-named root weight + (``model.fp16.safetensors``), a PEFT adapter (``adapter_model.*``), a gguf, and a NON-canonical root + weight (``consolidated.safetensors``) are excluded, since a default from_pretrained probes only the + canonical names, so a cache holding only something else is retried over HTTP rather than loaded (its + canonical weight is still missing). The request's ignore filter is applied to the ROOT weights, so an + offline-fallback partial holding only the format the load will NOT read (an ignored ``*.bin`` under a + safetensors request) does not count as a usable weight.""" try: is_diffusers = (snapshot_dir / "model_index.json").is_file() except OSError: @@ -1074,12 +1076,9 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - try: for entry in snapshot_dir.iterdir(): name = entry.name - if not _is_default_load_weight_file(name): - continue # a default load reads safetensors / bin, not gguf / pt / ... - if name.startswith("adapter_"): - continue # a PEFT adapter (adapter_model.*) is not read by a default base-model load - if _CANONICAL_BASE_VARIANT_RE.match(name): - continue # a variant of a canonical weight is not read by a default (no-variant) load + if not _CANONICAL_ROOT_MODEL_WEIGHT_RE.match(name): + continue # only a canonical model.safetensors / pytorch_model.bin (single or shard) is + # read by a default load -- an adapter, variant, gguf, or consolidated.* is not try: if entry.is_file(): rels.append(name) From 6b5c3d613e9f6194c3896094be0a5ac07c11e2d9 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 1 Jul 2026 13:01:41 +0000 Subject: [PATCH 67/82] Reject non-canonical variant names, lone shards without an index, and undeclared diffusers components Three more post-download acceptance gaps where a stale/partial cache passed and the later in-process load would fetch the missing weights over un-killable Xet: - Variant presence counted any name carrying the variant token. A DEFAULT variant load reads the canonical model..safetensors / a numbered variant shard transformers writes (variant on the base then sharded: model.-00001-of- 00002.safetensors). A non-canonical sidecar (consolidated.fp16.safetensors) or a name transformers does not read masked the missing canonical variant weight. _root_has_variant_weight now requires _ROOT_MODEL_VARIANT_WEIGHT_RE (the variant analog of the canonical-name restriction already applied to the no-variant path). - A selected numbered shard without its index was accepted. For a patterned non-root request (allow=['adapter_model*'], ['unet/*']) a lone adapter_model-00001-of-00002 / unet/...-00001-of-00002 shard has a loadable-looking file but no index, so the load cannot enumerate the set and would fetch the index + remaining shards outside the watched child. _selected_shard_index_incomplete (and the diffusers-component check) now reject a selected numbered shard whose directory has no index of the read format; a complete indexed set is still accepted. - Diffusers weight presence counted undeclared subfolders. An unpatterned pipeline warm treated any component-subfolder weight as proof, so a stale cache holding only an UNDECLARED leftover (controlnet/ not in model_index.json) passed while the declared unet/vae weights were missing. _has_diffusers_component_weight now scopes to model_index.json-declared components (fail-open on a malformed index), matching the scoping already applied to the component shard check. Adds regression tests for each. Verified against the safety-invariant fuzz (0 false-accepts). Not taken: the P2 broken-symlink either-format alternative report -- it needs an exact ['model.safetensors','pytorch_model.bin'] request the loaders do not generate and is a false-reject, not a hang; making the shared broken-symlink utility either-format-group-aware is disproportionate for that narrow case. --- tests/test_hf_xet_fallback.py | 88 ++++++++++++++++++++++++++- unsloth_zoo/hf_cache_state.py | 106 ++++++++++++++++++++++++--------- unsloth_zoo/hf_xet_fallback.py | 46 ++++++++------ 3 files changed, 194 insertions(+), 46 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index a48fc050e..ede61d63e 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2373,14 +2373,25 @@ def test_post_download_rejects_patterned_canonical_only_for_variant(tmp_path): assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, variant = "fp16") is True - # A sharded in-scope variant weight (dash infix) is likewise accepted. + # A COMPLETE sharded in-scope variant weight (dash infix + its variant index) is accepted. snap2, blob2 = _mk_snapshot(tmp_path, "subvarshard") sub2 = snap2 / "weights" sub2.mkdir() (sub2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) + (sub2 / "model.fp16-00002-of-00002.safetensors").symlink_to(blob2) + (sub2 / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) assert xf._download_result_usable( snap2, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, variant = "fp16") is True + # A LONE variant shard with no index is an incomplete set the load cannot enumerate -> rejected. + snap2b, blob2b = _mk_snapshot(tmp_path, "subvarshard_lone") + (snap2b / "weights").mkdir() + (snap2b / "weights" / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2b) + assert xf._download_result_usable( + snap2b, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is False # An out-of-scope variant weight does NOT satisfy an in-scope variant request. snap3, blob3 = _mk_snapshot(tmp_path, "subvaroos") (snap3 / "model.fp16.safetensors").symlink_to(blob3) # at root, but request scopes to weights/ @@ -2855,6 +2866,81 @@ def test_diffusers_component_check_scoped_to_declared_components(tmp_path): snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False +def test_post_download_variant_presence_requires_canonical_name(tmp_path): + """The unpatterned variant presence check counts only a CANONICAL model variant name a default + variant load reads (model..safetensors, model.-NNNNN-of-NNNNN.safetensors). A + non-canonical sidecar (consolidated.fp16.safetensors) or a non-transformers dot-infix shard name + (model-00001-of-00001.fp16.safetensors) must NOT satisfy the request, else the load fetches the + absent model.fp16.safetensors over un-killable Xet (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "var_noncanonical") + (snap / "config.json").write_text("{}") + (snap / "consolidated.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + snap_dot, blob_dot = _mk_snapshot(tmp_path, "var_dotinfix") + (snap_dot / "config.json").write_text("{}") + (snap_dot / "model-00001-of-00001.fp16.safetensors").symlink_to(blob_dot) # not a name tf reads + assert xf._download_result_usable( + snap_dot, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # The canonical single variant weight -> accepted. + (snap / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_rejects_selected_shard_without_index(tmp_path): + """A SELECTED non-root numbered shard with NO index of the read format is an incomplete set the load + cannot enumerate (it needs the index to list the shards), so it is rejected and retried over HTTP, + else the adapter / component load fetches the index and remaining shards over un-killable Xet + (Codex #829). A complete indexed set is accepted.""" + # A sharded ADAPTER with a lone shard and no index. + snap, blob = _mk_snapshot(tmp_path, "adapter_lone_shard") + (snap / "config.json").write_text("{}") + (snap / "adapter_config.json").write_text("{}") + (snap / "adapter_model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is False + # Complete it with the second shard + index -> accepted. + (snap / "adapter_model-00002-of-00002.safetensors").symlink_to(blob) + (snap / "adapter_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "adapter_model-00001-of-00002.safetensors", + "b": "adapter_model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is True + # A component subfolder lone shard (allow=['unet/*']) is likewise rejected. + snap2, blob2 = _mk_snapshot(tmp_path, "unet_lone_shard") + (snap2 / "unet").mkdir() + (snap2 / "unet" / "config.json").write_text("{}") + (snap2 / "unet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is False + + +def test_post_download_diffusers_presence_scoped_to_declared(tmp_path): + """An UNPATTERNED diffusers pipeline warm counts a component weight as proof only for a DECLARED + component. A stale cache holding only an UNDECLARED leftover (controlnet/ not in model_index.json) + must be rejected, else the pipeline fetches the declared unet/vae weights in-process over Xet + (Codex #829). The declared components present -> accepted.""" + snap, blob = _mk_snapshot(tmp_path, "diffusers_undeclared_only") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + (snap / "controlnet").mkdir() # UNDECLARED + (snap / "controlnet" / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # The declared components present -> accepted. + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + def test_selected_readable_weight_complete_entry_point(tmp_path): """The weight-bearing acceptance check funnels through one helper enforcing two invariants: (A) a readable weight is present (ignore + scope applied), (B) its in-scope shard set is complete. diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 4a172085c..1656a32e0 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -740,36 +740,64 @@ def _selected_shard_index_incomplete( - precedence: within a directory transformers reads safetensors before bin, so when both a safetensors and a bin index are selected only the safetensors set's completeness is required. - The ROOT canonical / variant MODEL index is skipped -- ``_has_incomplete_canonical_root_shards`` / + Also rejects a SELECTED numbered shard FILE (adapter_model-00001-of-00002.safetensors, + unet/diffusion_pytorch_model-00001-of-00002.safetensors) whose directory has NO index of the read + format: the load enumerates a sharded weight through its index, so a shard set without one is + incomplete and would fetch the index and remaining shards over Xet. + + The ROOT canonical / variant MODEL shard set is skipped -- ``_has_incomplete_canonical_root_shards`` / ``_has_incomplete_variant_root_shards`` own it (with their own precedence handling).""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) want_variant = variant or None + if want_variant is None: + shard_file_re = re.compile(r"^[^.]+-\d{5}-of-\d{5}\.(?:safetensors|bin)$") + else: + v = re.escape(want_variant) + shard_file_re = re.compile(rf"^[^.]+\.{v}-\d{{5}}-of-\d{{5}}\.(?:safetensors|bin)$") try: entries = list(snapshot_dir.rglob("*")) except OSError: return False - per_dir: dict = {} # dir_rel -> {"safetensors": [shard_rels, ...], "bin": [...]} + per_dir: dict = {} # dir_rel -> {"safetensors": [shard_rels, ...], "bin": [...]} (from indices) + index_fmts: dict = {} # dir_rel -> {fmt} an index of the read variant is present (non-root-model) + shard_fmts: dict = {} # dir_rel -> {fmt} a SELECTED numbered shard file is present (non-root-model) for entry in entries: name = entry.name - if not _is_weight_shard_index(name) or not _safe_is_file(entry): + if not _safe_is_file(entry): continue - if _index_variant_token(name) != want_variant: - continue # a wrong-variant index the load does not read try: rel = entry.relative_to(snapshot_dir).as_posix() except ValueError: continue dir_rel = rel.rsplit("/", 1)[0] if "/" in rel else "" - if dir_rel == "" and _ROOT_MODEL_SHARD_INDEX_RE.match(name): - continue # the ROOT model index -- owned by the canonical / variant root checks - shard_rels = _index_shard_rel_paths(entry, dir_rel) - if shard_rels is None: - return True # a malformed / non-string index -> defer to the watched child - if not _filter_paths(shard_rels, allow_patterns, ignore_patterns): - continue # the load does not read this set (out of scope / ignored format) - fmt = "safetensors" if ".safetensors.index." in name else "bin" - per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) + if _is_weight_shard_index(name): + if _index_variant_token(name) != want_variant: + continue # a wrong-variant index the load does not read + if dir_rel == "" and _ROOT_MODEL_SHARD_INDEX_RE.match(name): + continue # the ROOT model index -- owned by the canonical / variant root checks + fmt = "safetensors" if ".safetensors.index." in name else "bin" + index_fmts.setdefault(dir_rel, set()).add(fmt) + shard_rels = _index_shard_rel_paths(entry, dir_rel) + if shard_rels is None: + return True # a malformed / non-string index -> defer to the watched child + if not _filter_paths(shard_rels, allow_patterns, ignore_patterns): + continue # the load does not read this set (out of scope / ignored format) + per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) + elif shard_file_re.match(name): + # a numbered weight shard FILE of the read variant. Skip the ROOT model shard set (owned by + # the canonical / variant root-shard checks) and any training-checkpoint subtree. + if dir_rel == "" and ( + (want_variant is None and _CANONICAL_ROOT_SHARD_RE.match(name)) + or (want_variant is not None and _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name)) + ): + continue + if any(_CHECKPOINT_DIR_RE.match(p) for p in rel.split("/")[:-1]): + continue + if not _filter_paths([rel], allow_patterns, ignore_patterns): + continue # the load does not read this shard (out of scope / ignored format) + fmt = "safetensors" if name.endswith(".safetensors") else "bin" + shard_fmts.setdefault(dir_rel, set()).add(fmt) for by_fmt in per_dir.values(): # safetensors read before bin: require only the preferred format present in this directory. for shard_rels in by_fmt.get("safetensors") or by_fmt.get("bin") or []: @@ -779,6 +807,12 @@ def _selected_shard_index_incomplete( return True except OSError: return True + for dir_rel, fmts in shard_fmts.items(): + # a numbered shard of the read (preferred) format with NO index in its directory: the load cannot + # enumerate the set and would fetch the index + remaining shards over Xet. + preferred = "safetensors" if "safetensors" in fmts else "bin" + if preferred not in index_fmts.get(dir_rel, set()): + return True return False @@ -825,41 +859,55 @@ def _diffusers_component_shards_incomplete( a ROOT index (owned by the canonical / variant root-model checks) and a training-checkpoint subtree (checkpoint-N/) are skipped, and the request's ignore filter selects the read format. Per directory, safetensors is read before bin, so only the preferred format's set must be complete. A plain load - reads canonical component indices (token None); a variant load reads variant ones. Positive-evidence: - a single-file component or a complete component shard set is not flagged, so a complete download passes.""" + reads canonical component indices (token None); a variant load reads variant ones. Also rejects a + component holding a numbered shard FILE with NO index of the read format (the pipeline cannot + enumerate the set and would fetch the index + remaining shards over Xet). Positive-evidence: a + single-file component or a complete component shard set is not flagged, so a complete download passes.""" want_variant = variant or None ignore_patterns = _as_pattern_list(ignore_patterns) declared = _diffusers_declared_components(snapshot_dir) + if want_variant is None: + shard_file_re = re.compile(r"^[^.]+-\d{5}-of-\d{5}\.(?:safetensors|bin)$") + else: + v = re.escape(want_variant) + shard_file_re = re.compile(rf"^[^.]+\.{v}-\d{{5}}-of-\d{{5}}\.(?:safetensors|bin)$") try: entries = list(snapshot_dir.rglob("*")) except OSError: return False per_dir: dict = {} + index_fmts: dict = {} # component dir_rel -> {fmt} an index of the read variant is present + shard_fmts: dict = {} # component dir_rel -> {fmt} a numbered shard file (ignore-kept) is present for entry in entries: name = entry.name - if not _is_weight_shard_index(name) or not _safe_is_file(entry): + if not _safe_is_file(entry): continue - if _index_variant_token(name) != want_variant: - continue # a wrong-variant index the load does not read try: rel = entry.relative_to(snapshot_dir).as_posix() except ValueError: continue parts = rel.split("/") if len(parts) < 2: - continue # a ROOT index -- owned by the canonical / variant root-model checks + continue # a ROOT file -- owned by the canonical / variant root-model checks if declared is not None and parts[0] not in declared: continue # an UNDECLARED subtree the DiffusionPipeline load does not read if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): continue # a training-checkpoint subtree, not a pipeline component dir_rel = rel.rsplit("/", 1)[0] - shard_rels = _index_shard_rel_paths(entry, dir_rel) - if shard_rels is None: - return True # a malformed / non-string index -> defer to the watched child - if not _filter_paths(shard_rels, None, ignore_patterns): - continue # the load does not read this set (ignored format) - fmt = "safetensors" if ".safetensors.index." in name else "bin" - per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) + if _is_weight_shard_index(name): + if _index_variant_token(name) != want_variant: + continue # a wrong-variant index the load does not read + fmt = "safetensors" if ".safetensors.index." in name else "bin" + index_fmts.setdefault(dir_rel, set()).add(fmt) + shard_rels = _index_shard_rel_paths(entry, dir_rel) + if shard_rels is None: + return True # a malformed / non-string index -> defer to the watched child + if not _filter_paths(shard_rels, None, ignore_patterns): + continue # the load does not read this set (ignored format) + per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) + elif shard_file_re.match(name) and _filter_paths([rel], None, ignore_patterns): + fmt = "safetensors" if name.endswith(".safetensors") else "bin" + shard_fmts.setdefault(dir_rel, set()).add(fmt) for by_fmt in per_dir.values(): for shard_rels in by_fmt.get("safetensors") or by_fmt.get("bin") or []: for shard in shard_rels: @@ -868,6 +916,10 @@ def _diffusers_component_shards_incomplete( return True except OSError: return True + for dir_rel, fmts in shard_fmts.items(): + preferred = "safetensors" if "safetensors" in fmts else "bin" + if preferred not in index_fmts.get(dir_rel, set()): + return True # a component numbered shard with no index of the read format return False diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 043f59393..bede199b6 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -43,8 +43,10 @@ from unsloth_zoo.hf_cache_state import ( INCOMPLETE_SUFFIX, + _ROOT_MODEL_VARIANT_WEIGHT_RE, _as_pattern_list, _diffusers_component_shards_incomplete, + _diffusers_declared_components, _filter_paths, _has_glob, _has_incomplete_canonical_root_shards, @@ -1024,13 +1026,17 @@ def _is_default_load_weight_file(name: str) -> bool: def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: - """True if a diffusers pipeline COMPONENT weight (a loadable weight in a component SUBFOLDER: unet/, - vae/, text_encoder/, ...) that the ignore filter keeps is present. Excludes ROOT-level weights (an - adapter / merged file a ``DiffusionPipeline`` does not read as a component) and training-checkpoint - subtrees (checkpoint-N/), so a stale partial holding only those does not mask the missing component - weights the pipeline reads -- which the in-process load would then fetch over un-killable Xet. Stays - lenient on WHICH components are required (a pipeline's components can be optional): it only tells a - real component warm from a checkpoint-only / config-only stale snapshot.""" + """True if a DECLARED diffusers pipeline COMPONENT weight (a loadable weight in a component SUBFOLDER + the ``model_index.json`` declares: unet/, vae/, text_encoder/, ...) that the ignore filter keeps is + present. Scoped to declared components, so a stale partial holding only an UNDECLARED leftover subtree + (a controlnet/ dir not in ``model_index.json``) does not read as proof the pipeline is warm while the + declared unet / vae weights are still missing -- which the in-process load would then fetch over + un-killable Xet. Also excludes ROOT-level weights (an adapter / merged file a ``DiffusionPipeline`` + does not read as a component) and training-checkpoint subtrees (checkpoint-N/). A malformed / empty + ``model_index.json`` fails OPEN (any component subfolder counts). Stays lenient on WHICH declared + components are required (a pipeline's components can be optional): it only tells a real component warm + from an undeclared-leftover / checkpoint-only / config-only stale snapshot.""" + declared = _diffusers_declared_components(snapshot_dir) rels: list = [] try: for entry in snapshot_dir.rglob("*"): @@ -1045,6 +1051,8 @@ def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any parts = rel.split("/") if len(parts) < 2: continue # a ROOT-level weight is not a pipeline component + if declared is not None and parts[0] not in declared: + continue # an UNDECLARED subtree the DiffusionPipeline load does not read if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): continue # under a training-checkpoint subtree, not a component rels.append(rel) @@ -1092,24 +1100,26 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - def _root_has_variant_weight( snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None ) -> bool: - """True if a ROOT weight carrying the requested *variant* token, and kept by the ignore filter, is - present. transformers inserts the variant before the extension (a ``..`` infix: - ``model.fp16.safetensors``) or before a shard suffix (a ``.-`` infix: - ``model.fp16-00001-of-00002.safetensors``), so an offline-fallback partial that kept only the - canonical weight does not satisfy a variant request. The ignore filter is applied so a partial - holding only the ignored format (``model.fp16.bin`` under ``ignore=['*.bin']``) does not count.""" + """True if a CANONICAL ROOT model weight carrying the requested *variant* token, kept by the ignore + filter, is present. transformers writes the variant on the model base then shards it, so the names it + reads are ``model..safetensors`` (single) and ``model.-00001-of-00002.safetensors`` + (a ``.-`` shard infix) -- matched by ``_ROOT_MODEL_VARIANT_WEIGHT_RE`` plus the specific + variant infix. A non-canonical base (``consolidated..safetensors``), a PEFT adapter, or a + non-``model`` variant name a default variant load never reads is excluded, so a cache holding only + those is retried over HTTP rather than loaded (its ``model..*`` weight is still missing). The + ignore filter is applied so a partial holding only the ignored format (``model.fp16.bin`` under + ``ignore=['*.bin']``) does not count.""" infix_dot = f".{variant}." infix_dash = f".{variant}-" rels: list = [] try: for entry in snapshot_dir.iterdir(): name = entry.name - if not _is_default_load_weight_file(name): - continue # a default variant load reads safetensors / bin, not gguf / pt / ... - if name.startswith("adapter_"): - continue # a PEFT adapter variant is not read by a default base-model variant load if infix_dot not in name and infix_dash not in name: - continue + continue # not the requested variant token + if not _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name): + continue # only a canonical model / pytorch_model variant weight is read by a default + # variant load -- an adapter, a consolidated.* sidecar, or a gguf is not try: if entry.is_file(): rels.append(name) From f823d37e88c3b45201b0765ef0665c96e063ff10 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Thu, 2 Jul 2026 05:19:57 +0000 Subject: [PATCH 68/82] Scope diffusers variant presence, order variant single-vs-index, skip root shards for diffusers, scope malformed indexes Five more post-download acceptance gaps from the cache-completeness classifier: two would let a stale/partial cache pass and the later in-process load fetch missing weights over un-killable Xet (hang), three would false-reject a genuinely complete download into a DownloadStallError. - Variant diffusers presence counted UNDECLARED components. An unpatterned variant pipeline warm treated any component-subfolder variant weight as proof, so a stale cache holding only an UNDECLARED leftover (controlnet/....fp16.safetensors not in model_index.json) passed while the declared unet/vae variant weights were missing. _has_diffusers_component_variant_weight now scopes to model_index.json-declared components (fail-open on a malformed index), the variant twin of the scoping already applied to the plain component-presence helper. - A single-file variant was masked by a stale variant index. transformers probes the single model..safetensors BEFORE model.safetensors.index..json (and the single .bin before the .bin variant index), so a complete single variant weight co-resident with a stale incomplete variant index is usable. _has_incomplete_variant_root_shards now judges the single-file variant before its index (with the .bin single tracked symmetrically), mirroring the canonical single-beats-index order; a stale index with no single weight is still breakage. - The root-model shard checks ran on diffusers snapshots. A DiffusionPipeline reads component subfolders, not root model shards, so a complete pipeline co-resident with a stale root model.safetensors.index(.).json was rejected. The canonical and variant root-shard checks in _readable_shard_set_incomplete are now gated on not is_diffusers; component sets are still checked (unpatterned via _diffusers_component_shards_incomplete, patterned via _selected_shard_index_incomplete). - A malformed shard index the request did not select forced a retry. _selected_shard_index_incomplete returned incomplete for any malformed non-root index before checking scope, so a base ['model*'] / subfolder warm with a complete weight and a co-resident stale malformed adapter index was failed. It now defers to the child only when the request actually selects that index; an in-scope malformed index is still breakage. Adds regression tests for each (176 -> 180). Verified against the safety-invariant fuzz (0 false-accepts). Not taken: the P2 broad-glob variant root-presence report -- base variant loads pass allow_patterns=None and are already root-scoped via the canonical-name-hardened _root_model_has_variant_weight; the only patterned variant requests the loaders generate (adapter / subfolder / gguf) do not select root variant weights, so requiring a root variant weight whenever a filter could select one would over-assume base-load intent and false-reject legitimate broad-glob adapter / diffusers variant downloads. --- tests/test_hf_xet_fallback.py | 120 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 33 ++++++--- unsloth_zoo/hf_xet_fallback.py | 60 +++++++++++------ 3 files changed, 182 insertions(+), 31 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index ede61d63e..843377690 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2941,6 +2941,126 @@ def test_post_download_diffusers_presence_scoped_to_declared(tmp_path): snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True +def test_post_download_diffusers_variant_presence_scoped_to_declared(tmp_path): + """Variant twin of the declared-scope check: an UNPATTERNED diffusers VARIANT warm counts a component + variant weight as proof only for a DECLARED component. A stale cache holding only an UNDECLARED variant + leftover (controlnet/....fp16.safetensors not in model_index.json) must be rejected, else the pipeline + fetches the declared unet/vae variant weights in-process over un-killable Xet (Codex #829).""" + snap, blob = _mk_snapshot(tmp_path, "diffusers_variant_undeclared_only") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + (snap / "controlnet").mkdir() # UNDECLARED + (snap / "controlnet" / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # The declared component variant weights present -> accepted. + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_single_variant_beats_stale_variant_index(tmp_path): + """Variant twin of single-beats-index: transformers probes single model..safetensors BEFORE + model.safetensors.index..json, so a complete single variant weight co-resident with a STALE + incomplete variant index is usable and must not be looped into a DownloadStallError (Codex #829). Same + for a single .bin variant vs a stale .bin variant index; a stale variant index with NO single weight is + still breakage.""" + snap, blob = _mk_snapshot(tmp_path, "single_variant_beats_index") + (snap / "config.json").write_text("{}") + (snap / "model.fp16.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) # shards absent + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # A single .bin variant co-resident with a stale .bin variant index (no safetensors) -> usable. + snapb, blobb = _mk_snapshot(tmp_path, "single_bin_variant_beats_index") + (snapb / "config.json").write_text("{}") + (snapb / "pytorch_model.fp16.bin").symlink_to(blobb) + (snapb / "pytorch_model.bin.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "pytorch_model.fp16-00001-of-00002.bin"}})) + assert xf._download_result_usable( + snapb, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # A stale variant index with NO single variant weight -> the sharded load fetches missing shards. + snap2, _ = _mk_snapshot(tmp_path, "stale_variant_index_only") + (snap2 / "config.json").write_text("{}") + (snap2 / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_diffusers_skips_root_model_shard_checks(tmp_path): + """A diffusers pipeline reads COMPONENT subfolders, not root model shards. A complete pipeline + (declared unet+vae present) co-resident with a STALE root model shard INDEX -- canonical or variant -- + must be ACCEPTED: the root-model shard check does not apply to a diffusers snapshot, else a valid + pipeline is looped into a DownloadStallError (Codex #829). Component completeness is still enforced.""" + # Plain: stale root model.safetensors.index.json alongside complete components. + snap, blob = _mk_snapshot(tmp_path, "diffusers_stale_root_index_plain") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) # shards absent (stale) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # Variant: stale root model.safetensors.index.fp16.json alongside complete variant components. + snapv, blobv = _mk_snapshot(tmp_path, "diffusers_stale_root_index_variant") + (snapv / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + for comp in ("unet", "vae"): + (snapv / comp).mkdir() + (snapv / comp / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blobv) + (snapv / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snapv, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # Hang protection kept: an incomplete DECLARED component is still rejected. + (snapv / "unet" / "diffusion_pytorch_model.fp16.safetensors").unlink() + (snapv / "unet" / "diffusion_pytorch_model.fp16-00001-of-00002.safetensors").symlink_to(blobv) + (snapv / "unet" / "diffusion_pytorch_model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model.fp16-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snapv, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_out_of_scope_malformed_index_not_rejected(tmp_path): + """A malformed shard index the REQUEST does not select is not read by the load, so it must not + false-reject a complete in-scope download into a DownloadStallError (Codex #829). A base ['model*'] warm + with a complete model.safetensors and a co-resident stale MALFORMED adapter index is accepted; an + IN-scope malformed index (an adapter warm) is still breakage.""" + snap, blob = _mk_snapshot(tmp_path, "malformed_out_of_scope") + (snap / "config.json").write_text("{}") + (snap / "model.safetensors").symlink_to(blob) + (snap / "adapter_model.safetensors.index.json").write_text("{ not valid json") # malformed, unselected + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["model*"], ignore_patterns = None) is True + # An IN-scope malformed index (adapter warm selects adapter_model*) is still rejected. + snap2, blob2 = _mk_snapshot(tmp_path, "malformed_in_scope") + (snap2 / "config.json").write_text("{}") + (snap2 / "adapter_config.json").write_text("{}") + (snap2 / "adapter_model-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "adapter_model.safetensors.index.json").write_text("{ not valid json") + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is False + + def test_selected_readable_weight_complete_entry_point(tmp_path): """The weight-bearing acceptance check funnels through one helper enforcing two invariants: (A) a readable weight is present (ignore + scope applied), (B) its in-scope shard set is complete. diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 1656a32e0..9207cdf9f 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -626,7 +626,7 @@ def _format_kept(weight_name: str) -> bool: st_index_incomplete = None # tri-state: None absent, else present & (in)complete bin_index_incomplete = None has_st_shard = has_bin_shard = False - has_single_st = False + has_single_st = has_single_bin = False for entry in entries: name = entry.name # Restrict to the ROOT model index (model.safetensors.index..json / @@ -651,18 +651,25 @@ def _format_kept(weight_name: str) -> bool: else: has_bin_shard = True elif dot_infix in name and _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name): - # a single-file ROOT model variant weight; only a safetensors single-file matters for - # precedence (a single-file bin variant is complete and handled by the fall-through). - if name.endswith(".safetensors") and _safe_is_file(entry) and _format_kept(name): - has_single_st = True - # transformers reads safetensors before bin: judge the safetensors variant first, and fall to bin - # only when no safetensors variant is present in any form. + # a single-file ROOT model variant weight (model..safetensors / .bin). + if _safe_is_file(entry) and _format_kept(name): + if name.endswith(".safetensors"): + has_single_st = True + else: + has_single_bin = True + # transformers' local precedence, mirrored: a single-file model..safetensors is probed + # BEFORE the safetensors index, safetensors before .bin, and the single .bin before the .bin index. + # So a complete single-file variant is never masked by a co-resident stale index (that would force a + # spurious HTTP retry and DownloadStallError on a usable cache), and an incomplete PREFERRED + # (safetensors) index is still breakage a complete .bin must not mask. + if has_single_st: + return False # a complete single-file safetensors variant, probed before the index if st_index_incomplete is not None: return st_index_incomplete if has_st_shard: return True # variant safetensors shard files with no index -> incomplete - if has_single_st: - return False # a complete single-file safetensors variant + if has_single_bin: + return False # a complete single-file bin variant, probed before the .bin index if bin_index_incomplete is not None: return bin_index_incomplete if has_bin_shard: @@ -780,7 +787,13 @@ def _selected_shard_index_incomplete( index_fmts.setdefault(dir_rel, set()).add(fmt) shard_rels = _index_shard_rel_paths(entry, dir_rel) if shard_rels is None: - return True # a malformed / non-string index -> defer to the watched child + # A malformed / non-string index. Defer to the watched child only when the REQUEST + # selects this index (the load would read it to enumerate its shards); a co-resident + # stale malformed index the request does NOT select (a leftover adapter index under a + # base ['model*'] / subfolder warm) is not read, so it must not force a spurious retry. + if _filter_paths([rel], allow_patterns, ignore_patterns): + return True + continue if not _filter_paths(shard_rels, allow_patterns, ignore_patterns): continue # the load does not read this set (out of scope / ignored format) per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index bede199b6..82d9b8f24 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1133,12 +1133,19 @@ def _root_has_variant_weight( def _has_diffusers_component_variant_weight( snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None ) -> bool: - """Variant analog of ``_has_diffusers_component_weight``: True if a diffusers pipeline COMPONENT - subfolder (unet/, vae/, text_encoder/, ...) holds a weight carrying the requested *variant* token - (``unet/diffusion_pytorch_model.fp16.safetensors``). A variant pipeline warm's weights are - component-scoped, not root ``model..*`` files, so a root-only variant check would - false-reject a complete diffusers variant download into a ``DownloadStallError``. Excludes ROOT-level - and training-checkpoint weights (as the plain component check does) and reads only safetensors / bin.""" + """Variant analog of ``_has_diffusers_component_weight``: True if a DECLARED diffusers pipeline + COMPONENT subfolder (unet/, vae/, text_encoder/, ... that ``model_index.json`` declares) holds a + weight carrying the requested *variant* token (``unet/diffusion_pytorch_model.fp16.safetensors``). A + variant pipeline warm's weights are component-scoped, not root ``model..*`` files, so a + root-only variant check would false-reject a complete diffusers variant download into a + ``DownloadStallError``. Scoped to declared components (as the plain component helper is), so a stale + partial holding only an UNDECLARED leftover variant weight (a ``controlnet/`` dir not in + ``model_index.json``) does not read as proof the pipeline is warm while the declared unet / vae + variant weights are still missing -- which ``DiffusionPipeline.from_pretrained(..., variant=...)`` + would then fetch over un-killable Xet. A malformed / empty ``model_index.json`` fails OPEN. Excludes + ROOT-level and training-checkpoint weights (as the plain component check does) and reads only + safetensors / bin.""" + declared = _diffusers_declared_components(snapshot_dir) infix_dot = f".{variant}." infix_dash = f".{variant}-" rels: list = [] @@ -1158,6 +1165,8 @@ def _has_diffusers_component_variant_weight( parts = rel.split("/") if len(parts) < 2: continue # a ROOT-level variant weight is not a pipeline component + if declared is not None and parts[0] not in declared: + continue # an UNDECLARED subtree the DiffusionPipeline load does not read if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): continue # under a training-checkpoint subtree, not a component rels.append(rel) @@ -1393,18 +1402,19 @@ def _readable_shard_set_incomplete( reads (a leftover root checkpoint under a subfolder/adapter/gguf request) does not false-reject a complete download: - - variant: the ROOT variant-shard check applies for an UNPATTERNED request, or a PATTERNED request - that selects a ROOT variant weight (a globbed ``['*.safetensors']``); a subfolder-scoped variant - request does not root-check. - - plain: the canonical-root-shard check applies for an UNPATTERNED request, or a GLOBBED request that - selects canonical root shards; an exact-named subset or an out-of-scope request does not. + - variant: the ROOT variant-shard check applies (for a NON-diffusers snapshot) for an UNPATTERNED + request, or a PATTERNED request that selects a ROOT variant weight (a globbed ``['*.safetensors']``); + a subfolder-scoped variant request does not root-check. + - plain: the canonical-root-shard check applies (for a NON-diffusers snapshot) for an UNPATTERNED + request, or a GLOBBED request that selects canonical root shards; an exact-named subset or an + out-of-scope request does not. - non-root: a PATTERNED request additionally checks any SELECTED shard index the root-model checks do not cover (a sharded adapter under ``['adapter_model*']``, a component subfolder) via ``_selected_shard_index_incomplete``; an exact-named subset defers to the exact-file presence check. - - diffusers: an UNPATTERNED plain / variant warm of a pipeline (root ``model_index.json``) reads - COMPONENT subfolders (unet/, vae/, ...), so a component shard index missing a shard -- which the - root-model checks do not cover -- is caught via ``_diffusers_component_shards_incomplete``. A - non-diffusers unpatterned request reads only the root model weight, so it does not sub-check. + - diffusers: a pipeline (root ``model_index.json``) reads COMPONENT subfolders (unet/, vae/, ...), NOT + root model shards, so the root-model checks above are SKIPPED for it (a stale root index must not + reject a complete pipeline); an UNPATTERNED warm's component shard sets are checked via + ``_diffusers_component_shards_incomplete``, and a PATTERNED one via ``_selected_shard_index_incomplete``. The ignore filter is threaded through so completeness is judged for the FORMAT the load reads (a complete safetensors set does not mask an incomplete ``.bin`` under ``ignore=['*.safetensors']``).""" @@ -1413,9 +1423,13 @@ def _readable_shard_set_incomplete( except OSError: is_diffusers = False if variant: - if allow_patterns is None or _request_selects_root_variant_weight( - allow_patterns, ignore_patterns, variant + if not is_diffusers and ( + allow_patterns is None + or _request_selects_root_variant_weight(allow_patterns, ignore_patterns, variant) ): + # A diffusers pipeline reads component-subfolder variant weights, not root model. + # shards, so a stale root variant index must not reject a complete pipeline (handled below by + # the component check); only a non-diffusers root variant load runs the root-shard check. if _has_incomplete_variant_root_shards( snapshot_dir, variant, ignore_patterns = ignore_patterns ): @@ -1434,9 +1448,11 @@ def _readable_shard_set_incomplete( return True return False if allow_patterns is None: - if _has_incomplete_canonical_root_shards( + if not is_diffusers and _has_incomplete_canonical_root_shards( snapshot_dir, ignore_patterns = ignore_patterns ): + # a non-diffusers root model load; a diffusers pipeline reads component subfolders, not root + # model shards, so a stale root index there is handled by the component check below. return True if is_diffusers and _diffusers_component_shards_incomplete( snapshot_dir, variant = None, ignore_patterns = ignore_patterns @@ -1447,9 +1463,11 @@ def _readable_shard_set_incomplete( return False if _patterns_are_exact_names(allow_patterns): return False # an exact-named subset defers to the exact-file presence check - if _request_selects_canonical_root_shards(allow_patterns, ignore_patterns) and ( - _has_incomplete_canonical_root_shards(snapshot_dir, ignore_patterns = ignore_patterns) - ): + if not is_diffusers and _request_selects_canonical_root_shards( + allow_patterns, ignore_patterns + ) and _has_incomplete_canonical_root_shards(snapshot_dir, ignore_patterns = ignore_patterns): + # non-diffusers only: a diffusers pipeline never reads root model shards (its component sets are + # checked via _selected_shard_index_incomplete below), so a stale root index must not reject it. return True return _selected_shard_index_incomplete( snapshot_dir, allow_patterns = allow_patterns, From c64a868e402da938b8949c5f6d150b9593216a6b Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Thu, 2 Jul 2026 05:49:42 +0000 Subject: [PATCH 69/82] Harden shard-path validation against Windows drive letters; stop query redaction eating trailing delimiters Two robustness fixes from a static review of the new modules: - Weight-index shard paths (weight_map values from a downloaded, possibly crafted index) were rejected only for a leading / or \ or a .. component. A Windows drive-letter value (C:\x or C:x) slipped through, and on Windows base / "C:\x" resolves OUTSIDE the snapshot, so a crafted index could make the completeness check probe an unrelated existing file and read it as a present shard. The two duplicated checks (_weight_shard_index_complete, _index_shard_rel_paths) are centralized into _is_unsafe_shard_ref, which rejects absolute, drive-letter, UNC, and parent-escaping references under BOTH PurePosixPath and PureWindowsPath semantics so a crafted index is rejected regardless of the OS running the check. A well-formed relative index is unaffected. - The signed-URL query redaction matched the query with a greedy [^\s]*, so a presigned URL embedded in structured text (a JSON body, a dict repr) with no surrounding whitespace had its trailing delimiter (the closing "} or )) swallowed and replaced by ***, corrupting the log line. The query now stops at whitespace or a structural delimiter (quote, bracket, brace, paren, angle, pipe); a genuine presigned query percent-encodes those chars, so its redaction is unchanged. Adds regression tests for both (180 -> 182). Verified against the safety-invariant fuzz (0 false-accepts). --- tests/test_hf_xet_fallback.py | 52 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 36 +++++++++++++++++------ unsloth_zoo/hf_xet_fallback.py | 9 +++++- 3 files changed, 88 insertions(+), 9 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 843377690..27c463d7d 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1171,6 +1171,25 @@ def test_scrub_redacts_presigned_url(): assert "download=true" in plain +def test_scrub_redaction_preserves_surrounding_delimiters(): + """A signed URL embedded in structured text (JSON / dict repr) has no surrounding whitespace, so the + query redaction must stop at the closing delimiter and not swallow it -- else the ``"}`` is replaced + by ``***`` and the log line's structure is corrupted (Gemini #829). The signed query is still fully + redacted.""" + embedded = ( + '{"error": "403", "url": ' + '"https://cas-bridge.xethub.hf.co/x/y?X-Amz-Signature=deadbeef&X-Amz-Expires=3600"}' + ) + out = xf._default_scrub_secrets(embedded) + assert "deadbeef" not in out # the signed query is redacted + assert "cas-bridge.xethub.hf.co/x/y?***" in out + assert out.endswith('"}') # the closing delimiters are preserved + # A signed URL wrapped in single quotes / parens keeps those delimiters too. + wrapped = "(https://s3.amazonaws.com/b/k?X-Amz-Signature=abc123) tail" + out2 = xf._default_scrub_secrets(wrapped) + assert "abc123" not in out2 and "?***)" in out2 and out2.endswith(") tail") + + def test_local_files_only_file_resolves_in_process(monkeypatch): """local_files_only resolves the single file from cache in-process and never spawns a network child (Hugging Face offline semantics).""" @@ -3149,6 +3168,39 @@ def test_gate_rejects_malformed_shard_index(tmp_path): assert hcs._weight_shard_index_complete(snap3 / "model.safetensors.index.json") is False +def test_shard_index_rejects_unsafe_path_refs(tmp_path): + """A weight-shard index is attacker-influenced (weight_map from a downloaded repo). An absolute, + Windows drive-letter, UNC, or parent-escaping shard value must be rejected so ``base / shard`` cannot + resolve to an existing file OUTSIDE the snapshot and read as "present" -- on Windows ``base / 'C:\\x'`` + escapes, which a startswith(('/', '\\\\')) check misses (Gemini #829). Both the completeness check and + the shard-path enumerator reject these, judged under POSIX and Windows semantics on any OS.""" + # Unit: the shared helper flags every escape variant and keeps legit relative names. + for bad in ["/etc/passwd", r"C:\evil.safetensors", "C:evil.safetensors", r"\\srv\share\x", + "../../x.safetensors", r"..\x.safetensors", "a/../../b"]: + assert hcs._is_unsafe_shard_ref(bad) is True, bad + for ok in ["model-00001-of-00002.safetensors", "unet/diffusion_pytorch_model.safetensors", + "model.fp16.safetensors"]: + assert hcs._is_unsafe_shard_ref(ok) is False, ok + # A crafted index listing a drive-letter shard is not "complete" (never probes outside the snapshot). + snap, blob = _mk_snapshot(tmp_path, "unsafe_shard_idx") + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": r"C:\Windows\System32\x.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert hcs._weight_shard_index_complete(snap / "model.safetensors.index.json") is False + # The enumerator returns None (defer to the child) rather than a path that escapes the snapshot. + assert hcs._index_shard_rel_paths(snap / "model.safetensors.index.json", "") is None + # A well-formed relative index still enumerates + validates normally. + snap2, blob2 = _mk_snapshot(tmp_path, "safe_shard_idx") + (snap2 / "model-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model-00002-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert hcs._weight_shard_index_complete(snap2 / "model.safetensors.index.json") is True + assert set(hcs._index_shard_rel_paths(snap2 / "model.safetensors.index.json", "")) == { + "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"} + + def test_gate_ignored_canonical_weight_does_not_prove_complete(tmp_path): """Finding 3 (over-accept): a stale canonical weight whose FORMAT the request ignores must not count as proof of completeness. ignore=['*.bin'] with only a pytorch_model.bin on disk (no diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 9207cdf9f..1f859a15e 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -29,7 +29,7 @@ import fnmatch import re import sys -from pathlib import Path +from pathlib import Path, PurePosixPath, PureWindowsPath from typing import Iterator, Optional @@ -243,6 +243,25 @@ def _is_canonical_weight_shard_index(name: str) -> bool: return name in ("model.safetensors.index.json", "pytorch_model.bin.index.json") +def _is_unsafe_shard_ref(shard: str) -> bool: + """True if a weight-index ``weight_map`` value is NOT a safe relative path inside the snapshot: an + absolute path, a Windows drive-letter reference (``C:\\x`` / ``C:x``), a UNC path, or a + parent-escaping (``..``) reference. Judged under BOTH POSIX and Windows path semantics so a crafted / + malformed index is rejected regardless of the OS running the check -- on Windows ``base / "C:\\x"`` + resolves OUTSIDE the snapshot and would read as a present shard, and ``startswith(("/", "\\"))`` alone + misses a drive-letter value. A well-formed HF index lists a plain relative basename (or subfolder + path), so a legitimate index is never rejected.""" + if not shard or shard.startswith(("/", "\\")): + return True + win = PureWindowsPath(shard) + if win.is_absolute() or win.drive or ".." in win.parts: + return True + posix = PurePosixPath(shard) + if posix.is_absolute() or ".." in posix.parts: + return True + return False + + def _weight_shard_index_complete(index_path: Path) -> bool: """True only if every shard a HF weight index lists is present next to it. @@ -268,10 +287,10 @@ def _weight_shard_index_complete(index_path: Path) -> bool: shards = set(values) base = index_path.parent for shard in shards: - # A well-formed HF index lists a relative shard basename. Reject an absolute / parent-escaping - # value (a malformed or crafted index) rather than let ``base / shard`` resolve to an unrelated - # existing file OUTSIDE the snapshot and read as "present". - if shard.startswith(("/", "\\")) or ".." in shard.replace("\\", "/").split("/"): + # A well-formed HF index lists a relative shard basename. Reject an absolute / drive-letter / + # parent-escaping value (a malformed or crafted index) rather than let ``base / shard`` resolve + # to an unrelated existing file OUTSIDE the snapshot and read as "present". + if _is_unsafe_shard_ref(shard): return False try: if not (base / shard).exists(): @@ -707,8 +726,9 @@ def _index_variant_token(name: str) -> "Optional[str]": def _index_shard_rel_paths(index_path: Path, dir_rel: str) -> "Optional[list]": """The snapshot-relative posix paths of the shards a weight index lists, or None if the index is unreadable / malformed -- mirrors the fail-CLOSED rules of ``_weight_shard_index_complete`` (a - non-dict payload or ``weight_map``, an empty shard set, or a non-string / absolute / parent-escaping - shard value all return None). *dir_rel* is the index's snapshot-relative directory ("" at root), so a + non-dict payload or ``weight_map``, an empty shard set, or a non-string / absolute / drive-letter / + parent-escaping shard value all return None). *dir_rel* is the index's snapshot-relative dir ("" at + root), so a listed basename is joined back to a full repo-relative path for the request filter.""" import json @@ -726,7 +746,7 @@ def _index_shard_rel_paths(index_path: Path, dir_rel: str) -> "Optional[list]": prefix = f"{dir_rel}/" if dir_rel else "" out: list = [] for shard in set(values): - if shard.startswith(("/", "\\")) or ".." in shard.replace("\\", "/").split("/"): + if _is_unsafe_shard_ref(shard): return None out.append(f"{prefix}{shard}") return out diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 82d9b8f24..d149201e5 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -165,7 +165,14 @@ def _redact_signed_query(match: "re.Match") -> str: return f"{base}?***" return match.group(0) - out = re.sub(r"(https?://[^\s?]+)\?([^\s]*)", _redact_signed_query, out) + # Match the query up to whitespace OR a structural delimiter (quote, bracket, brace, paren, angle, + # pipe): a signed URL embedded in JSON / a dict repr / other structured text has no surrounding + # whitespace, so a greedy [^\s]* would swallow the trailing "} / ") and replace it with ***, + # corrupting the log line. Real signed-query values percent-encode these chars, so the redaction of + # a genuine presigned URL is unaffected. + out = re.sub( + r"(https?://[^\s?]+)\?([^\s\"'()<>{}|[\]]*)", _redact_signed_query, out + ) return out From a4ef8b2c9024ef00e94f2ccde25460808fd5f5e4 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Thu, 2 Jul 2026 07:25:18 +0000 Subject: [PATCH 70/82] Scope a malformed shard index by its weight FORMAT, not the .json filename The malformed-index branch of the selected-shard and diffusers-component checks deferred to the child whenever the request's allow/ignore filter kept the index's .json name. But ignore is a WEIGHT-format filter: ignore=['*.bin'] does not match pytorch_model.bin.index.json, so a stale/truncated bin index was treated as read and a complete safetensors download (which never reads that bin index) was looped into a DownloadStallError. It is judged now on a representative shard of the index's OWN base name + format (_index_shard_probe) run through the same allow/ignore filter as the well-formed path, so an ignored-format index is correctly skipped while a selected, read-format malformed index still defers. The earlier out-of-scope guard (an adapter index under a base ['model*'] warm) is preserved, and an unrecognizable index still defers. Applied to both _selected_shard_index_incomplete and _diffusers_component_shards_incomplete. Adds a regression test (182 -> 183). Verified against the safety-invariant fuzz (0 false-accepts). Not taken (from the same review): - Require ALL declared Diffusers components before accepting: inferring which declared components are weight-bearing needs a static heuristic that false-rejects a complete download whenever a component is disabled at the repo level (safety_checker: [null, null]) or optional, breaking a working download; snapshot_download already guarantees manifest completeness on the normal path, and "at least one declared component weight present" already separates a real pipeline warm from a config-only / checkpoint-only / undeclared-leftover stale cache. - Don't skip the child for mutable revisions: the warm-cache fast path is an intentional offline-capable short-circuit for the common bare from_pretrained (allow=None); gating it on immutable commit hashes would force a per-load child spawn and an offline regression. A mutable revision that moved upstream is resolved by the caller's own load, matching snapshot_download's local-cache resolution semantics. - Match lone shards to their own index stem: the reachable model-vs-adapter co-residence at root is already handled (the root model index is excluded from the index-presence set), so the only gap is two DIFFERENT sharded stems in one selected subfolder (one indexed, one not) -- a layout the loaders do not produce. --- tests/test_hf_xet_fallback.py | 34 +++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 41 ++++++++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 27c463d7d..6576cbd6c 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -3201,6 +3201,40 @@ def test_shard_index_rejects_unsafe_path_refs(tmp_path): "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"} +def test_malformed_index_scope_honors_ignored_format(tmp_path): + """A malformed shard index is judged by the WEIGHT the load reads (a representative shard of the + index's base + format), not the .json filename. So a stale/truncated index for an IGNORED format + (a *.bin index under ignore=['*.bin']) is skipped -- the load reads safetensors and never touches it, + so a complete safetensors download must not be looped into a DownloadStallError (Codex #829). A + malformed index of the READ format is still breakage.""" + # Patterned subfolder warm reading safetensors: a co-resident malformed bin index is ignored. + snap, blob = _mk_snapshot(tmp_path, "malformed_ignored_bin_idx") + (snap / "unet").mkdir() + (snap / "unet" / "config.json").write_text("{}") + (snap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob) + (snap / "unet" / "diffusion_pytorch_model.bin.index.json").write_text("{ truncated") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = ["*.bin"]) is True + # The malformed index of the READ format (safetensors, not ignored) is still breakage. + snap2, _ = _mk_snapshot(tmp_path, "malformed_read_st_idx") + (snap2 / "unet").mkdir() + (snap2 / "unet" / "config.json").write_text("{}") + (snap2 / "unet" / "diffusion_pytorch_model.safetensors.index.json").write_text("{ truncated") + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is False + # Diffusers pipeline: a malformed component bin index under ignore=['*.bin'] does not reject a + # complete safetensors pipeline. + snap3, blob3 = _mk_snapshot(tmp_path, "malformed_diffusers_bin_idx") + (snap3 / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + for comp in ("unet", "vae"): + (snap3 / comp).mkdir() + (snap3 / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob3) + (snap3 / "unet" / "diffusion_pytorch_model.bin.index.json").write_text("{ truncated") + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is True + + def test_gate_ignored_canonical_weight_does_not_prove_complete(tmp_path): """Finding 3 (over-accept): a stale canonical weight whose FORMAT the request ignores must not count as proof of completeness. ignore=['*.bin'] with only a pytorch_model.bin on disk (no diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 1f859a15e..cbe742083 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -752,6 +752,24 @@ def _index_shard_rel_paths(index_path: Path, dir_rel: str) -> "Optional[list]": return out +def _index_shard_probe(index_name: str, dir_rel: str) -> "Optional[str]": + """A representative numbered-shard path for a weight-shard INDEX whose listed shards are unknown (a + malformed / truncated index): the index's own base name + format as a first shard, joined under + *dir_rel*. Lets the malformed-index scope check judge the request's allow / ignore filter on the + WEIGHT the load reads rather than on the ``.json`` index filename -- ``ignore=['*.bin']`` does not + match ``pytorch_model.bin.index.json`` but the load never reads that ignored-format index, so + filtering the filename would wrongly retry a complete other-format download. None when the name is not + a recognizable shard index.""" + for marker, ext in ((".safetensors.index.", "safetensors"), (".bin.index.", "bin")): + if marker in index_name: + base = index_name.split(marker, 1)[0] + if not base: + return None + prefix = f"{dir_rel}/" if dir_rel else "" + return f"{prefix}{base}-00001-of-00002.{ext}" + return None + + def _selected_shard_index_incomplete( snapshot_dir: Path, *, allow_patterns: "Optional[object]", ignore_patterns: "Optional[object]", variant: "Optional[str]", @@ -807,11 +825,15 @@ def _selected_shard_index_incomplete( index_fmts.setdefault(dir_rel, set()).add(fmt) shard_rels = _index_shard_rel_paths(entry, dir_rel) if shard_rels is None: - # A malformed / non-string index. Defer to the watched child only when the REQUEST - # selects this index (the load would read it to enumerate its shards); a co-resident - # stale malformed index the request does NOT select (a leftover adapter index under a - # base ['model*'] / subfolder warm) is not read, so it must not force a spurious retry. - if _filter_paths([rel], allow_patterns, ignore_patterns): + # A malformed / non-string index. Defer to the watched child only when the REQUEST reads + # this index's weight set -- judged on a representative shard of the index's OWN base + + # format (via the allow / ignore filter), not the .json filename. So a co-resident stale + # malformed index the request does NOT select (a leftover adapter index under a base + # ['model*'] / subfolder warm) OR one for an IGNORED format (a *.bin index under + # ignore=['*.bin']) is not read and must not force a spurious retry; an unrecognizable + # index defers to the child. + probe = _index_shard_probe(name, dir_rel) + if probe is None or _filter_paths([probe], allow_patterns, ignore_patterns): return True continue if not _filter_paths(shard_rels, allow_patterns, ignore_patterns): @@ -934,7 +956,14 @@ def _diffusers_component_shards_incomplete( index_fmts.setdefault(dir_rel, set()).add(fmt) shard_rels = _index_shard_rel_paths(entry, dir_rel) if shard_rels is None: - return True # a malformed / non-string index -> defer to the watched child + # A malformed / non-string index. Defer only when its FORMAT is read (a representative + # shard of the index's base + format survives the ignore filter); a stale malformed + # index for an IGNORED format (a *.bin component index under ignore=['*.bin']) is not + # read, so it must not force a spurious retry of a complete other-format pipeline. + probe = _index_shard_probe(name, dir_rel) + if probe is None or _filter_paths([probe], None, ignore_patterns): + return True + continue if not _filter_paths(shard_rels, None, ignore_patterns): continue # the load does not read this set (ignored format) per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) From 61a629a05085e80ad1b3d2b59775f5585021e61e Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Thu, 2 Jul 2026 09:52:04 +0000 Subject: [PATCH 71/82] Retry a transient LocalEntryNotFoundError over the other transport huggingface_hub raises LocalEntryNotFoundError in two very different situations from a child attempt (local_files_only=False): a genuine offline / uncached miss (outgoing traffic disabled), and -- crucially -- as its WRAPPER around a transient HEAD connection error / timeout for an uncached file, whose message reads "... Please check your connection and try again." The error-classification helper matched the class NAME against the deterministic set before the transient-hint check, so the transient connection wrapper (which a raw ConnectionError with the same message would have retried) was surfaced immediately instead of taking the one HTTP retry that may recover. _is_retryable_download_error now retries a LocalEntryNotFoundError ONLY when its message carries a transient hint (connection / timeout / reset / 5xx ...); a true offline miss (no hint) still falls through to the deterministic set, stays non-retryable, and keeps its type reconstructed across the spawn boundary (LocalEntryNotFoundError remains in _DETERMINISTIC_ERROR_NAMES, so _resolve_exception_class still rebuilds it). Adds a regression test (183 -> 184). Watchdog / spawn / retry-loop mechanism untouched. Not taken (same review): - Require ALL declared Diffusers components: a static "weight-bearing component" heuristic false-rejects a complete download whenever a component is disabled at the repo level (safety_checker: [null, null]) or optional -- verified it rejects a real StableDiffusionPipeline that legitimately omits safety_checker weights. POST-download stays lenient; snapshot_download already guarantees manifest completeness on the normal path, and one declared component weight already separates a real pipeline warm from a config-only / checkpoint-only / undeclared-leftover stale cache. - Keep broad globs from satisfying root model warms: the loaders warm a bare model with allow_patterns=None (routed through the canonical-scoped _root_model_has_weight), never a broad ['*.safetensors'] glob; a broad glob's load intent is ambiguous (adapter / subfolder / diffusers are all valid), so requiring a canonical root weight would false-reject those legitimate broad-glob downloads. - License metadata: the AGPL-3.0-only headers on the new files are intentional (per repo convention) and the header already documents the split from the LGPL package; reconciling the package-level license expression is a maintainer decision, out of scope. --- tests/test_hf_xet_fallback.py | 29 +++++++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 9 +++++++++ 2 files changed, 38 insertions(+) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 6576cbd6c..00676f4bb 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -912,6 +912,35 @@ class _Resp404(Exception): assert f(ValueError("unexpected response payload")) is False +def test_local_entry_not_found_transient_is_retryable(): + """huggingface_hub wraps a TRANSIENT HEAD connection error / timeout for an uncached file as + LocalEntryNotFoundError ('... check your connection and try again'). That sub-case must be retryable + (the other transport may recover), while a genuine offline miss ('outgoing traffic has been + disabled') stays deterministic and keeps its reconstructed type across the spawn boundary + (Codex #829).""" + f = xf._is_retryable_download_error + + class LocalEntryNotFoundError(Exception): + pass + + # Transient connection wrapper -> retryable. + transient = LocalEntryNotFoundError( + "An error happened while trying to locate the file on the Hub and we cannot find the " + "requested files in the local cache. Please check your connection and try again." + ) + assert f(transient) is True + timed_out = LocalEntryNotFoundError("Read timed out while fetching metadata") + assert f(timed_out) is True + # Genuine offline miss (no transient hint) -> deterministic, and still type-preserved. + offline = LocalEntryNotFoundError( + "Cannot find the requested files in the disk cache and outgoing traffic has been disabled." + ) + assert f(offline) is False + assert "LocalEntryNotFoundError" in xf._DETERMINISTIC_ERROR_NAMES + cls = xf._resolve_exception_class("LocalEntryNotFoundError") + assert cls is not None and issubclass(cls, BaseException) + + def test_immediate_success_uses_xet_only(monkeypatch): prepared = [] monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a, **k: prepared.append(a)) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index d149201e5..c8adb7325 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -608,6 +608,15 @@ def _is_retryable_download_error(exc: BaseException) -> bool: fail identically. Unknown errors are treated as deterministic, so a real repeatable failure is surfaced rather than looped between transports.""" name = type(exc).__name__ + # huggingface_hub raises LocalEntryNotFoundError BOTH for a genuine offline / uncached miss + # (deterministic) AND as its wrapper around a TRANSIENT HEAD connection error / timeout for an + # uncached file ("... Please check your connection and try again"). Retry the transient sub-case + # over the other transport; a true offline miss (no transient hint) falls through to the + # deterministic set below and keeps its reconstructed type. + if name == "LocalEntryNotFoundError" and any( + hint in f"{name}: {exc}".lower() for hint in _TRANSIENT_ERROR_HINTS + ): + return True if name in _DETERMINISTIC_ERROR_NAMES: return False # Disk full / quota: a different transport cannot help. From 26bd3d56f9067f7c91219337efb0610a9b0ccd13 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 3 Jul 2026 06:16:18 +0000 Subject: [PATCH 72/82] Defer a bin-only cache from the strict pre-download fast path when safetensors is preferred A default transformers load probes model.safetensors BEFORE pytorch_model.bin, but the pre-download fast path treated a cache holding only pytorch_model.bin as complete and handed it back without the killable child. For a repo that ALSO publishes model.safetensors (which the local cache cannot rule out), the in-process load then fetches the preferred model.safetensors over un-killable Xet -- the exact hang this fallback exists to prevent. The strict pre gate (_cache_can_skip_download) now passes prefer_safetensors=True to snapshot_dir_is_complete: a .bin weight satisfies the fast path only when safetensors is IGNORED (use_safetensors=False reads .bin), otherwise a bin-only cache defers to the child, which does the authoritative Hub check. A safetensors cache, a dual-format cache, and a use_safetensors=False bin cache all still fast-path, so the common load is unaffected. The flag is pre-only. The POST shard check (_has_incomplete_canonical_root_shards -> snapshot_dir_is_complete) leaves it False, so a finished bin-only download -- a genuinely bin-only repo -- is still accepted and never false-rejected into a DownloadStallError (POST stays lenient; a bin-only cache is not positive breakage evidence there). Adds a regression test (184 -> 185). Verified against the safety-invariant fuzz (0 false-accepts) and the stall->HTTP e2e sim. Not taken (same review): - Require ALL declared Diffusers components: a static weight-bearing-component heuristic false-rejects a complete download when a component is disabled at the repo level (safety_checker: [null, null]) or optional; POST stays lenient and one declared component weight already separates a real warm from a config-only / undeclared-leftover stale cache. - Reject noncanonical weights for broad patterned warms: a base load passes allow_patterns=None (routed through the canonical-scoped _root_model_has_weight, which this change makes safetensors-aware), never a broad ['*'] / ['*.safetensors'] glob; a broad glob's intent is ambiguous (adapter / subfolder / diffusers), so requiring a canonical root weight would false-reject those legitimate broad-glob downloads. --- tests/test_hf_xet_fallback.py | 39 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 29 ++++++++++++++++++++++--- unsloth_zoo/hf_xet_fallback.py | 5 +++++ 3 files changed, 70 insertions(+), 3 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 00676f4bb..3e00dbffb 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2338,6 +2338,45 @@ def test_pre_download_defers_variant_on_canonical_cache(tmp_path): variant = "fp16") is False +def test_pre_download_defers_bin_only_when_safetensors_preferred(tmp_path): + """A default transformers load probes model.safetensors BEFORE pytorch_model.bin. A cache holding + only pytorch_model.bin cannot prove the repo has no safetensors, so the STRICT pre-download gate must + NOT fast-path it -- else the in-process load fetches the preferred model.safetensors over un-killable + Xet (Codex #829). It still fast-paths when safetensors is IGNORED (use_safetensors=False reads bin), + or when safetensors is present. The lenient POST path still accepts a finished bin-only download (a + genuinely bin-only repo), so a good download is never false-rejected.""" + snap, blob = _mk_snapshot(tmp_path, "binonly") + (snap / "config.json").write_text("{}") + (snap / "pytorch_model.bin").symlink_to(blob) + # PRE: safetensors preferred (not ignored) + bin-only -> defer to the child. + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # PRE: use_safetensors=False (safetensors ignored) -> the bin cache fast-paths. + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, + ignore_patterns = ["*.safetensors", "*.safetensors.index.json"]) is True + # PRE: safetensors present -> fast-path (the common load is unaffected). + (snap / "model.safetensors").symlink_to(blob) + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # POST stays LENIENT: a finished bin-only download is a genuinely bin-only repo -> accepted, not + # looped into a DownloadStallError. + snap2, blob2 = _mk_snapshot(tmp_path, "binonly_post") + (snap2 / "config.json").write_text("{}") + (snap2 / "pytorch_model.bin").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # POST: a sharded bin-only repo is likewise accepted (no false-reject). + snap3, blob3 = _mk_snapshot(tmp_path, "binonly_sharded_post") + (snap3 / "pytorch_model-00001-of-00002.bin").symlink_to(blob3) + (snap3 / "pytorch_model-00002-of-00002.bin").symlink_to(blob3) + (snap3 / "pytorch_model.bin.index.json").write_text(json.dumps( + {"weight_map": {"a": "pytorch_model-00001-of-00002.bin", + "b": "pytorch_model-00002-of-00002.bin"}})) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + def test_pre_download_does_not_skip_diffusers_but_post_accepts(tmp_path): """The pre/post asymmetry: a diffusers warm is NOT fast-pathed (spawn the child), but the same complete diffusers result IS accepted post-download (it has component weights), so a good diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index cbe742083..628affbd4 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -496,7 +496,8 @@ def request_can_include_weights( def _canonical_root_weights_complete( - snapshot_dir: Path, entries: list, ignore_patterns: "Optional[list]" = None + snapshot_dir: Path, entries: list, ignore_patterns: "Optional[list]" = None, + *, prefer_safetensors: bool = False, ) -> bool: """True iff the snapshot holds a complete canonical ROOT weight set: a root ``model.safetensors`` / ``pytorch_model.bin``, OR a root shard index whose every shard is present. @@ -504,7 +505,16 @@ def _canonical_root_weights_complete( A weight whose FORMAT the ignore filter drops does not count (a stale ``pytorch_model.bin`` under ``ignore=['*.bin']`` is not proof the requested safetensors are on disk). The format probe also - discards a ``pytorch_model.bin.index.json`` whose ``.json`` name would slip the raw filter.""" + discards a ``pytorch_model.bin.index.json`` whose ``.json`` name would slip the raw filter. + + *prefer_safetensors* is set by the STRICT pre-download gate: a default transformers load probes + ``model.safetensors`` BEFORE ``pytorch_model.bin``, so when safetensors is a format the load would + read (not ignored) a bin-only cache cannot be proven complete -- the local cache cannot show the + preferred safetensors is absent remotely, and skipping the child would let the in-process load fetch + it over un-killable Xet. So a ``.bin`` weight then satisfies the gate only when safetensors is + IGNORED (``use_safetensors=False``); otherwise the bin-only cache defers to the child. The lenient + POST path leaves this False: a finished bin-only download is a genuinely bin-only repo and must not + be false-rejected into a ``DownloadStallError``.""" root_files: set = set() root_indices: list = [] for entry in entries: @@ -539,6 +549,11 @@ def _format_kept(weight_name: str) -> bool: return True if st_index is not None and _format_kept("model.safetensors"): return _weight_shard_index_complete(st_index) + if prefer_safetensors and _format_kept("model.safetensors"): + # STRICT pre-download gate: safetensors is preferred (not ignored) but absent from the cache, so + # a default load would fetch model.safetensors over un-killable Xet. A bin-only cache cannot + # prove safetensors is absent remotely -> defer to the watched child rather than fast-path. + return False if "pytorch_model.bin" in root_files and _format_kept("pytorch_model.bin"): return True if bin_index is not None and _format_kept("pytorch_model.bin"): @@ -552,6 +567,7 @@ def snapshot_dir_is_complete( allow_patterns: "Optional[object]" = None, ignore_patterns: "Optional[object]" = None, require_named_weights: bool = False, + prefer_safetensors: bool = False, ) -> bool: """Conservative fast-path gate: True only for an unambiguously complete canonical ROOT model cache, so an in-process load will not fetch a weight. True requires: an UNPATTERNED request @@ -560,6 +576,11 @@ def snapshot_dir_is_complete( True risks a silent Xet fetch; a false False only spawns the cheap child. *require_named_weights* is accepted for signature compatibility (a named-weight request is patterned, so never fast-pathed). + *prefer_safetensors* (set by the strict pre-download gate) rejects a bin-only cache when a default + load would prefer safetensors (not ignored): the local cache cannot prove the preferred file is + absent remotely, so fast-pathing it would let the in-process load fetch it over Xet. The POST caller + leaves it False so a genuinely bin-only download is still accepted. + *ignore_patterns* need no eligibility gate: the canonical-weight check below is what the load reads, so an ignore that dropped some format (the common ``*.bin`` / subdir prefetch ignores) cannot make an incomplete cache read complete -- keeping the common warm ``from_pretrained`` cache eligible.""" @@ -575,7 +596,9 @@ def snapshot_dir_is_complete( return False # diffusers needs component reasoning we do not fast-path if snapshot_dir_has_broken_symlinks(snapshot_dir): return False # interrupted blob - return _canonical_root_weights_complete(snapshot_dir, entries, ignore_patterns) + return _canonical_root_weights_complete( + snapshot_dir, entries, ignore_patterns, prefer_safetensors = prefer_safetensors + ) # A canonical numbered root shard: the index sits IMMEDIATELY before the extension (no variant token), diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index c8adb7325..82f3fd153 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1367,8 +1367,13 @@ def _cache_can_skip_download( # the variant too). if variant: return False + # STRICT: a default load probes model.safetensors before pytorch_model.bin, so a bin-only cache + # for a repo that also publishes safetensors (which the local cache cannot rule out) would fetch + # the preferred safetensors in-process over Xet. prefer_safetensors defers such a cache to the + # child; a use_safetensors=False request (safetensors ignored) still fast-paths its bin cache. return snapshot_dir_is_complete( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + prefer_safetensors = True, ) # Weightless / non-model: skip only for an intact exact-named subset. A None / glob request cannot # be proven complete from local files, so defer to the child for the manifest compare + resume. From 21fcb5580dad5e8684d6555390ae8e2d3bf7030f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 3 Jul 2026 08:10:21 +0000 Subject: [PATCH 73/82] Exclude variant component weights from the plain diffusers post-download check _has_diffusers_component_weight counted ANY loadable weight in a declared component subfolder, including a variant one (unet/diffusion_pytorch_model.fp16.safetensors). So a PLAIN diffusers warm (variant=None) whose silently-stale snapshot_download returned a cache left by a prior variant='fp16' pull passed _download_result_usable, and the in-process DiffusionPipeline load -- which reads the NON-variant diffusion_pytorch_model.safetensors -- then fetched the missing plain weight over un-killable Xet, the hang this fallback exists to prevent. The plain component-presence check now counts only CANONICAL (non-variant) component weights via _CANONICAL_COMPONENT_WEIGHT_RE (a base with no intermediate dotted token, single or numbered shard, safetensors/bin). This mirrors the checks that were already variant-aware: the root check (_CANONICAL_ROOT_MODEL_WEIGHT_RE excludes model.fp16.safetensors) and the plain component SHARD regex in _diffusers_component_shards_incomplete. A complete plain pipeline, a pipeline shipping both plain and variant weights, and the variant='fp16' warm of the same cache are all still accepted -- the variant analog (_has_diffusers_component_variant_weight) is untouched, so no complete download is false-rejected into a DownloadStallError. Adds a regression test (185 -> 186). Safety-invariant fuzz stays at 0 false-accepts; stall->HTTP e2e sim stays green. Not taken (same review): - Ignore training checkpoints for patterned model warms (allow=['*'] / ['*.bin']): not loader-reachable -- a base load passes allow_patterns=None (root-scoped, checkpoint subtrees already excluded), never a broad ['*'] / ['*.bin']; the patterned warms are gguf / adapter / subfolder-scoped. - Preserve variant tokens in malformed-index probes (allow=['unet/*fp16*']): not loader-reachable -- a variant diffusers warm passes allow=None or the subfolder ['unet/*'] glob, both of which already flag a malformed variant component index (the probe survives an allow=None or 'unet/*' filter); no loader emits a variant-infix glob. - Validate exact sharded weight requests: not loader-reachable -- every loader warm carries a glob (*.py, ...), so _patterns_are_exact_names is never true for a loader; a direct caller that lists an index but only some of its shards is honoring exactly the files it named. --- tests/test_hf_xet_fallback.py | 41 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 19 +++++++++++++++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 3e00dbffb..b35d6b4e8 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2489,6 +2489,47 @@ def test_post_download_rejects_patterned_canonical_only_for_variant(tmp_path): variant = "fp16") is False +def test_post_download_rejects_variant_only_diffusers_for_plain_load(tmp_path): + """A PLAIN diffusers warm (variant=None) whose returned partial kept only a prior variant='fp16' + download's component weights (unet/diffusion_pytorch_model.fp16.safetensors) must be rejected: the + plain pipeline load reads the NON-variant name, so accepting it would fetch the missing + diffusion_pytorch_model.safetensors in-process over un-killable Xet (Codex #829). A complete plain + pipeline, and a pipeline shipping both plain + variant, are still accepted; the variant='fp16' warm + of the same variant-only cache stays accepted (the plain restriction does not touch the variant + check).""" + def _mi(**comps): + data = {"_class_name": "StableDiffusionPipeline", "_diffusers_version": "0.21.0"} + data.update(comps) + return json.dumps(data) + + snap, blob = _mk_snapshot(tmp_path, "plainvaronly") + (snap / "model_index.json").write_text( + _mi(unet = ["diffusers", "UNet2DConditionModel"], vae = ["diffusers", "AutoencoderKL"])) + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) + # plain load: variant-only components do not satisfy it -> retry over HTTP. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = None) is False + # the SAME cache is a complete fp16 warm -> the variant load accepts it (no regression). + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # a COMPLETE plain pipeline (non-variant component weights) is accepted (no false-reject). + snap2, blob2 = _mk_snapshot(tmp_path, "plaincomplete") + (snap2 / "model_index.json").write_text( + _mi(unet = ["diffusers", "UNet2DConditionModel"], vae = ["diffusers", "AutoencoderKL"])) + for comp in ("unet", "vae"): + (snap2 / comp).mkdir() + (snap2 / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = None) is True + # a pipeline shipping BOTH plain + fp16 in a component is accepted for a plain load. + (snap2 / "unet" / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = None) is True + + def test_post_download_rejects_incomplete_sharded_glob(tmp_path): """A globbed weight request (allow=['*.safetensors']) whose returned partial has a canonical shard index but is missing a shard must be rejected -- globs get the same shard-completeness check as the diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 82f3fd153..6ad2288a5 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1036,6 +1036,17 @@ def _is_default_load_weight_file(name: str) -> bool: r"^(?:model|pytorch_model)(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" ) +# A CANONICAL (non-variant) diffusers component weight name -- what a PLAIN pipeline load reads inside a +# component subfolder: a base with no intermediate dotted token (diffusion_pytorch_model / model, single +# or numbered shard), safetensors or bin. A VARIANT weight (diffusion_pytorch_model.fp16.safetensors) +# carries an extra dotted token before the extension and is EXCLUDED here, so a stale cache left by a +# prior variant='fp16' download does not read as a warm PLAIN pipeline -- the in-process +# DiffusionPipeline load (reading the non-variant name) would otherwise fetch it over un-killable Xet. +# This mirrors the root check (_CANONICAL_ROOT_MODEL_WEIGHT_RE) and the plain component shard regex. +_CANONICAL_COMPONENT_WEIGHT_RE = re.compile( + r"^[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" +) + # A training-checkpoint subdir (checkpoint-500/, checkpoint_7/): its weights are never read as diffusers # pipeline COMPONENTS, so they must not mask missing unet/vae/text-encoder weights. _CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") @@ -1051,13 +1062,19 @@ def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any does not read as a component) and training-checkpoint subtrees (checkpoint-N/). A malformed / empty ``model_index.json`` fails OPEN (any component subfolder counts). Stays lenient on WHICH declared components are required (a pipeline's components can be optional): it only tells a real component warm - from an undeclared-leftover / checkpoint-only / config-only stale snapshot.""" + from an undeclared-leftover / checkpoint-only / config-only stale snapshot. Counts only CANONICAL + (non-variant) component weights (``_CANONICAL_COMPONENT_WEIGHT_RE``): a variant weight + (``unet/diffusion_pytorch_model.fp16.safetensors`` left by a prior ``variant='fp16'`` warm) is not + what a PLAIN pipeline load reads, so a variant-only stale cache is retried over HTTP rather than + loaded (its non-variant component weight is still missing).""" declared = _diffusers_declared_components(snapshot_dir) rels: list = [] try: for entry in snapshot_dir.rglob("*"): if not _is_default_load_weight_file(entry.name): continue + if not _CANONICAL_COMPONENT_WEIGHT_RE.match(entry.name): + continue # a VARIANT component weight -- a plain load reads the non-variant name try: if not entry.is_file(): continue From 930d02773dfdc9aeb1eec10a338e9d278c1d095f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 3 Jul 2026 09:27:31 +0000 Subject: [PATCH 74/82] Accept from_tf/from_flax weights, fix the no-psutil resumed-partial watchdog, and honor exact variant shard requests Three fixes from a consolidated 15-reviewer pass (12x reviewer.py personas + 3 forked reviewers). 1. from_tf / from_flax POST false-reject (HIGH, loader-reachable). A base load with from_tf=True / from_flax=True prefetches with ignore=['*.safetensors','*.bin',...] and reads tf_model.h5 / flax_model.msgpack. _root_model_has_weight recognized a root weight only via _CANONICAL_ROOT_MODEL_WEIGHT_RE (safetensors/bin), so a COMPLETE h5/msgpack download read as "no weight" -> POST rejected it -> HTTP retry -> same h5-only result -> DownloadStallError on a working load. It now also counts a canonical TF/Flax root weight (_CANONICAL_ROOT_TF_FLAX_WEIGHT_RE), gated on _pytorch_root_weight_formats_ignored (BOTH model.safetensors and pytorch_model.bin dropped by the ignore filter -- the from_tf/from_flax signature). A normal load keeps a PyTorch format so the gate is False and behavior is unchanged; a stray leftover h5 under a normal load still does not count. 2. Single-file watchdog resumed-partial hang without pid ownership (Med, Windows/macOS without psutil). When a child_pid is supplied but its open files cannot be inspected (no psutil AND no /proc), _measure followed only post-baseline partials; a RESUMED partial reuses a baseline blob-hash name, so a frozen Xet resume was excluded forever and never tripped the stall -- the exact hang the fallback exists to prevent. That branch now falls back to the repo-wide measure (as snapshots use), so a resumed partial is watched. The child_pid=None path (no pid at all) keeps the post-baseline name filter unchanged; the precise open-fds path (psutil / /proc) is untouched. 3. Exact-named variant shard requests false-rejected (Low). The plain branch skipped shard-set completeness for an exact-name allow list, but the variant branch did not, so allow=['model.fp16-00001-of-00002.safetensors'] + variant='fp16' -- a satisfied exact request -- was rejected as an incomplete shard set. The _patterns_are_exact_names escape now sits above both branches, so an exact-named subset (variant or plain) defers to the exact-file presence check. Adds 3 regression tests (186 -> 189). Safety-invariant fuzz stays at 0 false-accepts (PRE stays conservative -- these are all POST leniency / watchdog fixes); stall->HTTP e2e sim stays green. Not taken (same review pass): - Injected prepare_for_http_fn missing owned_incomplete_blobs (flagged by several personas): by design and documented at the call site -- an injected (Studio) hook owns its own marker-based cache accounting and keeps the plain (repo_type, repo_id) signature; the default hook gets the owned set. - Missing HF kwargs (local_dir / max_workers / etag_timeout): the helper is a cache-based prefetch wrapper, not a full snapshot_download drop-in; local_dir bypasses the cache the completeness gate reasons about. Current callers pass only what the prefetch needs. - token dropped on the in-process cache probes: both are local_files_only=True (cache-only, no network), so no auth is needed there. - License metadata (LGPL vs AGPL): out of scope per maintainer. - Snapshot repo-wide stall masked by a concurrent same-repo pull in a separate process: documented tradeoff; per-pid snapshot scoping would risk false-stalling a progressing snapshot. --- tests/test_hf_xet_fallback.py | 70 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 64 +++++++++++++++++++++++++------ 2 files changed, 123 insertions(+), 11 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index b35d6b4e8..59252a759 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -328,6 +328,29 @@ def test_file_watchdog_detects_resumed_baseline_partial(hf_cache): child.wait(timeout = 5) +def test_file_watchdog_resumed_partial_fires_without_pid_ownership(hf_cache, monkeypatch): + """No psutil AND no /proc (native Windows / macOS without psutil): _child_open_incomplete_blobs + returns None, so per-child ownership is unknowable. Excluding baseline names would let a RESUMED + partial (which reuses a baseline blob-hash name) hang forever. The None fallback drops to the + repo-wide measure, so a frozen resumed baseline partial still trips the stall instead of hanging.""" + blobs = _blobs_dir(hf_cache) + (blobs / "resumed.incomplete").write_bytes(b"\0" * 4096) # leftover resume, constant (hung) + monkeypatch.setattr(xf, "_child_open_incomplete_blobs", lambda pid: None) # no psutil / no /proc + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + watch_new_partials_only = True, baseline_incomplete_blobs = {"resumed.incomplete"}, + child_pid = 4242, # non-None, but open-file inspection yields None -> repo-wide fallback + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), ( + "watchdog never fired on a resumed baseline partial when pid ownership is unavailable" + ) + finally: + stop.set() + + def test_file_watchdog_pid_scope_ignores_unowned_sibling(hf_cache): """With pid scoping, a sibling partial this child does NOT hold open is ignored even if it grows, so the child's own constant partial still trips the stall.""" @@ -2655,6 +2678,53 @@ def test_post_download_accepts_exact_named_shard_subset(tmp_path): allow_patterns = ["model-00001-of-00002.safetensors"], ignore_patterns = None) is False +def test_post_download_accepts_from_tf_flax_weights(tmp_path): + """A from_tf / from_flax base load ignores BOTH PyTorch formats (ignore=['*.safetensors','*.bin', + ...]) and reads tf_model.h5 / flax_model.msgpack. A COMPLETE such download must be accepted, not + false-rejected into a DownloadStallError because the canonical safetensors/bin check found nothing + (#829 re-review). Gated on both PyTorch formats ignored, so a normal load and a stray leftover h5 do + not change.""" + ig = ["*.safetensors", "*.safetensors.index.json", "*.bin", "*.bin.index.json"] + for wt in ("tf_model.h5", "flax_model.msgpack"): + snap, blob = _mk_snapshot(tmp_path, f"tf_{wt}") + (snap / wt).symlink_to(blob) + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is True + # Both PyTorch formats ignored but NO h5/msgpack present -> still rejected (weight missing). + snap, _ = _mk_snapshot(tmp_path, "tf_none") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is False + # A normal load (PyTorch format NOT ignored) is unchanged: a stray leftover h5 must NOT count as + # the readable weight, so a repo holding only tf_model.h5 is rejected for a default (non-tf) load. + snap, blob = _mk_snapshot(tmp_path, "stray_h5") + (snap / "tf_model.h5").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_accepts_exact_named_variant_shard_subset(tmp_path): + """A caller naming an EXACT variant shard (allow=['model.fp16-00001-of-00002.safetensors'] + + variant='fp16') asked for precisely that file; once present the result is accepted even though its + index / sibling shard is absent. The exact-name escape applies to the VARIANT branch too, not only + the plain one, so a satisfied exact variant request is not failed into a DownloadStallError + (#829 re-review).""" + snap, blob = _mk_snapshot(tmp_path, "exact_var_shard") + (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["model.fp16-00001-of-00002.safetensors"], ignore_patterns = None, + variant = "fp16") is True + # The exact-named variant shard absent -> still rejected. + snap2, _ = _mk_snapshot(tmp_path, "exact_var_absent") + (snap2 / "config.json").write_text("{}") + assert xf._download_result_usable( + snap2, repo_type = "model", + allow_patterns = ["model.fp16-00001-of-00002.safetensors"], ignore_patterns = None, + variant = "fp16") is False + + def test_post_download_rejects_patterned_incomplete_variant_shards(tmp_path): """A GLOBBED variant request (allow=['*.safetensors'] + variant='fp16') whose partial kept only a lone root variant shard without its index / remaining shards must be rejected too -- the diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 6ad2288a5..e2a61a902 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -408,9 +408,19 @@ def _measure() -> Optional[tuple[int, bool]]: # name, excludes siblings). hf_xet holds the .incomplete fd continuously, so an EMPTY # set means the child owns no partial YET (connect / metadata phase), not a sibling's. owned = {name: n for name, n in sizes.items() if name in open_names} - else: - # No psutil / /proc: fall back to following only newly-created (post-baseline) partials. - owned = {name: n for name, n in sizes.items() if name not in baseline} + return (sum(owned.values()), len(owned) > 0) + if child_pid: + # A pid was given but its open files cannot be inspected (no psutil AND no /proc: native + # Windows / macOS without psutil). Post-baseline name filtering would EXCLUDE a resumed + # partial that reuses a baseline blob name forever, so a frozen Xet resume never trips the + # watchdog and the hang persists -- defeating the fallback. Fall back to the repo-wide + # measure (as the snapshot path uses): a resumed partial is then watched; a concurrent + # same-repo sibling's progress may mask this child's stall, the accepted snapshot tradeoff. + return get_hf_download_state( + [single_repo_id], repo_type = repo_type, cache_dir = cache_dir + ) + # No child pid at all: follow only newly-created (post-baseline) partials. + owned = {name: n for name, n in sizes.items() if name not in baseline} return (sum(owned.values()), len(owned) > 0) return get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) @@ -1047,6 +1057,24 @@ def _is_default_load_weight_file(name: str) -> bool: r"^[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" ) +# A CANONICAL root TF / Flax weight name (transformers TF2_WEIGHTS_NAME / FLAX_WEIGHTS_NAME, single or +# numbered shard): what a from_tf / from_flax load reads instead of a PyTorch format. +_CANONICAL_ROOT_TF_FLAX_WEIGHT_RE = re.compile( + r"^(?:tf_model(?:-\d{5}-of-\d{5})?\.h5|flax_model(?:-\d{5}-of-\d{5})?\.msgpack)$" +) + + +def _pytorch_root_weight_formats_ignored(ignore_patterns: Any) -> bool: + """True when the request's ignore filter drops BOTH canonical PyTorch root weights + (``model.safetensors`` AND ``pytorch_model.bin``) -- the signature of a ``from_tf`` / ``from_flax`` + load, whose prefetch ignores ``*.safetensors`` + ``*.bin`` and keeps ``*.h5`` / ``*.msgpack``. Lets + the readable-weight check count the TF/Flax weight the load actually reads (``tf_model.h5`` / + ``flax_model.msgpack``) rather than false-reject a complete h5/msgpack download into a + ``DownloadStallError``. Never true for a normal load (which keeps at least one PyTorch format).""" + return not _filter_paths( + ["model.safetensors", "pytorch_model.bin"], None, ignore_patterns + ) + # A training-checkpoint subdir (checkpoint-500/, checkpoint_7/): its weights are never read as diffusers # pipeline COMPONENTS, so they must not mask missing unet/vae/text-encoder weights. _CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") @@ -1114,20 +1142,31 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - if is_diffusers: return _has_diffusers_component_weight(snapshot_dir, ignore_patterns = ignore_patterns) rels: list = [] + tf_flax_rels: list = [] try: for entry in snapshot_dir.iterdir(): name = entry.name - if not _CANONICAL_ROOT_MODEL_WEIGHT_RE.match(name): - continue # only a canonical model.safetensors / pytorch_model.bin (single or shard) is - # read by a default load -- an adapter, variant, gguf, or consolidated.* is not try: - if entry.is_file(): - rels.append(name) + if not entry.is_file(): + continue except OSError: continue + if _CANONICAL_ROOT_MODEL_WEIGHT_RE.match(name): + rels.append(name) # a canonical model.safetensors / pytorch_model.bin (single or shard) + elif _CANONICAL_ROOT_TF_FLAX_WEIGHT_RE.match(name): + tf_flax_rels.append(name) # a TF/Flax root weight (from_tf / from_flax) except OSError: return False - return bool(_filter_paths(rels, None, ignore_patterns)) + if _filter_paths(rels, None, ignore_patterns): + return True + # from_tf / from_flax: the ignore filter drops BOTH canonical PyTorch formats, so the load reads a + # TF (tf_model.h5) / Flax (flax_model.msgpack) root weight the safetensors/bin check above cannot + # see. Count that surviving root weight so a complete from_tf/from_flax download is not + # false-rejected into a DownloadStallError. Gated on "both PyTorch formats ignored", so a normal + # load (which keeps a PyTorch format) is unchanged and a stray leftover h5/msgpack never counts. + if tf_flax_rels and _pytorch_root_weight_formats_ignored(ignore_patterns): + return bool(_filter_paths(tf_flax_rels, None, ignore_patterns)) + return False def _root_has_variant_weight( @@ -1460,6 +1499,11 @@ def _readable_shard_set_incomplete( is_diffusers = (snapshot_dir / "model_index.json").is_file() except OSError: is_diffusers = False + if _patterns_are_exact_names(allow_patterns): + # An exact-named subset (variant or plain) defers to the exact-file presence check: the load + # reads exactly the named shard(s), so a lone exact variant shard is not judged against its + # (unrequested) index -- else a valid exact request is false-rejected into a DownloadStallError. + return False if variant: if not is_diffusers and ( allow_patterns is None @@ -1499,8 +1543,6 @@ def _readable_shard_set_incomplete( # component shard index missing a shard is not covered by the canonical ROOT-shard check. return True return False - if _patterns_are_exact_names(allow_patterns): - return False # an exact-named subset defers to the exact-file presence check if not is_diffusers and _request_selects_canonical_root_shards( allow_patterns, ignore_patterns ) and _has_incomplete_canonical_root_shards(snapshot_dir, ignore_patterns = ignore_patterns): From b2089bfd95c33dcd4a67bad27ea6e98a6ea87571 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 3 Jul 2026 10:24:50 +0000 Subject: [PATCH 75/82] Complete the from_tf/from_flax fix for sharded weights, check explicit-checkpoint shard sets, harden the process-group kill, and drop dead code Four fixes from a second consolidated review pass (12x reviewer.py personas + 3 forked reviewers) on 930d0277. 1. Sharded TF/Flax POST false-accept -> hang (from the prior from_tf/from_flax fix). TF/Flax weights can be SHARDED (transformers TF2_WEIGHTS_INDEX_NAME=tf_model.h5.index.json, FLAX_WEIGHTS_INDEX_NAME=flax_model.msgpack.index.json). _CANONICAL_ROOT_TF_FLAX_WEIGHT_RE matched the shard form, so a LONE shard read as a present weight and _download_result_usable accepted an INCOMPLETE sharded set -> the in-process load then fetched the missing shard over un-killable Xet. The regex is now single-file only (tf_model.h5 / flax_model.msgpack); a sharded set is judged through its index via _weight_shard_index_complete (every listed shard present). Complete sharded and single-file downloads are accepted; an incomplete shard set or a lone shard without its index is retried over HTTP. 2. Explicit-checkpoint shard set accepted incomplete -> hang. _selected_shard_index_incomplete skipped EVERY training-checkpoint subtree, but subfolder=checkpoint-N (allow=['checkpoint-N/*']) is an explicit load of that checkpoint. A lone checkpoint-7/model-00001-of-00002.safetensors with no index was accepted, then the load fetched the missing index/shards over Xet. The skip is now conditional (_request_scopes_into_dir): a checkpoint the request explicitly targets is completeness-checked, while a leftover checkpoint a base/adapter/other-subfolder warm does not read is still skipped (no false-reject). 3. _terminate_process_group could signal the WRONG process group. The SIGTERM path called os.killpg(pid) whenever pid was non-null; before the child runs os.setsid() its pid is not yet a pgid, and a recycled pid could collide with an unrelated group. It now signals the group only when os.getpgid(pid) == pid (the child is confirmed its own leader), else falls back to a single-process signal -- matching the caution the post-reap SIGKILL branch already documents. 4. Removed dead _has_any_weight (defined, never called; the readable-weight checks supersede it) and the two docstrings that named it. Adds 2 regression tests (189 -> 191). Safety-invariant fuzz stays at 0 false-accepts; stall->HTTP e2e sim (which exercises the kill path) stays green. Not taken (same review pass): - Require ALL declared Diffusers components in POST: a static heuristic false-rejects a complete pipeline that declares weightless components (scheduler/tokenizer/feature_extractor, safety_checker:[null,null]); POST stays lenient and one declared component weight already separates a real warm from a config-only stale cache. - Injected prepare_for_http_fn missing owned_incomplete_blobs (5 personas): by design and documented at the call site -- an injected (Studio) hook owns its own marker-based cache accounting and keeps the plain (repo_type, repo_id) signature. - Transient-error / crashed-child HTTP retry uses the coarse mtime guard rather than an owned-partial set: the stall path captures owned partials from the child's open fds, but a crashed/errored child has no fds to inspect, so the coarse guard (spares actively-progressing siblings) is the only option there -- inherent, not an asymmetry that can be closed. - Broad ['*.safetensors'] base-model warm accepts an adapter-only cache: not loader-reachable (a base load passes allow_patterns=None, never a broad glob). - Missing HF kwargs (local_dir / max_workers / etag_timeout): the helper is a cache-based prefetch wrapper, not a full snapshot_download drop-in; local_dir bypasses the cache the completeness gate reasons about. Current callers pass only what the prefetch needs. --- tests/test_hf_xet_fallback.py | 63 +++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 18 ++++++- unsloth_zoo/hf_xet_fallback.py | 86 +++++++++++++++++----------------- 3 files changed, 123 insertions(+), 44 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 59252a759..047069a52 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2704,6 +2704,69 @@ def test_post_download_accepts_from_tf_flax_weights(tmp_path): snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False +def test_post_download_checks_sharded_tf_flax_completeness(tmp_path): + """TF / Flax weights can be SHARDED (tf_model.h5.index.json / flax_model.msgpack.index.json). A + COMPLETE sharded set (index + all shards) is accepted, but an INCOMPLETE one (a shard missing, or a + lone shard with no index) must be rejected: the single-file regex no longer matches a lone shard, so + an incomplete sharded from_tf/from_flax download is retried over HTTP instead of loaded over Xet + (#829 re-review, sharded-TF false-accept).""" + ig = ["*.safetensors", "*.safetensors.index.json", "*.bin", "*.bin.index.json"] + for base, ext in (("tf_model", "h5"), ("flax_model", "msgpack")): + idx = json.dumps({"weight_map": {"a": f"{base}-00001-of-00002.{ext}", + "b": f"{base}-00002-of-00002.{ext}"}}) + # Complete sharded set -> accepted. + snap, blob = _mk_snapshot(tmp_path, f"tfshard_ok_{base}") + (snap / f"{base}-00001-of-00002.{ext}").symlink_to(blob) + (snap / f"{base}-00002-of-00002.{ext}").symlink_to(blob) + (snap / f"{base}.{ext}.index.json").write_text(idx) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is True + # A shard listed by the index is missing -> rejected. + snap2, blob2 = _mk_snapshot(tmp_path, f"tfshard_missing_{base}") + (snap2 / f"{base}-00001-of-00002.{ext}").symlink_to(blob2) + (snap2 / f"{base}.{ext}.index.json").write_text(idx) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is False + # A lone shard with NO index -> rejected (the load cannot enumerate the set). + snap3, blob3 = _mk_snapshot(tmp_path, f"tfshard_lone_{base}") + (snap3 / f"{base}-00001-of-00002.{ext}").symlink_to(blob3) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is False + + +def test_post_download_checks_explicit_checkpoint_shard_completeness(tmp_path): + """An EXPLICIT checkpoint load (subfolder=checkpoint-N -> allow=['checkpoint-N/*']) reads the + checkpoint's weights, so a lone numbered shard there with no index must be rejected, not skipped as a + 'leftover checkpoint subtree' (#829 re-review, checkpoint false-accept). A complete checkpoint shard + set is accepted; a leftover checkpoint the request does NOT target (subfolder=unet) is still ignored + so a complete in-scope download is not false-rejected.""" + # Lone checkpoint shard, no index, explicitly requested -> rejected. + snap, blob = _mk_snapshot(tmp_path, "ckpt_lone") + (snap / "checkpoint-7").mkdir() + (snap / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["checkpoint-7/*"], ignore_patterns = None) is False + # Complete checkpoint shard set (index + all shards) -> accepted. + snap2, blob2 = _mk_snapshot(tmp_path, "ckpt_complete") + (snap2 / "checkpoint-7").mkdir() + (snap2 / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "checkpoint-7" / "model-00002-of-00002.safetensors").symlink_to(blob2) + (snap2 / "checkpoint-7" / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["checkpoint-7/*"], ignore_patterns = None) is True + # A leftover checkpoint the request does NOT target (subfolder=unet) must not false-reject a complete + # in-scope download. + snap3, blob3 = _mk_snapshot(tmp_path, "ckpt_leftover") + (snap3 / "unet").mkdir() + (snap3 / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob3) + (snap3 / "checkpoint-7").mkdir() + (snap3 / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob3) # lone, but not read + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is True + + def test_post_download_accepts_exact_named_variant_shard_subset(tmp_path): """A caller naming an EXACT variant shard (allow=['model.fp16-00001-of-00002.safetensors'] + variant='fp16') asked for precisely that file; once present the result is accepted even though its diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 628affbd4..6ac651b2b 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -793,6 +793,17 @@ def _index_shard_probe(index_name: str, dir_rel: str) -> "Optional[str]": return None +def _request_scopes_into_dir(allow_patterns: "Optional[list]", dir_name: str) -> bool: + """True when an allow pattern names *dir_name* as a literal leading path segment + (``subfolder=checkpoint-7`` -> ``allow=['checkpoint-7/*']``), i.e. the load reads INTO that + directory. Lets the shard-completeness check skip a leftover checkpoint subtree the request does not + target, while still validating a checkpoint the request explicitly loads from.""" + for p in allow_patterns or (): + if isinstance(p, str) and "/" in p and p.split("/", 1)[0] == dir_name: + return True + return False + + def _selected_shard_index_incomplete( snapshot_dir: Path, *, allow_patterns: "Optional[object]", ignore_patterns: "Optional[object]", variant: "Optional[str]", @@ -870,7 +881,12 @@ def _selected_shard_index_incomplete( or (want_variant is not None and _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name)) ): continue - if any(_CHECKPOINT_DIR_RE.match(p) for p in rel.split("/")[:-1]): + ckpt_dirs = [p for p in rel.split("/")[:-1] if _CHECKPOINT_DIR_RE.match(p)] + if ckpt_dirs and not _request_scopes_into_dir(allow_patterns, ckpt_dirs[0]): + # a leftover training-checkpoint subtree the request does not explicitly target (a base / + # adapter / other-subfolder warm never reads it). But an EXPLICIT checkpoint load + # (subfolder=checkpoint-N -> allow=['checkpoint-N/*']) DOES read it, so its shard set must + # be checked for completeness rather than silently accepted as a lone shard. continue if not _filter_paths([rel], allow_patterns, ignore_patterns): continue # the load does not read this shard (out of scope / ignored format) diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index e2a61a902..7787d9e5e 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -53,6 +53,7 @@ _has_incomplete_variant_root_shards, _is_loadable_weight_file, _selected_shard_index_incomplete, + _weight_shard_index_complete, blob_bytes_present, has_active_incomplete_blobs, hf_cache_root, @@ -749,18 +750,24 @@ def _download_child_entry( def _terminate_process_group(proc: "mp.process.BaseProcess", grace_period: float) -> None: """Kill *proc* and its whole process group (Xet may spawn helpers). The child ``os.setsid()``s so - its pgid equals its pid; signal via ``os.killpg(pid, ...)`` -- NOT ``getpgid``, which before the - child is a group leader resolves to OUR group. SIGTERM, then SIGKILL after *grace_period*.""" + its pgid equals its pid; the group is signalled via ``os.killpg(pid, ...)`` only once the child is + confirmed its own leader (``os.getpgid(pid) == pid``). SIGTERM, then SIGKILL after *grace_period*.""" pid = proc.pid def _signal_group(sig: int) -> None: - if pid is not None and hasattr(os, "killpg"): + # Signal the whole GROUP only once the child is confirmed its own leader (setsid done): its pgid + # then equals its pid. BEFORE setsid the child is still in OUR group, and its freshly-allocated + # pid could collide with an unrelated recycled process group -- so ``getpgid(pid) != pid`` guards + # against ``killpg(pid)`` targeting the WRONG group; a reaped child raises here. Fall through to a + # single-process signal in all those cases (also Windows, which has no killpg / getpgid). + if pid is not None and hasattr(os, "killpg") and hasattr(os, "getpgid"): try: - os.killpg(pid, sig) - return + if os.getpgid(pid) == pid: + os.killpg(pid, sig) + return except (ProcessLookupError, PermissionError, OSError): pass - # Windows or pre-setsid: best effort on the single process. + # Windows, pre-setsid, or the child is not (yet) its own group leader: signal the single process. try: proc.terminate() if sig != getattr(signal, "SIGKILL", -9) else proc.kill() except Exception: @@ -1006,27 +1013,6 @@ def _intact_subset( ) -def _has_any_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: - """True if the snapshot holds at least one loadable weight anywhere (root or subfolder) that the - request's ignore filter keeps. Lenient: it only tells a real model warm from a config-only stale - snapshot, without classifying layout. The ignore filter matters for diffusers, whose component - weights live in subfolders -- a partial holding only the ignored format (``unet/*.bin`` under - ``ignore=['*.bin']``) is not a usable weight for a safetensors load.""" - rels: list = [] - try: - for entry in snapshot_dir.rglob("*"): - if not _is_loadable_weight_file(entry.name): - continue - try: - if entry.is_file(): - rels.append(entry.relative_to(snapshot_dir).as_posix()) - except (OSError, ValueError): - continue - except OSError: - return False - return bool(_filter_paths(rels, None, ignore_patterns)) - - def _is_default_load_weight_file(name: str) -> bool: """A weight in a format a DEFAULT ``from_pretrained`` reads: safetensors or bin only. Excludes gguf / pt / pth / onnx / msgpack / ... -- a default (non-format-specific) transformers / diffusers load does @@ -1057,11 +1043,14 @@ def _is_default_load_weight_file(name: str) -> bool: r"^[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" ) -# A CANONICAL root TF / Flax weight name (transformers TF2_WEIGHTS_NAME / FLAX_WEIGHTS_NAME, single or -# numbered shard): what a from_tf / from_flax load reads instead of a PyTorch format. -_CANONICAL_ROOT_TF_FLAX_WEIGHT_RE = re.compile( - r"^(?:tf_model(?:-\d{5}-of-\d{5})?\.h5|flax_model(?:-\d{5}-of-\d{5})?\.msgpack)$" -) +# A SINGLE-FILE canonical root TF / Flax weight (transformers TF2_WEIGHTS_NAME / FLAX_WEIGHTS_NAME): +# what a from_tf / from_flax load reads instead of a PyTorch format. A SHARDED TF/Flax weight is judged +# through its index (tf_model.h5.index.json / flax_model.msgpack.index.json) instead -- a lone shard +# here must NOT read as a present weight, else an incomplete sharded set is loaded over Xet. +_CANONICAL_ROOT_TF_FLAX_WEIGHT_RE = re.compile(r"^(?:tf_model\.h5|flax_model\.msgpack)$") + +# The shard-index sidecars a sharded TF / Flax weight is enumerated through. +_TF_FLAX_WEIGHT_INDEX_NAMES = ("tf_model.h5.index.json", "flax_model.msgpack.index.json") def _pytorch_root_weight_formats_ignored(ignore_patterns: Any) -> bool: @@ -1125,9 +1114,9 @@ def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: """Whether an UNPATTERNED model warm holds a weight a default load reads: a CANONICAL ROOT weight (``model.safetensors`` / ``pytorch_model.bin``, single or numbered shard), or -- for a diffusers - pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting any subtree weight (as - ``_has_any_weight`` does) would accept a stale checkpoint-only snapshot and then fetch the root - weights over un-killable Xet; diffusers is the one layout whose weights live in subfolders. Only the + pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting ANY subtree weight + would accept a stale checkpoint-only snapshot and then fetch the root weights over un-killable Xet; + diffusers is the one layout whose weights live in subfolders. Only the canonical names are counted (``_CANONICAL_ROOT_MODEL_WEIGHT_RE``): a VARIANT-named root weight (``model.fp16.safetensors``), a PEFT adapter (``adapter_model.*``), a gguf, and a NON-canonical root weight (``consolidated.safetensors``) are excluded, since a default from_pretrained probes only the @@ -1161,11 +1150,22 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - return True # from_tf / from_flax: the ignore filter drops BOTH canonical PyTorch formats, so the load reads a # TF (tf_model.h5) / Flax (flax_model.msgpack) root weight the safetensors/bin check above cannot - # see. Count that surviving root weight so a complete from_tf/from_flax download is not - # false-rejected into a DownloadStallError. Gated on "both PyTorch formats ignored", so a normal - # load (which keeps a PyTorch format) is unchanged and a stray leftover h5/msgpack never counts. - if tf_flax_rels and _pytorch_root_weight_formats_ignored(ignore_patterns): - return bool(_filter_paths(tf_flax_rels, None, ignore_patterns)) + # see. Count a SINGLE-FILE TF/Flax weight, or a COMPLETE sharded set (its index present with every + # listed shard present), so a complete from_tf/from_flax download is not false-rejected into a + # DownloadStallError -- while an INCOMPLETE sharded set (a lone shard, or an index missing a shard) + # is NOT counted, so it is retried over HTTP rather than loaded over un-killable Xet. Gated on "both + # PyTorch formats ignored", so a normal load is unchanged and a stray leftover h5/msgpack never + # counts. + if _pytorch_root_weight_formats_ignored(ignore_patterns): + if tf_flax_rels and _filter_paths(tf_flax_rels, None, ignore_patterns): + return True + for index_name in _TF_FLAX_WEIGHT_INDEX_NAMES: + index_path = snapshot_dir / index_name + try: + if index_path.is_file() and _weight_shard_index_complete(index_path): + return True + except OSError: + continue return False @@ -1322,9 +1322,9 @@ def _requested_exact_files_present_grouped( def _has_selected_weight( snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, ) -> bool: - """True if a loadable weight the request SELECTS is present. Applies the allow / ignore filter (vs - ``_has_any_weight``), so a patterned request is not satisfied by an out-of-scope weight (a stale - ``.bin``, an unrequested checkpoint subfolder).""" + """True if a loadable weight the request SELECTS is present. Applies the allow / ignore filter, so a + patterned request is not satisfied by an out-of-scope weight (a stale ``.bin``, an unrequested + checkpoint subfolder).""" weights: list = [] try: for entry in snapshot_dir.rglob("*"): From fb80cf9a1699dd8c5614611d164c2518fe6fdc67 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 3 Jul 2026 11:01:13 +0000 Subject: [PATCH 76/82] Handle nested checkpoint subfolders and run the fallback test suite in CI Two fixes from a third consolidated review pass (12x reviewer.py personas + 3 forked reviewers) on b2089bfd. 1. Nested checkpoint subfolder POST false-accept -> hang. _request_scopes_into_dir matched the checkpoint dir only against the FIRST path segment, so a nested subfolder=foo/checkpoint-7 (allow=['foo/checkpoint-7/*']) did not register as an explicit checkpoint load: the checkpoint subtree was skipped, a lone incomplete shard there was accepted, and the in-process load then fetched the missing index/shards over un-killable Xet -- the same hang the flat-checkpoint fix closed, one directory deeper. The scope check now matches the target dir against ALL literal leading segments (up to the first glob), so an explicitly-loaded checkpoint at any depth is completeness-checked while a leftover checkpoint the request does not target is still skipped. 2. test_hf_xet_fallback.py now runs in CI. The 191-test fallback suite was only reached by `pytest --collect-only` (continue-on-error), so it could regress without failing CI. It is now in the behavioral HARD GATE alongside the other CPU-pure gated files, so a future change that breaks the watchdog / cache-completeness / Xet->HTTP fallback behavior blocks the merge. Adds a nested-checkpoint regression assertion. Full suite 191 passed; safety-invariant fuzz stays at 0 false-accepts. Known low-reachability edge (documented, not fixed): a from_tf/from_flax load with a non-root subfolder AND a SHARDED TF/Flax weight AND a silent-stale transient can POST-accept an incomplete shard set -- the sharded TF/Flax completeness check landed on the allow=None root path, and the patterned-subfolder path's index detection is safetensors/bin-only. The conjunction (from_tf is already niche + subfolder + sharding + a transient-error stale return) is essentially unreachable, and PRE still defers safely; extending the generic patterned-shard machinery to h5/msgpack is not worth its surface for that edge. Not taken (same review pass): - Retryable-error / crashed-child HTTP retry uses the coarse mtime guard, not an owned-partial set: the stall path captures owned partials from the LIVE child's open fds, but a child that returned a retryable error / crashed has already exited, so there are no fds to inspect -- the coarse guard (which spares an actively-progressing sibling by mtime) is the only option there. Inherent, not a closable asymmetry. - Injected prepare_for_http_fn missing owned_incomplete_blobs: by design and documented at the call site -- a Studio hook owns its own marker-based cache accounting. - Broad ['*.safetensors'] / ['model*','adapter_model*'] accepting one selected weight: not loader-reachable (a base load passes allow=None; no loader requests base+adapter together). - Exact ['config.json','tokenizer.json'] over-requiring optional files: not loader-reachable (every loader allow-list carries a glob, so the exact-file check is never the deciding gate). - import unsloth_zoo.hf_xet_fallback not import-light by default: intentional opt-in (UNSLOTH_ZOO_DISABLE_GPU_INIT=1); the download child sets it, Studio imports its own fork. --- .github/workflows/consolidated-tests-ci.yml | 5 ++++- tests/test_hf_xet_fallback.py | 9 +++++++++ unsloth_zoo/hf_cache_state.py | 19 +++++++++++++------ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/.github/workflows/consolidated-tests-ci.yml b/.github/workflows/consolidated-tests-ci.yml index 1c70ce6f2..cfdf6c619 100644 --- a/.github/workflows/consolidated-tests-ci.yml +++ b/.github/workflows/consolidated-tests-ci.yml @@ -99,11 +99,14 @@ jobs: # invisible until #739 because no job executed it (collect-only plus # the macOS shim smoke). test_training_utils_use_cache pins the # use_cache disable/restore contract for gradient checkpointing - # (#715). Executing both here keeps fixture/source drift visible. + # (#715). test_hf_xet_fallback exercises the Xet->HTTP stall fallback + # + cache-completeness gate (CPU-pure; collect-only would let it + # regress silently). Executing them here keeps fixture/source drift visible. run: | python -m pytest \ tests/test_training_utils_use_cache.py \ tests/test_mlx_finetune_last_n_layers.py \ + tests/test_hf_xet_fallback.py \ -v - name: pytest tests/test_mlx_module_exports + zoo-specific CPU tests diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 047069a52..a131163ac 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2765,6 +2765,15 @@ def test_post_download_checks_explicit_checkpoint_shard_completeness(tmp_path): (snap3 / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob3) # lone, but not read assert xf._download_result_usable( snap3, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is True + # A NESTED checkpoint the request explicitly targets (subfolder=foo/checkpoint-7 -> + # allow=['foo/checkpoint-7/*']) is read INTO at depth, so its lone shard must still be rejected -- + # the scope check matches the checkpoint dir at ANY literal leading segment, not just the first. + snap4, blob4 = _mk_snapshot(tmp_path, "ckpt_nested") + (snap4 / "foo" / "checkpoint-7").mkdir(parents = True) + (snap4 / "foo" / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob4) + assert xf._download_result_usable( + snap4, repo_type = "model", + allow_patterns = ["foo/checkpoint-7/*"], ignore_patterns = None) is False def test_post_download_accepts_exact_named_variant_shard_subset(tmp_path): diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 6ac651b2b..433df8c40 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -794,13 +794,20 @@ def _index_shard_probe(index_name: str, dir_rel: str) -> "Optional[str]": def _request_scopes_into_dir(allow_patterns: "Optional[list]", dir_name: str) -> bool: - """True when an allow pattern names *dir_name* as a literal leading path segment - (``subfolder=checkpoint-7`` -> ``allow=['checkpoint-7/*']``), i.e. the load reads INTO that - directory. Lets the shard-completeness check skip a leftover checkpoint subtree the request does not - target, while still validating a checkpoint the request explicitly loads from.""" + """True when an allow pattern names *dir_name* among its LITERAL leading path segments + (``subfolder=checkpoint-7`` -> ``allow=['checkpoint-7/*']``; a NESTED ``subfolder=foo/checkpoint-7`` + -> ``allow=['foo/checkpoint-7/*']``), i.e. the load reads INTO that directory at any depth. Lets the + shard-completeness check skip a leftover checkpoint subtree the request does not target, while still + validating a checkpoint the request explicitly loads from. Segments are read only up to the first + glob (a wildcard segment could match anything, so it is not a literal directory target).""" for p in allow_patterns or (): - if isinstance(p, str) and "/" in p and p.split("/", 1)[0] == dir_name: - return True + if not isinstance(p, str) or "/" not in p: + continue + for seg in p.split("/"): + if _has_glob(seg): + break # a wildcard segment is not a literal directory target + if seg == dir_name: + return True return False From 346ae8681ef5edc94f198214296978cf1a7f94f4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 3 Jul 2026 11:52:32 +0000 Subject: [PATCH 77/82] Revert version bump to 2026.6.7 The shared Xet->HTTP fallback does not require a version bump; keep the base version. --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 364c3afe3..6c0f1d93e 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2026.6.8" +__version__ = "2026.6.7" import os import warnings From 8ad533be7ebed65afff15661ecc6d940404d64bb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 3 Jul 2026 12:27:28 +0000 Subject: [PATCH 78/82] Trim verbose comments in the shared Xet->HTTP fallback modules Condense docstrings and inline comments across hf_xet_fallback.py, hf_cache_state.py, the __init__.py opt-in block, and the test suite: drop restating-the-obvious comments, collapse multi-line blocks, cut small-helper docstrings to one line. Comments/docstrings only, no code changed (AST-verified). Preserves the load-bearing rationale tersely (STRICT pre-download / LENIENT post-download, the un-killable Xet hang reason, and the scoping / format / variant / from_tf / diffusers / checkpoint / redaction notes). Full suite 191 passed; safety fuzz 0 false-accepts. --- tests/test_hf_xet_fallback.py | 1025 +++++++++----------------------- unsloth_zoo/__init__.py | 19 +- unsloth_zoo/hf_cache_state.py | 418 ++++++------- unsloth_zoo/hf_xet_fallback.py | 858 ++++++++++++-------------- 4 files changed, 851 insertions(+), 1469 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index a131163ac..b745c85ee 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. -"""Tests for unsloth_zoo.hf_xet_fallback: the no-progress watchdog, the Xet->HTTP -transport policy, the per-file and whole-snapshot entrypoints, the UNSLOTH_DISABLE_XET -knob, and the HF_HUB_DISABLE_XET precondition the fallback rests on. +"""Tests for unsloth_zoo.hf_xet_fallback: watchdog, Xet->HTTP transport policy, +file/snapshot entrypoints, the UNSLOTH_DISABLE_XET knob, and HF_HUB_DISABLE_XET. -CPU-only, no network, no real subprocess (the per-attempt download seam is -monkeypatched). The two modules under test are loaded directly via importlib so the -tests do not import the full ``unsloth_zoo`` package (which pulls in torch + GPU init). +CPU-only, no network, no real subprocess (the download seam is monkeypatched). +The modules under test are loaded via importlib to avoid importing the full +``unsloth_zoo`` package (torch + GPU init). """ from __future__ import annotations @@ -39,11 +38,9 @@ def _load(name: str, filename: str): return module -# A package placeholder so ``from unsloth_zoo.hf_cache_state import ...`` inside hf_xet_fallback -# resolves to the file we load below, not the installed package. RESTORE sys.modules afterwards: a -# leftover placeholder would shadow the REAL unsloth_zoo (its __init__ never runs) and fail a later -# test that imports it. The two loaded modules keep their own bound references, so they work after -# the placeholder is removed. +# Package placeholder so intra-package imports in hf_xet_fallback resolve to the +# files loaded below. Restored afterwards: a leftover would shadow the real +# unsloth_zoo; the loaded modules keep their own references and work regardless. _saved_modules = { name: sys.modules.get(name) for name in ("unsloth_zoo", "unsloth_zoo.hf_cache_state", "unsloth_zoo.hf_xet_fallback") @@ -66,9 +63,7 @@ def _load(name: str, filename: str): _REAL_DEFAULT_PREPARE = xf._default_prepare_for_http -# --------------------------------------------------------------------------- # # Watchdog: fires only on a constant-size .incomplete, sparse-aware byte total. -# --------------------------------------------------------------------------- # REPO = "ztest/xet-watchdog" @@ -154,16 +149,13 @@ def test_no_incomplete_never_stalls(hf_cache): def test_transient_unmeasurable_tick_is_progress(hf_cache, monkeypatch): - """A tick whose cache state is momentarily unmeasurable (get_hf_download_state -> None on a - transient FS error) is treated as progress, so a run of None ticks cannot trip a false stall. - Once the state is readable again and confirms a frozen .incomplete, the real stall still fires -- - the None-handling must not permanently mask a genuine stall.""" + """An unmeasurable tick (state -> None) counts as progress, but a later frozen state still stalls.""" seq = {"n": 0} - frozen = (2048, True) # constant size + active .incomplete: would stall if measured every tick + frozen = (2048, True) # constant size + active .incomplete def fake_state(*args, **kwargs): seq["n"] += 1 - return None if seq["n"] <= 8 else frozen # first ~8 ticks unmeasurable, then measurable+frozen + return None if seq["n"] <= 8 else frozen # first ~8 ticks unmeasurable, then frozen monkeypatch.setattr(xf, "get_hf_download_state", fake_state) @@ -172,7 +164,7 @@ def fake_state(*args, **kwargs): repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, ) try: - time.sleep(0.3) # within the unmeasurable window: no false stall despite no measured progress + time.sleep(0.3) # within the unmeasurable window assert calls == [], "watchdog fired during a transient-unmeasurable window" assert _wait(lambda: len(calls) >= 1, timeout = 3.0), "stall never fired after state recovered" finally: @@ -196,12 +188,9 @@ def test_stall_fires_at_most_once(hf_cache): def test_file_watchdog_scopes_to_child_partial(hf_cache): - """A single-file download follows only its own child's partials. A concurrent sibling - download of a different file in the same repo (its partial already in flight, so in the - baseline) keeps growing, but must not keep resetting this file's stall timer -- the - constant child partial still fires.""" + """A single-file download follows only its own child's partials, so a growing baseline sibling does not mask its stalled child.""" blobs = _blobs_dir(hf_cache) - sibling = blobs / "sibling.incomplete" # already in flight -> captured in baseline + sibling = blobs / "sibling.incomplete" # in flight -> in baseline sibling.write_bytes(b"\0" * 1024) baseline = {"sibling.incomplete"} @@ -211,12 +200,12 @@ def _grow(): size = 1024 while not grow_stop.wait(0.05): size += 4096 - sibling.write_bytes(b"\0" * size) # healthy sibling keeps making progress + sibling.write_bytes(b"\0" * size) # healthy sibling progresses grower = threading.Thread(target = _grow, daemon = True) grower.start() - # This download's child writes its own constant (stalled) partial, not in the baseline. + # This download's child writes its own constant (stalled) partial, not in baseline. (blobs / "child.incomplete").write_bytes(b"\0" * 2048) calls: list[str] = [] @@ -234,9 +223,7 @@ def _grow(): def test_repo_wide_watchdog_is_masked_by_sibling(hf_cache): - """Contrast for the single-file scoping: the default repo-wide measurement sums every - blob, so a growing sibling resets the timer and a constant partial never trips. This is - correct for snapshots (all blobs are one pull) and is exactly what file-scoping avoids.""" + """Contrast for file-scoping: the default repo-wide measure sums every blob, so a growing sibling resets the timer.""" blobs = _blobs_dir(hf_cache) sibling = blobs / "sibling.incomplete" sibling.write_bytes(b"\0" * 1024) @@ -258,7 +245,7 @@ def _grow(): repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5, ) try: - time.sleep(1.0) # well past stall_timeout, but repo-wide bytes keep growing + time.sleep(1.0) # past stall_timeout, but repo-wide bytes keep growing assert calls == [], "repo-wide watchdog should be reset by the growing sibling" finally: stop.set() @@ -266,9 +253,7 @@ def _grow(): def test_file_watchdog_ignores_baseline_only_partials(hf_cache): - """If the only active partial is a baseline sibling's (this child has not written one - yet), the file watchdog sees no owned progress and must not fire: there is nothing of - ours to stall on, so post-spawn metadata/connect time is never misread as our stall.""" + """When the only partial is a baseline sibling's, the file watchdog owns nothing and must not fire.""" blobs = _blobs_dir(hf_cache) (blobs / "sibling.incomplete").write_bytes(b"\0" * 4096) # constant baseline sibling @@ -285,8 +270,7 @@ def test_file_watchdog_ignores_baseline_only_partials(hf_cache): def _spawn_holding_open(path: Path) -> "subprocess.Popen": - """A real child process that opens *path* and holds it open without writing, modelling a - hung download. Prints 'ok' once the file is open so the caller can synchronize.""" + """Child process that opens *path* and holds it open without writing (a hung download); prints 'ok' when open.""" code = ( "import sys, time\n" "f = open(sys.argv[1], 'r+b')\n" @@ -301,13 +285,11 @@ def _spawn_holding_open(path: Path) -> "subprocess.Popen": def test_file_watchdog_detects_resumed_baseline_partial(hf_cache): - """A resumed single-file download reuses the prior blob-hash .incomplete, so it sits in - the baseline. Name-based exclusion would never flag a hung resume; scoping to the - partials the child process holds open detects it.""" + """A resumed download reuses a baseline .incomplete; pid-scoping (not name exclusion) still detects the hang.""" blobs = _blobs_dir(hf_cache) partial = blobs / "resumed.incomplete" partial.write_bytes(b"\0" * 4096) # leftover from a prior interrupted download - baseline = {"resumed.incomplete"} # present before the (resuming) child starts + baseline = {"resumed.incomplete"} # present before the resuming child starts child = _spawn_holding_open(partial) # hung resume: holds it open, never grows it try: @@ -329,10 +311,7 @@ def test_file_watchdog_detects_resumed_baseline_partial(hf_cache): def test_file_watchdog_resumed_partial_fires_without_pid_ownership(hf_cache, monkeypatch): - """No psutil AND no /proc (native Windows / macOS without psutil): _child_open_incomplete_blobs - returns None, so per-child ownership is unknowable. Excluding baseline names would let a RESUMED - partial (which reuses a baseline blob-hash name) hang forever. The None fallback drops to the - repo-wide measure, so a frozen resumed baseline partial still trips the stall instead of hanging.""" + """When ownership is unknowable (no psutil/proc -> None), the repo-wide fallback still stalls a frozen resumed baseline partial.""" blobs = _blobs_dir(hf_cache) (blobs / "resumed.incomplete").write_bytes(b"\0" * 4096) # leftover resume, constant (hung) monkeypatch.setattr(xf, "_child_open_incomplete_blobs", lambda pid: None) # no psutil / no /proc @@ -341,7 +320,7 @@ def test_file_watchdog_resumed_partial_fires_without_pid_ownership(hf_cache, mon stop = xf.start_watchdog( repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, watch_new_partials_only = True, baseline_incomplete_blobs = {"resumed.incomplete"}, - child_pid = 4242, # non-None, but open-file inspection yields None -> repo-wide fallback + child_pid = 4242, # non-None, but open-file inspection -> None -> repo-wide fallback ) try: assert _wait(lambda: len(calls) >= 1, timeout = 3.0), ( @@ -352,11 +331,10 @@ def test_file_watchdog_resumed_partial_fires_without_pid_ownership(hf_cache, mon def test_file_watchdog_pid_scope_ignores_unowned_sibling(hf_cache): - """With pid scoping, a sibling partial this child does NOT hold open is ignored even if - it grows, so the child's own constant partial still trips the stall.""" + """With pid scoping, a growing sibling the child does not hold open is ignored; the child's own constant partial still stalls.""" blobs = _blobs_dir(hf_cache) owned_partial = blobs / "owned.incomplete" - owned_partial.write_bytes(b"\0" * 2048) # the child holds this open, constant (hung) + owned_partial.write_bytes(b"\0" * 2048) # child holds this open, constant (hung) sibling = blobs / "sibling.incomplete" sibling.write_bytes(b"\0" * 1024) @@ -366,7 +344,7 @@ def _grow(): size = 1024 while not grow_stop.wait(0.05): size += 4096 - sibling.write_bytes(b"\0" * size) # an unrelated sibling making progress + sibling.write_bytes(b"\0" * size) # unrelated sibling progressing grower = threading.Thread(target = _grow, daemon = True) grower.start() @@ -392,13 +370,9 @@ def _grow(): def test_file_watchdog_empty_open_set_ignores_sibling(hf_cache, monkeypatch): - """hf_xet writes in-process and holds its .incomplete fd continuously, so an EMPTY child - open-set means the child owns no partial YET (the connect / metadata phase), NOT that a - helper process owns one. A concurrent sibling's post-baseline partial must therefore NOT be - attributed to a still-connecting child -- otherwise a stalled sibling would kill it and force - a needless HTTP retry. The precise empty-set branch owns nothing, so no stall fires.""" + """An EMPTY child open-set means the child owns no partial yet (connect/metadata phase), so a stalled sibling must not fire.""" blobs = _blobs_dir(hf_cache) - # A sibling partial created after baseline (not name-excluded), constant (stalled). + # Sibling partial created after baseline (not name-excluded), constant (stalled). (blobs / "sibling.incomplete").write_bytes(b"\0" * 4096) monkeypatch.setattr(xf, "_child_open_incomplete_blobs", lambda pid: set()) # child owns none @@ -425,7 +399,7 @@ def test_get_state_absent_cache_root(tmp_path, monkeypatch): def test_get_state_skips_local_paths(hf_cache): - # Filesystem paths are not HF repo IDs and must be ignored without error. + # Filesystem paths are not repo IDs and must be ignored without error. assert xf.get_hf_download_state( ["/abs/path", "./rel", "~user", "c:\\x", "c:/x"] ) == (0, False) @@ -445,9 +419,7 @@ def test_get_state_sparse_aware(hf_cache): def test_blob_bytes_present_zero_blocks_is_zero(tmp_path): - """A freshly truncated, fully-sparse .incomplete reports st_size > 0 with 0 - allocated blocks; it must count as 0 bytes present, not full size (a > 0 guard - would mis-read an empty partial as complete).""" + """A fully-sparse .incomplete (st_size > 0, 0 allocated blocks) counts as 0 bytes present, not full size.""" p = tmp_path / "sparse.incomplete" with open(p, "wb") as f: f.truncate(8 * 1024 * 1024) # apparent 8 MiB, nothing actually written @@ -460,8 +432,7 @@ def test_blob_bytes_present_zero_blocks_is_zero(tmp_path): def test_custom_cache_dir_is_watched_and_cleaned(tmp_path, monkeypatch): - """A stall under a caller-supplied snapshot ``cache_dir`` (not HF_HUB_CACHE) - must still be seen by the state probe, the watchdog, and the HTTP-prep purge.""" + """A stall under a caller-supplied cache_dir (not HF_HUB_CACHE) is seen by the probe, watchdog, and HTTP-prep purge.""" default_cache = tmp_path / "default" custom_cache = tmp_path / "custom" default_cache.mkdir() @@ -489,9 +460,8 @@ def test_custom_cache_dir_is_watched_and_cleaned(tmp_path, monkeypatch): finally: stop.set() - # The HTTP-prep purge removes the unsafe partial from the custom cache - # (call the real impl; the autouse fixture stubs the module attribute). Age it - # past the active-partial grace so it reads as a stalled, not in-flight, blob. + # HTTP-prep purge removes the unsafe partial from the custom cache (real impl; + # the autouse fixture stubs the attribute). Age past the grace so it reads stalled. old = time.time() - 600 os.utime(partial, (old, old)) _REAL_DEFAULT_PREPARE("model", REPO, cache_dir = str(custom_cache)) @@ -499,8 +469,7 @@ def test_custom_cache_dir_is_watched_and_cleaned(tmp_path, monkeypatch): def test_prepare_for_http_clears_broken_snapshot_symlink(tmp_path): - """A broken snapshot symlink is counted as active-incomplete state by the - detector, so HTTP prep must clear it too or the retry re-trips the watchdog.""" + """A broken snapshot symlink reads as incomplete, so HTTP prep must clear it too or the retry re-trips the watchdog.""" repo = "ztest/broken-symlink" repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" snap = repo_dir / "snapshots" / "abc123" @@ -509,7 +478,6 @@ def test_prepare_for_http_clears_broken_snapshot_symlink(tmp_path): link.symlink_to(repo_dir / "blobs" / "missing-blob") # dangling assert link.is_symlink() and not link.exists() - # Detector treats the dangling link as active incomplete state. assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, True) _REAL_DEFAULT_PREPARE("model", repo, cache_dir = str(tmp_path)) @@ -519,9 +487,7 @@ def test_prepare_for_http_clears_broken_snapshot_symlink(tmp_path): def test_prepare_for_http_spares_concurrent_sibling_active_symlink(tmp_path): - """HTTP prep must NOT delete a concurrent sibling's dangling snapshot symlink while that sibling is - still writing the target blob (a fresh .incomplete partner exists). Our own stale interrupted link - (no .incomplete partner) is still cleared in the same sweep.""" + """HTTP prep spares a sibling's dangling link while its .incomplete partner is being written, but clears our own stale link.""" repo = "ztest/concurrent" repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" blobs = repo_dir / "blobs" @@ -529,7 +495,7 @@ def test_prepare_for_http_spares_concurrent_sibling_active_symlink(tmp_path): blobs.mkdir(parents = True) snap.mkdir(parents = True) - # Sibling mid-download: a dangling link to a blob whose .incomplete partner is being written now. + # Sibling mid-download: dangling link to a blob with an active .incomplete partner. active_partner = blobs / "activehash.incomplete" active_partner.write_bytes(b"active") sibling_link = snap / "active.safetensors" @@ -557,35 +523,27 @@ def test_snapshot_dir_has_broken_symlinks_unit(tmp_path): def test_broken_older_snapshot_detected_when_newer_is_clean(tmp_path): - """Detector must inspect every snapshot, not just the newest by mtime: an older - revision with a dangling symlink must read as incomplete even when a more - recently landed snapshot is fully present.""" + """The detector inspects every snapshot: an older broken revision reads incomplete even when a newer one is clean.""" repo = "ztest/two-snaps" repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" old = repo_dir / "snapshots" / "oldsha" new = repo_dir / "snapshots" / "newsha" old.mkdir(parents = True) new.mkdir(parents = True) - # Broken (older) revision; clean (newer) revision. - (old / "model.safetensors").symlink_to(repo_dir / "blobs" / "missing") - (new / "config.json").write_text("{}") - # Make the clean snapshot the newest by mtime so a latest-only check would - # report the repo healthy. + (old / "model.safetensors").symlink_to(repo_dir / "blobs" / "missing") # broken older + (new / "config.json").write_text("{}") # clean newer + # Make the clean snapshot newest by mtime so a latest-only check would pass. os.utime(new, (time.time() + 10, time.time() + 10)) assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, True) def test_snapshot_fast_path_rejects_broken_requested_revision(tmp_path, monkeypatch): - """snapshot_download(local_files_only=True) can hand back an older requested - revision whose snapshot is broken while the repo-wide scan is clean. The fast - path must validate the EXACT returned dir and complete in the killable child - rather than short-circuiting to a snapshot with missing files.""" + """The fast path validates the EXACT returned dir, so a broken requested revision defers to the killable child.""" snap = tmp_path / "snapshots" / "oldsha" snap.mkdir(parents = True) (snap / "model.safetensors").symlink_to(tmp_path / "blobs" / "missing") # dangling monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) - # Repo-wide incomplete-blob scan sees nothing (empty cache root), so only the - # per-revision symlink check can catch the broken returned dir. + # Repo-wide scan sees nothing (empty cache root); only the per-revision check can catch it. monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path / "empty-cache")) fake = _install(monkeypatch, [("ok", "/cache/snap-fresh")]) out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) @@ -594,8 +552,7 @@ def test_snapshot_fast_path_rejects_broken_requested_revision(tmp_path, monkeypa def test_prepare_for_http_clears_broken_symlink_in_older_snapshot(tmp_path): - """HTTP prep must clear dangling links across all snapshots, not just the - newest, so the incomplete detector reads clean afterwards.""" + """HTTP prep clears dangling links across all snapshots, not just the newest.""" repo = "ztest/old-broken" repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" old = repo_dir / "snapshots" / "oldsha" @@ -614,8 +571,7 @@ def test_prepare_for_http_clears_broken_symlink_in_older_snapshot(tmp_path): def test_prepare_for_http_preserves_case_colliding_repo(tmp_path): - """On a case-sensitive filesystem, preparing HTTP for ``Org/Repo`` must purge - only its exact-case cache dir, never a case-colliding ``org/repo``.""" + """On a case-sensitive FS, HTTP prep for Org/Repo purges only its exact-case dir, never a case-colliding org/repo.""" upper = tmp_path / "models--Org--Repo" / "blobs" lower = tmp_path / "models--org--repo" / "blobs" upper.mkdir(parents = True) @@ -626,8 +582,7 @@ def test_prepare_for_http_preserves_case_colliding_repo(tmp_path): lower_partial = lower / "b.incomplete" upper_partial.write_bytes(b"x") lower_partial.write_bytes(b"y") - # Age both past the active-partial grace so the purge is exercised on stalled blobs - # (lower is preserved by repo attribution, not mtime). + # Age both past the grace; lower is spared by repo attribution, not mtime. old = time.time() - 600 os.utime(upper_partial, (old, old)) os.utime(lower_partial, (old, old)) @@ -639,8 +594,7 @@ def test_prepare_for_http_preserves_case_colliding_repo(tmp_path): def test_repo_type_none_resolves_model_cache(hf_cache): - """A caller forwarding repo_type=None (HF's default model) must still see the - real models-- partial, not look up a bogus Nones-- dir.""" + """repo_type=None (HF default model) resolves the models-- dir, not a bogus Nones-- dir.""" blobs = _blobs_dir(hf_cache) (blobs / "x.incomplete").write_bytes(b"abc") @@ -651,8 +605,7 @@ def test_repo_type_none_resolves_model_cache(hf_cache): def test_state_ignores_case_colliding_repo_partial(tmp_path, monkeypatch): - """The read/watchdog path attributes a partial only to an exact-case repo dir, - so a stale partial in a case-colliding repo cannot trip the watchdog.""" + """The read/watchdog path attributes a partial only to an exact-case repo dir.""" monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path)) exact = tmp_path / "models--Org--Repo" / "blobs" other = tmp_path / "models--org--repo" / "blobs" @@ -662,30 +615,22 @@ def test_state_ignores_case_colliding_repo_partial(tmp_path, monkeypatch): pytest.skip("case-insensitive filesystem; cannot collide cache dirs") (other / "stale.incomplete").write_bytes(b"x") # only the lowercase repo - # Org/Repo has no partial of its own; the lowercase repo's must not count. assert xf.get_hf_download_state(["Org/Repo"]) == (0, False) def test_single_folded_match_rejected_on_case_sensitive_fs(tmp_path, monkeypatch): - """A single folded-but-not-exact cache dir must not be attributed to a - differently-cased repo on a case-sensitive filesystem -- it is a different - repo, and charging its partial here could misread the watchdog or let HTTP-prep - delete it. Only an exact-case dir (or a folded dir the FS resolves to the same - entry on a case-insensitive FS) counts.""" + """On a case-sensitive FS a folded-but-not-exact dir is a different repo, so its partial is not attributed here.""" monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path)) lower = tmp_path / "models--org--repo" / "blobs" lower.mkdir(parents = True) if (tmp_path / "models--Org--Repo").exists(): pytest.skip("case-insensitive filesystem; the folded dir is the same entry") (lower / "stale.incomplete").write_bytes(b"x") # only the lowercase repo exists - # Request the exact-case repo, which has no dir of its own: the lowercase repo's - # partial must not be attributed to it. assert xf.get_hf_download_state(["Org/Repo"]) == (0, False) def test_cache_dir_is_expanded(tmp_path, monkeypatch): - """A custom cache_dir with ~ must be expanded (as HF does on write), else the - state probe scans the literal '~/...' path and misses the partial.""" + """A cache_dir with ~ is expanded (as HF does on write), else the probe scans the literal '~/...' path.""" monkeypatch.setenv("HOME", str(tmp_path)) monkeypatch.setenv("USERPROFILE", str(tmp_path)) # Windows home var blobs = tmp_path / "hfcache" / f"models--{REPO.replace('/', '--')}" / "blobs" @@ -697,8 +642,7 @@ def test_cache_dir_is_expanded(tmp_path, monkeypatch): def test_status_callback_failure_does_not_kill_watchdog(hf_cache): - """A raising on_heartbeat (e.g. a disconnected client) must not stop the - daemon watchdog from detecting a real stall and firing on_stall.""" + """A raising on_heartbeat must not stop the watchdog from detecting a stall and firing on_stall.""" blobs = _blobs_dir(hf_cache) (blobs / "x.incomplete").write_bytes(b"\0" * 1024) # constant size -> stalls @@ -718,25 +662,22 @@ def boom(_message): stop.set() -# --------------------------------------------------------------------------- # # Transport policy: cached short-circuit, cancel, error propagation, the single -# Xet->HTTP fallback, the injected prepare seam, and the UNSLOTH_DISABLE_XET knob. -# _run_download_attempt is faked, so no real spawn. -# --------------------------------------------------------------------------- # +# Xet->HTTP fallback, injected prepare seam, and UNSLOTH_DISABLE_XET. The download +# seam (_run_download_attempt) is faked, so no real spawn. DL_REPO, FILE = "ztest/xet-dl", "model-Q4_K_XL.gguf" @pytest.fixture(autouse = True) def _no_real_cache_hit(monkeypatch): - """Default: the file cached probe misses and the snapshot fast path misses, so - tests exercise the download seam unless they override these.""" + """Default: cache probe and snapshot fast path miss, so tests exercise the download seam.""" monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: None) def _snap_miss(*a, **k): raise FileNotFoundError("not fully cached") monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap_miss) - # Neutralize the generic cache purge by default; tests that care record it. + # Neutralize the generic cache purge; tests that care record it. monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a, **k: None) # No env knob unless a test sets it. monkeypatch.delenv("UNSLOTH_DISABLE_XET", raising = False) @@ -745,10 +686,7 @@ def _snap_miss(*a, **k): class _FakeAttempt: - """Records calls to the download seam and returns scripted results. - - Matches unsloth_zoo.hf_xet_fallback._run_download_attempt's signature. - """ + """Records download-seam calls and returns scripted results (matches _run_download_attempt's signature).""" def __init__(self, results): self._results = list(results) @@ -794,7 +732,7 @@ def test_cached_file_short_circuits(monkeypatch, tmp_path): cached = tmp_path / "cached.gguf" cached.write_bytes(b"\0" * 8) monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: str(cached)) - fake = _install(monkeypatch, []) # must not be called + fake = _install(monkeypatch, []) out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) assert out == str(cached) @@ -811,14 +749,11 @@ def test_cancel_before_start_raises_no_attempt(monkeypatch): def test_cancel_honored_even_when_file_cached(monkeypatch, tmp_path): - """A cancel_event set before the call must raise even when the file is ALREADY cached: the - warm-cache short-circuit returns without reaching _download_with_xet_fallback (the other - cancel check), so it must honor cancellation first rather than hand back the cached path - (Codex #829).""" + """A pre-set cancel_event raises even when the file is already cached (the warm-cache short-circuit honors cancel first).""" cached = tmp_path / "cached.gguf" cached.write_bytes(b"\0" * 8) monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: str(cached)) - fake = _install(monkeypatch, []) # the attempt must never run + fake = _install(monkeypatch, []) ev = threading.Event() ev.set() with pytest.raises(RuntimeError, match = "Cancelled"): @@ -827,13 +762,12 @@ def test_cancel_honored_even_when_file_cached(monkeypatch, tmp_path): def test_snapshot_cancel_honored_even_when_cached(monkeypatch, tmp_path): - """The snapshot wrapper must also honor a pre-set cancel before its warm-cache short-circuit, - so a cancelled request does not resolve a cached snapshot (Codex #829).""" + """The snapshot wrapper honors a pre-set cancel before its warm-cache short-circuit.""" snap = tmp_path / "snap" snap.mkdir() (snap / "model.safetensors").write_bytes(b"x") monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) - fake = _install(monkeypatch, []) # the attempt must never run + fake = _install(monkeypatch, []) ev = threading.Event() ev.set() with pytest.raises(RuntimeError, match = "Cancelled"): @@ -843,9 +777,8 @@ def test_snapshot_cancel_honored_even_when_cached(monkeypatch, tmp_path): def test_nonstall_error_propagates_without_fallback(monkeypatch): fake = _install(monkeypatch, [("error", "RepositoryNotFoundError: 404 not found")]) - # A deterministic Hub error is re-raised with its ORIGINAL type preserved across the spawn - # boundary (not flattened to a bare RuntimeError), so a caller's typed except clause still - # matches (Codex #829). The parent reconstructs the class from the child's ": ..." prefix. + # Deterministic Hub error re-raised with its original type across the spawn boundary, + # reconstructed from the child's ": ..." prefix. expected_cls = xf._resolve_exception_class("RepositoryNotFoundError") assert expected_cls is not None and expected_cls is not RuntimeError with pytest.raises(expected_cls, match = "RepositoryNotFoundError"): @@ -855,9 +788,7 @@ def test_nonstall_error_propagates_without_fallback(monkeypatch): def test_crashed_child_retries_over_http(monkeypatch): - """A silent process-level crash (child exits without a result, e.g. a native hf_xet - abort) is not a deterministic error, so it retries over HTTP; a clean second result is - accepted.""" + """A silent process-level crash (child exits without a result) retries over HTTP; a clean second result is accepted.""" fake = _install(monkeypatch, [("crashed", "exited without a result"), ("ok", "/cache/x")]) out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) assert out == "/cache/x" @@ -873,8 +804,7 @@ def test_crashed_child_on_both_transports_raises(monkeypatch): def test_retryable_xet_error_retries_over_http(monkeypatch): - """A transient Xet transport failure (CAS timeout / 5xx) is not a deterministic Hub error, - so it retries over HTTP; a clean HTTP result is accepted (Codex #829).""" + """A transient Xet failure (CAS timeout / 5xx) retries over HTTP; a clean HTTP result is accepted.""" fake = _install(monkeypatch, [("retryable_error", "CasClientError: request timed out"), ("ok", "/cache/x")]) out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) assert out == "/cache/x" @@ -882,8 +812,7 @@ def test_retryable_xet_error_retries_over_http(monkeypatch): def test_retryable_xet_error_on_both_transports_raises(monkeypatch): - """A transient error on Xet AND on HTTP has no other transport left, so it surfaces after - both attempts rather than looping (Codex #829).""" + """A transient error on both transports surfaces after both attempts rather than looping.""" fake = _install(monkeypatch, [("retryable_error", "503 Server Error"), ("retryable_error", "503 Server Error")]) with pytest.raises(RuntimeError, match = "503"): xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) @@ -891,9 +820,7 @@ def test_retryable_xet_error_on_both_transports_raises(monkeypatch): def test_is_retryable_download_error_classification(): - """Transient transport failures (hf_xet / CAS, timeout, reset, HTTP 5xx / 429) are - retryable; deterministic Hub errors (not-found, gated, 4xx, disk full) and unknown errors - are not, so a real repeatable failure is surfaced rather than looped (Codex #829).""" + """Transient transport failures (hf_xet/CAS, timeout, reset, 5xx/429) are retryable; deterministic Hub/OS and unknown errors are not.""" f = xf._is_retryable_download_error # Transient transport failures -> retryable. @@ -916,11 +843,11 @@ class _Status408(Exception): assert f(_Status408("Request Timeout")) is True # 408 is transient - # Deterministic Hub errors -> not retryable (matched by class name or 4xx status). + # Deterministic Hub errors -> not retryable (class name or 4xx status). class _Status416(Exception): status_code = 416 - assert f(_Status416("Range Not Satisfiable")) is False # a retry would fail identically + assert f(_Status416("Range Not Satisfiable")) is False class RepositoryNotFoundError(Exception): pass @@ -931,16 +858,11 @@ class _Resp404(Exception): assert f(_Resp404("not found")) is False assert f(OSError(errno.ENOSPC, "No space left on device")) is False - # An unknown / generic error stays deterministic -> surfaced, not looped over transports. - assert f(ValueError("unexpected response payload")) is False + assert f(ValueError("unexpected response payload")) is False # unknown -> deterministic def test_local_entry_not_found_transient_is_retryable(): - """huggingface_hub wraps a TRANSIENT HEAD connection error / timeout for an uncached file as - LocalEntryNotFoundError ('... check your connection and try again'). That sub-case must be retryable - (the other transport may recover), while a genuine offline miss ('outgoing traffic has been - disabled') stays deterministic and keeps its reconstructed type across the spawn boundary - (Codex #829).""" + """A transient LocalEntryNotFoundError (HEAD connection error/timeout) is retryable; a genuine offline miss stays deterministic and type-preserved.""" f = xf._is_retryable_download_error class LocalEntryNotFoundError(Exception): @@ -954,7 +876,7 @@ class LocalEntryNotFoundError(Exception): assert f(transient) is True timed_out = LocalEntryNotFoundError("Read timed out while fetching metadata") assert f(timed_out) is True - # Genuine offline miss (no transient hint) -> deterministic, and still type-preserved. + # Genuine offline miss (no transient hint) -> deterministic, type-preserved. offline = LocalEntryNotFoundError( "Cannot find the requested files in the disk cache and outgoing traffic has been disabled." ) @@ -983,12 +905,12 @@ def test_stall_then_http_fallback_succeeds(monkeypatch): assert out == "/cache/model.gguf" assert len(fake.calls) == 2 assert fake.calls[0].disable_xet is False # Xet first - assert fake.calls[1].disable_xet is True # HTTP fallback + assert fake.calls[1].disable_xet is True # HTTP fallback assert prepared == [("model", DL_REPO)], "must prep cache for HTTP before the retry" def test_injected_prepare_for_http_used(monkeypatch): - """Studio injects its marker-aware prepare; the generic default must not run.""" + """An injected prepare_for_http_fn is used; the generic default must not run.""" monkeypatch.setattr( xf, "_default_prepare_for_http", lambda *a, **k: pytest.fail("generic prepare ran") ) @@ -1024,7 +946,7 @@ def test_per_file_independent_fallback(monkeypatch): def test_unsloth_disable_xet_forces_http_first(monkeypatch): - """UNSLOTH_DISABLE_XET=1 skips the Xet attempt: first (and only) attempt is HTTP.""" + """UNSLOTH_DISABLE_XET=1 skips Xet: the first (and only) attempt is HTTP.""" monkeypatch.setenv("UNSLOTH_DISABLE_XET", "1") fake = _install(monkeypatch, [("ok", "/http/model.gguf")]) out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) @@ -1042,18 +964,16 @@ def test_unsloth_disable_xet_stall_raises_no_retry(monkeypatch): def test_file_path_accepts_cache_dir(monkeypatch): - """The single-file wrapper accepts cache_dir (no TypeError) and threads it through.""" + """The single-file wrapper accepts cache_dir and threads it through.""" fake = _install(monkeypatch, [("ok", "/cache/model.gguf")]) out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cache_dir = "/custom/cache") assert out == "/cache/model.gguf" assert fake.calls[0].cache_dir == "/custom/cache" -# --------------------------------------------------------------------------- # -# Spawn env-timing: the parent sets HF_HUB_DISABLE_XET before the child starts, -# so the child inherits it before re-importing huggingface_hub (whose constants -# cache the value at import). Uses a fake spawn context -- no real subprocess. -# --------------------------------------------------------------------------- # +# Spawn env-timing: the parent sets HF_HUB_DISABLE_XET before the child starts, so +# the child inherits it before re-importing huggingface_hub (constants cache it at +# import). Uses a fake spawn context -- no real subprocess. class _FakeProc: def __init__(self, recorder): self._rec = recorder @@ -1114,7 +1034,7 @@ def test_http_retry_sets_disable_xet_before_spawn(monkeypatch): # Child inherited HTTP transport env at spawn time. assert rec["disable_xet"] == "1" assert rec["hf_transfer"] == "0" - # Parent env is restored afterwards (was unset). + # Parent env restored afterwards (was unset). assert "HF_HUB_DISABLE_XET" not in os.environ @@ -1127,7 +1047,7 @@ def test_xet_attempt_does_not_force_disable_before_spawn(monkeypatch): repo_type = "model", disable_xet = False, cancel_event = None, stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, ) - # On the Xet-first attempt we must NOT force-disable Xet for the child. + # On the Xet-first attempt, do not force-disable Xet for the child. assert rec["disable_xet"] is None @@ -1145,8 +1065,7 @@ def put(self, item): def test_run_attempt_no_result_is_crashed(monkeypatch): - """A child that exits without enqueuing a result maps to 'crashed' (a process-level - crash that HTTP may still recover), not a deterministic 'error'.""" + """A child that exits without enqueuing a result maps to 'crashed' (HTTP may recover), not a deterministic 'error'.""" rec: dict = {} class _Ctx: @@ -1166,9 +1085,7 @@ def Queue(self): def test_child_skips_gpu_init_env_set_before_spawn_and_restored(monkeypatch): - """The download child inherits UNSLOTH_ZOO_DISABLE_GPU_INIT=1 at spawn (so its - fresh unsloth_zoo import skips heavy torch/transformers init), and the parent's - env is restored afterwards.""" + """The child inherits UNSLOTH_ZOO_DISABLE_GPU_INIT=1 at spawn (skips heavy torch init); the parent env is restored after.""" monkeypatch.delenv("UNSLOTH_ZOO_DISABLE_GPU_INIT", raising = False) rec: dict = {} monkeypatch.setattr(xf, "_CTX", _FakeCtx(rec, {"ok": True, "path": "/cache/x"})) @@ -1183,9 +1100,7 @@ def test_child_skips_gpu_init_env_set_before_spawn_and_restored(monkeypatch): def test_spawn_repoints_main_file_and_restores(monkeypatch): - """For an unguarded top-level caller script, the spawn child must import this - side-effect-free module as __mp_main__ rather than re-execute the caller, so the - parent repoints __main__.__file__ here at spawn and restores it afterwards.""" + """The parent repoints __main__.__file__ to this module at spawn (so an unguarded caller is not re-executed) and restores it.""" main_mod = sys.modules["__main__"] monkeypatch.setattr(main_mod, "__file__", "/fake/user_script.py", raising = False) rec: dict = {} @@ -1201,15 +1116,13 @@ def test_spawn_repoints_main_file_and_restores(monkeypatch): def test_scrub_secrets_handles_boolean_token(): - """token=True ("use the cached token") must not crash the child error scrubber.""" + """token=True must not crash the child error scrubber.""" out = xf._default_scrub_secrets("auth failed for hf_abcdefghij", hf_token = True) assert "hf_abcdefghij" not in out and "***" in out def test_scrub_redacts_presigned_url(): - """A presigned S3/CAS blob URL in a child error carries temporary credentials in - its query string; the default scrubber must redact the query before it is - raised/logged in the parent, while leaving non-signed URLs intact.""" + """The scrubber redacts a presigned S3/CAS URL's credential query string, leaving non-signed URLs intact.""" url = ( "https://cas-bridge.xethub.hf.co/xet-bridge-us/abc/def" "?X-Amz-Signature=deadbeefcafe&X-Amz-Credential=AKIAEXAMPLE123" @@ -1224,18 +1137,15 @@ def test_scrub_redacts_presigned_url(): def test_scrub_redaction_preserves_surrounding_delimiters(): - """A signed URL embedded in structured text (JSON / dict repr) has no surrounding whitespace, so the - query redaction must stop at the closing delimiter and not swallow it -- else the ``"}`` is replaced - by ``***`` and the log line's structure is corrupted (Gemini #829). The signed query is still fully - redacted.""" + """Query redaction stops at the closing delimiter of an embedded signed URL (does not swallow the ``"}``).""" embedded = ( '{"error": "403", "url": ' '"https://cas-bridge.xethub.hf.co/x/y?X-Amz-Signature=deadbeef&X-Amz-Expires=3600"}' ) out = xf._default_scrub_secrets(embedded) - assert "deadbeef" not in out # the signed query is redacted + assert "deadbeef" not in out # signed query redacted assert "cas-bridge.xethub.hf.co/x/y?***" in out - assert out.endswith('"}') # the closing delimiters are preserved + assert out.endswith('"}') # closing delimiters preserved # A signed URL wrapped in single quotes / parens keeps those delimiters too. wrapped = "(https://s3.amazonaws.com/b/k?X-Amz-Signature=abc123) tail" out2 = xf._default_scrub_secrets(wrapped) @@ -1243,8 +1153,7 @@ def test_scrub_redaction_preserves_surrounding_delimiters(): def test_local_files_only_file_resolves_in_process(monkeypatch): - """local_files_only resolves the single file from cache in-process and never - spawns a network child (Hugging Face offline semantics).""" + """local_files_only resolves the file from cache in-process and never spawns a network child.""" seen = {} def _dl(*a, **k): @@ -1252,7 +1161,7 @@ def _dl(*a, **k): return "/cache/file.gguf" monkeypatch.setattr(huggingface_hub, "hf_hub_download", _dl) - fake = _install(monkeypatch, []) # the download seam must not be called + fake = _install(monkeypatch, []) out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, local_files_only = True) assert out == "/cache/file.gguf" assert seen.get("local_files_only") is True @@ -1275,43 +1184,35 @@ def _snap(*a, **k): def test_file_probe_uses_expanded_cache_dir(monkeypatch, tmp_path): - """The single-file cache probe must use the expanded cache_dir (HF expands ~ - before writing), or a finalized file under ~/hf-cache is missed and a child is - spawned for an already-cached file.""" + """The single-file probe uses the expanded cache_dir, else a finalized file under ~/hf-cache is missed.""" monkeypatch.setenv("HOME", str(tmp_path)) monkeypatch.setenv("USERPROFILE", str(tmp_path)) # Windows home var seen = {} def _probe(repo_id, filename, *, repo_type, revision, cache_dir): seen["cache_dir"] = cache_dir - return None # not cached -> falls through to the (faked) download seam + return None # not cached -> falls through to the faked seam monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", _probe) fake = _install(monkeypatch, [("ok", "/cache/x")]) xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cache_dir = "~/hfcache") assert seen["cache_dir"] == str(tmp_path / "hfcache") - # The expanded cache_dir is also what the download attempt receives. assert fake.calls[0].cache_dir == str(tmp_path / "hfcache") def test_pathlib_cache_dir_is_expanded(monkeypatch, tmp_path): - """A pathlib.Path cache_dir with ~ must be normalized too (HF accepts Path), or - the child writes under the literal '~/...' while the watchdog watches $HOME/... - and the stall is never detected.""" + """A pathlib.Path cache_dir with ~ is normalized too, else the child writes under '~/...' and the stall is undetected.""" monkeypatch.setenv("HOME", str(tmp_path)) monkeypatch.setenv("USERPROFILE", str(tmp_path)) fake = _install(monkeypatch, [("ok", "/cache/snap")]) xf.snapshot_download_with_xet_fallback( DL_REPO, token = None, cache_dir = Path("~/hfcache") ) - # Normalized to an expanded string for the child attempt + probes. assert fake.calls[0].cache_dir == str(tmp_path / "hfcache") def test_subfolder_forwarded_to_file_download(monkeypatch): - """A single-file caller passing subfolder must not get a TypeError; subfolder - is forwarded into the download params (and the cache probe uses the combined - '/' path).""" + """subfolder is forwarded into the download params and the probe uses the combined '/' path.""" probed = {} def _probe(repo_id, filename, *, repo_type, revision, cache_dir): @@ -1336,17 +1237,14 @@ def test_file_download_defaults_token_to_none(monkeypatch): def test_unrelated_partial_does_not_block_clean_cached_snapshot(hf_cache, monkeypatch): - """A clean requested snapshot must short-circuit in-process even when the same - repo cache holds a stale .incomplete from another (unrelated) revision: the fast - path validates only the returned snapshot dir, not the whole repo, so a sibling - mid-download does not force a needless re-fetch of a snapshot that is complete.""" + """A clean requested snapshot short-circuits in-process even with a stale unrelated .incomplete: the fast path validates only the returned dir.""" blobs = _blobs_dir(hf_cache, DL_REPO) repo_dir = blobs.parent snap = repo_dir / "snapshots" / "goodsha" snap.mkdir(parents = True) good = blobs / "good" good.write_bytes(b"weights") - (snap / "model.safetensors").symlink_to(good) # resolves -> snapshot is clean + (snap / "model.safetensors").symlink_to(good) # resolves -> snapshot clean (blobs / "other.incomplete").write_bytes(b"abc") # unrelated stale partial monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) fake = _install(monkeypatch, []) # must NOT spawn a child @@ -1368,30 +1266,25 @@ def boom(_message): def test_unclearable_partial_forces_clean_redownload(hf_cache, monkeypatch): - """When prep cannot clear an unsafe partial, the HTTP attempt forces a clean - re-download instead of an unsafe resume over the sparse partial.""" - # The autouse fixture makes _default_prepare_for_http a no-op (simulates a - # cleanup that left the partial in place). + """When prep cannot clear an unsafe partial, the HTTP attempt forces a clean re-download rather than resume over it.""" + # The autouse fixture makes _default_prepare_for_http a no-op (partial left in place). (_blobs_dir(hf_cache, DL_REPO) / "x.incomplete").write_bytes(b"abc") fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/x")]) out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) assert out == "/cache/x" assert fake.calls[0].force_download is False # Xet attempt: not forced - assert fake.calls[1].force_download is True # HTTP attempt: forced clean + assert fake.calls[1].force_download is True # HTTP attempt: forced clean -# --------------------------------------------------------------------------- # # Snapshot variant: in-process fast path on a warm cache, else watched download. -# --------------------------------------------------------------------------- # def test_snapshot_fast_path_no_child(hf_cache, monkeypatch): - """A fully cached repo (weights present) resolves in-process via local_files_only - -- no child attempt.""" + """A fully cached repo (weights present) resolves in-process via local_files_only, no child attempt.""" blobs = _blobs_dir(hf_cache, DL_REPO) snap = blobs.parent / "snapshots" / "sha" snap.mkdir(parents = True) weight = blobs / "w" weight.write_bytes(b"\0" * 16) - (snap / "model.safetensors").symlink_to(weight) # weights present -> complete + (snap / "model.safetensors").symlink_to(weight) # weight present -> complete (snap / "config.json").write_text("{}") seen = {} @@ -1408,8 +1301,7 @@ def _snap(*a, **k): def test_snapshot_dir_is_complete_unit(tmp_path): - """Weight presence drives completeness: a config-only snapshot is incomplete; one - with a resolvable weight file is complete.""" + """Weight presence drives completeness: config-only is incomplete, a resolvable weight is complete.""" snap = tmp_path / "snap" snap.mkdir() (snap / "config.json").write_text("{}") @@ -1451,13 +1343,7 @@ def test_snapshot_dir_is_complete_missing_shard(tmp_path): def test_snapshot_dir_is_complete_missing_shard_without_index(tmp_path): - """An interrupted multi-shard pull with NO index sidecar reads as incomplete. While the shards - are partial, the numbered shard name itself states the full set, so missing siblings are - detectable without a manifest. But even with EVERY shard on disk, a full warm is still - incomplete until model.safetensors.index.json is present: transformers' local from_pretrained - resolves a directory by probing model.safetensors then model.safetensors.index.json (never by - globbing model-*-of-*.safetensors), so a sharded checkpoint without its index raises rather than - loads, and the missing index would otherwise be fetched in-process over Xet.""" + """A multi-shard pull is incomplete until the index sidecar is present, even with every shard on disk (transformers needs the index to load).""" snap = tmp_path / "snap" snap.mkdir() blob = tmp_path / "blob" @@ -1467,7 +1353,7 @@ def test_snapshot_dir_is_complete_missing_shard_without_index(tmp_path): (snap / "model-00002-of-00003.safetensors").symlink_to(blob) assert hcs.snapshot_dir_is_complete(snap) is False # shard 3 still missing (snap / "model-00003-of-00003.safetensors").symlink_to(blob) - assert hcs.snapshot_dir_is_complete(snap) is False # all shards present but no index sidecar + assert hcs.snapshot_dir_is_complete(snap) is False # all shards present, no index (snap / "model.safetensors.index.json").write_text( json.dumps( { @@ -1483,9 +1369,7 @@ def test_snapshot_dir_is_complete_missing_shard_without_index(tmp_path): def test_snapshot_dir_is_complete_ignores_trainer_artifacts(tmp_path): - """Trainer / optimizer state files carry weight suffixes (.bin/.pt/.pth) but are not - loadable model weights. A checkpoint cache holding only those must read as incomplete - so the killable child still fetches the real weights.""" + """Trainer/optimizer state files (.bin/.pt/.pth) are not loadable weights, so a cache holding only those reads incomplete.""" snap = tmp_path / "snap" snap.mkdir() blob = tmp_path / "blob" @@ -1502,15 +1386,13 @@ def test_snapshot_dir_is_complete_ignores_trainer_artifacts(tmp_path): def test_snapshot_dir_is_complete_checkpoint_index_does_not_gate_root(tmp_path): - """A per-checkpoint shard index with missing shards must not fail an unpatterned root warm: - the root weights are what the load reads, so an incomplete checkpoint index is irrelevant to - root completeness (and a complete root weight set is enough).""" + """A per-checkpoint shard index with missing shards does not fail an unpatterned root warm (the root weights are what loads).""" snap = tmp_path / "snap" (snap / "checkpoint-7").mkdir(parents = True) blob = tmp_path / "blob" blob.write_bytes(b"x") (snap / "model.safetensors").symlink_to(blob) # complete root weight - # An incomplete checkpoint shard index (shard 2 missing) lives under checkpoint-7/. + # Incomplete checkpoint shard index (shard 2 missing) under checkpoint-7/. (snap / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob) (snap / "checkpoint-7" / "model.safetensors.index.json").write_text( json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", @@ -1521,9 +1403,7 @@ def test_snapshot_dir_is_complete_checkpoint_index_does_not_gate_root(tmp_path): def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): - """HF's local_files_only returns a config-only snapshot (e.g. left by an earlier - AutoConfig fetch) without checking weights. The fast path must reject it and complete - the download in the killable child rather than load with missing weights.""" + """The fast path rejects a config-only snapshot HF's local_files_only may return, deferring to the killable child.""" blobs = _blobs_dir(hf_cache, DL_REPO) snap = blobs.parent / "snapshots" / "sha" snap.mkdir(parents = True) @@ -1537,9 +1417,7 @@ def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): def test_fast_path_requires_each_named_weight(hf_cache, monkeypatch): - """The pre-download cache short-circuit must not accept a stale snapshot holding only one - of several explicitly named weights: base + adapter requested, only the base cached -> the - guarded child still runs to fetch the rest (Codex #829).""" + """The pre-download short-circuit rejects a cache holding only one of several explicitly named weights (base+adapter, only base cached).""" blobs = _blobs_dir(hf_cache, DL_REPO) snap = blobs.parent / "snapshots" / "sha" snap.mkdir(parents = True) @@ -1556,9 +1434,7 @@ def test_fast_path_requires_each_named_weight(hf_cache, monkeypatch): def test_child_broken_snapshot_retries_over_http(monkeypatch, tmp_path): - """A real but broken child snapshot result (HF offline-fallback returning a dir with - dangling symlinks) is rejected on the Xet attempt and retried over HTTP; a clean - second result is accepted.""" + """A broken child snapshot (dangling symlinks) is rejected on Xet and retried over HTTP; a clean second result is accepted.""" broken = tmp_path / "broken" broken.mkdir() (broken / "model.safetensors").symlink_to(tmp_path / "missing") # dangling @@ -1574,8 +1450,7 @@ def test_child_broken_snapshot_retries_over_http(monkeypatch, tmp_path): def test_child_broken_snapshot_after_http_raises(monkeypatch, tmp_path): - """If even the HTTP attempt returns a broken snapshot, fail loudly rather than hand - missing files to the load.""" + """If even the HTTP attempt returns a broken snapshot, fail loudly rather than hand missing files to the load.""" broken = tmp_path / "broken" broken.mkdir() (broken / "model.safetensors").symlink_to(tmp_path / "missing") @@ -1586,10 +1461,7 @@ def test_child_broken_snapshot_after_http_raises(monkeypatch, tmp_path): def test_child_weight_incomplete_snapshot_retries_over_http(monkeypatch, tmp_path): - """A child result with no weight files (HF silently returning a stale config-only - snapshot on an offline / timed-out request) is rejected on the Xet attempt and retried - over HTTP; a complete second result is accepted. The helper warms model repos, so a - weight-less result means the download did not finish, not that the repo is weightless.""" + """A weight-less child result (stale config-only snapshot) is rejected on Xet and retried; a complete second result is accepted.""" cfg_only = tmp_path / "cfg" cfg_only.mkdir() (cfg_only / "config.json").write_text("{}") # no weights @@ -1605,12 +1477,10 @@ def test_child_weight_incomplete_snapshot_retries_over_http(monkeypatch, tmp_pat def test_patterned_snapshot_without_weights_is_accepted(monkeypatch, tmp_path): - """A patterned download (allow_patterns) legitimately returns only the requested files - (e.g. config / tokenizer, no model weights). The child result must be accepted as-is, - not rejected and retried for lacking weights.""" + """A patterned download that returns only the requested (weightless) files is accepted as-is, not retried for lacking weights.""" cfg_only = tmp_path / "cfg" cfg_only.mkdir() - (cfg_only / "config.json").write_text("{}") # exactly what was requested, no weights + (cfg_only / "config.json").write_text("{}") # exactly what was requested fake = _install(monkeypatch, [("ok", str(cfg_only))]) out = xf.snapshot_download_with_xet_fallback( DL_REPO, token = None, allow_patterns = ["config.json"] @@ -1619,8 +1489,7 @@ def test_patterned_snapshot_without_weights_is_accepted(monkeypatch, tmp_path): def test_dataset_snapshot_without_weights_is_accepted(monkeypatch, tmp_path): - """A non-model snapshot (repo_type='dataset') has no model weights by nature; its - child result must be accepted rather than retried/raised as 'incomplete'.""" + """A dataset snapshot has no weights by nature; its child result is accepted, not retried as 'incomplete'.""" files = tmp_path / "ds" files.mkdir() (files / "data.json").write_text("[]") @@ -1630,9 +1499,7 @@ def test_dataset_snapshot_without_weights_is_accepted(monkeypatch, tmp_path): def test_model_snapshot_with_weights_excluded_is_accepted(monkeypatch, tmp_path): - """A model repo fetched with ignore_patterns that drop every weight format (e.g. to - warm only config / tokenizer files) legitimately yields a weightless snapshot; the - result must be accepted, not rejected for lacking weights.""" + """A model repo whose ignore_patterns drop every weight format yields a weightless snapshot that is accepted, not retried.""" cfg_only = tmp_path / "cfg" cfg_only.mkdir() (cfg_only / "config.json").write_text("{}") @@ -1649,9 +1516,7 @@ def test_model_snapshot_with_weights_excluded_is_accepted(monkeypatch, tmp_path) def test_request_can_include_weights_unit(): - """Unsloth's default prefetch ignores (onnx/h5/msgpack/gguf, never safetensors) still - count as including weights, so model warmups keep requiring them; excluding every - weight format does not.""" + """Default prefetch ignores (onnx/h5/msgpack/gguf) still count as weight-including; excluding every weight format does not.""" assert hcs.request_can_include_weights(None, None) is True assert hcs.request_can_include_weights(None, ["*.onnx", "*.h5", "*.msgpack", "*.gguf"]) is True assert hcs.request_can_include_weights(["config.json"], None) is False @@ -1662,22 +1527,17 @@ def test_request_can_include_weights_unit(): def test_request_can_include_weights_index_json_only(): - """A metadata-only request that matches the shard *index* sidecars but no real weight - file must read as weightless: the index is JSON, not weights, so a JSON-only warmup - (allow_patterns=['*.json'] or ['*.index.json']) should not be forced to land shards.""" + """A request matching only shard *index* sidecars reads weightless (the index is JSON, not weights); a real weight pattern does not.""" assert hcs.request_can_include_weights(["*.json"], None) is False assert hcs.request_can_include_weights(["*.index.json"], None) is False assert hcs.request_can_include_weights( ["model.safetensors.index.json", "pytorch_model.bin.index.json"], None ) is False - # A real weight pattern still counts as including weights. assert hcs.request_can_include_weights(["*.safetensors"], None) is True def test_request_can_include_weights_path_qualified(): - """Path-qualified allow_patterns must be resolved inside their directory, and a bare - non-first shard recognized, so a subfolder / checkpoint / specific-shard weight request - is not misread as weightless (which would skip the killable child).""" + """Path-qualified allow_patterns resolve inside their directory, so a subfolder/checkpoint/shard weight request is not misread as weightless.""" # Concrete subfolder globs: weights live under the directory. assert hcs.request_can_include_weights(["checkpoint-10/*"], None) is True assert hcs.request_can_include_weights(["checkpoint-10/*.safetensors"], None) is True @@ -1685,22 +1545,20 @@ def test_request_can_include_weights_path_qualified(): # A specific (non-first) shard named verbatim. assert hcs.request_can_include_weights(["model-00002-of-00005.safetensors"], None) is True assert hcs.request_can_include_weights(["checkpoint-10/pytorch_model.bin"], None) is True - # Globbed parent dir, weight-targeting basename -> can include weights. + # Globbed parent dir, weight-targeting basename. assert hcs.request_can_include_weights(["checkpoint-*/*.safetensors"], None) is True - # Subfolder requests that target only non-weight files stay weightless. + # Subfolder requests targeting only non-weight files stay weightless. assert hcs.request_can_include_weights(["checkpoint-10/config.json"], None) is False assert hcs.request_can_include_weights(["checkpoint-10/*.json"], None) is False assert hcs.request_can_include_weights(["checkpoint-*/tokenizer.json"], None) is False - # The unsloth subfolder warmup shape: "/*" plus root aux files -> weights expected. + # The unsloth subfolder warmup shape: "/*" plus root aux files. assert hcs.request_can_include_weights( ["checkpoint-10/*", "config.json", "tokenizer.json"], None ) is True def test_request_can_include_weights_path_qualified_custom_globs(): - """A path-qualified custom weight glob (checkpoint-10/lora_*.safetensors, with a globbed - parent too) names a weight whose basename matches no canonical probe; it must read as - weight-including via a concretized self-probe, not weightless (Codex #829).""" + """A path-qualified custom weight glob (checkpoint-10/lora_*.safetensors) reads weight-including via a concretized self-probe.""" assert hcs.request_can_include_weights(["checkpoint-10/lora_*.safetensors"], None) is True assert hcs.request_can_include_weights(["checkpoint-*/lora_*.bin"], None) is True assert hcs.request_can_include_weights(["models/custom_*.pt"], None) is True @@ -1710,16 +1568,12 @@ def test_request_can_include_weights_path_qualified_custom_globs(): def test_request_can_include_weights_empty_allow_list(tmp_path): - """allow_patterns=[] is a real filter that selects NO objects (Hugging Face semantics), so - the request is weightless -- it must not collapse with None (an unfiltered warmup) and - reject a legitimately empty snapshot (Codex #829). ignore_patterns=[] ignores nothing, so - it stays weight-including.""" + """allow_patterns=[] selects nothing (weightless, distinct from None); ignore_patterns=[] ignores nothing (weight-including).""" assert hcs.request_can_include_weights([], None) is False assert hcs.request_can_include_weights(None, None) is True assert hcs.request_can_include_weights(None, []) is True assert hcs.request_can_include_weights([], []) is False - # snapshot_dir_is_complete agrees: allow=[] is a scoped (select-nothing) request, not a full - # warmup, so a snapshot carrying an unrelated weight is not read as complete for it. + # snapshot_dir_is_complete agrees: allow=[] is a select-nothing request, so an unrelated weight is not complete for it. snap = tmp_path / "snap" snap.mkdir() blob = tmp_path / "blob" @@ -1730,9 +1584,7 @@ def test_request_can_include_weights_empty_allow_list(tmp_path): def test_request_can_include_weights_string_form(): - """Hugging Face accepts allow / ignore patterns as a bare string; it must be treated as - one pattern, not iterated character by character (which would misclassify a subfolder - weight request as weightless).""" + """A bare-string allow/ignore pattern is treated as one pattern, not iterated char by char.""" assert hcs.request_can_include_weights("checkpoint-10/*", None) is True assert hcs.request_can_include_weights("*.safetensors", None) is True assert hcs.request_can_include_weights("config.json", None) is False @@ -1745,15 +1597,13 @@ def test_request_can_include_weights_string_form(): def test_prepare_for_http_spares_active_sibling_partial(hf_cache): - """The generic HTTP-prep purge must not unlink a concurrent download's still-active - .incomplete temp file: only stale (old-mtime) partials are removed, so a sibling - download of another file in the same repo keeps writing safely.""" + """The HTTP-prep purge removes only stale (old-mtime) partials, so a concurrent sibling's active .incomplete keeps writing.""" blobs = _blobs_dir(hf_cache, DL_REPO) stale = blobs / "stalled.incomplete" stale.write_bytes(b"\0" * 16) active = blobs / "sibling.incomplete" active.write_bytes(b"\0" * 16) - # Age the stalled partial well past the active-partial grace; leave the sibling current. + # Age the stalled partial past the grace; leave the sibling current. old = time.time() - 600 os.utime(stale, (old, old)) _REAL_DEFAULT_PREPARE("model", DL_REPO, cache_dir = str(hf_cache)) @@ -1773,8 +1623,7 @@ def test_snapshot_stall_then_http(monkeypatch): def test_force_download_skips_fast_path_and_threads(monkeypatch): - """force_download=True must bypass the warm-cache short-circuit and re-fetch in - the killable child, forwarding force_download into the download params.""" + """force_download=True bypasses the warm-cache short-circuit and threads force_download into the download params.""" def _snap(*a, **k): pytest.fail("force_download must not take the local_files_only fast path") @@ -1786,8 +1635,7 @@ def _snap(*a, **k): def test_force_download_file_skips_cache_probe(monkeypatch, tmp_path): - """The single-file path must also skip the cached-blob short-circuit and thread - force_download through when force_download=True.""" + """The single-file path also skips the cached-blob short-circuit and threads force_download through.""" cached = tmp_path / "cached.gguf" cached.write_bytes(b"\0" * 8) monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: str(cached)) @@ -1797,10 +1645,8 @@ def test_force_download_file_skips_cache_probe(monkeypatch, tmp_path): assert len(fake.calls) == 1 and fake.calls[0].force_download is True -# --------------------------------------------------------------------------- # -# Precondition: HF_HUB_DISABLE_XET is read at import time, so assert its effect -# in a FRESH interpreter (huggingface/huggingface_hub#3266 once ignored it). -# --------------------------------------------------------------------------- # +# Precondition: HF_HUB_DISABLE_XET is read at import time, so assert its effect in a +# FRESH interpreter (huggingface/huggingface_hub#3266 once ignored it). def _safe_path() -> str: import os @@ -1841,13 +1687,9 @@ def test_default_leaves_xet_enabled(): ) -# --------------------------------------------------------------------------- # # Exported Xet knobs + child-leak safety + malformed-index resilience. -# --------------------------------------------------------------------------- # def test_xet_availability_and_disable_helpers(monkeypatch): - """The exported Xet knobs: child_should_disable_xet reads the per-worker config flag; - xet_force_disabled honors every documented env knob; is_hf_xet_available reflects the - importability of hf_xet and treats a probe error as 'unavailable'.""" + """child_should_disable_xet reads the per-worker flag; xet_force_disabled honors every env knob; is_hf_xet_available probes importability.""" assert xf.child_should_disable_xet({"disable_xet": True}) is True assert xf.child_should_disable_xet({"disable_xet": False}) is False assert xf.child_should_disable_xet({}) is False @@ -1868,18 +1710,16 @@ def _raise(name): raise ImportError("boom") monkeypatch.setattr(xf.importlib.util, "find_spec", _raise) - assert xf.is_hf_xet_available() is False # a probe exception -> treated as unavailable + assert xf.is_hf_xet_available() is False # a probe exception -> unavailable def test_run_attempt_terminates_child_if_watchdog_start_raises(monkeypatch): - """If start_watchdog raises (e.g. thread/FD exhaustion: 'can't start new thread') AFTER the - download child has already spawned, the child must STILL be reaped -- no leaked process. The - error then propagates (a watchdog-start failure is not a transport fault to retry over HTTP).""" + """If start_watchdog raises after the child spawned, the child is still reaped (no leak) and the error propagates.""" rec = {"terminated": False} class _AliveProc: def __init__(self): - self.pid = None # None -> _terminate_process_group skips killpg, uses terminate() + self.pid = None # None -> uses terminate(), not killpg self.exitcode = None self._alive = True @@ -1923,15 +1763,9 @@ def _boom(*a, **k): assert rec["terminated"] is True # child reaped despite the watchdog-start failure -# --------------------------------------------------------------------------- # # Codex review round: scoped completeness, weightless named files, type preservation. -# --------------------------------------------------------------------------- # - - def test_requested_named_files_present_exact_request(tmp_path): - """An EXACT-named weightless request (allow=['tokenizer.json'], no glob) requires its named file - on disk; a config-only snapshot must not pass. A glob list or no allow_patterns is best-effort - (Codex #829).""" + """An EXACT-named weightless request requires its named file on disk; a glob list or no allow_patterns is best-effort.""" snap = tmp_path / "snap" snap.mkdir() (snap / "config.json").write_text("{}") @@ -1942,15 +1776,14 @@ def test_requested_named_files_present_exact_request(tmp_path): assert hcs.requested_named_files_present(snap, allow_patterns = ["tokenizer*", "vocab.txt"]) is True # No allow_patterns -> trivially satisfied. assert hcs.requested_named_files_present(snap) is True - # An ignore-filtered name is not actually requested, so its absence does not fail. + # An ignore-filtered name is not requested, so its absence does not fail. assert hcs.requested_named_files_present( snap, allow_patterns = ["tokenizer.json", "absent.json"], ignore_patterns = ["absent.json"] ) is True def test_deterministic_oserror_type_preserved(monkeypatch): - """A deterministic disk-full OSError is re-raised as OSError (not flattened to RuntimeError), so a - caller's `except OSError` cleanup still runs across the spawn boundary (Codex #829).""" + """A deterministic disk-full OSError is re-raised as OSError across the spawn boundary, not flattened to RuntimeError.""" fake = _install(monkeypatch, [("error", "OSError: [Errno 28] No space left on device")]) with pytest.raises(OSError, match = "No space left"): xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) @@ -1958,8 +1791,7 @@ def test_deterministic_oserror_type_preserved(monkeypatch): def test_unknown_error_falls_back_to_runtimeerror(monkeypatch): - """An unrecognized error class name still surfaces (as RuntimeError, the prior behavior) without - a transport fallback -- only KNOWN deterministic Hub / OS types are reconstructed (Codex #829).""" + """An unrecognized error class name surfaces as RuntimeError without a fallback; only known Hub/OS types are reconstructed.""" fake = _install(monkeypatch, [("error", "SomeWeirdError: kaboom")]) with pytest.raises(RuntimeError, match = "kaboom"): xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) @@ -1967,8 +1799,7 @@ def test_unknown_error_falls_back_to_runtimeerror(monkeypatch): def test_resolve_exception_class_maps_known_names(): - """The reconstruction map resolves the documented deterministic Hub error names + OSError, and - returns None (-> RuntimeError) for an unknown name (Codex #829).""" + """The map resolves known deterministic Hub error names + OSError, and returns None for an unknown name.""" assert xf._resolve_exception_class("OSError") is OSError cls = xf._resolve_exception_class("RepositoryNotFoundError") assert cls is not None and issubclass(cls, BaseException) @@ -1976,10 +1807,7 @@ def test_resolve_exception_class_maps_known_names(): def test_error_type_preserved_when_constructor_needs_kwarg(monkeypatch): - """A Hub error class whose constructor rejects a lone positional string (newer huggingface_hub - makes HfHubHTTPError's `response` required / keyword-only) must STILL be re-raised with its type - preserved -- via an __init__-bypassing reconstruction -- not silently downgraded to RuntimeError - (Codex #829).""" + """A Hub error whose constructor rejects a lone positional string is still re-raised with its type via an __init__-bypassing reconstruction.""" class PickyHubError(Exception): def __init__(self, message, *, response): # response required + keyword-only super().__init__(message) @@ -1996,9 +1824,7 @@ def __init__(self, message, *, response): # response required + keyword-only def test_instantiate_preserving_type_paths(): - """Direct coverage of the layered reconstruction: a normal constructor is used when it accepts a - string; a keyword-only-required constructor falls through to the __new__ bypass; both yield an - instance of the requested type carrying the message (Codex #829).""" + """Layered reconstruction: a normal constructor is used when it accepts a string, else the __new__ bypass; both carry the message.""" class Plain(Exception): pass @@ -2012,11 +1838,7 @@ def __init__(self, message, *, response): assert "the message" in str(exc) -# --------------------------------------------------------------------------- # # Codex round: dir/ wildcard, logical-weight grouping post-download, errno preservation. -# --------------------------------------------------------------------------- # - - def test_parse_errno(): assert xf._parse_errno("OSError: [Errno 28] No space left on device") == 28 assert xf._parse_errno("OSError: [Errno 122] Disk quota exceeded") == 122 @@ -2024,8 +1846,7 @@ def test_parse_errno(): def test_oserror_errno_preserved(monkeypatch): - """A disk-full child OSError keeps its errno (ENOSPC) across the spawn boundary, so a caller's - `except OSError` cleanup can still branch on exc.errno -- not see errno=None (Codex #829).""" + """A disk-full child OSError keeps its errno (ENOSPC) across the spawn boundary, not errno=None.""" import errno as _errno fake = _install(monkeypatch, [("error", "OSError: [Errno 28] No space left on device")]) @@ -2036,10 +1857,7 @@ def test_oserror_errno_preserved(monkeypatch): def test_oserror_subclass_errno_preserved(monkeypatch): - """An OSError SUBCLASS (PermissionError from an unwritable cache) keeps BOTH its type AND its errno - across the spawn boundary, so a caller branching on exc.errno still matches (Gemini #829). Errno is - set as an attribute, so it survives even for a subclass whose constructor rejects (errno, strerror); - the message is not double-prefixed with the errno.""" + """An OSError subclass (PermissionError) keeps both its type and errno across the spawn boundary; the message is not double-prefixed.""" import errno as _errno fake = _install(monkeypatch, [("error", "PermissionError: [Errno 13] Permission denied")]) @@ -2051,17 +1869,13 @@ def test_oserror_subclass_errno_preserved(monkeypatch): def test_raise_child_error_errno_only_for_builtin_oserror(): - """errno is preserved only for a BUILTIN OSError type (a real OS errno), set via attribute so it - survives a builtin whose __init__ rejects the (errno, strerror) form. A NON-builtin OSError subclass - -- an HF HTTP error subclasses OSError via requests -> IOError -- with a bracketed [Errno N] in its - message must NOT get a spurious errno (#829 re-review).""" + """errno is preserved only for a BUILTIN OSError type; a non-builtin OSError subclass (e.g. HfHubHTTPError) with a bracketed [Errno N] gets no spurious errno.""" # Builtin OSError subclass -> errno preserved. with pytest.raises(FileNotFoundError) as excinfo: xf._raise_child_error("FileNotFoundError: [Errno 2] No such file or directory") assert excinfo.value.errno == 2 - # A non-builtin OSError subclass (simulating HfHubHTTPError) whose message merely contains a - # bracketed [Errno N] must NOT have it mistaken for a real OS errno. + # Non-builtin OSError subclass whose message merely contains a bracketed [Errno N]. class _FakeHubHTTPError(OSError): def __init__(self, message): # single-arg, like hf_hub's error types super().__init__(message) @@ -2078,16 +1892,9 @@ def __init__(self, message): # single-arg, like hf_hub's error types xf._resolve_exception_class = orig -# --------------------------------------------------------------------------- # # Spawn-safety regressions: failed-spawn queue cleanup + disable-Xet env-race lock. -# --------------------------------------------------------------------------- # - - def test_failed_spawn_closes_result_queue(monkeypatch): - """R2-2: if proc.start() raises (e.g. OSError "can't start new process" under fd / thread - exhaustion), the result_queue's OS pipe fds -- allocated before the spawn -- must be closed - rather than leaked. The lifecycle try/finally that closes them is only entered after a - successful start, so a dedicated except around the spawn must close the queue and re-raise.""" + """R2-2: if proc.start() raises, the result_queue's pipe fds (allocated before the spawn) are closed, not leaked.""" closed = {"cancel_join": False, "close": False} class _FakeQueue: @@ -2131,17 +1938,12 @@ def Process(self, *a, **k): def test_disable_xet_read_under_spawn_lock(monkeypatch): - """R2-1: _download_with_xet_fallback must read xet_force_disabled() while holding - _SPAWN_ENV_LOCK. A concurrent download briefly sets the child-only HF_HUB_DISABLE_XET=1 in the - parent os.environ around its spawn (under the same lock); reading the live env outside the lock - could observe that value and wrongly force THIS download onto HTTP from the first attempt.""" + """R2-1: xet_force_disabled() is read while holding _SPAWN_ENV_LOCK, so a concurrent spawn's transient HF_HUB_DISABLE_XET=1 is not observed.""" seen = {} real = xf.xet_force_disabled def _spy(): - # A plain (non-reentrant) Lock cannot be re-acquired by its owner, so a non-blocking acquire - # FAILS iff the read is happening inside `with _SPAWN_ENV_LOCK:`. If the read were outside the - # lock, the acquire would succeed. + # A non-reentrant Lock's non-blocking acquire fails iff the read is inside `with _SPAWN_ENV_LOCK:`. got = xf._SPAWN_ENV_LOCK.acquire(blocking = False) if got: xf._SPAWN_ENV_LOCK.release() @@ -2168,13 +1970,10 @@ def _spy(): assert seen.get("held") is True -# --------------------------------------------------------------------------- # -# Conservative fast-path gate + pre/post-download acceptance split (PR #829 trim). -# The completeness gate is intentionally narrow: it fast-paths ONLY the unambiguous -# canonical model cache, deferring everything else to the watched snapshot_download -# child. The pre-download (skip the child?) and post-download (accept the result?) -# checks are deliberately asymmetric -- strict pre, lenient post. -# --------------------------------------------------------------------------- # +# Conservative fast-path gate + pre/post-download acceptance split. The gate fast-paths +# ONLY the unambiguous canonical model cache, deferring everything else to the watched +# child. Pre-download (skip the child?) and post-download (accept the result?) are +# deliberately asymmetric: strict pre, lenient post. def _mk_snapshot(tmp_path, name): blob = tmp_path / "_blob" if not blob.exists(): @@ -2185,7 +1984,7 @@ def _mk_snapshot(tmp_path, name): def test_gate_fast_paths_canonical_single_file(tmp_path): - """A complete, unpatterned single-file model cache is fast-path eligible (skip the child).""" + """A complete, unpatterned single-file model cache is fast-path eligible.""" snap, blob = _mk_snapshot(tmp_path, "single") (snap / "model.safetensors").symlink_to(blob) (snap / "config.json").write_text("{}") @@ -2193,9 +1992,7 @@ def test_gate_fast_paths_canonical_single_file(tmp_path): def test_gate_fast_paths_canonical_sharded_with_index(tmp_path): - """A complete sharded model with its index sidecar is fast-path eligible; without the index, or - with a listed shard missing, it is not (transformers cannot load a local sharded checkpoint - without a complete index).""" + """A complete sharded model with its index is fast-path eligible; without the index or with a listed shard missing, it is not.""" snap, blob = _mk_snapshot(tmp_path, "shard") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) (snap / "model-00002-of-00002.safetensors").symlink_to(blob) @@ -2213,9 +2010,7 @@ def test_gate_fast_paths_canonical_sharded_with_index(tmp_path): def test_shard_index_with_non_string_value_is_incomplete(tmp_path): - """A malformed shard index mapping a tensor to a non-string value (e.g. null) is NOT complete even - when the remaining string-mapped shard is present -- transformers cannot load it, so fail closed and - defer to the watched child rather than silently dropping the bad entry (Codex #829).""" + """A shard index mapping a tensor to a non-string value (null) is incomplete: fail closed and defer to the child.""" snap, blob = _mk_snapshot(tmp_path, "badindex") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) (snap / "model.safetensors.index.json").write_text( @@ -2225,10 +2020,7 @@ def test_shard_index_with_non_string_value_is_incomplete(tmp_path): def test_gate_defers_incomplete_preferred_index_masked_by_complete_bin(tmp_path): - """A present-but-incomplete safetensors index must not be masked by a complete pytorch_model.bin: - transformers probes the safetensors index BEFORE the bin, so the load would fetch the missing - safetensors shards over un-killable Xet. The gate defers unless safetensors is explicitly ignored - (then the load reads the bin) (Codex #829).""" + """An incomplete safetensors index (probed before the bin) is not masked by a complete bin; the gate defers unless safetensors is ignored.""" snap, blob = _mk_snapshot(tmp_path, "prefidx") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) # ST shard 2 absent -> incomplete index (snap / "model.safetensors.index.json").write_text( @@ -2236,18 +2028,15 @@ def test_gate_defers_incomplete_preferred_index_masked_by_complete_bin(tmp_path) "b": "model-00002-of-00002.safetensors"}})) (snap / "pytorch_model.bin").symlink_to(blob) # complete bin co-resident assert hcs.snapshot_dir_is_complete(snap) is False # load prefers the incomplete safetensors - # safetensors explicitly ignored -> the load reads the complete bin -> eligible. + # safetensors ignored -> the load reads the complete bin -> eligible. assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.safetensors"]) is True - # A COMPLETE safetensors index alongside the bin is eligible. + # A complete safetensors index alongside the bin is eligible. (snap / "model-00002-of-00002.safetensors").symlink_to(blob) assert hcs.snapshot_dir_is_complete(snap) is True def test_gate_rejects_sharded_adapter_only_root_cache(tmp_path): - """A complete sharded ADAPTER at the root (adapter_model.safetensors.index.json + its shards) is - NOT a canonical base model: only model/pytorch_model index names gate the fast path. A base+adapter - repo whose cache holds only the adapter must defer to the child, else a default from_pretrained - base load fetches the missing base weights over un-killable Xet.""" + """A complete sharded adapter at root is not a canonical base model, so an adapter-only cache defers to the child.""" assert hcs._is_canonical_weight_shard_index("adapter_model.safetensors.index.json") is False assert hcs._is_canonical_weight_shard_index("model.safetensors.index.json") is True assert hcs._is_canonical_weight_shard_index("pytorch_model.bin.index.json") is True @@ -2270,8 +2059,7 @@ def test_gate_rejects_config_only(tmp_path): def test_gate_rejects_diffusers_marker(tmp_path): - """A diffusers pipeline (root model_index.json) is never fast-pathed -> defer to the child, - even when a root-level weight happens to be present.""" + """A diffusers pipeline (root model_index.json) is never fast-pathed, even with a root-level weight present.""" snap, blob = _mk_snapshot(tmp_path, "diff") (snap / "model_index.json").write_text("{}") (snap / "model.safetensors").symlink_to(blob) @@ -2279,7 +2067,7 @@ def test_gate_rejects_diffusers_marker(tmp_path): def test_gate_rejects_any_allow_pattern(tmp_path): - """Any allow_patterns makes the request non-trivial -> defer to the child (no fast-path).""" + """Any allow_patterns makes the request non-trivial, so no fast-path.""" snap, blob = _mk_snapshot(tmp_path, "pat") (snap / "model.safetensors").symlink_to(blob) assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is False @@ -2287,17 +2075,12 @@ def test_gate_rejects_any_allow_pattern(tmp_path): def test_gate_eligible_under_ignore_patterns(tmp_path): - """allow=None with ANY ignore patterns stays fast-path eligible: the canonical-weight presence - check verifies the surviving root weight is on disk, so ignores that drop other formats cannot - make an incomplete cache read complete. This covers the common bare from_pretrained warm, whose - real ignore list mixes root-level format excludes (*.onnx, *.gguf, *.pt, *.bin) with subdir - (*/*.safetensors) drops.""" + """allow=None with any ignore patterns stays eligible: ignores that drop other formats cannot make an incomplete cache read complete.""" snap, blob = _mk_snapshot(tmp_path, "ign") (snap / "model.safetensors").symlink_to(blob) assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*/*.safetensors", "*/*.bin"]) is True assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.onnx"]) is True - # The actual unsloth bare-from_pretrained combined ignore list (root-level format excludes + - # subdir-weight drops) -- the warm root model.safetensors must still fast-path. + # The real unsloth bare-from_pretrained ignore list; the warm root model.safetensors must still fast-path. unsloth_ignore = [ "*.onnx", "*.h5", "*.msgpack", "*.tflite", "*.mlmodel", "*.gguf", "*.pt", "*.pth", "*.ckpt", "optimizer.*", "scheduler.*", "rng_state*", "trainer_state.json", @@ -2327,10 +2110,7 @@ def test_request_can_include_weights_trim_semantics(): def test_request_can_include_weights_partial_ignore_strip_is_weight_bearing(): - """An ignore-only request is weightless ONLY when it strips EVERY weight format. A partial strip - -- only the canonical model.safetensors/pytorch_model.bin names while a variant survives, or only - some suffixes while a .pt / .gguf weight survives -- must read as weight-bearing, so the fast path - requires a real weight and never skips the protective child on a config-only cache (the Xet hang).""" + """An ignore-only request is weightless only when it strips EVERY weight format; a partial strip stays weight-bearing.""" r = hcs.request_can_include_weights assert r(None, ["model.safetensors", "pytorch_model.bin"]) is True # variant / other-format survives assert r(None, ["*.safetensors", "*.bin"]) is True # .pt / .gguf / .pth / ... survive @@ -2348,10 +2128,7 @@ def test_pre_download_skips_complete_model(tmp_path): def test_pre_download_defers_variant_on_canonical_cache(tmp_path): - """A variant load (variant="fp16") reads model..safetensors, which the canonical gate - does not check. A cache holding only the non-variant canonical weight must NOT fast-path when a - variant is requested -- else the in-process load fetches the missing variant over un-killable Xet. - Same cache, no variant, still fast-paths (the child is only spawned when actually needed).""" + """A variant load reads model..safetensors, so a canonical-only cache does not fast-path a variant request (but does with no variant).""" snap, blob = _mk_snapshot(tmp_path, "var") (snap / "model.safetensors").symlink_to(blob) assert xf._cache_can_skip_download( @@ -2362,34 +2139,28 @@ def test_pre_download_defers_variant_on_canonical_cache(tmp_path): def test_pre_download_defers_bin_only_when_safetensors_preferred(tmp_path): - """A default transformers load probes model.safetensors BEFORE pytorch_model.bin. A cache holding - only pytorch_model.bin cannot prove the repo has no safetensors, so the STRICT pre-download gate must - NOT fast-path it -- else the in-process load fetches the preferred model.safetensors over un-killable - Xet (Codex #829). It still fast-paths when safetensors is IGNORED (use_safetensors=False reads bin), - or when safetensors is present. The lenient POST path still accepts a finished bin-only download (a - genuinely bin-only repo), so a good download is never false-rejected.""" + """A default load probes model.safetensors before the bin, so the strict pre-gate defers a bin-only cache; the lenient post path accepts a finished bin-only download.""" snap, blob = _mk_snapshot(tmp_path, "binonly") (snap / "config.json").write_text("{}") (snap / "pytorch_model.bin").symlink_to(blob) - # PRE: safetensors preferred (not ignored) + bin-only -> defer to the child. + # PRE: safetensors preferred (not ignored) + bin-only -> defer. assert xf._cache_can_skip_download( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False - # PRE: use_safetensors=False (safetensors ignored) -> the bin cache fast-paths. + # PRE: safetensors ignored -> the bin cache fast-paths. assert xf._cache_can_skip_download( snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.safetensors", "*.safetensors.index.json"]) is True - # PRE: safetensors present -> fast-path (the common load is unaffected). + # PRE: safetensors present -> fast-path. (snap / "model.safetensors").symlink_to(blob) assert xf._cache_can_skip_download( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True - # POST stays LENIENT: a finished bin-only download is a genuinely bin-only repo -> accepted, not - # looped into a DownloadStallError. + # POST is lenient: a finished bin-only download is accepted. snap2, blob2 = _mk_snapshot(tmp_path, "binonly_post") (snap2 / "config.json").write_text("{}") (snap2 / "pytorch_model.bin").symlink_to(blob2) assert xf._download_result_usable( snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True - # POST: a sharded bin-only repo is likewise accepted (no false-reject). + # POST: a sharded bin-only repo is likewise accepted. snap3, blob3 = _mk_snapshot(tmp_path, "binonly_sharded_post") (snap3 / "pytorch_model-00001-of-00002.bin").symlink_to(blob3) (snap3 / "pytorch_model-00002-of-00002.bin").symlink_to(blob3) @@ -2401,9 +2172,7 @@ def test_pre_download_defers_bin_only_when_safetensors_preferred(tmp_path): def test_pre_download_does_not_skip_diffusers_but_post_accepts(tmp_path): - """The pre/post asymmetry: a diffusers warm is NOT fast-pathed (spawn the child), but the same - complete diffusers result IS accepted post-download (it has component weights), so a good - download is never re-looped into a stall error.""" + """Pre/post asymmetry: a diffusers warm is not fast-pathed, but the same complete result is accepted post-download.""" snap, blob = _mk_snapshot(tmp_path, "diff") (snap / "model_index.json").write_text("{}") comp = snap / "unet" @@ -2416,8 +2185,7 @@ def test_pre_download_does_not_skip_diffusers_but_post_accepts(tmp_path): def test_post_download_rejects_config_only_model(tmp_path): - """A model warm that came back with NO weight (HF handed back a stale config-only snapshot) is - rejected post-download and retried over HTTP.""" + """A model warm returning no weight (stale config-only snapshot) is rejected post-download and retried.""" snap, _ = _mk_snapshot(tmp_path, "cfg") (snap / "config.json").write_text("{}") assert xf._download_result_usable( @@ -2425,10 +2193,7 @@ def test_post_download_rejects_config_only_model(tmp_path): def test_post_download_rejects_ignored_only_format(tmp_path): - """snapshot_download silently returns a stale cache on a transient metadata error. A safetensors - load (ignore=['*.bin']) whose returned partial kept only the ignored pytorch_model.bin -- not the - requested model.safetensors -- must be rejected (the weight check applies the ignore filter) and - retried over HTTP, not loaded in-process (Codex #829).""" + """A safetensors load (ignore=['*.bin']) whose result kept only the ignored .bin is rejected (the weight check applies the ignore filter).""" snap, blob = _mk_snapshot(tmp_path, "ignfmt") (snap / "pytorch_model.bin").symlink_to(blob) assert xf._download_result_usable( @@ -2440,10 +2205,7 @@ def test_post_download_rejects_ignored_only_format(tmp_path): def test_post_download_rejects_canonical_only_for_variant(tmp_path): - """A variant load (variant='fp16') whose returned partial kept only the canonical model.safetensors - -- not model.fp16.safetensors -- must be rejected and retried, else the in-process load fetches the - missing variant over un-killable Xet (Codex #829). A present variant weight (single or sharded - infix) is accepted.""" + """A variant load whose result kept only the canonical (non-variant) weight is rejected; a present variant weight is accepted.""" snap, blob = _mk_snapshot(tmp_path, "varpost") (snap / "model.safetensors").symlink_to(blob) assert xf._download_result_usable( @@ -2453,8 +2215,7 @@ def test_post_download_rejects_canonical_only_for_variant(tmp_path): assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is True - # A COMPLETE sharded variant set (shards + index) is accepted; an incomplete one is retried - # (covered by test_post_download_rejects_incomplete_variant_shards). + # A complete sharded variant set (shards + index) is accepted. snap2, blob2 = _mk_snapshot(tmp_path, "varshard") (snap2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) (snap2 / "model.fp16-00002-of-00002.safetensors").symlink_to(blob2) @@ -2467,10 +2228,7 @@ def test_post_download_rejects_canonical_only_for_variant(tmp_path): def test_post_download_rejects_patterned_canonical_only_for_variant(tmp_path): - """A PATTERNED variant load (subfolder= + variant=, so allow=['weights/*']) whose returned partial - kept only the canonical weight in scope must be rejected -- the variant check applies to the - patterned branch too, not only allow=None (Codex #829). A present in-scope variant weight is - accepted, and a complete variant download is never false-rejected.""" + """The variant check applies to the patterned branch too: a subfolder variant request kept only the canonical weight is rejected.""" snap, blob = _mk_snapshot(tmp_path, "subvar") sub = snap / "weights" sub.mkdir() @@ -2478,12 +2236,12 @@ def test_post_download_rejects_patterned_canonical_only_for_variant(tmp_path): assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, variant = "fp16") is False - # The in-scope variant weight present -> complete -> accepted (no false-reject). + # The in-scope variant weight present -> accepted. (sub / "model.fp16.safetensors").symlink_to(blob) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, variant = "fp16") is True - # A COMPLETE sharded in-scope variant weight (dash infix + its variant index) is accepted. + # A complete sharded in-scope variant weight (dash infix + variant index) is accepted. snap2, blob2 = _mk_snapshot(tmp_path, "subvarshard") sub2 = snap2 / "weights" sub2.mkdir() @@ -2495,14 +2253,14 @@ def test_post_download_rejects_patterned_canonical_only_for_variant(tmp_path): assert xf._download_result_usable( snap2, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, variant = "fp16") is True - # A LONE variant shard with no index is an incomplete set the load cannot enumerate -> rejected. + # A lone variant shard with no index is an incomplete set -> rejected. snap2b, blob2b = _mk_snapshot(tmp_path, "subvarshard_lone") (snap2b / "weights").mkdir() (snap2b / "weights" / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2b) assert xf._download_result_usable( snap2b, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, variant = "fp16") is False - # An out-of-scope variant weight does NOT satisfy an in-scope variant request. + # An out-of-scope variant weight does not satisfy an in-scope variant request. snap3, blob3 = _mk_snapshot(tmp_path, "subvaroos") (snap3 / "model.fp16.safetensors").symlink_to(blob3) # at root, but request scopes to weights/ (snap3 / "weights").mkdir() @@ -2513,13 +2271,7 @@ def test_post_download_rejects_patterned_canonical_only_for_variant(tmp_path): def test_post_download_rejects_variant_only_diffusers_for_plain_load(tmp_path): - """A PLAIN diffusers warm (variant=None) whose returned partial kept only a prior variant='fp16' - download's component weights (unet/diffusion_pytorch_model.fp16.safetensors) must be rejected: the - plain pipeline load reads the NON-variant name, so accepting it would fetch the missing - diffusion_pytorch_model.safetensors in-process over un-killable Xet (Codex #829). A complete plain - pipeline, and a pipeline shipping both plain + variant, are still accepted; the variant='fp16' warm - of the same variant-only cache stays accepted (the plain restriction does not touch the variant - check).""" + """A plain diffusers warm (variant=None) whose result kept only variant component weights is rejected; complete plain / both-format pipelines are accepted.""" def _mi(**comps): data = {"_class_name": "StableDiffusionPipeline", "_diffusers_version": "0.21.0"} data.update(comps) @@ -2531,14 +2283,14 @@ def _mi(**comps): for comp in ("unet", "vae"): (snap / comp).mkdir() (snap / comp / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) - # plain load: variant-only components do not satisfy it -> retry over HTTP. + # plain load: variant-only components do not satisfy it -> retry. assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = None) is False - # the SAME cache is a complete fp16 warm -> the variant load accepts it (no regression). + # the same cache is a complete fp16 warm -> the variant load accepts it. assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is True - # a COMPLETE plain pipeline (non-variant component weights) is accepted (no false-reject). + # a complete plain pipeline (non-variant component weights) is accepted. snap2, blob2 = _mk_snapshot(tmp_path, "plaincomplete") (snap2 / "model_index.json").write_text( _mi(unet = ["diffusers", "UNet2DConditionModel"], vae = ["diffusers", "AutoencoderKL"])) @@ -2547,16 +2299,14 @@ def _mi(**comps): (snap2 / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob2) assert xf._download_result_usable( snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = None) is True - # a pipeline shipping BOTH plain + fp16 in a component is accepted for a plain load. + # a pipeline shipping both plain + fp16 in a component is accepted for a plain load. (snap2 / "unet" / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob2) assert xf._download_result_usable( snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = None) is True def test_post_download_rejects_incomplete_sharded_glob(tmp_path): - """A globbed weight request (allow=['*.safetensors']) whose returned partial has a canonical shard - index but is missing a shard must be rejected -- globs get the same shard-completeness check as the - unpatterned root path -- so the load does not finish the missing shard over Xet (Codex #829).""" + """A globbed weight request (allow=['*.safetensors']) with a shard index missing a shard is rejected (globs get the same shard-completeness check).""" snap, blob = _mk_snapshot(tmp_path, "shardglob") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) (snap / "model.safetensors.index.json").write_text( @@ -2571,18 +2321,14 @@ def test_post_download_rejects_incomplete_sharded_glob(tmp_path): def test_post_download_accepts_patterned_with_coresident_partial_canonical_shards(tmp_path): - """A COMPLETE patterned download (adapter / gguf / subfolder) whose selected weight the load reads is - present must be ACCEPTED even when an UNRELATED partial canonical base shard set is co-resident at - root (a leftover from a prior interrupted base pull). The canonical-shard gate is request-agnostic; - scoping it to requests that actually select canonical root shards avoids failing a working download - into a DownloadStallError (#829 re-review).""" + """A complete patterned download is accepted even with an unrelated partial canonical base shard set co-resident at root.""" def _partial_base_shards(snap, blob): (snap / "model-00001-of-00002.safetensors").symlink_to(blob) # shard 1 present (snap / "model-00002-of-00002.safetensors").symlink_to(snap / "MISSING") # dangling shard 2 (snap / "model.safetensors.index.json").write_text(json.dumps( {"weight_map": {"a": "model-00001-of-00002.safetensors", "b": "model-00002-of-00002.safetensors"}})) - # Adapter request completes; co-resident partial base shards must not force-reject it. + # Adapter request completes; co-resident partial base shards must not reject it. snap, blob = _mk_snapshot(tmp_path, "adapter_coresident") (snap / "adapter_model.safetensors").symlink_to(blob) (snap / "adapter_config.json").write_text("{}") @@ -2600,7 +2346,7 @@ def _partial_base_shards(snap, blob): snap2, repo_type = "model", allow_patterns = ["model.Q4_K_M.gguf", "config.json", "*.json"], ignore_patterns = None) is True - # A globbed weight request that DOES select canonical root shards still gets the completeness gate. + # A globbed weight request that DOES select canonical root shards still gets the gate. snap3, blob3 = _mk_snapshot(tmp_path, "glob_still_gated") _partial_base_shards(snap3, blob3) assert xf._download_result_usable( @@ -2608,10 +2354,7 @@ def _partial_base_shards(snap, blob): def test_post_download_rejects_incomplete_ignored_format_shards(tmp_path): - """An unpatterned load that ignores safetensors (so it reads .bin) whose returned partial has a - COMPLETE safetensors shard set but an INCOMPLETE .bin set must be rejected -- the shard gate applies - the ignore filter, so the complete safetensors does not mask the incomplete .bin the load actually - reads (else the in-process load finishes the missing .bin shard over Xet) (#829 re-review).""" + """A load ignoring safetensors (reads .bin) with a complete ST set but incomplete .bin set is rejected (the shard gate applies the ignore filter).""" snap, blob = _mk_snapshot(tmp_path, "ignored_format_shards") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) (snap / "model-00002-of-00002.safetensors").symlink_to(blob) @@ -2628,10 +2371,7 @@ def test_post_download_rejects_incomplete_ignored_format_shards(tmp_path): def test_post_download_rejects_incomplete_variant_shards(tmp_path): - """An unpatterned variant load whose returned partial has a variant shard INDEX but is missing a - listed variant shard must be rejected, else the in-process load finishes the missing variant shard - over Xet (#829 re-review). Positive-evidence only: a COMPLETE variant shard set and a SINGLE-FILE - variant are both accepted (a complete variant download is never false-rejected).""" + """A variant load with a variant shard index missing a listed shard is rejected; a complete set and a single-file variant are accepted.""" snap, blob = _mk_snapshot(tmp_path, "variant_incomplete") (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) # shard 1; shard 2 absent (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( @@ -2640,13 +2380,13 @@ def test_post_download_rejects_incomplete_variant_shards(tmp_path): assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is False - # A lone variant shard with NO index (a partial that never fetched the index) is also incomplete. + # A lone variant shard with no index is also incomplete. snap_noidx, blob_ni = _mk_snapshot(tmp_path, "variant_no_index") (snap_noidx / "model.fp16-00001-of-00002.safetensors").symlink_to(blob_ni) assert xf._download_result_usable( snap_noidx, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is False - # The missing variant shard present -> complete set -> accepted (no false-reject). + # The missing variant shard present -> complete set -> accepted. (snap / "model.fp16-00002-of-00002.safetensors").symlink_to(blob) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, @@ -2660,17 +2400,13 @@ def test_post_download_rejects_incomplete_variant_shards(tmp_path): def test_post_download_accepts_exact_named_shard_subset(tmp_path): - """A caller naming an EXACT shard file (allow=['model-00001-of-00002.safetensors']) asked for - precisely that file; once it is present the result is accepted, even though its sibling shard / index - is absent -- the whole-checkpoint completeness gate applies only to GLOBBED weight warms, not an - exact-named subset (else a satisfied request is failed into a DownloadStallError) (#829 re-review). - A named shard that is ABSENT is still rejected by the exact-files check.""" + """An exact-named shard request is accepted once that file is present (the whole-checkpoint gate applies only to globbed warms); an absent named shard is rejected.""" snap, blob = _mk_snapshot(tmp_path, "exact_shard_present") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = ["model-00001-of-00002.safetensors"], ignore_patterns = None) is True - # The exact-named shard absent -> rejected (nothing to load). + # The exact-named shard absent -> rejected. snap2, _ = _mk_snapshot(tmp_path, "exact_shard_absent") (snap2 / "config.json").write_text("{}") assert xf._download_result_usable( @@ -2679,11 +2415,7 @@ def test_post_download_accepts_exact_named_shard_subset(tmp_path): def test_post_download_accepts_from_tf_flax_weights(tmp_path): - """A from_tf / from_flax base load ignores BOTH PyTorch formats (ignore=['*.safetensors','*.bin', - ...]) and reads tf_model.h5 / flax_model.msgpack. A COMPLETE such download must be accepted, not - false-rejected into a DownloadStallError because the canonical safetensors/bin check found nothing - (#829 re-review). Gated on both PyTorch formats ignored, so a normal load and a stray leftover h5 do - not change.""" + """A from_tf/from_flax load (both PyTorch formats ignored) reading tf_model.h5 / flax_model.msgpack is accepted when complete.""" ig = ["*.safetensors", "*.safetensors.index.json", "*.bin", "*.bin.index.json"] for wt in ("tf_model.h5", "flax_model.msgpack"): snap, blob = _mk_snapshot(tmp_path, f"tf_{wt}") @@ -2691,13 +2423,12 @@ def test_post_download_accepts_from_tf_flax_weights(tmp_path): (snap / "config.json").write_text("{}") assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is True - # Both PyTorch formats ignored but NO h5/msgpack present -> still rejected (weight missing). + # Both PyTorch formats ignored but no h5/msgpack present -> still rejected. snap, _ = _mk_snapshot(tmp_path, "tf_none") (snap / "config.json").write_text("{}") assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is False - # A normal load (PyTorch format NOT ignored) is unchanged: a stray leftover h5 must NOT count as - # the readable weight, so a repo holding only tf_model.h5 is rejected for a default (non-tf) load. + # A normal load (PyTorch format not ignored): a stray leftover h5 does not count, so an h5-only repo is rejected. snap, blob = _mk_snapshot(tmp_path, "stray_h5") (snap / "tf_model.h5").symlink_to(blob) assert xf._download_result_usable( @@ -2705,11 +2436,7 @@ def test_post_download_accepts_from_tf_flax_weights(tmp_path): def test_post_download_checks_sharded_tf_flax_completeness(tmp_path): - """TF / Flax weights can be SHARDED (tf_model.h5.index.json / flax_model.msgpack.index.json). A - COMPLETE sharded set (index + all shards) is accepted, but an INCOMPLETE one (a shard missing, or a - lone shard with no index) must be rejected: the single-file regex no longer matches a lone shard, so - an incomplete sharded from_tf/from_flax download is retried over HTTP instead of loaded over Xet - (#829 re-review, sharded-TF false-accept).""" + """Sharded TF/Flax weights: a complete set (index + all shards) is accepted, an incomplete one (missing shard or lone shard, no index) rejected.""" ig = ["*.safetensors", "*.safetensors.index.json", "*.bin", "*.bin.index.json"] for base, ext in (("tf_model", "h5"), ("flax_model", "msgpack")): idx = json.dumps({"weight_map": {"a": f"{base}-00001-of-00002.{ext}", @@ -2727,7 +2454,7 @@ def test_post_download_checks_sharded_tf_flax_completeness(tmp_path): (snap2 / f"{base}.{ext}.index.json").write_text(idx) assert xf._download_result_usable( snap2, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is False - # A lone shard with NO index -> rejected (the load cannot enumerate the set). + # A lone shard with no index -> rejected. snap3, blob3 = _mk_snapshot(tmp_path, f"tfshard_lone_{base}") (snap3 / f"{base}-00001-of-00002.{ext}").symlink_to(blob3) assert xf._download_result_usable( @@ -2735,11 +2462,7 @@ def test_post_download_checks_sharded_tf_flax_completeness(tmp_path): def test_post_download_checks_explicit_checkpoint_shard_completeness(tmp_path): - """An EXPLICIT checkpoint load (subfolder=checkpoint-N -> allow=['checkpoint-N/*']) reads the - checkpoint's weights, so a lone numbered shard there with no index must be rejected, not skipped as a - 'leftover checkpoint subtree' (#829 re-review, checkpoint false-accept). A complete checkpoint shard - set is accepted; a leftover checkpoint the request does NOT target (subfolder=unet) is still ignored - so a complete in-scope download is not false-rejected.""" + """An explicit checkpoint load (allow=['checkpoint-N/*']) reads its weights, so a lone shard there with no index is rejected; a complete set / untargeted leftover is fine.""" # Lone checkpoint shard, no index, explicitly requested -> rejected. snap, blob = _mk_snapshot(tmp_path, "ckpt_lone") (snap / "checkpoint-7").mkdir() @@ -2756,8 +2479,7 @@ def test_post_download_checks_explicit_checkpoint_shard_completeness(tmp_path): "b": "model-00002-of-00002.safetensors"}})) assert xf._download_result_usable( snap2, repo_type = "model", allow_patterns = ["checkpoint-7/*"], ignore_patterns = None) is True - # A leftover checkpoint the request does NOT target (subfolder=unet) must not false-reject a complete - # in-scope download. + # A leftover checkpoint the request does not target (subfolder=unet) must not reject a complete in-scope download. snap3, blob3 = _mk_snapshot(tmp_path, "ckpt_leftover") (snap3 / "unet").mkdir() (snap3 / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob3) @@ -2765,9 +2487,8 @@ def test_post_download_checks_explicit_checkpoint_shard_completeness(tmp_path): (snap3 / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob3) # lone, but not read assert xf._download_result_usable( snap3, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is True - # A NESTED checkpoint the request explicitly targets (subfolder=foo/checkpoint-7 -> - # allow=['foo/checkpoint-7/*']) is read INTO at depth, so its lone shard must still be rejected -- - # the scope check matches the checkpoint dir at ANY literal leading segment, not just the first. + # A nested checkpoint the request explicitly targets (allow=['foo/checkpoint-7/*']) still rejects its lone shard + # (the scope check matches a checkpoint dir at any leading segment). snap4, blob4 = _mk_snapshot(tmp_path, "ckpt_nested") (snap4 / "foo" / "checkpoint-7").mkdir(parents = True) (snap4 / "foo" / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob4) @@ -2777,11 +2498,7 @@ def test_post_download_checks_explicit_checkpoint_shard_completeness(tmp_path): def test_post_download_accepts_exact_named_variant_shard_subset(tmp_path): - """A caller naming an EXACT variant shard (allow=['model.fp16-00001-of-00002.safetensors'] + - variant='fp16') asked for precisely that file; once present the result is accepted even though its - index / sibling shard is absent. The exact-name escape applies to the VARIANT branch too, not only - the plain one, so a satisfied exact variant request is not failed into a DownloadStallError - (#829 re-review).""" + """An exact variant shard request is accepted once present (index/sibling absent); the exact-name escape applies to the variant branch too.""" snap, blob = _mk_snapshot(tmp_path, "exact_var_shard") (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) assert xf._download_result_usable( @@ -2798,10 +2515,7 @@ def test_post_download_accepts_exact_named_variant_shard_subset(tmp_path): def test_post_download_rejects_patterned_incomplete_variant_shards(tmp_path): - """A GLOBBED variant request (allow=['*.safetensors'] + variant='fp16') whose partial kept only a - lone root variant shard without its index / remaining shards must be rejected too -- the - variant-shard completeness check applies to the patterned variant branch, not only allow=None (Codex - #829). A complete root variant shard set in scope is accepted (no false-reject).""" + """A globbed variant request with a lone root variant shard (no index) is rejected too; a complete in-scope set is accepted.""" snap, blob = _mk_snapshot(tmp_path, "pat_var_incomplete") (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) # lone shard, no index assert xf._download_result_usable( @@ -2818,9 +2532,7 @@ def test_post_download_rejects_patterned_incomplete_variant_shards(tmp_path): def test_post_download_applies_ignore_to_diffusers_components(tmp_path): - """An unpatterned diffusers warm that ignores a format must not be satisfied by a component weight in - that ignored format: only unet/*.bin present under ignore=['*.bin'] (safetensors requested) is - rejected, else the load fetches the missing safetensors components over Xet (Codex #829).""" + """A diffusers warm ignoring a format is not satisfied by a component weight in that ignored format (only unet/*.bin under ignore=['*.bin'] is rejected).""" snap, blob = _mk_snapshot(tmp_path, "diff_ignore") (snap / "model_index.json").write_text("{}") (snap / "unet").mkdir() @@ -2834,10 +2546,7 @@ def test_post_download_applies_ignore_to_diffusers_components(tmp_path): def test_post_download_rejects_index_only_sharded_masked_by_bin(tmp_path): - """A safetensors index present with NONE of its shards (an index-only partial), co-resident with a - complete pytorch_model.bin, must be rejected: transformers probes the safetensors index before the - bin, so the load would fetch the missing shards over Xet (Codex #829). The shard-completeness gate - fires on a present index even before any shard file exists.""" + """A safetensors index with none of its shards (index-only), co-resident with a complete bin, is rejected (the index is probed before the bin).""" snap, blob = _mk_snapshot(tmp_path, "index_only") (snap / "model.safetensors.index.json").write_text(json.dumps( {"weight_map": {"a": "model-00001-of-00002.safetensors", @@ -2851,10 +2560,7 @@ def test_post_download_rejects_index_only_sharded_masked_by_bin(tmp_path): def test_post_download_patterned_shard_check_honors_ignore(tmp_path): - """A patterned request that ignores safetensors (allow=['*'], ignore=['*.safetensors']) selects the - complete .bin; a co-resident incomplete safetensors shard set must NOT force-reject it -- the - patterned shard-completeness check applies the ignore filter, so a satisfied request is not failed - into a DownloadStallError (Codex #829).""" + """A patterned request ignoring safetensors selects the complete .bin; a co-resident incomplete ST set does not reject it (the check applies the ignore filter).""" snap, blob = _mk_snapshot(tmp_path, "pat_ignore") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) # incomplete ST (shard 2 absent) (snap / "model.safetensors.index.json").write_text(json.dumps( @@ -2866,9 +2572,7 @@ def test_post_download_patterned_shard_check_honors_ignore(tmp_path): def test_post_download_variant_root_shard_check_scoped_to_selection(tmp_path): - """A subfolder variant request (allow=['unet/*'] + variant) whose selected weight is complete must be - accepted even when a stale ROOT variant shard (out of scope) is co-resident -- the root variant-shard - check applies only when the request selects a root variant weight (Codex #829).""" + """A subfolder variant request with a complete selected weight is accepted even with a stale out-of-scope root variant shard co-resident.""" snap, blob = _mk_snapshot(tmp_path, "var_scope") (snap / "unet").mkdir() (snap / "unet" / "model.fp16.safetensors").symlink_to(blob) # complete in-scope variant @@ -2885,9 +2589,7 @@ def test_post_download_variant_root_shard_check_scoped_to_selection(tmp_path): def test_post_download_root_variant_weight_honors_ignore(tmp_path): - """An unpatterned variant load that ignores .bin must not be satisfied by a variant .bin: only - model.fp16.bin present under ignore=['*.bin'] is rejected, else the load fetches - model.fp16.safetensors over Xet (Codex #829).""" + """A variant load ignoring .bin is not satisfied by a variant .bin (only model.fp16.bin under ignore=['*.bin'] is rejected).""" snap, blob = _mk_snapshot(tmp_path, "var_ignore") (snap / "model.fp16.bin").symlink_to(blob) # only the ignored-format variant assert xf._download_result_usable( @@ -2901,27 +2603,21 @@ def test_post_download_root_variant_weight_honors_ignore(tmp_path): def test_post_download_variant_shard_check_honors_ignore(tmp_path): - """A variant load that ignores .bin must judge the variant shard set for the READ format only: a - complete model.fp16.safetensors co-resident with a stale IGNORED model.fp16-00001-of-00002.bin shard - (no index) is accepted, not force-rejected -- the variant shard-completeness check applies the ignore - filter, so a satisfied variant download is not failed into a DownloadStallError (Codex #829).""" + """A variant load ignoring .bin judges the variant shard set for the read format only: a complete ST variant beside a stale ignored .bin shard is accepted.""" snap, blob = _mk_snapshot(tmp_path, "var_shard_ignore") (snap / "model.fp16.safetensors").symlink_to(blob) # complete, the read format (snap / "model.fp16-00001-of-00002.bin").symlink_to(blob) # stale ignored .bin shard, no index assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"], variant = "fp16") is True - # Without the ignore the lone .bin variant shard IS an incomplete set (no index) -> rejected. + # Without the ignore, the complete safetensors variant is preferred and read. assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is True # the complete safetensors variant is preferred and read def test_post_download_rejects_variant_index_only_masked_by_bin(tmp_path): - """A VARIANT safetensors index present with NONE of its shards (an index-only partial), co-resident - with a complete variant pytorch_model.fp16.bin, must be rejected: transformers probes the variant - safetensors index before the variant bin, so the load would fetch the missing variant shards over Xet - (the variant analog of the canonical index-only case) (Codex #829).""" + """The variant analog of index-only: a variant ST index with none of its shards, beside a complete variant bin, is rejected.""" snap, blob = _mk_snapshot(tmp_path, "var_index_only") (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", @@ -2930,18 +2626,14 @@ def test_post_download_rejects_variant_index_only_masked_by_bin(tmp_path): assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is False - # The variant safetensors explicitly ignored -> load reads the complete variant bin -> usable. + # The variant safetensors ignored -> load reads the complete variant bin -> usable. assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.safetensors"], variant = "fp16") is True def test_post_download_rejects_incomplete_sharded_adapter(tmp_path): - """A PEFT adapter load (allow=['adapter_config.json', 'adapter_model*']) whose partial kept a sharded - adapter INDEX but is missing a listed adapter shard must be rejected, else the in-process load - finishes the missing adapter shard over Xet. The canonical/variant ROOT-model checks do not cover a - non-model 'adapter_model' index, so the selected-index check catches it (Codex #829). A complete - adapter shard set in scope is accepted (no false-reject).""" + """A PEFT adapter load whose partial has a sharded adapter index missing a shard is rejected (the selected-index check covers the non-model adapter index).""" snap, blob = _mk_snapshot(tmp_path, "adapter_incomplete") (snap / "adapter_config.json").write_text("{}") (snap / "adapter_model-00001-of-00002.safetensors").symlink_to(blob) # shard 1; shard 2 absent @@ -2951,16 +2643,14 @@ def test_post_download_rejects_incomplete_sharded_adapter(tmp_path): allow = ["adapter_config.json", "adapter_model*"] assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = allow, ignore_patterns = None) is False - # The missing adapter shard present -> complete set -> accepted (no false-reject). + # The missing adapter shard present -> complete set -> accepted. (snap / "adapter_model-00002-of-00002.safetensors").symlink_to(blob) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = allow, ignore_patterns = None) is True def test_post_download_rejects_incomplete_component_subfolder_shards(tmp_path): - """A subfolder-scoped request (allow=['unet/*']) whose selected component has a shard INDEX missing a - listed shard must be rejected -- the selected-index check covers component subfolders the root-model - checks do not (Codex #829). A complete component shard set in scope is accepted.""" + """A subfolder request (allow=['unet/*']) whose component has a shard index missing a shard is rejected (the selected-index check covers component subfolders).""" snap, blob = _mk_snapshot(tmp_path, "component_incomplete") (snap / "unet").mkdir() (snap / "unet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) @@ -2975,26 +2665,20 @@ def test_post_download_rejects_incomplete_component_subfolder_shards(tmp_path): def test_post_download_rejects_gguf_only_default_load(tmp_path): - """A DEFAULT (unpatterned) transformers warm reads model.safetensors / pytorch_model.bin, not a GGUF - file (only a GGUF-specific request does). A stale snapshot holding only model.Q4_K_M.gguf must be - rejected, else the in-process default load fetches the absent safetensors / bin over un-killable Xet - (Codex #829).""" + """A default warm reads safetensors/bin, not GGUF, so a snapshot holding only a .gguf is rejected.""" snap, blob = _mk_snapshot(tmp_path, "gguf_only") (snap / "model.Q4_K_M.gguf").symlink_to(blob) (snap / "config.json").write_text("{}") assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False - # The safetensors weight present -> the default warm accepts (no false-reject), even beside the gguf. + # The safetensors weight present -> the default warm accepts, even beside the gguf. (snap / "model.safetensors").symlink_to(blob) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True def test_post_download_rejects_adapter_variant_for_default_variant_load(tmp_path): - """An unpatterned variant warm reads the ROOT model variant (model.fp16.safetensors), not a PEFT - adapter variant. A stale snapshot holding only adapter_model.fp16.safetensors must be rejected, else - the in-process base-model variant load fetches the absent model.fp16.safetensors over un-killable Xet - (Codex #829). The base model variant present -> accepted (no false-reject).""" + """A variant warm reads the root model variant, not an adapter variant, so an adapter-variant-only snapshot is rejected; the base variant present is accepted.""" snap, blob = _mk_snapshot(tmp_path, "adapter_variant_only") (snap / "adapter_model.fp16.safetensors").symlink_to(blob) (snap / "adapter_config.json").write_text("{}") @@ -3009,10 +2693,7 @@ def test_post_download_rejects_adapter_variant_for_default_variant_load(tmp_path def test_post_download_accepts_complete_diffusers_variant(tmp_path): - """A diffusers pipeline variant warm's weights are COMPONENT-scoped (unet/....fp16.safetensors), not - root model..* files. A complete diffusers variant download must be accepted -- the root-only - variant presence check would false-reject it into a spurious DownloadStallError (Codex #829). A - pipeline holding only the NON-variant component weight does not satisfy a variant load.""" + """A diffusers variant warm's weights are component-scoped, so a complete diffusers variant download is accepted; a non-variant-only pipeline does not satisfy a variant load.""" snap, blob = _mk_snapshot(tmp_path, "diffusers_variant") (snap / "model_index.json").write_text("{}") (snap / "unet").mkdir() @@ -3031,10 +2712,7 @@ def test_post_download_accepts_complete_diffusers_variant(tmp_path): def test_post_download_rejects_incomplete_diffusers_component_shards_unpatterned(tmp_path): - """An UNPATTERNED diffusers pipeline warm reads component subfolders (unet/, vae/, ...). A component - shard INDEX listing a shard that is absent -- which the canonical ROOT-shard check does not cover -- - must be rejected, else the in-process pipeline load fetches the missing shard over un-killable Xet - (Codex #829). Both the plain and the variant component index are covered; a complete set is accepted.""" + """An unpatterned diffusers warm rejects a component shard index listing an absent shard; plain and variant indexes are covered, a complete set accepted.""" snap, blob = _mk_snapshot(tmp_path, "diffusers_comp_incomplete") (snap / "model_index.json").write_text("{}") (snap / "unet").mkdir() @@ -3048,7 +2726,7 @@ def test_post_download_rejects_incomplete_diffusers_component_shards_unpatterned (snap / "unet" / "diffusion_pytorch_model-00002-of-00002.safetensors").symlink_to(blob) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True - # The same for a VARIANT component index (variant='fp16', unpatterned). + # Same for a variant component index (variant='fp16', unpatterned). snapv, blobv = _mk_snapshot(tmp_path, "diffusers_comp_variant_incomplete") (snapv / "model_index.json").write_text("{}") (snapv / "unet").mkdir() @@ -3066,9 +2744,7 @@ def test_post_download_rejects_incomplete_diffusers_component_shards_unpatterned def test_post_download_single_safetensors_beats_stale_index(tmp_path): - """transformers probes single model.safetensors BEFORE model.safetensors.index.json, so a complete - single weight co-resident with a STALE incomplete index is usable and must not be looped into a - DownloadStallError (Codex #829). A stale index with NO single weight is still breakage.""" + """A complete single model.safetensors (probed before the index) beside a stale incomplete index is usable; a stale index with no single weight is breakage.""" snap, blob = _mk_snapshot(tmp_path, "single_beats_index") (snap / "config.json").write_text("{}") (snap / "model.safetensors").symlink_to(blob) @@ -3077,8 +2753,8 @@ def test_post_download_single_safetensors_beats_stale_index(tmp_path): "b": "model-00002-of-00002.safetensors"}})) # shards absent assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True - assert hcs.snapshot_dir_is_complete(snap) is True # the PRE gate agrees (offline warm short-circuit) - # No single weight, only the stale index -> the sharded-safetensors load would fetch missing shards. + assert hcs.snapshot_dir_is_complete(snap) is True # the PRE gate agrees + # No single weight, only the stale index -> incomplete. snap2, _ = _mk_snapshot(tmp_path, "stale_index_only") (snap2 / "config.json").write_text("{}") (snap2 / "model.safetensors.index.json").write_text(json.dumps( @@ -3089,10 +2765,7 @@ def test_post_download_single_safetensors_beats_stale_index(tmp_path): def test_post_download_rejects_noncanonical_root_weight_for_default_load(tmp_path): - """A DEFAULT load probes only the canonical model.safetensors / pytorch_model.bin (single or numbered - shard). A stale cache holding only a NON-canonical root weight (consolidated.safetensors) must be - rejected, else the default load fetches the absent canonical weight over un-killable Xet (Codex #829). - The canonical weight present -> accepted.""" + """A default load probes only canonical names, so a cache holding only a non-canonical root weight (consolidated.safetensors) is rejected; the canonical weight present is accepted.""" snap, blob = _mk_snapshot(tmp_path, "noncanonical") (snap / "config.json").write_text("{}") (snap / "consolidated.safetensors").symlink_to(blob) @@ -3104,10 +2777,7 @@ def test_post_download_rejects_noncanonical_root_weight_for_default_load(tmp_pat def test_diffusers_component_check_scoped_to_declared_components(tmp_path): - """The component shard check is scoped to the components model_index.json declares. A complete - pipeline (declared unet+vae present) co-resident with a STALE UNDECLARED subtree (a leftover - controlnet/ with an incomplete shard index the DiffusionPipeline load never reads) must still be - accepted (Codex #829); an incomplete DECLARED component is still rejected (hang protection kept).""" + """The component shard check is scoped to declared components: a stale undeclared subtree does not reject a complete pipeline; an incomplete declared component still does.""" snap, blob = _mk_snapshot(tmp_path, "declared_scope") (snap / "model_index.json").write_text(json.dumps( {"_class_name": "StableDiffusionPipeline", @@ -3123,7 +2793,7 @@ def test_diffusers_component_check_scoped_to_declared_components(tmp_path): (snap / "controlnet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True - # An incomplete DECLARED component (unet index missing a shard) is still caught. + # An incomplete declared component (unet index missing a shard) is still caught. snap2, blob2 = _mk_snapshot(tmp_path, "declared_incomplete") (snap2 / "model_index.json").write_text(json.dumps( {"_class_name": "P", "unet": ["diffusers", "UNet2DConditionModel"]})) @@ -3137,11 +2807,7 @@ def test_diffusers_component_check_scoped_to_declared_components(tmp_path): def test_post_download_variant_presence_requires_canonical_name(tmp_path): - """The unpatterned variant presence check counts only a CANONICAL model variant name a default - variant load reads (model..safetensors, model.-NNNNN-of-NNNNN.safetensors). A - non-canonical sidecar (consolidated.fp16.safetensors) or a non-transformers dot-infix shard name - (model-00001-of-00001.fp16.safetensors) must NOT satisfy the request, else the load fetches the - absent model.fp16.safetensors over un-killable Xet (Codex #829).""" + """The variant presence check counts only a canonical variant name; a non-canonical sidecar or dot-infix shard does not satisfy the request.""" snap, blob = _mk_snapshot(tmp_path, "var_noncanonical") (snap / "config.json").write_text("{}") (snap / "consolidated.fp16.safetensors").symlink_to(blob) @@ -3162,11 +2828,8 @@ def test_post_download_variant_presence_requires_canonical_name(tmp_path): def test_post_download_rejects_selected_shard_without_index(tmp_path): - """A SELECTED non-root numbered shard with NO index of the read format is an incomplete set the load - cannot enumerate (it needs the index to list the shards), so it is rejected and retried over HTTP, - else the adapter / component load fetches the index and remaining shards over un-killable Xet - (Codex #829). A complete indexed set is accepted.""" - # A sharded ADAPTER with a lone shard and no index. + """A selected non-root numbered shard with no index is an incomplete set the load cannot enumerate, so it is rejected; a complete indexed set is accepted.""" + # A sharded adapter with a lone shard and no index. snap, blob = _mk_snapshot(tmp_path, "adapter_lone_shard") (snap / "config.json").write_text("{}") (snap / "adapter_config.json").write_text("{}") @@ -3192,10 +2855,7 @@ def test_post_download_rejects_selected_shard_without_index(tmp_path): def test_post_download_diffusers_presence_scoped_to_declared(tmp_path): - """An UNPATTERNED diffusers pipeline warm counts a component weight as proof only for a DECLARED - component. A stale cache holding only an UNDECLARED leftover (controlnet/ not in model_index.json) - must be rejected, else the pipeline fetches the declared unet/vae weights in-process over Xet - (Codex #829). The declared components present -> accepted.""" + """A diffusers warm counts a component weight only for a declared component, so an undeclared-leftover-only cache is rejected; declared components present are accepted.""" snap, blob = _mk_snapshot(tmp_path, "diffusers_undeclared_only") (snap / "model_index.json").write_text(json.dumps( {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) @@ -3212,10 +2872,7 @@ def test_post_download_diffusers_presence_scoped_to_declared(tmp_path): def test_post_download_diffusers_variant_presence_scoped_to_declared(tmp_path): - """Variant twin of the declared-scope check: an UNPATTERNED diffusers VARIANT warm counts a component - variant weight as proof only for a DECLARED component. A stale cache holding only an UNDECLARED variant - leftover (controlnet/....fp16.safetensors not in model_index.json) must be rejected, else the pipeline - fetches the declared unet/vae variant weights in-process over un-killable Xet (Codex #829).""" + """Variant twin of the declared-scope check: a diffusers variant warm counts a component variant weight only for a declared component.""" snap, blob = _mk_snapshot(tmp_path, "diffusers_variant_undeclared_only") (snap / "model_index.json").write_text(json.dumps( {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) @@ -3234,11 +2891,7 @@ def test_post_download_diffusers_variant_presence_scoped_to_declared(tmp_path): def test_post_download_single_variant_beats_stale_variant_index(tmp_path): - """Variant twin of single-beats-index: transformers probes single model..safetensors BEFORE - model.safetensors.index..json, so a complete single variant weight co-resident with a STALE - incomplete variant index is usable and must not be looped into a DownloadStallError (Codex #829). Same - for a single .bin variant vs a stale .bin variant index; a stale variant index with NO single weight is - still breakage.""" + """Variant twin of single-beats-index: a complete single variant weight beside a stale variant index is usable (ST and bin); a stale index with no single weight is breakage.""" snap, blob = _mk_snapshot(tmp_path, "single_variant_beats_index") (snap / "config.json").write_text("{}") (snap / "model.fp16.safetensors").symlink_to(blob) @@ -3248,7 +2901,7 @@ def test_post_download_single_variant_beats_stale_variant_index(tmp_path): assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is True - # A single .bin variant co-resident with a stale .bin variant index (no safetensors) -> usable. + # A single .bin variant beside a stale .bin variant index (no ST) -> usable. snapb, blobb = _mk_snapshot(tmp_path, "single_bin_variant_beats_index") (snapb / "config.json").write_text("{}") (snapb / "pytorch_model.fp16.bin").symlink_to(blobb) @@ -3257,7 +2910,7 @@ def test_post_download_single_variant_beats_stale_variant_index(tmp_path): assert xf._download_result_usable( snapb, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is True - # A stale variant index with NO single variant weight -> the sharded load fetches missing shards. + # A stale variant index with no single variant weight -> incomplete. snap2, _ = _mk_snapshot(tmp_path, "stale_variant_index_only") (snap2 / "config.json").write_text("{}") (snap2 / "model.safetensors.index.fp16.json").write_text(json.dumps( @@ -3269,10 +2922,7 @@ def test_post_download_single_variant_beats_stale_variant_index(tmp_path): def test_post_download_diffusers_skips_root_model_shard_checks(tmp_path): - """A diffusers pipeline reads COMPONENT subfolders, not root model shards. A complete pipeline - (declared unet+vae present) co-resident with a STALE root model shard INDEX -- canonical or variant -- - must be ACCEPTED: the root-model shard check does not apply to a diffusers snapshot, else a valid - pipeline is looped into a DownloadStallError (Codex #829). Component completeness is still enforced.""" + """A diffusers pipeline reads component subfolders, so a stale root model shard index (canonical or variant) is accepted; component completeness is still enforced.""" # Plain: stale root model.safetensors.index.json alongside complete components. snap, blob = _mk_snapshot(tmp_path, "diffusers_stale_root_index_plain") (snap / "model_index.json").write_text(json.dumps( @@ -3298,7 +2948,7 @@ def test_post_download_diffusers_skips_root_model_shard_checks(tmp_path): assert xf._download_result_usable( snapv, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is True - # Hang protection kept: an incomplete DECLARED component is still rejected. + # An incomplete declared component is still rejected. (snapv / "unet" / "diffusion_pytorch_model.fp16.safetensors").unlink() (snapv / "unet" / "diffusion_pytorch_model.fp16-00001-of-00002.safetensors").symlink_to(blobv) (snapv / "unet" / "diffusion_pytorch_model.safetensors.index.fp16.json").write_text(json.dumps( @@ -3310,17 +2960,14 @@ def test_post_download_diffusers_skips_root_model_shard_checks(tmp_path): def test_post_download_out_of_scope_malformed_index_not_rejected(tmp_path): - """A malformed shard index the REQUEST does not select is not read by the load, so it must not - false-reject a complete in-scope download into a DownloadStallError (Codex #829). A base ['model*'] warm - with a complete model.safetensors and a co-resident stale MALFORMED adapter index is accepted; an - IN-scope malformed index (an adapter warm) is still breakage.""" + """A malformed shard index the request does not select is not read, so it does not reject a complete in-scope download; an in-scope malformed index is breakage.""" snap, blob = _mk_snapshot(tmp_path, "malformed_out_of_scope") (snap / "config.json").write_text("{}") (snap / "model.safetensors").symlink_to(blob) (snap / "adapter_model.safetensors.index.json").write_text("{ not valid json") # malformed, unselected assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = ["model*"], ignore_patterns = None) is True - # An IN-scope malformed index (adapter warm selects adapter_model*) is still rejected. + # An in-scope malformed index (adapter warm selects adapter_model*) is still rejected. snap2, blob2 = _mk_snapshot(tmp_path, "malformed_in_scope") (snap2 / "config.json").write_text("{}") (snap2 / "adapter_config.json").write_text("{}") @@ -3332,9 +2979,7 @@ def test_post_download_out_of_scope_malformed_index_not_rejected(tmp_path): def test_selected_readable_weight_complete_entry_point(tmp_path): - """The weight-bearing acceptance check funnels through one helper enforcing two invariants: - (A) a readable weight is present (ignore + scope applied), (B) its in-scope shard set is complete. - Directly exercise the entry point for a present+complete, an absent, and an incomplete-shard case.""" + """The acceptance helper enforces (A) a readable weight present, (B) its shard set complete: exercise present+complete, absent, and incomplete-shard cases.""" # Present + complete single weight -> True. snap, blob = _mk_snapshot(tmp_path, "srwc_ok") (snap / "model.safetensors").symlink_to(blob) @@ -3363,9 +3008,7 @@ def test_post_download_accepts_dataset_without_weight(tmp_path): def test_post_download_accepts_either_format_single_present(tmp_path): - """An either-format named request (['pytorch_model.bin','model.safetensors']) against a repo that - ships only safetensors: the finished download has a weight, so it is accepted -- not re-looped - for the absent .bin the repo simply does not publish.""" + """An either-format named request against a safetensors-only repo is accepted (not retried for the .bin the repo does not publish).""" snap, blob = _mk_snapshot(tmp_path, "either") (snap / "model.safetensors").symlink_to(blob) assert xf._download_result_usable( @@ -3374,8 +3017,7 @@ def test_post_download_accepts_either_format_single_present(tmp_path): def test_pre_download_skips_intact_tokenizer_only(tmp_path): - """A tokenizer-only (weightless) warm short-circuits offline: an intact requested subset is - enough, no weight required.""" + """A tokenizer-only (weightless) warm short-circuits offline: an intact requested subset is enough, no weight required.""" snap, _ = _mk_snapshot(tmp_path, "tok") (snap / "tokenizer.json").write_text("{}") (snap / "config.json").write_text("{}") @@ -3385,10 +3027,7 @@ def test_pre_download_skips_intact_tokenizer_only(tmp_path): def test_pre_download_partial_ignore_does_not_skip_config_only(tmp_path): - """Over-accept guard (safety reviewer finding): an ignore-only request stripping only SOME weight - formats (ignore=['*.safetensors','*.bin']) on a config-only cache must NOT skip the child -- a - repo whose surviving weight is e.g. model.gguf / model.fp16.safetensors / a .pt checkpoint would - otherwise be fetched in-process over un-killable Xet (the hang).""" + """An ignore-only request stripping only some weight formats on a config-only cache must not skip the child (a surviving weight format would hang over Xet).""" snap, _ = _mk_snapshot(tmp_path, "cfgign") (snap / "config.json").write_text("{}") assert xf._cache_can_skip_download( @@ -3396,14 +3035,9 @@ def test_pre_download_partial_ignore_does_not_skip_config_only(tmp_path): ignore_patterns = ["*.safetensors", "*.bin"]) is False -# --------------------------------------------------------------------------- # Review-round regression guards (10-reviewer findings) -# --------------------------------------------------------------------------- - def test_gate_rejects_malformed_shard_index(tmp_path): - """Finding 2 (over-accept): a truncated / non-dict / empty weight-shard index must NOT read as - complete. _weight_shard_index_complete is fail-CLOSED so the fast path defers a malformed index - to the watched child rather than skipping it and failing the in-process load on the bad index.""" + """A truncated / non-dict / empty weight-shard index does not read as complete (_weight_shard_index_complete is fail-closed).""" snap, blob = _mk_snapshot(tmp_path, "malidx") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) (snap / "model.safetensors.index.json").write_text("{not valid json") @@ -3413,34 +3047,30 @@ def test_gate_rejects_malformed_shard_index(tmp_path): (snap2 / "model-00001-of-00002.safetensors").symlink_to(blob2) (snap2 / "model.safetensors.index.json").write_text(json.dumps({"weight_map": {}})) assert hcs.snapshot_dir_is_complete(snap2) is False - # weight_map present but not a dict. + # weight_map not a dict. snap3, blob3 = _mk_snapshot(tmp_path, "listidx") (snap3 / "model.safetensors.index.json").write_text(json.dumps({"weight_map": ["a", "b"]})) assert hcs._weight_shard_index_complete(snap3 / "model.safetensors.index.json") is False def test_shard_index_rejects_unsafe_path_refs(tmp_path): - """A weight-shard index is attacker-influenced (weight_map from a downloaded repo). An absolute, - Windows drive-letter, UNC, or parent-escaping shard value must be rejected so ``base / shard`` cannot - resolve to an existing file OUTSIDE the snapshot and read as "present" -- on Windows ``base / 'C:\\x'`` - escapes, which a startswith(('/', '\\\\')) check misses (Gemini #829). Both the completeness check and - the shard-path enumerator reject these, judged under POSIX and Windows semantics on any OS.""" - # Unit: the shared helper flags every escape variant and keeps legit relative names. + """An attacker-influenced shard value (absolute, drive-letter, UNC, parent-escaping) is rejected so ``base / shard`` cannot probe outside the snapshot.""" + # Unit: the helper flags every escape variant and keeps legit relative names. for bad in ["/etc/passwd", r"C:\evil.safetensors", "C:evil.safetensors", r"\\srv\share\x", "../../x.safetensors", r"..\x.safetensors", "a/../../b"]: assert hcs._is_unsafe_shard_ref(bad) is True, bad for ok in ["model-00001-of-00002.safetensors", "unet/diffusion_pytorch_model.safetensors", "model.fp16.safetensors"]: assert hcs._is_unsafe_shard_ref(ok) is False, ok - # A crafted index listing a drive-letter shard is not "complete" (never probes outside the snapshot). + # A crafted index listing a drive-letter shard is not "complete". snap, blob = _mk_snapshot(tmp_path, "unsafe_shard_idx") (snap / "model.safetensors.index.json").write_text(json.dumps( {"weight_map": {"a": r"C:\Windows\System32\x.safetensors", "b": "model-00002-of-00002.safetensors"}})) assert hcs._weight_shard_index_complete(snap / "model.safetensors.index.json") is False - # The enumerator returns None (defer to the child) rather than a path that escapes the snapshot. + # The enumerator returns None (defer) rather than a path escaping the snapshot. assert hcs._index_shard_rel_paths(snap / "model.safetensors.index.json", "") is None - # A well-formed relative index still enumerates + validates normally. + # A well-formed relative index still enumerates + validates. snap2, blob2 = _mk_snapshot(tmp_path, "safe_shard_idx") (snap2 / "model-00001-of-00002.safetensors").symlink_to(blob2) (snap2 / "model-00002-of-00002.safetensors").symlink_to(blob2) @@ -3453,11 +3083,7 @@ def test_shard_index_rejects_unsafe_path_refs(tmp_path): def test_malformed_index_scope_honors_ignored_format(tmp_path): - """A malformed shard index is judged by the WEIGHT the load reads (a representative shard of the - index's base + format), not the .json filename. So a stale/truncated index for an IGNORED format - (a *.bin index under ignore=['*.bin']) is skipped -- the load reads safetensors and never touches it, - so a complete safetensors download must not be looped into a DownloadStallError (Codex #829). A - malformed index of the READ format is still breakage.""" + """A malformed index is judged by the weight the load reads, so a malformed index for an ignored format is skipped; a malformed index of the read format is breakage.""" # Patterned subfolder warm reading safetensors: a co-resident malformed bin index is ignored. snap, blob = _mk_snapshot(tmp_path, "malformed_ignored_bin_idx") (snap / "unet").mkdir() @@ -3473,8 +3099,7 @@ def test_malformed_index_scope_honors_ignored_format(tmp_path): (snap2 / "unet" / "diffusion_pytorch_model.safetensors.index.json").write_text("{ truncated") assert xf._download_result_usable( snap2, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is False - # Diffusers pipeline: a malformed component bin index under ignore=['*.bin'] does not reject a - # complete safetensors pipeline. + # Diffusers: a malformed component bin index under ignore=['*.bin'] does not reject a complete ST pipeline. snap3, blob3 = _mk_snapshot(tmp_path, "malformed_diffusers_bin_idx") (snap3 / "model_index.json").write_text(json.dumps( {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) @@ -3487,24 +3112,21 @@ def test_malformed_index_scope_honors_ignored_format(tmp_path): def test_gate_ignored_canonical_weight_does_not_prove_complete(tmp_path): - """Finding 3 (over-accept): a stale canonical weight whose FORMAT the request ignores must not - count as proof of completeness. ignore=['*.bin'] with only a pytorch_model.bin on disk (no - safetensors) defers to the child, so a use_safetensors load cannot silently fetch over Xet.""" + """A canonical weight whose format the request ignores does not prove completeness (a bin-only cache under ignore=['*.bin'] defers to the child).""" snap, blob = _mk_snapshot(tmp_path, "ignbin") (snap / "config.json").write_text("{}") (snap / "pytorch_model.bin").symlink_to(blob) assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.bin"]) is False # Without the ignore, the present .bin is what a default load reads -> complete. assert hcs.snapshot_dir_is_complete(snap) is True - # A .bin shard index is also discarded when *.bin is ignored (its .json sidecar would slip the - # raw name filter, but the format probe catches it). + # A .bin shard index is also discarded when *.bin is ignored (the format probe catches its .json sidecar). snap2, blob2 = _mk_snapshot(tmp_path, "ignbinshard") (snap2 / "pytorch_model-00001-of-00001.bin").symlink_to(blob2) (snap2 / "pytorch_model.bin.index.json").write_text( json.dumps({"weight_map": {"a": "pytorch_model-00001-of-00001.bin"}})) assert hcs.snapshot_dir_is_complete(snap2, ignore_patterns = ["*.bin"]) is False assert hcs.snapshot_dir_is_complete(snap2) is True - # A safetensors warm survives an *.bin ignore (the common bare from_pretrained case). + # A safetensors warm survives an *.bin ignore. snap3, blob3 = _mk_snapshot(tmp_path, "stignbin") (snap3 / "config.json").write_text("{}") (snap3 / "model.safetensors").symlink_to(blob3) @@ -3512,10 +3134,7 @@ def test_gate_ignored_canonical_weight_does_not_prove_complete(tmp_path): def test_post_download_accepts_weightless_patterned_result(tmp_path): - """Finding 1 (over-reject): a genuinely weightless PATTERNED result (e.g. allow=['tokenizer*']) - must be accepted post-download -- the caller scoped it, so 'no weight' is intended, not a stale - config-only snapshot. Rejecting it would loop into a spurious DownloadStallError on a good - download. The no-weight rejection stays in force for an UNPATTERNED model warm.""" + """A genuinely weightless patterned result (allow=['tokenizer*']) is accepted post-download; the no-weight rejection stays for an unpatterned model warm.""" snap, _ = _mk_snapshot(tmp_path, "tokglob") (snap / "tokenizer.json").write_text("{}") (snap / "tokenizer_config.json").write_text("{}") @@ -3529,11 +3148,7 @@ def test_post_download_accepts_weightless_patterned_result(tmp_path): def test_gate_rejects_variant_only_shard_index(tmp_path): - """codex :269 (over-accept): a variant-only shard index (model.safetensors.index.fp16.json) must - NOT satisfy the canonical allow=None fast path -- snapshot_dir_is_complete is variant-blind (a - default load probes the canonical index whose weights are absent). Only a canonical (non-variant) - index counts here; a variant REQUEST is deferred one level up in _cache_can_skip_download (see - test_pre_download_defers_variant_on_canonical_cache).""" + """A variant-only shard index does not satisfy the canonical allow=None fast path (snapshot_dir_is_complete is variant-blind); a canonical index does.""" snap, blob = _mk_snapshot(tmp_path, "variant") (snap / "config.json").write_text("{}") (snap / "model-00001-of-00001.fp16.safetensors").symlink_to(blob) @@ -3548,10 +3163,7 @@ def test_gate_rejects_variant_only_shard_index(tmp_path): def test_generic_hub_http_error_type_preserved_but_status_drives_retry(): - """codex :499: a deterministic 4xx surfaced as a bare HfHubHTTPError keeps its TYPE across the - spawn boundary (so `except HfHubHTTPError` still works) WITHOUT joining the retry-deterministic - name shortcut -- a transient 5xx bare HfHubHTTPError must still retry over HTTP via its status - code.""" + """A bare HfHubHTTPError keeps its type across the spawn boundary while retry stays status-driven: a 5xx retries, a 4xx does not.""" assert "HfHubHTTPError" not in xf._DETERMINISTIC_ERROR_NAMES # status-driven, not name-driven cls = xf._resolve_exception_class("HfHubHTTPError") assert cls is not None and issubclass(cls, BaseException) @@ -3567,9 +3179,7 @@ def __init__(self, code): self.status_code = code def test_hfvalidationerror_type_preserved_across_spawn(): - """Finding 4: a malformed repo id fails identically over either transport, so HFValidationError is - deterministic (not retried) and its TYPE is reconstructed across the spawn boundary instead of - degrading to a generic RuntimeError.""" + """A malformed repo id fails identically over either transport, so HFValidationError is deterministic and its type is reconstructed across the spawn boundary.""" assert "HFValidationError" in xf._DETERMINISTIC_ERROR_NAMES cls = xf._resolve_exception_class("HFValidationError") assert cls is not None and issubclass(cls, BaseException) @@ -3579,10 +3189,7 @@ def test_hfvalidationerror_type_preserved_across_spawn(): def test_oserror_subclass_type_preserved_across_spawn(): - """A deterministic builtin OSError subclass (PermissionError from an unwritable cache, - FileNotFoundError, ...) keeps its TYPE across the spawn boundary so a caller's `except OSError` / - `except PermissionError` still fires instead of seeing a generic RuntimeError. Non-OSError builtins - are not spuriously resolved (they fall through to the Hub-name lookup / None).""" + """A builtin OSError subclass keeps its type across the spawn boundary; a non-OSError builtin is not spuriously resolved.""" for name in ("PermissionError", "FileNotFoundError", "IsADirectoryError", "NotADirectoryError"): cls = xf._resolve_exception_class(name) assert cls is not None and issubclass(cls, OSError) and cls.__name__ == name @@ -3590,13 +3197,12 @@ def test_oserror_subclass_type_preserved_across_spawn(): perm = xf._instantiate_preserving_type(xf._resolve_exception_class("PermissionError"), "denied") assert isinstance(perm, PermissionError) assert xf._is_retryable_download_error(perm) is False - # An unrelated builtin (not an OSError subclass, not a Hub error name) is not resolved here. + # An unrelated builtin (not OSError, not a Hub error name) is not resolved. assert xf._resolve_exception_class("ValueError") is None def test_weight_pattern_selector_handles_globs(tmp_path): - """The weight-pattern selector reads tokenizer / config / json globs as weightless (keeps their - offline short-circuit) but classifies standard weight names and ? / [] globs as weight-bearing.""" + """The selector reads tokenizer/config/json globs as weightless but standard weight names and ?/[] globs as weight-bearing.""" weightless = ["tokenizer*", "*.json", "config.json", "tokenizer.model", "*.txt"] weighty = [ "model.safetensors", "*.safetensors", "model.?afetensors", "model.[sp]afetensors", @@ -3609,8 +3215,7 @@ def test_weight_pattern_selector_handles_globs(tmp_path): def test_post_download_rejects_config_only_for_explicit_weight_pattern(tmp_path): - """An explicit weight request (allow=['model.safetensors']) returning only config.json is a stale - config-only snapshot: reject and retry over HTTP. A weightless patterned request stays accepted.""" + """An explicit weight request returning only config.json is a stale config-only snapshot: reject and retry.""" snap, _ = _mk_snapshot(tmp_path, "patcfg") (snap / "config.json").write_text("{}") assert xf._download_result_usable( @@ -3621,10 +3226,7 @@ def test_post_download_rejects_config_only_for_explicit_weight_pattern(tmp_path) def test_post_download_rejects_incomplete_canonical_root_shards(tmp_path): - """An interrupted canonical sharded warm (loose model-00001-of-00002.safetensors, no index) has a - loadable file but a default load cannot read it and would fetch the rest over un-killable Xet, so - it is rejected. A complete sharded set is accepted; a variant-only shard layout does not satisfy a - default (no-variant) load, which reads only canonical names.""" + """An interrupted canonical sharded warm (loose shard, no index) is rejected; a complete set is accepted; a variant-only layout does not satisfy a default load.""" snap, blob = _mk_snapshot(tmp_path, "incshard") (snap / "config.json").write_text("{}") (snap / "model-00001-of-00002.safetensors").symlink_to(blob) @@ -3637,8 +3239,7 @@ def test_post_download_rejects_incomplete_canonical_root_shards(tmp_path): "b": "model-00002-of-00002.safetensors"}})) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True - # A variant-named shard is NOT a canonical weight a default load reads, so a variant-only cache is - # rejected (the default load would fetch the absent canonical model.safetensors over Xet). + # A variant-named shard is not a canonical weight a default load reads, so a variant-only cache is rejected. vsnap, vblob = _mk_snapshot(tmp_path, "vshard") (vsnap / "config.json").write_text("{}") (vsnap / "model-00001-of-00001.fp16.safetensors").symlink_to(vblob) @@ -3647,8 +3248,7 @@ def test_post_download_rejects_incomplete_canonical_root_shards(tmp_path): def test_local_token_not_found_error_type_preserved(): - """A missing required token fails identically over either transport, so LocalTokenNotFoundError is - deterministic and its type is reconstructed across the spawn boundary.""" + """A missing required token fails identically over either transport, so LocalTokenNotFoundError is deterministic and type-preserved.""" assert "LocalTokenNotFoundError" in xf._DETERMINISTIC_ERROR_NAMES cls = xf._resolve_exception_class("LocalTokenNotFoundError") assert cls is not None and issubclass(cls, BaseException) @@ -3657,8 +3257,7 @@ def test_local_token_not_found_error_type_preserved(): def test_metadata_directory_pattern_is_weightless(tmp_path): - """A trailing-slash metadata dir pattern (allow=['tokenizer/']) reads weightless, so a complete - tokenizer-only download is accepted. Component / checkpoint dir patterns stay weight-bearing.""" + """A trailing-slash metadata dir pattern (allow=['tokenizer/']) reads weightless; component/checkpoint dir patterns stay weight-bearing.""" assert hcs.request_can_include_weights(["tokenizer/"], None) is False assert hcs.request_can_include_weights(["processor/"], None) is False assert hcs.request_can_include_weights(["unet/"], None) is True @@ -3671,9 +3270,7 @@ def test_metadata_directory_pattern_is_weightless(tmp_path): def test_metadata_directory_glob_is_weightless(tmp_path): - """A metadata-dir GLOB (allow=['tokenizer/*'], 'processor/*.json') reads weightless like its - trailing-slash form, so a complete tokenizer-only download is accepted instead of looped into a - DownloadStallError. A component / checkpoint dir glob stays weight-bearing.""" + """A metadata-dir glob (allow=['tokenizer/*']) reads weightless like its trailing-slash form; a component/checkpoint dir glob stays weight-bearing.""" assert hcs.request_can_include_weights(["tokenizer/*"], None) is False assert hcs.request_can_include_weights(["tokenizer/*.json"], None) is False assert hcs.request_can_include_weights(["processor/*"], None) is False @@ -3687,20 +3284,16 @@ def test_metadata_directory_glob_is_weightless(tmp_path): def test_allow_star_with_all_weights_ignored_is_weightless(tmp_path): - """An allow that the ignore filter strips of every weight reads weightless, so a complete config-only - download is accepted, not looped into a DownloadStallError. This holds for a ROOT allow (allow=['*']) - AND a subdir-scoped allow (allow=['unet/*']) -- a subdir warm that ignores every weight suffix selects - only that subdir's metadata (Codex #829). A subdir allow whose weight suffixes SURVIVE the ignore - stays weight-bearing, as does a root allow whose weights survive.""" + """An allow the ignore filter strips of every weight reads weightless (root allow=['*'] and subdir allow=['unet/*']); surviving weight suffixes stay weight-bearing.""" all_weight_ignores = [ "*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", "*.ckpt", "*.onnx", "*.msgpack", "*.h5", "*.pdparams", ] assert hcs.request_can_include_weights(["*"], all_weight_ignores) is False assert hcs.request_can_include_weights(["*"], None) is True - # A subdir allow that ignores every weight suffix is weightless too (only unet/ metadata selected)... + # A subdir allow that ignores every weight suffix is weightless too... assert hcs.request_can_include_weights(["unet/*"], all_weight_ignores) is False - # ...but one whose weight suffixes survive the ignore stays weight-bearing. + # ...but one whose weight suffixes survive stays weight-bearing. assert hcs.request_can_include_weights(["unet/*"], ["*.bin"]) is True assert hcs.request_can_include_weights(["*.safetensors"], ["*.bin"]) is True snap, _ = _mk_snapshot(tmp_path, "cfgonly") @@ -3711,9 +3304,7 @@ def test_allow_star_with_all_weights_ignored_is_weightless(tmp_path): def test_post_download_rejects_checkpoint_only_root_model(tmp_path): - """A stale snapshot whose only weight is under checkpoint-7/ is rejected for an unpatterned root - warm -- a default from_pretrained ignores checkpoint-*/ and would fetch the missing root weights - over un-killable Xet. The same checkpoint is accepted when explicitly scoped.""" + """A snapshot whose only weight is under checkpoint-7/ is rejected for an unpatterned root warm, but accepted when explicitly scoped.""" snap, blob = _mk_snapshot(tmp_path, "ckonly") (snap / "config.json").write_text("{}") (snap / "checkpoint-7").mkdir() @@ -3729,9 +3320,7 @@ def test_post_download_rejects_checkpoint_only_root_model(tmp_path): (dsnap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(dblob) assert xf._download_result_usable( dsnap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True - # A diffusers snapshot whose ONLY weight is under checkpoint-N/ (a training artifact, not a pipeline - # component) is rejected: DiffusionPipeline reads component subfolders (unet/vae/...), so the load - # would fetch the missing components over Xet (Codex #829). + # A diffusers snapshot whose only weight is under checkpoint-N/ (not a pipeline component) is rejected. dck, dckb = _mk_snapshot(tmp_path, "diff_ckpt") (dck / "model_index.json").write_text("{}") (dck / "checkpoint-7").mkdir() @@ -3741,32 +3330,25 @@ def test_post_download_rejects_checkpoint_only_root_model(tmp_path): def test_post_download_rejects_adapter_only_for_default_load(tmp_path): - """A DEFAULT (unpatterned) model warm reads the base model.safetensors / pytorch_model.bin, not a PEFT - adapter. A stale snapshot holding only adapter_model.safetensors must be rejected, else the in-process - base load fetches the absent base weight over un-killable Xet (Codex #829). An adapter-scoped request - (allow=['adapter_model*']) is unaffected: it reads the adapter and still accepts it.""" + """A default warm reads the base weight, not an adapter, so an adapter-only snapshot is rejected; an adapter-scoped request still accepts it.""" snap, blob = _mk_snapshot(tmp_path, "adapter_only_default") (snap / "adapter_model.safetensors").symlink_to(blob) (snap / "adapter_config.json").write_text("{}") (snap / "config.json").write_text("{}") assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False - # An ADAPTER load (patterned) reads the adapter and accepts it (no regression to the PEFT path). + # An adapter load (patterned) reads the adapter and accepts it. assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], ignore_patterns = None) is True - # The base weight present -> the default warm accepts (no false-reject), even beside the adapter. + # The base weight present -> the default warm accepts, even beside the adapter. (snap / "model.safetensors").symlink_to(blob) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True def test_post_download_variant_root_check_ignores_adapter_index(tmp_path): - """An unpatterned variant load reads the ROOT model variant, not a PEFT adapter. A complete - model.fp16.safetensors co-resident with a STALE, incomplete adapter_model.safetensors.index.fp16.json - must still be accepted -- the root variant-shard check is restricted to model / pytorch_model variant - names, so the adapter index's incompleteness does not force a spurious DownloadStallError (Codex #829). - A genuinely incomplete ROOT model variant index is still rejected.""" + """A variant load reads the root model variant, so a complete model.fp16 beside a stale adapter variant index is accepted; an incomplete root variant index is still rejected.""" snap, blob = _mk_snapshot(tmp_path, "var_adapter_idx") (snap / "model.fp16.safetensors").symlink_to(blob) # complete root model variant (the read weight) (snap / "adapter_model.safetensors.index.fp16.json").write_text(json.dumps( @@ -3775,7 +3357,7 @@ def test_post_download_variant_root_check_ignores_adapter_index(tmp_path): assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = "fp16") is True - # An incomplete ROOT model variant index is still caught (the restriction did not disable the check). + # An incomplete root model variant index is still caught. snap2, blob2 = _mk_snapshot(tmp_path, "var_root_idx_incomplete") (snap2 / "model.safetensors.index.fp16.json").write_text(json.dumps( {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", @@ -3787,10 +3369,7 @@ def test_post_download_variant_root_check_ignores_adapter_index(tmp_path): def test_post_download_rejects_variant_only_root_for_default_load(tmp_path): - """A DEFAULT (no-variant) load reads the canonical model.safetensors / pytorch_model.bin, NOT a - variant-named model.fp16.safetensors. A stale snapshot holding only the variant weight must be - rejected, else the in-process default load fetches the absent canonical weight over un-killable Xet - (Codex #829). A single-file or sharded variant name is excluded; canonical names still pass.""" + """A default (no-variant) load reads the canonical name, not a variant name, so a variant-only snapshot is rejected; canonical names still pass.""" snap, blob = _mk_snapshot(tmp_path, "var_only") (snap / "model.fp16.safetensors").symlink_to(blob) # variant-named only (snap / "config.json").write_text("{}") @@ -3802,14 +3381,14 @@ def test_post_download_rejects_variant_only_root_for_default_load(tmp_path): (snap_sh / "config.json").write_text("{}") assert xf._download_result_usable( snap_sh, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False - # The canonical weight present -> accepted (no false-reject), even beside the variant. + # The canonical weight present -> accepted, even beside the variant. (snap / "model.safetensors").symlink_to(blob) assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True - # A variant load (variant='fp16') DOES read the variant weight -> accepted. + # A variant load DOES read the variant weight. assert xf._download_result_usable( snap_sh, repo_type = "model", allow_patterns = None, ignore_patterns = None, - variant = "fp16") is False # sharded variant with no index is still an incomplete set + variant = "fp16") is False # sharded variant with no index is still incomplete snap_v, blob_v = _mk_snapshot(tmp_path, "var_single") (snap_v / "model.fp16.safetensors").symlink_to(blob_v) (snap_v / "config.json").write_text("{}") @@ -3819,23 +3398,20 @@ def test_post_download_rejects_variant_only_root_for_default_load(tmp_path): def test_post_download_variant_either_format_exact_alternatives(tmp_path): - """An exact request listing both variant formats (allow=['model.fp16.safetensors', - 'pytorch_model.fp16.bin']) is an ALTERNATIVE over the repo, like the canonical either-format pair: - a repo publishing only the safetensors variant is complete and must not be failed into a - DownloadStallError (Codex #829). A distinct-variant / base+adapter request still requires each.""" + """An exact request listing both variant formats is an alternative (like the canonical either-format pair): a safetensors-variant-only repo is complete; distinct/base+adapter still requires each.""" snap, blob = _mk_snapshot(tmp_path, "var_either") (snap / "model.fp16.safetensors").symlink_to(blob) # only the safetensors variant present assert xf._download_result_usable( snap, repo_type = "model", allow_patterns = ["model.fp16.safetensors", "pytorch_model.fp16.bin"], ignore_patterns = None, variant = "fp16") is True - # The canonical either-format pair keeps working (regression). + # The canonical either-format pair keeps working. snap_c, blob_c = _mk_snapshot(tmp_path, "canon_either") (snap_c / "pytorch_model.bin").symlink_to(blob_c) assert xf._download_result_usable( snap_c, repo_type = "model", allow_patterns = ["model.safetensors", "pytorch_model.bin"], ignore_patterns = None) is True - # Base AND adapter are distinct groups: the adapter present but base absent -> rejected. + # Base and adapter are distinct groups: adapter present, base absent -> rejected. snap_d, blob_d = _mk_snapshot(tmp_path, "base_and_adapter") (snap_d / "adapter_model.safetensors").symlink_to(blob_d) assert xf._download_result_usable( @@ -3845,8 +3421,7 @@ def test_post_download_variant_either_format_exact_alternatives(tmp_path): def test_post_download_validates_weightless_named_subset(tmp_path): - """An exact weightless request (allow=['tokenizer.json'], or a dataset file) returning a stale - snapshot missing the named file is rejected and retried. A glob allow list stays lenient.""" + """An exact weightless request missing its named file is rejected and retried; a glob allow list stays lenient.""" snap, _ = _mk_snapshot(tmp_path, "noname") (snap / "config.json").write_text("{}") assert xf._download_result_usable( @@ -3862,10 +3437,7 @@ def test_post_download_validates_weightless_named_subset(tmp_path): def test_post_download_rejects_missing_exact_weight_request(tmp_path): - """An exact weight request whose file is missing is rejected even when a different weight is present: - allow=['adapter_model.safetensors'] is NOT satisfied by a stale base model.safetensors, and - ['model.safetensors','adapter_model.safetensors'] needs both. The either-format - ['model.safetensors','pytorch_model.bin'] pair stays satisfied by one (equivalence).""" + """An exact weight request whose file is missing is rejected even with a different weight present; the either-format pair stays satisfied by one.""" base, blob = _mk_snapshot(tmp_path, "baseonly") (base / "model.safetensors").symlink_to(blob) assert xf._download_result_usable( @@ -3885,8 +3457,7 @@ def test_post_download_rejects_missing_exact_weight_request(tmp_path): def test_dataset_unpatterned_or_glob_partial_does_not_skip_child(tmp_path): - """A dataset/space snapshot whose completeness cannot be proven from local files (allow=None or a - glob) must defer to the watched child. An intact exact-named subset still short-circuits.""" + """A dataset/space snapshot whose completeness cannot be proven locally (allow=None or a glob) defers to the child; an intact exact-named subset short-circuits.""" snap, _ = _mk_snapshot(tmp_path, "dspart") (snap / "README.md").write_text("partial") assert xf._cache_can_skip_download( @@ -3898,9 +3469,7 @@ def test_dataset_unpatterned_or_glob_partial_does_not_skip_child(tmp_path): def test_http_prep_scopes_blob_cleanup_to_owned_partials(tmp_path): - """HTTP prep must purge only the stalled child's OWN partials, never a concurrent same-repo - sibling's blob (multi-rank). With ownership known, a sibling's aged partial and dangling link are - spared; with ownership unknown (None), the coarser mtime guard purges both.""" + """HTTP prep purges only the child's own partials: with ownership known a sibling's aged partial/link is spared; with ownership None the mtime guard purges both.""" repo = "ztest/concurrent-blobs" repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" blobs = repo_dir / "blobs" @@ -3931,7 +3500,7 @@ def _seed(): assert not (snap / "our.safetensors").is_symlink() assert (snap / "sib.safetensors").is_symlink(), "sibling's dangling link must be spared" - # No ownership info -> coarse mtime guard purges both aged partials (prior behavior). + # No ownership info -> coarse mtime guard purges both aged partials. owned, sibling = _seed() _REAL_DEFAULT_PREPARE( "model", repo, cache_dir = str(tmp_path), active_grace = 180, owned_incomplete_blobs = None) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 6c0f1d93e..7f9b317ec 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -134,21 +134,16 @@ def has_429_exact_full_read(log_dir: str | Path) -> str: # Everything below this point is GPU-only. Use a flag to gate it. _SKIP_GPU_INIT = True else: - # Opt-in lightweight import. A short-lived helper subprocess that only needs - # the cache/download utilities (e.g. the unsloth_zoo.hf_xet_fallback download - # child) can set UNSLOTH_ZOO_DISABLE_GPU_INIT=1 to skip the heavy torch / - # transformers / device init it never uses. Off by default, so normal - # CUDA/CPU runs are byte-for-byte unchanged; the parent only sets it around - # spawning that child, never for a training/inference process. The - # unconditional HF cache redirect above still runs, so the child writes to the - # same cache as the parent. + # Opt-in: a download-only helper child (e.g. hf_xet_fallback) sets + # UNSLOTH_ZOO_DISABLE_GPU_INIT=1 to skip the heavy torch/transformers/device + # init it never uses. Off by default, so normal CUDA/CPU runs are unchanged. + # The HF cache redirect above still runs, so the child shares the parent's cache. _SKIP_GPU_INIT = os.environ.get("UNSLOTH_ZOO_DISABLE_GPU_INIT", "0") == "1" del _is_mlx_only, is_mlx_available -# Inject triton & bitsandbytes stubs whenever GPU init is skipped: Apple Silicon -# with MLX (torch/CUDA absent), or the opt-in light-import download child. unsloth's -# CUDA-only imports then resolve to a loud no-op stub instead of a hard ImportError; -# the stub is never touched by the cache/download-only child, so it is inert there. +# Inject triton & bitsandbytes stubs whenever GPU init is skipped (MLX host or the +# opt-in download child), so unsloth's CUDA-only imports resolve to a loud no-op stub +# instead of a hard ImportError. Inert in the download child, which never touches them. # On a normal CUDA/CPU run _SKIP_GPU_INIT is False and the real modules are untouched. if _SKIP_GPU_INIT: from .stubs.triton_stub import inject_into_sys_modules as _inject_triton diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 433df8c40..e5c35b04c 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -10,18 +10,15 @@ """Sparse-aware introspection of the active Hugging Face hub cache. -These helpers report, for a repo's blobs under ``HF_HUB_CACHE``, how many bytes are actually on disk -(sparse-aware, so a partial Xet / ``hf_transfer`` ``.incomplete`` is not read as full progress) and -whether an ``.incomplete`` partial is present -- the two signals the no-progress watchdog runs on. +Reports on-disk bytes (sparse-aware, so a partial Xet / ``hf_transfer`` ``.incomplete`` is not read as +full progress) and whether an ``.incomplete`` partial is present -- the signals the no-progress +watchdog runs on. ``snapshot_dir_is_complete`` is a CONSERVATIVE fast-path gate, not an authoritative verifier: it -returns "complete" only for unambiguous canonical model layouts, and defers everything else -(diffusers, variants, patterns, datasets) to the watched ``snapshot_download`` child. A false -"complete" is the only dangerous error (an in-process load could then fetch a missing weight over -un-killable Xet); a false "not complete" only spawns the cheap child, so the gate errs that way. - -Only the active ``HF_HUB_CACHE`` root is scanned; multi-root / transport-marker logic is a -download-manager concern that lives in the consumer. +returns "complete" only for unambiguous canonical model layouts and defers everything else to the +watched ``snapshot_download`` child. A false "complete" is the only dangerous error (an in-process +load could then fetch a missing weight over un-killable Xet); a false "not complete" only spawns the +cheap child. Only the active ``HF_HUB_CACHE`` root is scanned. """ from __future__ import annotations @@ -53,7 +50,7 @@ def _safe_is_dir(path: Path) -> bool: - """``Path.is_dir()`` that returns False instead of raising on an unreadable path.""" + """``Path.is_dir()`` returning False instead of raising.""" try: return path.is_dir() except OSError: @@ -61,7 +58,7 @@ def _safe_is_dir(path: Path) -> bool: def _safe_is_file(path: Path) -> bool: - """``Path.is_file()`` that returns False instead of raising on an unreadable / dangling path.""" + """``Path.is_file()`` returning False instead of raising.""" try: return path.is_file() except OSError: @@ -69,11 +66,11 @@ def _safe_is_file(path: Path) -> bool: def hf_cache_root(*, create: bool = False, cache_dir: "Optional[str | Path]" = None) -> Optional[Path]: - """The hub cache root to scan, or None if unavailable. A given *cache_dir* is used verbatim; - otherwise ``HF_HUB_CACHE`` is read lazily so an import-time redirect is honored.""" + """The hub cache root to scan, or None if unavailable. *cache_dir* is used verbatim; otherwise + ``HF_HUB_CACHE`` is read lazily so an import-time redirect is honored.""" if cache_dir is not None: - # Match huggingface_hub, which expands ~ before writing. expanduser() raises if no home can - # be resolved (HOME unset in a container); fall back to the literal path rather than crash. + # Match huggingface_hub's ~ expansion; expanduser() raises with no resolvable HOME (container), + # so fall back to the literal path rather than crash. try: root = Path(cache_dir).expanduser() except (RuntimeError, OSError): @@ -98,7 +95,7 @@ def target_dir_name(repo_type: Optional[str], repo_id: str) -> str: def repo_cache_dir_name(repo_type: Optional[str], repo_id: str) -> str: - # repo_type=None is HF's default "model"; mirror that so None resolves models--, not Nones--. + # repo_type=None is HF's default "model", so None resolves models-- not Nones--. repo_type = repo_type or "model" return f"{repo_type}s--{repo_id.replace('/', '--')}" @@ -114,14 +111,14 @@ def _blob_dir_is_partial(blobs_dir: Path) -> bool: def blob_bytes_present(path: Path) -> int: - """Sparse-aware on-disk size: a Xet / ``hf_transfer`` ``.incomplete`` reports full ``st_size`` - while only some blocks are allocated, so prefer ``st_blocks``, falling back to ``st_size`` where - it is unreported (Windows, some network filesystems).""" + """Sparse-aware on-disk size: a sparse ``.incomplete`` reports full ``st_size`` while only some + blocks are allocated, so prefer ``st_blocks``, falling back to ``st_size`` where it is unreported + (Windows, some network filesystems).""" st = path.stat() blocks = getattr(st, "st_blocks", None) if blocks is not None: - # Trust st_blocks even when 0: a truncated sparse .incomplete reports full st_size but 0 - # blocks and must read as 0 bytes present (a > 0 guard would fall through to st_size). + # Trust st_blocks even at 0: a truncated sparse .incomplete reports full st_size but 0 blocks + # and must read as 0 bytes (a > 0 guard would fall through to st_size). return min(blocks * 512, st.st_size) if sys.platform == "win32": allocated = _windows_allocated_size(path) @@ -157,7 +154,7 @@ def _windows_allocated_size(path: Path) -> Optional[int]: def latest_snapshot_dir(repo_dir: Path) -> Optional[Path]: - """Newest child of ``repo_dir/snapshots`` by mtime (the signal from_pretrained resolves to), or None.""" + """Newest child of ``repo_dir/snapshots`` by mtime (what from_pretrained resolves to), or None.""" snapshots_dir = repo_dir / "snapshots" try: if not snapshots_dir.is_dir(): @@ -171,7 +168,7 @@ def latest_snapshot_dir(repo_dir: Path) -> Optional[Path]: def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: - """True if *snapshot_dir* holds a dangling symlink (a referenced blob that is missing or still + """True if *snapshot_dir* holds a dangling symlink (a referenced blob missing or still ``.incomplete``) -- an interrupted download. Validates one requested revision, not just the newest.""" try: for entry in snapshot_dir.rglob("*"): @@ -199,8 +196,8 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: ".pdparams", ) -# Trainer / optimizer state carries weight suffixes (.bin / .pt / .pth) but is NOT a loadable weight, -# so a cache holding only these is not a warm model cache. +# Trainer / optimizer state carries weight suffixes but is NOT a loadable weight, so a cache holding +# only these is not a warm model cache. _NON_WEIGHT_BASENAMES = frozenset({ "training_args.bin", "optimizer.bin", @@ -216,8 +213,7 @@ def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: def _is_loadable_weight_file(name: str) -> bool: - """True if *name* is a loadable model weight: a weight suffix that is not a trainer / optimizer - state artifact (training_args.bin, optimizer.pt, rng_state.pth, ...).""" + """True if *name* is a loadable weight: a weight suffix that is not trainer / optimizer state.""" if not name.endswith(_WEIGHT_FILE_SUFFIXES): return False lowered = name.lower() @@ -229,28 +225,24 @@ def _is_loadable_weight_file(name: str) -> bool: def _is_weight_shard_index(name: str) -> bool: - """True for a weight-shard index sidecar, canonical or variant (``model.safetensors.index.json`` - and ``model.safetensors.index.fp16.json``); a plain suffix test would miss the variant form.""" + """True for a weight-shard index sidecar, canonical or variant; a plain suffix test would miss the + variant form (``model.safetensors.index.fp16.json``).""" return name.endswith(".json") and (".safetensors.index." in name or ".bin.index." in name) def _is_canonical_weight_shard_index(name: str) -> bool: - """True only for the CANONICAL (non-variant) index a default load probes - (``model.safetensors.index.json`` / ``pytorch_model.bin.index.json``). Exact names only: an - ``adapter_model.safetensors.index.json`` (or a variant ``...index.fp16.json``) is rejected, so a + """True only for the CANONICAL (non-variant) index a default load probes. Exact names only, so a sharded-adapter-only / variant-only cache does not satisfy the canonical fast path (its base canonical weights are still missing -> the load would fetch them over un-killable Xet).""" return name in ("model.safetensors.index.json", "pytorch_model.bin.index.json") def _is_unsafe_shard_ref(shard: str) -> bool: - """True if a weight-index ``weight_map`` value is NOT a safe relative path inside the snapshot: an - absolute path, a Windows drive-letter reference (``C:\\x`` / ``C:x``), a UNC path, or a - parent-escaping (``..``) reference. Judged under BOTH POSIX and Windows path semantics so a crafted / - malformed index is rejected regardless of the OS running the check -- on Windows ``base / "C:\\x"`` - resolves OUTSIDE the snapshot and would read as a present shard, and ``startswith(("/", "\\"))`` alone - misses a drive-letter value. A well-formed HF index lists a plain relative basename (or subfolder - path), so a legitimate index is never rejected.""" + """True if a ``weight_map`` value is NOT a safe relative path inside the snapshot (absolute, Windows + drive-letter, UNC, or ``..``-escaping). Judged under BOTH POSIX and Windows semantics so a crafted + index is rejected regardless of the running OS (on Windows ``base / "C:\\x"`` resolves OUTSIDE the + snapshot; ``startswith(("/", "\\"))`` alone misses a drive letter). A well-formed relative basename + is never rejected.""" if not shard or shard.startswith(("/", "\\")): return True win = PureWindowsPath(shard) @@ -266,8 +258,8 @@ def _weight_shard_index_complete(index_path: Path) -> bool: """True only if every shard a HF weight index lists is present next to it. Fail-CLOSED: an unreadable / truncated index, a non-dict payload or ``weight_map``, or an empty - shard set return False, so a malformed index defers to the watched child rather than letting the - in-process load skip it and then fail (or fetch over Xet).""" + shard set return False, deferring a malformed index to the watched child rather than letting the + load skip it and then fail (or fetch over Xet).""" import json try: @@ -279,17 +271,13 @@ def _weight_shard_index_complete(index_path: Path) -> bool: if not isinstance(weight_map, dict): return False values = list(weight_map.values()) - # A non-string shard value is a malformed index transformers cannot load; fail CLOSED (defer to the - # watched child) rather than silently dropping the bad entry and reading the remaining shards as a - # complete set. + # A non-string shard value is malformed; fail CLOSED rather than drop it and read the rest as complete. if not values or not all(isinstance(s, str) for s in values): return False shards = set(values) base = index_path.parent for shard in shards: - # A well-formed HF index lists a relative shard basename. Reject an absolute / drive-letter / - # parent-escaping value (a malformed or crafted index) rather than let ``base / shard`` resolve - # to an unrelated existing file OUTSIDE the snapshot and read as "present". + # Reject an unsafe ref rather than let ``base / shard`` resolve to a file OUTSIDE the snapshot. if _is_unsafe_shard_ref(shard): return False try: @@ -308,14 +296,13 @@ def _weight_shard_index_complete(index_path: Path) -> bool: def _has_glob(text: str) -> bool: - # A trailing-slash dir pattern ("unet/") is not an exact filename: HF expands it like "unet/*", - # so treat it as a wildcard rather than look for a literal "unet/" entry. + # A trailing-slash dir pattern ("unet/") is a wildcard: HF expands it like "unet/*". return text.endswith("/") or any(ch in text for ch in _GLOB_CHARS) def _as_pattern_list(patterns: "Optional[object]") -> "Optional[list]": - """Normalize an allow / ignore argument to a list. HF accepts a bare ``str``; iterating it would - walk it character by character ("checkpoint-10/*" would never match).""" + """Normalize an allow / ignore argument to a list. HF accepts a bare ``str``, which would otherwise + be iterated character by character.""" if patterns is None: return None if isinstance(patterns, str): @@ -329,8 +316,8 @@ def _filter_paths( ignore_patterns: "Optional[list]" = None, ) -> list: """Filter repo-relative *paths* by HF allow / ignore patterns (as ``snapshot_download`` selects). - Fails OPEN (returns all paths) so a snapshot that does hold weights is never rejected on an - unevaluable filter.""" + Fails OPEN (returns all paths) so a snapshot holding weights is never rejected on an unevaluable + filter.""" try: from huggingface_hub.utils import filter_repo_objects @@ -344,8 +331,8 @@ def _filter_paths( def _broken_symlink_rel_paths(snapshot_dir: Path) -> list: - """Repo-relative posix paths of every dangling symlink in *snapshot_dir* (empty when none), so the - interrupted-download signal can be scoped to the files a request actually selects.""" + """Repo-relative posix paths of every dangling symlink in *snapshot_dir*, so the interrupted-download + signal can be scoped to the files a request selects.""" out: list = [] try: for entry in snapshot_dir.rglob("*"): @@ -369,9 +356,9 @@ def snapshot_has_requested_broken_symlinks( ignore_patterns: "Optional[object]" = None, repo_type: "Optional[str]" = "model", ) -> bool: - """True iff a dangling symlink in *snapshot_dir* is for a file the request actually SELECTS, so a - dangling root ``model.safetensors`` does not fail a weightless ``allow=["config.json"]`` request - whose config is on disk. (*repo_type* is kept for signature compatibility.)""" + """True iff a dangling symlink in *snapshot_dir* is for a file the request SELECTS, so a dangling root + ``model.safetensors`` does not fail a weightless ``allow=["config.json"]`` request whose config is on + disk. (*repo_type* kept for signature compatibility.)""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) broken = _broken_symlink_rel_paths(snapshot_dir) @@ -389,9 +376,8 @@ def snapshot_has_requested_broken_symlinks( def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: - """True iff the ignore set provably excludes EVERY weight format (a probe of each suffix matches a - pattern). A partial strip is NOT weightless -- a surviving variant / .gguf / .pt weight could - still be pulled, so the request stays weight-bearing (conservative).""" + """True iff the ignore set provably excludes EVERY weight format. A partial strip is NOT weightless + (a surviving weight could still be pulled), so the request stays weight-bearing (conservative).""" for suffix in _WEIGHT_FILE_SUFFIXES: probe = "weight" + suffix if not any(isinstance(p, str) and fnmatch.fnmatchcase(probe, p) for p in ignore_patterns): @@ -399,9 +385,9 @@ def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: return True -# Representative weight names a glob allow pattern is probed against (via fnmatch): a glob matching one -# can select a weight; one matching none (``tokenizer*``, ``*.json``) is weightless. Covers canonical / -# variant / sharded / adapter / diffusers / consolidated and the non-safetensors formats. +# Representative weight names a glob allow pattern is probed against: a glob matching one can select a +# weight; one matching none (``tokenizer*``, ``*.json``) is weightless. Covers canonical / variant / +# sharded / adapter / diffusers / consolidated and the non-safetensors formats. _WEIGHT_PATTERN_PROBES = ( "model.safetensors", "model.fp16.safetensors", @@ -438,16 +424,14 @@ def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: def _pattern_can_select_weight(pattern: "object") -> bool: """Whether a single allow pattern could select a weight. A weight-suffix basename or a non-metadata directory pattern is weight-bearing; a glob basename is weight-bearing only if it matches a - ``_WEIGHT_PATTERN_PROBES`` name (so ``tokenizer*`` / ``*.json`` stay weightless while - ``model.?afetensors`` / ``unet/*`` do not); a concrete non-weight name is weightless. A false - weight-bearing only spawns the cheap child; the probe set avoids a false weightless on real weights.""" + ``_WEIGHT_PATTERN_PROBES`` name; a concrete non-weight name is weightless. A false weight-bearing only + spawns the cheap child; the probe set avoids a false weightless on real weights.""" if not isinstance(pattern, str): return True if pattern.endswith("/"): dir_name = pattern.rstrip("/").rsplit("/", 1)[-1].lower() return dir_name not in _NON_WEIGHT_DIRS - # A pattern scoped under a metadata dir ("tokenizer/*", "processor/*.json") is weightless like the - # "tokenizer/" form, instead of letting a "*" basename match a weight probe. + # A pattern scoped under a metadata dir ("tokenizer/*") is weightless like the "tokenizer/" form. if "/" in pattern: parent = pattern.rsplit("/", 1)[0].rstrip("/").rsplit("/", 1)[-1].lower() if parent in _NON_WEIGHT_DIRS: @@ -464,9 +448,9 @@ def request_can_include_weights( allow_patterns: "Optional[object]" = None, ignore_patterns: "Optional[object]" = None ) -> bool: """Whether a request restricted by *allow_patterns* / *ignore_patterns* can still include a weight. - Conservative: True when uncertain, so the acceptance check requires a weight; False only for a - clearly weightless request (a tokenizer / config allow list, an ignore list dropping every weight - format, or an allow + ignore pair that strips them all), preserving the tokenizer-only short-circuit.""" + Conservative: True when uncertain; False only for a clearly weightless request (a tokenizer / config + allow list, an ignore list dropping every weight format, or an allow + ignore pair that strips them + all), preserving the tokenizer-only short-circuit.""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) if allow_patterns is None and ignore_patterns is None: @@ -477,13 +461,11 @@ def request_can_include_weights( return False # allow=[] selects nothing if not any(_pattern_can_select_weight(pat) for pat in allow_patterns): return False - # An allow that can reach a weight can still be left weightless by the ignore filter: allow=["*"] + - # ignore=[every weight suffix], OR a subdir warm allow=["unet/*"] that ignores every weight suffix to - # fetch only that subdir's metadata / configs. Apply HF's allow-then-ignore semantics to representative - # weight probes at the ROOT and UNDER each subdir-scoped allow, so a genuinely weightless request is not - # required to hold a weight (which would false-reject a complete metadata-only subset after both - # transports). A subdir allow that keeps its weight suffixes still matches a subdir probe and stays - # weight-bearing. + # An allow that reaches a weight can still be left weightless by the ignore filter (allow=["*"] + + # ignore=[every weight suffix], or a subdir allow=["unet/*"] ignoring every weight suffix). Apply HF's + # allow-then-ignore semantics to weight probes at the ROOT and under each subdir-scoped allow, so a + # genuinely weightless request is not required to hold a weight (a subdir allow keeping its weight + # suffixes still matches a subdir probe and stays weight-bearing). if ignore_patterns: probes = list(_WEIGHT_PATTERN_PROBES) for pat in allow_patterns: @@ -504,17 +486,15 @@ def _canonical_root_weights_complete( Numbered shards without an index, or subfolder-only weights, do NOT count. A weight whose FORMAT the ignore filter drops does not count (a stale ``pytorch_model.bin`` under - ``ignore=['*.bin']`` is not proof the requested safetensors are on disk). The format probe also + ``ignore=['*.bin']`` is no proof the requested safetensors are on disk). The format probe also discards a ``pytorch_model.bin.index.json`` whose ``.json`` name would slip the raw filter. - *prefer_safetensors* is set by the STRICT pre-download gate: a default transformers load probes - ``model.safetensors`` BEFORE ``pytorch_model.bin``, so when safetensors is a format the load would - read (not ignored) a bin-only cache cannot be proven complete -- the local cache cannot show the - preferred safetensors is absent remotely, and skipping the child would let the in-process load fetch - it over un-killable Xet. So a ``.bin`` weight then satisfies the gate only when safetensors is - IGNORED (``use_safetensors=False``); otherwise the bin-only cache defers to the child. The lenient - POST path leaves this False: a finished bin-only download is a genuinely bin-only repo and must not - be false-rejected into a ``DownloadStallError``.""" + *prefer_safetensors* is set by the STRICT pre-download gate: a default load probes safetensors + BEFORE bin, so when safetensors is read (not ignored) a bin-only cache cannot be proven complete + (the local cache cannot show the preferred safetensors is absent remotely) and skipping the child + would fetch it over un-killable Xet. So ``.bin`` satisfies the gate only when safetensors is IGNORED. + The lenient POST path leaves this False: a finished bin-only download is a genuinely bin-only repo + and must not be false-rejected into a ``DownloadStallError``.""" root_files: set = set() root_indices: list = [] for entry in entries: @@ -531,28 +511,24 @@ def _canonical_root_weights_complete( root_files.add(entry.name) def _format_kept(weight_name: str) -> bool: - # The format a load reads from *weight_name* must survive the ignore filter, else the file is - # a stale artifact for an excluded format and proves nothing. + # The read format must survive the ignore filter, else the file is a stale excluded-format artifact. if not ignore_patterns: return True return bool(_filter_paths([weight_name], None, ignore_patterns)) st_index = next((e for e in root_indices if ".safetensors.index." in e.name), None) bin_index = next((e for e in root_indices if ".bin.index." in e.name), None) - # transformers' local weight-file precedence, mirrored exactly: a single model.safetensors is probed - # BEFORE the safetensors index, safetensors before the .bin single, and the .bin single before the - # .bin index. So a complete single weight is never masked by a co-resident stale index, and an - # incomplete PREFERRED (safetensors) index is breakage a complete .bin must not mask (transformers - # takes the safetensors-index branch and does not fall back to .bin). A format the ignore filter - # drops is skipped so the next format the load actually reads is judged. + # transformers' local precedence, mirrored: single safetensors before the safetensors index, + # safetensors before the .bin single, .bin single before the .bin index. So a complete single weight + # is never masked by a stale index, and an incomplete PREFERRED safetensors index is breakage a + # complete .bin must not mask. An ignore-dropped format is skipped so the next read format is judged. if "model.safetensors" in root_files and _format_kept("model.safetensors"): return True if st_index is not None and _format_kept("model.safetensors"): return _weight_shard_index_complete(st_index) if prefer_safetensors and _format_kept("model.safetensors"): - # STRICT pre-download gate: safetensors is preferred (not ignored) but absent from the cache, so - # a default load would fetch model.safetensors over un-killable Xet. A bin-only cache cannot - # prove safetensors is absent remotely -> defer to the watched child rather than fast-path. + # STRICT gate: safetensors is preferred but absent, and a bin-only cache cannot prove it absent + # remotely, so a default load would fetch it over Xet -> defer to the child. return False if "pytorch_model.bin" in root_files and _format_kept("pytorch_model.bin"): return True @@ -570,20 +546,18 @@ def snapshot_dir_is_complete( prefer_safetensors: bool = False, ) -> bool: """Conservative fast-path gate: True only for an unambiguously complete canonical ROOT model cache, - so an in-process load will not fetch a weight. True requires: an UNPATTERNED request - (``allow_patterns is None``), not a diffusers pipeline (no root ``model_index.json``), no dangling - symlink, and canonical root weights present. Everything else defers to the watched child. A false - True risks a silent Xet fetch; a false False only spawns the cheap child. *require_named_weights* - is accepted for signature compatibility (a named-weight request is patterned, so never fast-pathed). + so an in-process load will not fetch a weight. True requires: an UNPATTERNED request, not a diffusers + pipeline (no root ``model_index.json``), no dangling symlink, and canonical root weights present. + Everything else defers to the watched child. A false True risks a silent Xet fetch; a false False + only spawns the cheap child. *require_named_weights* is accepted for signature compatibility (a + named-weight request is patterned, so never fast-pathed). *prefer_safetensors* (set by the strict pre-download gate) rejects a bin-only cache when a default - load would prefer safetensors (not ignored): the local cache cannot prove the preferred file is - absent remotely, so fast-pathing it would let the in-process load fetch it over Xet. The POST caller - leaves it False so a genuinely bin-only download is still accepted. + load prefers safetensors: the cache cannot prove the preferred file absent remotely, so fast-pathing + would fetch it over Xet. The POST caller leaves it False so a genuinely bin-only download is accepted. *ignore_patterns* need no eligibility gate: the canonical-weight check below is what the load reads, - so an ignore that dropped some format (the common ``*.bin`` / subdir prefetch ignores) cannot make - an incomplete cache read complete -- keeping the common warm ``from_pretrained`` cache eligible.""" + so an ignore dropping some format cannot make an incomplete cache read complete.""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) if allow_patterns is not None: @@ -601,8 +575,8 @@ def snapshot_dir_is_complete( ) -# A canonical numbered root shard: the index sits IMMEDIATELY before the extension (no variant token), -# so ``model-00001-of-00002.safetensors`` matches but ``model-00001-of-00002.fp16.safetensors`` does not. +# A canonical numbered root shard (no variant token): ``model-00001-of-00002.safetensors`` matches but +# ``model-00001-of-00002.fp16.safetensors`` does not. _CANONICAL_ROOT_SHARD_RE = re.compile( r"^(?:model|pytorch_model)-\d{5}-of-\d{5}\.(?:safetensors|bin)$" ) @@ -612,18 +586,16 @@ def _has_incomplete_canonical_root_shards( snapshot_dir: Path, *, ignore_patterns: "Optional[object]" = None ) -> bool: """True when the root holds canonical numbered shards but is NOT a complete canonical model (index - missing or a shard absent) for the format the request READS -- a stale interrupted download a - default load cannot read, so the post-download check rejects it and retries over HTTP. The request's - ignore filter is applied, so a complete safetensors set does not mask an incomplete ``.bin`` set the - load reads under ``ignore=['*.safetensors']``. Variant shards are excluded (their names carry a - ``.-`` infix), so a variant-only repo is not force-failed here.""" + missing or a shard absent) for the format the request READS -- a stale interrupted download the post + check rejects and retries over HTTP. The request's ignore filter is applied, so a complete + safetensors set does not mask an incomplete ``.bin`` set read under ``ignore=['*.safetensors']``. + Variant shards (``.-`` infix) are excluded, so a variant-only repo is not force-failed.""" try: names = [entry.name for entry in snapshot_dir.iterdir()] except OSError: return False - # Canonical shard evidence = a numbered shard FILE, or a canonical shard INDEX. An index-only - # partial (index present, no shards yet) is still an incomplete sharded checkpoint the load would - # finish over Xet, so it must be caught here even before any shard file exists. + # Shard evidence = a numbered shard FILE or a canonical shard INDEX. An index-only partial (index + # present, no shards yet) is still incomplete and must be caught before any shard file exists. has_shard_evidence = ( any(_CANONICAL_ROOT_SHARD_RE.match(name) for name in names) or any(_is_canonical_weight_shard_index(name) for name in names) @@ -636,27 +608,21 @@ def _has_incomplete_canonical_root_shards( def _has_incomplete_variant_root_shards( snapshot_dir: Path, variant: str, *, ignore_patterns: "Optional[object]" = None ) -> bool: - """True when the ROOT variant weight the load READS is an incomplete sharded set. transformers writes - a sharded variant weight with a ``.-`` shard infix and its index as - ``model.safetensors.index..json`` (a ``..`` infix before ``.json``); a single-file - variant is ``model..safetensors``. Incomplete means: a present variant shard INDEX whose - listed shards are not all present (an index-only partial with no shard files counts), OR variant shard - FILES with no complete index. - - The request's ignore filter is applied so a variant weight in an ignored format is not the read - format, and safetensors is treated as read BEFORE bin (transformers' probe order): a present-but- - incomplete variant safetensors index is breakage even with a complete variant bin. Positive-evidence: - a single-file variant or a complete variant shard set returns False, so a complete or single-file - variant download is never rejected. Only the ROOT ``model`` / ``pytorch_model`` variant weight is - considered: a co-resident stale ``adapter_model`` variant index / shard set (which a default variant - model load does not read) must not force-fail a complete model variant.""" - dot_infix = f".{variant}." # variant index (model.safetensors.index..json) or single file + """True when the ROOT variant weight the load READS is an incomplete sharded set. Incomplete means: + a present variant shard INDEX whose listed shards are not all present (an index-only partial counts), + OR variant shard FILES with no complete index. + + The request's ignore filter is applied, and safetensors is read BEFORE bin (transformers' probe + order): an incomplete variant safetensors index is breakage even with a complete variant bin. + Positive-evidence: a single-file variant or a complete variant shard set returns False. Only the ROOT + ``model`` / ``pytorch_model`` variant weight is considered, so a stale ``adapter_model`` variant set + (which the default variant model load does not read) must not force-fail a complete model variant.""" + dot_infix = f".{variant}." # variant index or single file dash_infix = f".{variant}-" # a sharded variant weight (model.-00001-of-00002.safetensors) ignore_patterns = _as_pattern_list(ignore_patterns) def _format_kept(weight_name: str) -> bool: - # The format a load reads from *weight_name* must survive the ignore filter, else the file is a - # stale artifact for an excluded format the load does not read. + # The read format must survive the ignore filter, else the file is a stale excluded-format artifact. if not ignore_patterns: return True return bool(_filter_paths([weight_name], None, ignore_patterns)) @@ -671,9 +637,8 @@ def _format_kept(weight_name: str) -> bool: has_single_st = has_single_bin = False for entry in entries: name = entry.name - # Restrict to the ROOT model index (model.safetensors.index..json / - # pytorch_model.bin.index..json); an adapter_model / other non-model variant index the - # default load does not read is skipped so its incompleteness cannot force-fail the model variant. + # Restrict to the ROOT model index; an adapter_model / other non-model variant index the default + # load does not read is skipped so its incompleteness cannot force-fail the model variant. if dot_infix in name and _ROOT_MODEL_SHARD_INDEX_RE.match(name): is_safetensors = ".safetensors.index." in name fmt_probe = ( @@ -699,11 +664,10 @@ def _format_kept(weight_name: str) -> bool: has_single_st = True else: has_single_bin = True - # transformers' local precedence, mirrored: a single-file model..safetensors is probed - # BEFORE the safetensors index, safetensors before .bin, and the single .bin before the .bin index. - # So a complete single-file variant is never masked by a co-resident stale index (that would force a - # spurious HTTP retry and DownloadStallError on a usable cache), and an incomplete PREFERRED - # (safetensors) index is still breakage a complete .bin must not mask. + # transformers' local precedence, mirrored: single safetensors before the safetensors index, + # safetensors before .bin, single .bin before the .bin index. So a complete single-file variant is + # never masked by a stale index (which would force a spurious DownloadStallError), and an incomplete + # PREFERRED safetensors index is still breakage a complete .bin must not mask. if has_single_st: return False # a complete single-file safetensors variant, probed before the index if st_index_incomplete is not None: @@ -727,19 +691,17 @@ def _format_kept(weight_name: str) -> bool: r"^(?:model\.safetensors|pytorch_model\.bin)\.index(?:\.[^.]+)?\.json$" ) -# A ROOT model VARIANT weight (single or sharded): the variant token sits between the model / pytorch_model -# base and the extension / shard suffix (model.fp16.safetensors, pytorch_model.fp16-00001-of-00002.bin). -# Excludes a PEFT adapter (adapter_model..*) the default variant model load does not read. +# A ROOT model VARIANT weight, single or sharded (model.fp16.safetensors, +# pytorch_model.fp16-00001-of-00002.bin). Excludes a PEFT adapter the variant model load does not read. _ROOT_MODEL_VARIANT_WEIGHT_RE = re.compile( r"^(?:model|pytorch_model)\.[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" ) def _index_variant_token(name: str) -> "Optional[str]": - """The variant token of a weight-shard INDEX basename, or None for the canonical (non-variant) form. - ``model.safetensors.index.json`` -> None; ``model.safetensors.index.fp16.json`` -> ``"fp16"``. Lets - the selected-index check read only the indices a load reads (a variant load reads variant indices, a - plain load reads canonical ones).""" + """The variant token of a weight-shard INDEX basename, or None for the canonical form + (``model.safetensors.index.fp16.json`` -> ``"fp16"``). Lets the selected-index check read only the + indices a load reads (variant load -> variant indices, plain load -> canonical).""" if name.endswith(".safetensors.index.json") or name.endswith(".bin.index.json"): return None m = _VARIANT_SHARD_INDEX_RE.search(name) @@ -747,12 +709,10 @@ def _index_variant_token(name: str) -> "Optional[str]": def _index_shard_rel_paths(index_path: Path, dir_rel: str) -> "Optional[list]": - """The snapshot-relative posix paths of the shards a weight index lists, or None if the index is - unreadable / malformed -- mirrors the fail-CLOSED rules of ``_weight_shard_index_complete`` (a - non-dict payload or ``weight_map``, an empty shard set, or a non-string / absolute / drive-letter / - parent-escaping shard value all return None). *dir_rel* is the index's snapshot-relative dir ("" at - root), so a - listed basename is joined back to a full repo-relative path for the request filter.""" + """Snapshot-relative posix paths of the shards a weight index lists, or None if the index is + unreadable / malformed (fail-CLOSED, mirroring ``_weight_shard_index_complete``). *dir_rel* is the + index's snapshot-relative dir ("" at root), so a listed basename is joined back to a full + repo-relative path for the request filter.""" import json try: @@ -776,13 +736,10 @@ def _index_shard_rel_paths(index_path: Path, dir_rel: str) -> "Optional[list]": def _index_shard_probe(index_name: str, dir_rel: str) -> "Optional[str]": - """A representative numbered-shard path for a weight-shard INDEX whose listed shards are unknown (a - malformed / truncated index): the index's own base name + format as a first shard, joined under - *dir_rel*. Lets the malformed-index scope check judge the request's allow / ignore filter on the - WEIGHT the load reads rather than on the ``.json`` index filename -- ``ignore=['*.bin']`` does not - match ``pytorch_model.bin.index.json`` but the load never reads that ignored-format index, so - filtering the filename would wrongly retry a complete other-format download. None when the name is not - a recognizable shard index.""" + """A representative numbered-shard path for a malformed weight-shard INDEX (index base + format as a + first shard, under *dir_rel*), so the scope check judges the request filter on the WEIGHT the load + reads, not the ``.json`` filename -- ``ignore=['*.bin']`` misses ``pytorch_model.bin.index.json`` but + the load never reads that ignored-format index. None when the name is not a recognizable shard index.""" for marker, ext in ((".safetensors.index.", "safetensors"), (".bin.index.", "bin")): if marker in index_name: base = index_name.split(marker, 1)[0] @@ -795,11 +752,9 @@ def _index_shard_probe(index_name: str, dir_rel: str) -> "Optional[str]": def _request_scopes_into_dir(allow_patterns: "Optional[list]", dir_name: str) -> bool: """True when an allow pattern names *dir_name* among its LITERAL leading path segments - (``subfolder=checkpoint-7`` -> ``allow=['checkpoint-7/*']``; a NESTED ``subfolder=foo/checkpoint-7`` - -> ``allow=['foo/checkpoint-7/*']``), i.e. the load reads INTO that directory at any depth. Lets the - shard-completeness check skip a leftover checkpoint subtree the request does not target, while still - validating a checkpoint the request explicitly loads from. Segments are read only up to the first - glob (a wildcard segment could match anything, so it is not a literal directory target).""" + (``allow=['checkpoint-7/*']`` -> True for ``checkpoint-7``), i.e. the load reads INTO that directory. + Lets the shard check skip a leftover checkpoint subtree the request does not target while still + validating one it explicitly loads from. Segments are read only up to the first glob.""" for p in allow_patterns or (): if not isinstance(p, str) or "/" not in p: continue @@ -815,24 +770,19 @@ def _selected_shard_index_incomplete( snapshot_dir: Path, *, allow_patterns: "Optional[object]", ignore_patterns: "Optional[object]", variant: "Optional[str]", ) -> bool: - """True when a weight-shard INDEX the in-process load READS -- a sharded ADAPTER or a component - SUBFOLDER set that the canonical / variant ROOT-model checks do not cover -- lists a shard that is - absent (or the index is malformed). Scoped to the request so a complete download is never - false-rejected: + """True when a weight-shard INDEX the load READS -- a sharded ADAPTER or a component SUBFOLDER set not + covered by the canonical / variant ROOT-model checks -- lists an absent shard (or is malformed). + Scoped to the request so a complete download is never false-rejected: - - variant: a variant load reads only variant indices (token == variant); a plain load reads only - canonical (token is None) indices. + - variant: a variant load reads only variant indices; a plain load only canonical ones. - allow / ignore: an index is read only when its listed shards survive the request filter. - - precedence: within a directory transformers reads safetensors before bin, so when both a - safetensors and a bin index are selected only the safetensors set's completeness is required. - - Also rejects a SELECTED numbered shard FILE (adapter_model-00001-of-00002.safetensors, - unet/diffusion_pytorch_model-00001-of-00002.safetensors) whose directory has NO index of the read - format: the load enumerates a sharded weight through its index, so a shard set without one is - incomplete and would fetch the index and remaining shards over Xet. + - precedence: safetensors read before bin, so when both are selected only the safetensors set's + completeness is required. - The ROOT canonical / variant MODEL shard set is skipped -- ``_has_incomplete_canonical_root_shards`` / - ``_has_incomplete_variant_root_shards`` own it (with their own precedence handling).""" + Also rejects a SELECTED numbered shard FILE whose directory has NO index of the read format: the load + enumerates a sharded weight through its index, so a shard set without one is incomplete and would + fetch the index + remaining shards over Xet. The ROOT canonical / variant MODEL shard set is skipped + (the root-shard checks own it).""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) want_variant = variant or None @@ -866,13 +816,10 @@ def _selected_shard_index_incomplete( index_fmts.setdefault(dir_rel, set()).add(fmt) shard_rels = _index_shard_rel_paths(entry, dir_rel) if shard_rels is None: - # A malformed / non-string index. Defer to the watched child only when the REQUEST reads - # this index's weight set -- judged on a representative shard of the index's OWN base + - # format (via the allow / ignore filter), not the .json filename. So a co-resident stale - # malformed index the request does NOT select (a leftover adapter index under a base - # ['model*'] / subfolder warm) OR one for an IGNORED format (a *.bin index under - # ignore=['*.bin']) is not read and must not force a spurious retry; an unrecognizable - # index defers to the child. + # A malformed index. Defer only when the REQUEST reads its weight set, judged on a + # representative shard of the index's OWN base + format (not the .json filename). So a + # stale malformed index the request does NOT select, or one for an IGNORED format, must + # not force a spurious retry; an unrecognizable index defers to the child. probe = _index_shard_probe(name, dir_rel) if probe is None or _filter_paths([probe], allow_patterns, ignore_patterns): return True @@ -881,8 +828,8 @@ def _selected_shard_index_incomplete( continue # the load does not read this set (out of scope / ignored format) per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) elif shard_file_re.match(name): - # a numbered weight shard FILE of the read variant. Skip the ROOT model shard set (owned by - # the canonical / variant root-shard checks) and any training-checkpoint subtree. + # a numbered shard FILE of the read variant. Skip the ROOT model shard set (root checks own + # it) and any training-checkpoint subtree. if dir_rel == "" and ( (want_variant is None and _CANONICAL_ROOT_SHARD_RE.match(name)) or (want_variant is not None and _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name)) @@ -890,10 +837,8 @@ def _selected_shard_index_incomplete( continue ckpt_dirs = [p for p in rel.split("/")[:-1] if _CHECKPOINT_DIR_RE.match(p)] if ckpt_dirs and not _request_scopes_into_dir(allow_patterns, ckpt_dirs[0]): - # a leftover training-checkpoint subtree the request does not explicitly target (a base / - # adapter / other-subfolder warm never reads it). But an EXPLICIT checkpoint load - # (subfolder=checkpoint-N -> allow=['checkpoint-N/*']) DOES read it, so its shard set must - # be checked for completeness rather than silently accepted as a lone shard. + # a leftover checkpoint subtree the request does not target. An EXPLICIT checkpoint load + # (allow=['checkpoint-N/*']) DOES read it, so that set is checked rather than skipped. continue if not _filter_paths([rel], allow_patterns, ignore_patterns): continue # the load does not read this shard (out of scope / ignored format) @@ -909,26 +854,25 @@ def _selected_shard_index_incomplete( except OSError: return True for dir_rel, fmts in shard_fmts.items(): - # a numbered shard of the read (preferred) format with NO index in its directory: the load cannot - # enumerate the set and would fetch the index + remaining shards over Xet. + # a numbered shard of the read format with NO index in its dir: the load cannot enumerate the set + # and would fetch the index + remaining shards over Xet. preferred = "safetensors" if "safetensors" in fmts else "bin" if preferred not in index_fmts.get(dir_rel, set()): return True return False -# A training-checkpoint subdir (checkpoint-500/, checkpoint_7/): its weights are never read as diffusers -# pipeline COMPONENTS, so an incomplete shard index under it must not force-fail a complete pipeline. +# A training-checkpoint subdir (checkpoint-500/): never read as a diffusers pipeline COMPONENT, so an +# incomplete shard index under it must not force-fail a complete pipeline. _CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") def _diffusers_declared_components(snapshot_dir: Path) -> "Optional[set]": """The component subfolder names a diffusers ``model_index.json`` declares (top-level keys mapping to - a ``[library, class]`` list; ``_``-prefixed metadata keys excluded). None when the file is absent / - unreadable / malformed, so the caller falls back to treating every subfolder as a component (fail - OPEN, preserving hang protection). Scopes the component shard check to what the pipeline actually - reads, so a co-resident stale UNDECLARED subtree (a leftover adapter / controlnet dir the - ``DiffusionPipeline`` load never reads) cannot force-fail a complete pipeline download.""" + a ``[library, class]`` list; ``_``-prefixed metadata excluded). None when absent / unreadable / + malformed, so the caller falls back to every subfolder (fail OPEN, preserving hang protection). + Scopes the component check to what the pipeline reads, so a stale UNDECLARED subtree cannot + force-fail a complete pipeline download.""" import json try: @@ -942,8 +886,8 @@ def _diffusers_declared_components(snapshot_dir: Path) -> "Optional[set]": key for key, value in data.items() if not key.startswith("_") and isinstance(value, (list, tuple)) } - # A real pipeline always declares components; an empty / all-metadata model_index.json is degenerate - # or malformed -> fail OPEN (None) so the caller checks every subfolder, preserving hang protection. + # An empty / all-metadata model_index.json is degenerate -> fail OPEN (None) so the caller checks + # every subfolder, preserving hang protection. return components or None @@ -951,19 +895,16 @@ def _diffusers_component_shards_incomplete( snapshot_dir: Path, *, variant: "Optional[str]" = None, ignore_patterns: "Optional[object]" = None, ) -> bool: - """True when a diffusers pipeline COMPONENT subfolder (unet/, vae/, text_encoder/, ...) holds a - weight-shard INDEX of the read variant that lists a shard that is absent (or the index is malformed) - -- an interrupted component pull the in-process pipeline load would finish over un-killable Xet. - - Scoped so a complete pipeline is never false-rejected: the check is limited to the components - ``model_index.json`` declares (a stale UNDECLARED subtree the pipeline load never reads is skipped), - a ROOT index (owned by the canonical / variant root-model checks) and a training-checkpoint subtree - (checkpoint-N/) are skipped, and the request's ignore filter selects the read format. Per directory, - safetensors is read before bin, so only the preferred format's set must be complete. A plain load - reads canonical component indices (token None); a variant load reads variant ones. Also rejects a - component holding a numbered shard FILE with NO index of the read format (the pipeline cannot - enumerate the set and would fetch the index + remaining shards over Xet). Positive-evidence: a - single-file component or a complete component shard set is not flagged, so a complete download passes.""" + """True when a diffusers pipeline COMPONENT subfolder (unet/, vae/, ...) holds a weight-shard INDEX of + the read variant listing an absent shard (or a malformed index) -- an interrupted component pull the + pipeline load would finish over un-killable Xet. + + Scoped so a complete pipeline is never false-rejected: limited to declared components (a stale + UNDECLARED subtree is skipped), ROOT indices (root checks own them) and checkpoint subtrees are + skipped, and the ignore filter selects the read format. Per directory safetensors is read before bin. + A plain load reads canonical component indices, a variant load variant ones. Also rejects a component + numbered shard FILE with NO index of the read format. Positive-evidence: a single-file or complete + component set passes.""" want_variant = variant or None ignore_patterns = _as_pattern_list(ignore_patterns) declared = _diffusers_declared_components(snapshot_dir) @@ -1002,10 +943,9 @@ def _diffusers_component_shards_incomplete( index_fmts.setdefault(dir_rel, set()).add(fmt) shard_rels = _index_shard_rel_paths(entry, dir_rel) if shard_rels is None: - # A malformed / non-string index. Defer only when its FORMAT is read (a representative - # shard of the index's base + format survives the ignore filter); a stale malformed - # index for an IGNORED format (a *.bin component index under ignore=['*.bin']) is not - # read, so it must not force a spurious retry of a complete other-format pipeline. + # A malformed index. Defer only when its FORMAT is read (a representative shard survives + # the ignore filter); a stale malformed index for an IGNORED format is not read and must + # not force a spurious retry of a complete other-format pipeline. probe = _index_shard_probe(name, dir_rel) if probe is None or _filter_paths([probe], None, ignore_patterns): return True @@ -1037,10 +977,10 @@ def requested_named_files_present( allow_patterns: "Optional[object]" = None, ignore_patterns: "Optional[object]" = None, ) -> bool: - """For a request naming EXACT files (every entry glob-free), True only when each named file the - ignore filter keeps is on disk -- ``snapshot_download(local_files_only=True)`` returns the revision - dir even when config-only, so a ``["tokenizer.json"]`` request needs its file present. A request - with ANY glob, or no allow list, is trivially satisfied (it cannot be turned into an exact manifest).""" + """For a request naming EXACT files (every entry glob-free), True only when each named file the ignore + filter keeps is on disk -- ``local_files_only`` returns the revision dir even when config-only, so a + ``["tokenizer.json"]`` request needs its file present. A request with ANY glob, or no allow list, is + trivially satisfied.""" allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) if not allow_patterns or any(_has_glob(p) for p in allow_patterns): @@ -1081,8 +1021,7 @@ def _iter_snapshot_dirs(repo_dir: Path) -> Iterator[Path]: def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: - # Check every snapshot, not just the newest: a requested older revision may be broken while a - # newer one is clean, and a latest-only check would report the repo healthy. + # Check every snapshot, not just the newest: an older revision may be broken while a newer is clean. return any( snapshot_dir_has_broken_symlinks(snapshot) for snapshot in _iter_snapshot_dirs(repo_dir) @@ -1092,11 +1031,10 @@ def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: str) -> list: """Cache dirs safely attributable to this exact repo id. - The Hub case-folds the dir name, so a case-insensitive match is needed, but on a case-sensitive - filesystem ``models--Org--Repo`` and ``models--org--repo`` are distinct repos. Prefer an - exact-case match; otherwise accept a single folded match ONLY when the filesystem is - case-insensitive (the exact-case name resolves to it); on a 2+ way collision attribute to neither, - so a stale partial in one repo cannot make the watchdog kill / purge the other.""" + The Hub case-folds the dir name, so a case-insensitive match is needed, but on a case-sensitive fs + ``models--Org--Repo`` and ``models--org--repo`` are distinct repos. Prefer exact case; else accept a + single folded match ONLY on a case-insensitive fs; on a 2+ way collision attribute to neither, so a + stale partial in one repo cannot make the watchdog kill / purge the other.""" target = repo_cache_dir_name(repo_type, repo_id) folded_target = target.lower() try: @@ -1107,8 +1045,8 @@ def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: st if exact: return exact if len(entries) == 1: - # Attribute a single folded-but-not-exact match only on a case-insensitive filesystem, where - # the exact-case path resolves to the same dir; on a case-sensitive fs it is a DIFFERENT repo. + # A single folded-but-not-exact match is the same dir only on a case-insensitive fs; on a + # case-sensitive fs it is a DIFFERENT repo. try: if (root / target).exists(): return entries @@ -1121,7 +1059,7 @@ def iter_active_repo_cache_dirs( repo_type: Optional[str], repo_id: str, *, cache_dir: "Optional[str | Path]" = None ) -> Iterator[Path]: """Yield the repo's cache dir(s) under *cache_dir* (or the active ``HF_HUB_CACHE``). Case-collision - safe, so the read / watchdog path and the destructive HTTP-prep path share one attribution rule.""" + safe, so the read / watchdog path and the destructive HTTP-prep path share one rule.""" root = hf_cache_root(cache_dir = cache_dir) if root is None: return diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 7787d9e5e..c13d5aadd 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -9,19 +9,14 @@ """Xet-primary HF downloads with an automatic HTTP fallback on a no-progress stall. -Xet (``hf_xet``) is the fast default but can hang with no progress, no exception, and a native thread -that cannot be killed. Keep Xet primary and fall back to plain HTTP only when the parent observes a -stall. ``HF_HUB_DISABLE_XET`` is read at import time, so the fallback runs in a fresh ``spawn`` child -(not a thread) that sets the env before importing ``huggingface_hub``. Cached files short-circuit with -no child; deterministic errors (401/403/404/disk-full) and cancellation propagate without a fallback. - -``hf_hub_download_with_xet_fallback`` does a single file; ``snapshot_download_with_xet_fallback`` does -a whole repo (the entrypoint Unsloth's ``from_pretrained`` uses to warm the cache in a killable child -before the in-process load). Studio cache / secret / process helpers are used best-effort (imported -only if present) or injected, so one body runs both inside Studio and in Unsloth. - -The spawn child sets ``UNSLOTH_ZOO_DISABLE_GPU_INIT=1`` before importing the package, selecting -``unsloth_zoo``'s lightweight import path (no torch / transformers) so each child stays fast. +Xet (``hf_xet``) is fast but can hang with no progress, no exception, and an un-killable native thread. +``HF_HUB_DISABLE_XET`` is read at import time, so the fallback runs in a fresh ``spawn`` child (not a +thread) that sets the env before importing ``huggingface_hub``. Cached files short-circuit with no +child; deterministic errors (401/403/404/disk-full) and cancellation propagate without a fallback. +``snapshot_download_with_xet_fallback`` warms a whole repo in a killable child before Unsloth's +in-process load; ``hf_hub_download_with_xet_fallback`` does a single file. Studio cache / secret / +process helpers are used best-effort (imported only if present) or injected. The child sets +``UNSLOTH_ZOO_DISABLE_GPU_INIT=1`` for unsloth_zoo's lightweight import path (no torch / transformers). """ from __future__ import annotations @@ -66,8 +61,7 @@ logger = logging.getLogger(__name__) -# Public surface (Studio imports from this module, including a `import *` re-export shim), so -# an explicit list keeps the stdlib imports (os, re, signal, errno, ...) out of `import *`. +# Explicit list keeps stdlib imports out of Studio's `import *` re-export shim. __all__ = [ "DownloadStallError", "hf_hub_download_with_xet_fallback", @@ -87,15 +81,14 @@ DEFAULT_GRACE_PERIOD = 10.0 _POLL_INTERVAL = 0.5 -# Serializes the brief parent-env (and __main__.__file__) mutation around a child -# spawn (below) so concurrent downloads cannot observe each other's transport env. +# Serializes the parent-env (and __main__.__file__) mutation around a child spawn so +# concurrent downloads cannot observe each other's transport env. _SPAWN_ENV_LOCK = threading.Lock() -# Sentinel: "__main__.__file__ was not touched for this spawn" (distinct from a -# real saved value of None, which means the attribute was absent). +# Sentinel: "__main__.__file__ untouched for this spawn" (distinct from a saved None). _UNSET = object() -# Hugging Face boolean env convention: 1 / ON / YES / TRUE, case-insensitive. +# HF boolean env convention, case-insensitive. _TRUTHY = {"1", "true", "yes", "on"} @@ -104,8 +97,8 @@ def _is_true(value: Optional[str]) -> bool: def _safe_status(callback: Optional[Callable[[str], None]], message: str) -> None: - """Invoke a status / heartbeat callback without letting it kill the daemon watchdog thread: a - disconnected Studio client can make on_status raise, which would stop stall detection.""" + """Invoke a status callback; swallow its exceptions so a disconnected client cannot kill the + daemon watchdog thread (stopping stall detection).""" if callback is None: return try: @@ -115,8 +108,7 @@ def _safe_status(callback: Optional[Callable[[str], None]], message: str) -> Non class DownloadStallError(RuntimeError): - """Raised when no download progress is observed for too long. Canonical home; Studio re-imports it - so all paths share one type.""" + """Raised when no download progress is observed for too long. Studio re-imports this canonical type.""" def is_hf_xet_available() -> bool: @@ -128,8 +120,8 @@ def is_hf_xet_available() -> bool: def xet_force_disabled() -> bool: - """Whether the user asked to skip Xet up front (force HTTP), via ``UNSLOTH_DISABLE_XET`` / - ``UNSLOTH_STABLE_DOWNLOADS`` or HF's own ``HF_HUB_DISABLE_XET``.""" + """Whether the user asked to force HTTP up front via ``UNSLOTH_DISABLE_XET`` / + ``UNSLOTH_STABLE_DOWNLOADS`` / ``HF_HUB_DISABLE_XET``.""" return ( _is_true(os.environ.get("UNSLOTH_DISABLE_XET")) or _is_true(os.environ.get("UNSLOTH_STABLE_DOWNLOADS")) @@ -147,16 +139,13 @@ def _default_scrub_secrets(text: str, hf_token: Optional[str] = None) -> str: if not text: return text out = text - # HF callers commonly pass token=True ("use the cached token"); only a real - # string token can be substring-redacted (str.replace(True, ...) raises). + # token=True ("use cached token") is common; only a real string token can be substring-redacted. if isinstance(hf_token, str) and hf_token: out = out.replace(hf_token, "***") out = re.sub(r"hf_[A-Za-z0-9]{8,}", "***", out) out = re.sub(r"([Bb]earer\s+)[A-Za-z0-9._\-]+", r"\1***", out) - # HF download errors can embed the presigned S3/CAS blob URL, whose query - # string carries temporary credentials (X-Amz-Signature, sig, token, ...). - # Redact the query of any URL that looks signed so it is not echoed back to - # the parent and logged. Non-signed URLs (e.g. ...?download=true) are kept. + # Redact the query of a presigned S3/CAS blob URL (temporary creds in the query string); keep + # non-signed URLs (e.g. ...?download=true). def _redact_signed_query(match: "re.Match") -> str: base, query = match.group(1), match.group(2) if re.search( @@ -166,11 +155,10 @@ def _redact_signed_query(match: "re.Match") -> str: return f"{base}?***" return match.group(0) - # Match the query up to whitespace OR a structural delimiter (quote, bracket, brace, paren, angle, - # pipe): a signed URL embedded in JSON / a dict repr / other structured text has no surrounding - # whitespace, so a greedy [^\s]* would swallow the trailing "} / ") and replace it with ***, - # corrupting the log line. Real signed-query values percent-encode these chars, so the redaction of - # a genuine presigned URL is unaffected. + # Stop the query at whitespace OR a structural delimiter (quote/bracket/brace/paren/angle/pipe): a + # URL embedded in JSON / a dict repr has no trailing whitespace, so a greedy [^\s]* would swallow the + # closing "} and corrupt the log line. Signed-query values percent-encode these chars, so a genuine + # presigned URL is still fully redacted. out = re.sub( r"(https?://[^\s?]+)\?([^\s\"'()<>{}|[\]]*)", _redact_signed_query, out ) @@ -178,11 +166,10 @@ def _redact_signed_query(match: "re.Match") -> str: def _broken_link_has_active_partner(link: Path, *, active_grace: float) -> bool: - """True if a dangling snapshot symlink should be SPARED because a concurrent sibling download is - still writing the blob it points at. The discriminator is a FRESH ``.incomplete`` partner of the - target blob, NOT the link's own mtime: our own killed child's ``.incomplete`` was static for the - full stall timeout and is purged first (no partner -> link cleared), while a sibling mid-download - still has a growing partner (link spared).""" + """SPARE a dangling snapshot symlink iff a sibling is still writing its target blob. Discriminator + is a FRESH ``.incomplete`` partner of the target, NOT the link mtime: our killed child's partner was + static for the full stall timeout and is purged first (no partner -> link cleared), while a sibling + mid-download still has a growing partner (link spared).""" try: target = Path(os.readlink(link)) if not target.is_absolute(): @@ -196,7 +183,7 @@ def _broken_link_has_active_partner(link: Path, *, active_grace: float) -> bool: def _link_incomplete_partner_name(link: Path) -> Optional[str]: - """The ``.incomplete`` basename for a dangling snapshot symlink's target blob, or None.""" + """The ``.incomplete`` basename for a dangling symlink's target blob, or None.""" try: target = Path(os.readlink(link)) return target.name + INCOMPLETE_SUFFIX @@ -215,12 +202,12 @@ def _default_prepare_for_http( """Make the partial safe for an HTTP resume: delete the repo's active ``*.incomplete`` blobs (an HTTP resume over a sparse Xet / hf_transfer partial silently corrupts the blob) and the broken snapshot symlinks the detector counts as active (else the retry inherits stale state and re-trips). - Studio injects its marker-aware version instead. ``iter_active_repo_cache_dirs`` is case-collision - safe, so this destructive purge only touches an unambiguous repo cache dir. + ``iter_active_repo_cache_dirs`` is case-collision safe, so this destructive purge only touches an + unambiguous repo cache dir. Studio injects its marker-aware version instead. - When *owned_incomplete_blobs* is given (the basenames the stalled child held open, captured before - it was killed), the purge is SCOPED to them, so a concurrent same-repo sibling writing a DIFFERENT - blob is never touched even if its partial aged past *active_grace*. None -> coarser mtime guard only. + *owned_incomplete_blobs* (basenames the stalled child held open, captured before the kill) SCOPES the + purge so a same-repo sibling writing a DIFFERENT blob is spared even if aged past *active_grace*; + None -> coarser mtime guard only. """ try: for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): @@ -232,17 +219,15 @@ def _default_prepare_for_http( if owned_incomplete_blobs is not None and blob.name not in owned_incomplete_blobs: continue try: - # Spare a partial written within active_grace: a slower sibling that just - # has not written recently is not stalled. Our own killed partial has been - # static for the full stall timeout, so it is purged. + # Spare a partial touched within active_grace (a slow sibling, not stalled); + # our killed partial has been static for the full stall timeout so it purges. if time.time() - blob.stat().st_mtime < active_grace: continue blob.unlink() except OSError: - continue # a locked / permission-denied blob must not abort the rest - # A broken snapshot symlink also reads as active incomplete state; clear those too. Sweep - # EVERY snapshot (the detector inspects all), else a dangling link under an older revision - # keeps the repo marked incomplete and re-trips the watchdog. + continue # a locked / denied blob must not abort the rest + # Clear broken snapshot symlinks (also read as active incomplete state). Sweep EVERY snapshot, + # else a dangling link under an older revision keeps the repo incomplete and re-trips. snapshots_dir = entry / "snapshots" try: snapshot_dirs = [s for s in snapshots_dir.iterdir() if s.is_dir()] @@ -257,7 +242,7 @@ def _default_prepare_for_http( _link_incomplete_partner_name(link) not in owned_incomplete_blobs ): continue - # Spare a sibling's active link (target blob still has a fresh .incomplete). + # Spare a sibling's active link (target still has a fresh .incomplete). if _broken_link_has_active_partner(link, active_grace = active_grace): continue try: @@ -273,9 +258,9 @@ def _default_prepare_for_http( def _active_incomplete_blob_sizes( repo_type: Optional[str], repo_id: str, cache_dir: Optional[str] = None ) -> dict[str, int]: - """Map ``{blob_filename: bytes_present}`` (sparse-aware) for the repo's ``*.incomplete`` partials. - The single-file watchdog uses it to follow only its own child's partials, so a concurrent sibling - download of a different file cannot mask this file's stall with its own progress.""" + """Map ``{blob_filename: bytes_present}`` (sparse-aware) for the repo's ``*.incomplete`` partials, so + the single-file watchdog follows only its own child's partials and a sibling download of a different + file cannot mask this file's stall.""" sizes: dict[str, int] = {} try: for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): @@ -294,12 +279,11 @@ def _active_incomplete_blob_sizes( def _child_open_incomplete_blobs(pid: int) -> Optional[set]: - """Basenames of the ``*.incomplete`` blobs the download child *pid* currently has open -- exactly - the partial THIS child is writing (incl. a resumed partial that reuses a prior blob-hash name), - not a sibling's (held by a different pid). ``None`` when undeterminable (no ``psutil`` / ``/proc``, - or the process is gone) -> caller uses a coarser measure; an empty set means the child is not yet - writing a partial (connect / metadata phase).""" - # Cross-platform (Linux / macOS / Windows) when psutil is available. + """Basenames of the ``*.incomplete`` blobs child *pid* has open -- exactly the partial THIS child is + writing (incl. a resumed partial reusing a prior blob-hash name), not a sibling's (different pid). + ``None`` when undeterminable (no ``psutil`` / ``/proc``, or process gone) -> caller uses a coarser + measure; empty set -> child not yet writing (connect / metadata phase).""" + # psutil is cross-platform (Linux / macOS / Windows). try: import psutil # type: ignore except ImportError: @@ -310,12 +294,12 @@ def _child_open_incomplete_blobs(pid: int) -> Optional[set]: except Exception: return None return {os.path.basename(f.path) for f in files if f.path.endswith(INCOMPLETE_SUFFIX)} - # Linux fallback: read the open fds directly from /proc. + # Linux fallback: read open fds from /proc. fd_dir = f"/proc/{pid}/fd" try: entries = os.listdir(fd_dir) except OSError: - return None # no /proc (non-Linux) or the process is already gone + return None # no /proc (non-Linux) or process gone open_blobs: set = set() for fd in entries: try: @@ -333,10 +317,10 @@ def get_hf_download_state( repo_type: Optional[str] = "model", cache_dir: Optional[str] = None, ) -> Optional[tuple[int, bool]]: - """Return ``(total_on_disk_bytes, has_incomplete)`` for the HF cache being written (sparse-aware, - so a partial Xet / ``hf_transfer`` blob is not read as full progress). Scans *cache_dir* or the - active ``HF_HUB_CACHE``. A missing / empty cache reads as ``(0, False)``; ``None`` is returned only - on a probe exception (unmeasurable -> callers skip stall logic this tick).""" + """Return ``(total_on_disk_bytes, has_incomplete)`` for the HF cache being written (sparse-aware, so + a partial Xet / ``hf_transfer`` blob is not read as full progress). Scans *cache_dir* or the active + ``HF_HUB_CACHE``. Missing / empty cache -> ``(0, False)``; ``None`` only on a probe exception + (unmeasurable -> callers skip stall logic this tick).""" try: if hf_cache_root(cache_dir = cache_dir) is None: return (0, False) @@ -384,16 +368,15 @@ def start_watchdog( baseline_incomplete_blobs: Optional[set] = None, child_pid: Optional[int] = None, ) -> threading.Event: - """Start a daemon thread that fires ``on_stall(message)`` exactly once iff a ``*.incomplete`` is - present AND the on-disk size is unchanged for *stall_timeout* seconds. The timer resets while no - ``*.incomplete`` exists, so post-download init is not misread as a stall. Returns a stop event the - caller sets when the download phase ends. - - With *watch_new_partials_only* (single-file), progress is measured only over the child's own - partial, so a concurrent sibling pull of a different file cannot reset the timer and keep a hung - child alive. The child's partial is identified by the blobs *child_pid* has open (precise across a - resumed download), else by the partials not in *baseline_incomplete_blobs* (captured pre-spawn). - Snapshots keep the repo-wide measurement (every blob is part of the one pull).""" + """Start a daemon thread firing ``on_stall(message)`` once iff a ``*.incomplete`` is present AND the + on-disk size is unchanged for *stall_timeout* seconds. The timer resets while no ``*.incomplete`` + exists, so post-download init is not misread as a stall. Returns a stop event the caller sets when + the download phase ends. + + *watch_new_partials_only* (single-file) measures progress only over the child's own partial, so a + sibling pull of a different file cannot keep a hung child alive. That partial is identified by the + blobs *child_pid* has open (precise across a resume), else the partials not in + *baseline_incomplete_blobs* (captured pre-spawn). Snapshots keep the repo-wide measurement.""" stop = threading.Event() transport = "https" if xet_disabled else "xet" fired = False @@ -405,18 +388,17 @@ def _measure() -> Optional[tuple[int, bool]]: sizes = _active_incomplete_blob_sizes(repo_type, single_repo_id, cache_dir) open_names = _child_open_incomplete_blobs(child_pid) if child_pid else None if open_names is not None: - # Only the partials this child holds open (handles a resumed partial reusing a baseline - # name, excludes siblings). hf_xet holds the .incomplete fd continuously, so an EMPTY - # set means the child owns no partial YET (connect / metadata phase), not a sibling's. + # Only partials this child holds open (handles a resume reusing a baseline name, excludes + # siblings). hf_xet holds the .incomplete fd continuously, so an EMPTY set means the child + # owns no partial YET (connect / metadata phase), not a sibling's. owned = {name: n for name, n in sizes.items() if name in open_names} return (sum(owned.values()), len(owned) > 0) if child_pid: - # A pid was given but its open files cannot be inspected (no psutil AND no /proc: native - # Windows / macOS without psutil). Post-baseline name filtering would EXCLUDE a resumed - # partial that reuses a baseline blob name forever, so a frozen Xet resume never trips the - # watchdog and the hang persists -- defeating the fallback. Fall back to the repo-wide - # measure (as the snapshot path uses): a resumed partial is then watched; a concurrent - # same-repo sibling's progress may mask this child's stall, the accepted snapshot tradeoff. + # pid given but open files uninspectable (no psutil AND no /proc: native Windows / macOS + # without psutil). Post-baseline name filtering would forever EXCLUDE a resumed partial + # reusing a baseline name, so a frozen Xet resume never trips -- defeating the fallback. + # Fall back to the repo-wide measure (as snapshots do): the resume is watched, at the cost + # that a same-repo sibling's progress may mask this child's stall (accepted tradeoff). return get_hf_download_state( [single_repo_id], repo_type = repo_type, cache_dir = cache_dir ) @@ -437,7 +419,7 @@ def _beat() -> None: if state is None: # Unmeasurable this tick (transient FS error): treat as progress so the gap cannot - # trip a false stall once the state becomes readable again. + # trip a false stall once readable again. last_change = now _safe_status(on_heartbeat, f"Downloading ({transport} transport)...") continue @@ -447,8 +429,8 @@ def _beat() -> None: last_size = current_size last_change = now - # Reset unless .incomplete confirms an active download, so model init - # and lock waits are not counted as a stall. + # Reset unless .incomplete confirms an active download, so model init and lock waits + # are not counted as a stall. if not has_incomplete: last_change = now elif now - last_change >= stall_timeout: @@ -467,8 +449,7 @@ def _beat() -> None: def _scrub_in_child(text: str, token: Optional[str]) -> str: - """Redact secrets from a child error string, preferring Studio's richer - patterns when running inside Studio, else the generic redaction.""" + """Redact secrets from a child error string, preferring Studio's patterns if present.""" try: from hub.utils.download_registry import scrub_secrets # type: ignore @@ -477,9 +458,8 @@ def _scrub_in_child(text: str, token: Optional[str]) -> str: return _default_scrub_secrets(text, hf_token = token) -# Deterministic Hub failures that recur identically over either transport, so switching from -# Xet to HTTP is pointless: surface them. Matched by exception class name so the parent need -# not import huggingface_hub's error classes. +# Deterministic Hub failures that recur identically over either transport, so retrying HTTP is +# pointless: surface them. Matched by class NAME so the parent need not import HF's error classes. _DETERMINISTIC_ERROR_NAMES = frozenset({ "RepositoryNotFoundError", "RevisionNotFoundError", @@ -487,20 +467,18 @@ def _scrub_in_child(text: str, token: Optional[str]) -> str: "GatedRepoError", "DisabledRepoError", "LocalEntryNotFoundError", - "LocalTokenNotFoundError", # a missing required token fails identically over either transport + "LocalTokenNotFoundError", # a missing required token fails identically either way "BadRequestError", "HFValidationError", # a malformed repo id / revision never reaches the network }) -# Names whose TYPE is reconstructed across the spawn boundary but which must NOT join the -# retry-deterministic set above: ``HfHubHTTPError`` is the base of both deterministic 4xx and transient -# 5xx / 429, so its retry decision stays status-code driven. Once classified deterministic and surfaced -# as ``"HfHubHTTPError: "``, the parent still re-raises the real type so ``except HfHubHTTPError`` -# keeps working instead of seeing ``RuntimeError``. +# TYPE reconstructed across the spawn but NOT retry-deterministic: ``HfHubHTTPError`` bases both +# deterministic 4xx and transient 5xx / 429, so its retry stays status-code driven while the parent +# still re-raises the real type (not RuntimeError) so ``except HfHubHTTPError`` keeps working. _TYPE_PRESERVE_ONLY_NAMES = frozenset({ "HfHubHTTPError", }) -# Substrings that mark a transient transport failure (hf_xet / CAS error, timeout, reset, -# HTTP 5xx / 429) that disabling Xet and retrying over HTTP may recover. +# Substrings marking a transient transport failure (hf_xet / CAS error, timeout, reset, 5xx / 429) +# that an HTTP retry may recover. _TRANSIENT_ERROR_HINTS = ( "xet", "casclient", "cas_", "timeout", "timed out", "connection", "reset by peer", "temporarily", "try again", "incompleteread", "protocolerror", "remotedisconnected", @@ -511,16 +489,14 @@ def _scrub_in_child(text: str, token: Optional[str]) -> str: def _resolve_exception_class(type_name: str) -> "Optional[type]": - """Map a deterministic Hub / OS error class NAME (as captured in the child) back to its class, - so the parent can re-raise the original type rather than a generic RuntimeError. Best-effort: an - unknown name returns None. Imports are local so the helper stays import-light when no error - occurs and never hard-depends on a specific huggingface_hub layout.""" + """Map a deterministic Hub / OS error class NAME back to its class so the parent re-raises the + original type, not RuntimeError. Best-effort (unknown -> None); local imports keep it import-light + and independent of the huggingface_hub layout.""" if type_name == "OSError": return OSError - # Preserve builtin OSError subclasses (PermissionError, FileNotFoundError, ...): these are - # deterministic filesystem failures (e.g. an unwritable custom cache) the child cannot retry away, - # so a caller's `except OSError` / `except PermissionError` must still fire rather than see the - # generic RuntimeError the resolver would otherwise fall through to. + # Preserve builtin OSError subclasses (PermissionError, FileNotFoundError, ...): deterministic FS + # failures the child cannot retry away, so a caller's `except OSError` / `except PermissionError` + # must still fire rather than see RuntimeError. builtin_cls = getattr(builtins, type_name, None) if isinstance(builtin_cls, type) and issubclass(builtin_cls, OSError): return builtin_cls @@ -538,10 +514,10 @@ def _resolve_exception_class(type_name: str) -> "Optional[type]": def _instantiate_preserving_type(exc_cls: type, message: str) -> "Optional[BaseException]": - """Build an *exc_cls* instance carrying *message*, robust to a finicky constructor: Hub error - classes subclass ``HfHubHTTPError`` whose ``response`` arg is keyword-only (required on some - versions), so ``exc_cls(message)`` can raise ``TypeError``. Try the normal constructors first, then - BYPASS ``__init__`` via ``__new__`` so the TYPE and message survive. None only if ``__new__`` fails.""" + """Build an *exc_cls* instance carrying *message*, robust to a finicky constructor: Hub errors + subclass ``HfHubHTTPError`` whose ``response`` arg can be required, so ``exc_cls(message)`` may raise + ``TypeError``. Try normal constructors, then BYPASS ``__init__`` via ``__new__`` so type + message + survive. None only if ``__new__`` fails.""" for build in ( lambda: exc_cls(message), lambda: exc_cls(message, response = None), @@ -559,9 +535,8 @@ def _instantiate_preserving_type(exc_cls: type, message: str) -> "Optional[BaseE def _parse_errno(message: str) -> "Optional[int]": - """Pull the errno out of a stringified OSError. CPython formats it as ``[Errno 28] ...``, so a - disk-full (ENOSPC) / quota (EDQUOT) error keeps its code across the spawn boundary when the - parent reconstructs the OSError, letting callers branch on ``exc.errno``.""" + """Pull the errno out of a stringified OSError (CPython formats it ``[Errno 28] ...``), so a + disk-full / quota error keeps its code across the spawn boundary for ``exc.errno`` branching.""" match = re.search(r"\[Errno (\d+)\]", message) if match is None: return None @@ -572,10 +547,9 @@ def _parse_errno(message: str) -> "Optional[int]": def _is_builtin_oserror(exc: BaseException) -> bool: - """True iff *exc*'s type is a BUILTIN ``OSError`` (or subclass): a genuine OS-level error whose - ``[Errno N]`` is a real errno. Excludes HF/requests HTTP errors, which subclass ``OSError`` via - ``requests -> IOError`` yet carry no OS errno, so a bracketed ``[Errno N]`` in their message is not - mistaken for one.""" + """True iff *exc* is a BUILTIN ``OSError`` (or subclass) with a real errno. Excludes HF/requests HTTP + errors (``OSError`` via ``requests -> IOError`` but no OS errno), so a bracketed ``[Errno N]`` in + their message is not mistaken for one.""" if not isinstance(exc, OSError): return False builtin = getattr(builtins, type(exc).__name__, None) @@ -583,10 +557,10 @@ def _is_builtin_oserror(exc: BaseException) -> bool: def _raise_child_error(message: str) -> None: - """Re-raise a deterministic child error preserving its original TYPE when it is a known Hub / OS - error, so callers catching ``RepositoryNotFoundError`` / ``GatedRepoError`` / ``OSError`` still - match across the spawn boundary. The child reports ``": "``; an unrecognized or - uninstantiable class falls back to ``RuntimeError``.""" + """Re-raise a deterministic child error preserving its original TYPE for a known Hub / OS error, so + callers' ``except RepositoryNotFoundError`` / ``GatedRepoError`` / ``OSError`` still match across the + spawn. Child reports ``": "``; an unrecognized / uninstantiable class -> + ``RuntimeError``.""" type_name = message.split(":", 1)[0].strip() if ":" in message else "" exc_cls = _resolve_exception_class(type_name) if exc_cls is None: @@ -595,14 +569,11 @@ def _raise_child_error(message: str) -> None: if exc is None: raise RuntimeError(message) if _is_builtin_oserror(exc) and getattr(exc, "errno", None) is None: - # Preserve errno (ENOSPC / EDQUOT ...) across the spawn boundary so a caller's `except OSError` - # cleanup can still branch on exc.errno -- for EVERY builtin OSError subclass (PermissionError, - # FileNotFoundError, ...), not just exact OSError. Restricted to BUILTIN OSError types: an HF - # HTTP error (HfHubHTTPError / RepositoryNotFoundError ...) is ALSO an OSError subclass (via - # requests -> IOError), and a bracketed "[Errno N]" in its message must not be mistaken for a - # real OS errno. Set it as an attribute rather than via the (errno, strerror) constructor: a - # subclass with a single-arg __init__ (hf_hub's LocalEntryNotFoundError) rejects the two-arg - # form, and this keeps the message clean (no doubled "[Errno N]" prefix). + # Preserve errno (ENOSPC / EDQUOT ...) across the spawn for `except OSError` cleanup, for EVERY + # builtin OSError subclass. Restricted to BUILTIN types (an HF HTTP error is also an OSError via + # requests -> IOError, but its "[Errno N]" is not a real errno). Set as an attribute, not via the + # two-arg constructor: a single-arg __init__ subclass (LocalEntryNotFoundError) rejects it, and + # this avoids a doubled "[Errno N]" prefix. errno_val = _parse_errno(message) if errno_val is not None: try: @@ -613,34 +584,29 @@ def _raise_child_error(message: str) -> None: def _is_retryable_download_error(exc: BaseException) -> bool: - """True when a captured download exception looks like a transient transport failure (an - ``hf_xet`` / CAS error, connection reset, timeout, HTTP 5xx / 429) that the OTHER transport - may recover, vs a deterministic Hub error (auth, not-found, gated, disk-full) that would - fail identically. Unknown errors are treated as deterministic, so a real repeatable failure - is surfaced rather than looped between transports.""" + """True when a captured exception looks like a transient transport failure (``hf_xet`` / CAS error, + reset, timeout, 5xx / 429) the OTHER transport may recover, vs a deterministic Hub error (auth, + not-found, gated, disk-full). Unknown errors count as deterministic, so a real repeatable failure is + surfaced rather than looped between transports.""" name = type(exc).__name__ - # huggingface_hub raises LocalEntryNotFoundError BOTH for a genuine offline / uncached miss - # (deterministic) AND as its wrapper around a TRANSIENT HEAD connection error / timeout for an - # uncached file ("... Please check your connection and try again"). Retry the transient sub-case - # over the other transport; a true offline miss (no transient hint) falls through to the - # deterministic set below and keeps its reconstructed type. + # LocalEntryNotFoundError wraps BOTH a genuine offline / uncached miss (deterministic) AND a + # TRANSIENT HEAD connection error / timeout for an uncached file. Retry the transient sub-case; a + # true offline miss (no transient hint) falls through to the deterministic set below. if name == "LocalEntryNotFoundError" and any( hint in f"{name}: {exc}".lower() for hint in _TRANSIENT_ERROR_HINTS ): return True if name in _DETERMINISTIC_ERROR_NAMES: return False - # Disk full / quota: a different transport cannot help. + # Disk full / quota: another transport cannot help. if isinstance(exc, OSError) and getattr(exc, "errno", None) in (errno.ENOSPC, errno.EDQUOT): return False - # An HTTP status (HfHubHTTPError carries a requests / httpx response): 5xx and 429 are - # transient; other 4xx (401 / 403 / 404 / 416) are deterministic. + # HTTP status (HfHubHTTPError carries a requests / httpx response): 5xx / 429 / 408 transient, + # other 4xx (401 / 403 / 404 / 416) deterministic. status = getattr(getattr(exc, "response", None), "status_code", None) if not isinstance(status, int): status = getattr(exc, "status_code", None) if isinstance(status, int): - # 5xx server errors, 429 rate-limit, 408 request-timeout are transient; other 4xx - # (401 / 403 / 404 / 416) are deterministic and would fail identically over HTTP. return status >= 500 or status in (408, 429) text = f"{name}: {exc}".lower() return any(hint in text for hint in _TRANSIENT_ERROR_HINTS) @@ -685,10 +651,10 @@ def _download_child_entry( disable_xet: bool, result_queue: Any, ) -> None: - """Spawn-child entrypoint (top-level + picklable): set the Xet env BEFORE importing - huggingface_hub, form its own process group so the parent can kill the whole transfer, and never - log the token or signed URLs.""" - # Die with the parent on Linux under Studio (best-effort; the module is absent standalone). + """Spawn-child entrypoint (top-level + picklable): set the Xet env BEFORE importing huggingface_hub, + form its own process group so the parent can kill the whole transfer, never log token / signed + URLs.""" + # Die with the parent on Linux under Studio (best-effort; module absent standalone). try: from utils.process_lifetime import bind_current_process_to_parent_lifetime # type: ignore @@ -704,29 +670,26 @@ def _download_child_entry( if disable_xet: os.environ["HF_HUB_DISABLE_XET"] = "1" - # Keep the HTTP writer sequential and resumable (hf_transfer leaves sparse - # partials a sequential resume cannot safely continue). + # Keep the HTTP writer sequential and resumable (hf_transfer's sparse partials cannot). os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") repo_id = params["repo_id"] - # Test-only fault injection (never set in production): stall the Xet attempt so the watchdog + - # HTTP fallback can be exercised against a real repo. + # Test-only fault injection (never set in production): stall the Xet attempt to exercise the + # watchdog + HTTP fallback against a real repo. if not disable_xet and os.environ.get("UNSLOTH_HF_XET_FORCE_STALL") == "1": _stall_fh = None try: from huggingface_hub.constants import HF_HUB_CACHE - # Write the fake partial under the cache the watchdog scans, under the repo_type-correct - # dir, so the stall / HTTP fallback fires in tests. + # Write the fake partial under the repo_type-correct dir the watchdog scans. cache_root = params.get("cache_dir") or HF_HUB_CACHE repo_dir_name = f"{repo_type or 'model'}s--" + repo_id.replace("/", "--") blobs = os.path.join(cache_root, repo_dir_name, "blobs") os.makedirs(blobs, exist_ok = True) - # Hold the partial OPEN for the whole stall: the snapshot watchdog finds it by filename, but - # the single-file watchdog counts only partials this PID holds open (a closed file is - # ignored). The handle is bound to a local so it stays open across the sleep. + # Hold the partial OPEN for the whole stall (the single-file watchdog counts only partials + # this PID holds open); bound to a local so it stays open across the sleep. _stall_fh = open(os.path.join(blobs, "xet-force-stall.incomplete"), "wb") _stall_fh.write(b"\0" * 4096) _stall_fh.flush() @@ -739,8 +702,8 @@ def _download_child_entry( path = _child_download(kind = kind, params = params, token = token, repo_type = repo_type) result_queue.put({"ok": True, "path": path}) except BaseException as e: # noqa: BLE001 - report every failure to the parent - # Classify here, where the exception object (status, errno, type) is intact, so the parent can - # retry a transient failure over HTTP yet surface a deterministic error without a second attempt. + # Classify here where the exception (status, errno, type) is intact, so the parent retries a + # transient failure over HTTP yet surfaces a deterministic one without a second attempt. result_queue.put({ "ok": False, "error": _scrub_in_child(f"{type(e).__name__}: {e}", token), @@ -749,17 +712,16 @@ def _download_child_entry( def _terminate_process_group(proc: "mp.process.BaseProcess", grace_period: float) -> None: - """Kill *proc* and its whole process group (Xet may spawn helpers). The child ``os.setsid()``s so - its pgid equals its pid; the group is signalled via ``os.killpg(pid, ...)`` only once the child is - confirmed its own leader (``os.getpgid(pid) == pid``). SIGTERM, then SIGKILL after *grace_period*.""" + """Kill *proc* and its whole process group (Xet may spawn helpers). The child ``os.setsid()``s so its + pgid == pid; the group is signalled via ``os.killpg`` only once ``os.getpgid(pid) == pid`` confirms + it. SIGTERM, then SIGKILL after *grace_period*.""" pid = proc.pid def _signal_group(sig: int) -> None: - # Signal the whole GROUP only once the child is confirmed its own leader (setsid done): its pgid - # then equals its pid. BEFORE setsid the child is still in OUR group, and its freshly-allocated - # pid could collide with an unrelated recycled process group -- so ``getpgid(pid) != pid`` guards - # against ``killpg(pid)`` targeting the WRONG group; a reaped child raises here. Fall through to a - # single-process signal in all those cases (also Windows, which has no killpg / getpgid). + # Signal the GROUP only once the child is its own leader (pgid == pid after setsid). Before setsid + # it is still in OUR group and its pid could collide with a recycled group, so ``getpgid != pid`` + # guards ``killpg`` from the WRONG group (a reaped child raises here). Otherwise (also Windows: no + # killpg / getpgid) signal the single process. if pid is not None and hasattr(os, "killpg") and hasattr(os, "getpgid"): try: if os.getpgid(pid) == pid: @@ -767,7 +729,6 @@ def _signal_group(sig: int) -> None: return except (ProcessLookupError, PermissionError, OSError): pass - # Windows, pre-setsid, or the child is not (yet) its own group leader: signal the single process. try: proc.terminate() if sig != getattr(signal, "SIGKILL", -9) else proc.kill() except Exception: @@ -775,15 +736,12 @@ def _signal_group(sig: int) -> None: _signal_group(getattr(signal, "SIGTERM", signal.SIGINT)) proc.join(timeout = grace_period) - # SIGKILL only while the leader is alive, so its pid (== pgid after setsid) is a live target. Once - # join() reaps a leader that exited on SIGTERM, that pid is free and a busy host could recycle it - # into an unrelated group -- killpg(pid) would then signal the WRONG group. hf_xet 1.5.x spawns no - # helpers, so a reaped leader leaves nothing to clean up. + # SIGKILL only while alive, so the pid (== pgid) is a live target: once join() reaps a leader, a busy + # host could recycle its pid into an unrelated group and killpg would hit the WRONG one. hf_xet 1.5.x + # spawns no helpers, so a reaped leader leaves nothing to clean up. if proc.is_alive(): - # Match _signal_group's own SIGKILL sentinel (-9) so the force-kill branch (proc.kill()) is - # taken on Windows, where signal.SIGKILL is undefined. Functionally moot there (multiprocessing - # maps proc.kill() == proc.terminate() == TerminateProcess, a hard kill either way), but keeps - # the call site and the check consistent. + # -9 sentinel takes the force-kill branch on Windows (signal.SIGKILL undefined; moot there since + # proc.kill() == proc.terminate()), keeping the call site consistent. _signal_group(getattr(signal, "SIGKILL", -9)) proc.join(timeout = 5.0) @@ -802,12 +760,12 @@ def _run_download_attempt( grace_period: float, on_status: Optional[Callable[[str], None]], ) -> tuple[str, Optional[str]]: - """Run one download in a spawn child supervised by the no-progress watchdog. Returns ``("ok", - path)``, ``("stall", None)``, ``("cancelled", None)``, ``("crashed", message)`` (process crash, no - captured exception), ``("retryable_error", message)`` (transient, worth an HTTP retry), or - ``("error", message)`` (deterministic Hub error). The seam tests monkeypatch to avoid spawning.""" - # Single-file: capture the partials on disk BEFORE spawning so the watchdog ignores a sibling's - # in-flight partial and follows only the blob(s) this child writes. Snapshots stay repo-wide. + """Run one download in a spawn child under the no-progress watchdog. Returns ``("ok", path)``, + ``("stall", None)``, ``("cancelled", None)``, ``("crashed", message)`` (crash, no captured + exception), ``("retryable_error", message)`` (transient, worth an HTTP retry), or ``("error", + message)`` (deterministic Hub error). Tests monkeypatch this seam to avoid spawning.""" + # Single-file: snapshot the on-disk partials BEFORE spawning so the watchdog follows only the blob(s) + # this child writes, not a sibling's. Snapshots stay repo-wide. baseline_partials: Optional[set] = None if kind == "file": baseline_partials = set( @@ -826,13 +784,12 @@ def _run_download_attempt( ), daemon = True, ) - # Set the transport env in THIS process around the spawn so the child inherits it from creation: - # HF reads HF_HUB_DISABLE_XET into a constant at import time, and the child re-imports - # huggingface_hub before its body runs, so a child-side assignment would land too late. The child - # still sets it defensively. + # Set the transport env in THIS process around the spawn so the child inherits it from creation: HF + # caches HF_HUB_DISABLE_XET at import time and the child re-imports huggingface_hub before its body, + # so a child-side assignment lands too late. The child still sets it defensively. child_env = { "HF_HUB_DISABLE_PROGRESS_BARS": "1", - # Tell unsloth_zoo's __init__ to skip its heavy torch / transformers / device init in the child. + # Skip unsloth_zoo's heavy torch / transformers / device init in the child. "UNSLOTH_ZOO_DISABLE_GPU_INIT": "1", } if disable_xet: @@ -840,36 +797,35 @@ def _run_download_attempt( child_env["HF_HUB_ENABLE_HF_TRANSFER"] = "0" with _SPAWN_ENV_LOCK: # Cache Hub's transport constants in the PARENT from the REAL env NOW, before the child-only - # HF_HUB_DISABLE_XET=1 is briefly set below: a concurrent thread's FIRST `import huggingface_hub` - # in the spawn window would otherwise cache the disabled-Xet value and route later in-process - # downloads over HTTP. Once imported this is a no-op. + # HF_HUB_DISABLE_XET=1 is briefly set below: else a concurrent thread's FIRST `import + # huggingface_hub` in the spawn window caches the disabled value and routes later in-process + # downloads over HTTP. No-op once imported. try: import huggingface_hub.constants # noqa: F401 except Exception: pass saved_env = {k: os.environ.get(k) for k in child_env} - # 'spawn' reconstructs __main__ from __main__.__file__. A pseudo-path ('', a notebook) - # fails to start; a real but UNGUARDED caller script gets re-imported as __mp_main__, re-running - # the top-level from_pretrained and hitting the "start a process before bootstrapping" error -> - # the parent sees the child exit without a result. We only need the child to run - # _download_child_entry, so point __main__ at THIS side-effect-free module for the spawn. + # 'spawn' reconstructs __main__ from __main__.__file__: a pseudo-path ('', a notebook) + # fails to start, and a real UNGUARDED caller script re-imports as __mp_main__, re-running + # top-level from_pretrained -> "start a process before bootstrapping" -> the child exits without + # a result. Point __main__ at THIS side-effect-free module for the spawn. main_module = sys.modules.get("__main__") saved_main_file = _UNSET saved_main_spec = _UNSET if main_module is not None: saved_main_file = getattr(main_module, "__file__", _UNSET) main_module.__file__ = __file__ - # Launched as `python -m pkg`: spawn prefers __spec__.name and re-imports the module BY - # NAME (re-running its top-level code). Clearing __spec__ forces the __file__ path branch. + # `python -m pkg`: spawn prefers __spec__.name and re-imports BY NAME. Clearing __spec__ + # forces the __file__ path branch. saved_main_spec = getattr(main_module, "__spec__", _UNSET) main_module.__spec__ = None try: os.environ.update(child_env) proc.start() except BaseException: - # proc.start() can raise (OSError "can't start new process" under fd / thread exhaustion). - # The result_queue's pipe fds were allocated above but the lifecycle try/finally that - # closes them runs only after a successful start, so close the queue here to avoid an fd leak. + # proc.start() can raise (OSError under fd / thread exhaustion). The lifecycle try/finally + # that closes the queue's pipe fds runs only after a successful start, so close it here to + # avoid an fd leak. try: result_queue.cancel_join_thread() result_queue.close() @@ -931,17 +887,17 @@ def _run_download_attempt( _terminate_process_group(proc, grace_period) return ("cancelled", None) if stalled.is_set(): - # Prefer a result the child enqueued in the same window the watchdog fired in, so a - # download that just succeeded is not killed. The Queue's feeder thread may not have - # flushed a microseconds-earlier put, so use a short timeout, not get_nowait(). + # Prefer a result the child enqueued in the watchdog's fire window, so a just-succeeded + # download is not killed. Its feeder may not have flushed a microseconds-earlier put, so + # use a short timeout, not get_nowait(). try: result = result_queue.get(timeout = 1.0) break except queue.Empty: pass - # Capture the partials THIS child owns BEFORE killing it, so HTTP prep can scope its - # purge to them. Prefer the per-pid open-fd set; fall back to post-baseline partials - # when the child can't be inspected. None -> prep keeps its coarser mtime guard. + # Capture the partials THIS child owns BEFORE killing it, so HTTP prep scopes its purge to + # them. Prefer the per-pid open-fd set; else post-baseline partials; None -> coarser mtime + # guard. owned = _child_open_incomplete_blobs(proc.pid) if proc.pid else None if owned is None and baseline_partials is not None: current = set( @@ -957,8 +913,8 @@ def _run_download_attempt( except queue.Empty: continue else: - # Process exited; drain any result it enqueued. Short timeout, not get_nowait(): the child - # can exit just before its feeder flushes the pipe, which would spuriously look resultless. + # Process exited; drain any enqueued result. Short timeout, not get_nowait(): the child can + # exit just before its feeder flushes the pipe, which would spuriously look resultless. try: result = result_queue.get(timeout = 1.0) except queue.Empty: @@ -967,12 +923,12 @@ def _run_download_attempt( if stop_watchdog is not None: stop_watchdog.set() proc.join(timeout = grace_period) - # Any loop exit (completion, cancel/stall, KeyboardInterrupt) must not leak the child. - # _terminate_process_group is idempotent, so a redundant call here is a harmless no-op. + # No loop exit may leak the child; _terminate_process_group is idempotent so a redundant call + # is a no-op. if proc.is_alive(): _terminate_process_group(proc, grace_period) - # Release the queue's pipe fds deterministically rather than waiting for GC. The result is - # already extracted and a killed child has nothing to flush, so cancel the feeder before close. + # Release the queue's pipe fds now rather than at GC. The result is extracted and a killed child + # has nothing to flush, so cancel the feeder before close. try: result_queue.cancel_join_thread() result_queue.close() @@ -980,8 +936,8 @@ def _run_download_attempt( pass if result is None: - # The child exited without a result: a process-level crash (a native hf_xet abort / segfault), - # not a captured exception, so the other transport may still succeed -- report "crashed". + # Child exited resultless: a process-level crash (native hf_xet abort / segfault), not a captured + # exception, so the other transport may still succeed -- report "crashed". return ( "crashed", f"download process for '{repo_id}' exited " @@ -991,7 +947,6 @@ def _run_download_attempt( return ("ok", result["path"]) message = result.get("error") or "unknown download error" if result.get("retryable"): - # A transient transport failure the child flagged as worth another transport. return ("retryable_error", message) return ("error", message) @@ -999,9 +954,8 @@ def _run_download_attempt( def _intact_subset( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, ) -> bool: - """No interrupted-download evidence for the SELECTED files: no dangling requested symlink, and - every EXACT-named requested file present. A dangling EXCLUDED weight does not reject a complete - subset.""" + """No interrupted-download evidence for the SELECTED files: no dangling requested symlink and every + EXACT-named requested file present. A dangling EXCLUDED weight does not reject the subset.""" return ( not snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, @@ -1014,39 +968,33 @@ def _intact_subset( def _is_default_load_weight_file(name: str) -> bool: - """A weight in a format a DEFAULT ``from_pretrained`` reads: safetensors or bin only. Excludes gguf / - pt / pth / onnx / msgpack / ... -- a default (non-format-specific) transformers / diffusers load does - not read those, so a stale cache holding only e.g. ``model.Q4_K_M.gguf`` does not satisfy the load, - which would then fetch the missing ``model.safetensors`` / ``pytorch_model.bin`` over un-killable Xet. - Trainer / optimizer state (``optimizer.bin``, ...) is excluded by ``_is_loadable_weight_file``.""" + """A weight a DEFAULT ``from_pretrained`` reads: safetensors or bin only. Excludes gguf / pt / onnx / + msgpack / ... -- a stale cache holding only e.g. ``model.Q4_K_M.gguf`` must not satisfy the load, else + it fetches the missing ``model.safetensors`` over un-killable Xet. Optimizer state is already excluded + by ``_is_loadable_weight_file``.""" return _is_loadable_weight_file(name) and name.endswith((".safetensors", ".bin")) -# The CANONICAL root model weight a DEFAULT (no-variant) load reads: model.safetensors / -# pytorch_model.bin as a single file, or a numbered shard (model-00001-of-00002.safetensors -- a dash, -# not a dotted variant token). A PEFT adapter (adapter_model.*), a variant (model.fp16.safetensors), a -# gguf, and a non-canonical root weight (consolidated.safetensors, tf_model.h5) are NOT matched: a -# default from_pretrained probes only these canonical names, so a cache holding only something else does -# not satisfy the load, which would then fetch the missing canonical weight over un-killable Xet. +# CANONICAL root model weight a DEFAULT (no-variant) load reads: model.safetensors / pytorch_model.bin, +# single or numbered shard (dash infix, not a dotted variant token). A PEFT adapter, a variant +# (model.fp16.safetensors), a gguf, and a non-canonical root (consolidated.safetensors, tf_model.h5) are +# NOT matched -- a default from_pretrained probes only these names, so a cache holding only something +# else would fetch the missing canonical weight over un-killable Xet. _CANONICAL_ROOT_MODEL_WEIGHT_RE = re.compile( r"^(?:model|pytorch_model)(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" ) -# A CANONICAL (non-variant) diffusers component weight name -- what a PLAIN pipeline load reads inside a -# component subfolder: a base with no intermediate dotted token (diffusion_pytorch_model / model, single -# or numbered shard), safetensors or bin. A VARIANT weight (diffusion_pytorch_model.fp16.safetensors) -# carries an extra dotted token before the extension and is EXCLUDED here, so a stale cache left by a -# prior variant='fp16' download does not read as a warm PLAIN pipeline -- the in-process -# DiffusionPipeline load (reading the non-variant name) would otherwise fetch it over un-killable Xet. -# This mirrors the root check (_CANONICAL_ROOT_MODEL_WEIGHT_RE) and the plain component shard regex. +# CANONICAL (non-variant) diffusers component weight a PLAIN pipeline load reads in a component subfolder: +# a base with no intermediate dotted token, single or numbered shard, safetensors or bin. A VARIANT +# weight (diffusion_pytorch_model.fp16.safetensors) is EXCLUDED, so a variant='fp16' stale cache does not +# read as a warm PLAIN pipeline (which would fetch the non-variant name over un-killable Xet). _CANONICAL_COMPONENT_WEIGHT_RE = re.compile( r"^[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" ) -# A SINGLE-FILE canonical root TF / Flax weight (transformers TF2_WEIGHTS_NAME / FLAX_WEIGHTS_NAME): -# what a from_tf / from_flax load reads instead of a PyTorch format. A SHARDED TF/Flax weight is judged -# through its index (tf_model.h5.index.json / flax_model.msgpack.index.json) instead -- a lone shard -# here must NOT read as a present weight, else an incomplete sharded set is loaded over Xet. +# SINGLE-FILE canonical root TF / Flax weight a from_tf / from_flax load reads instead of a PyTorch +# format. A SHARDED TF/Flax weight is judged through its index instead -- a lone shard here must NOT read +# as present, else an incomplete sharded set is loaded over Xet. _CANONICAL_ROOT_TF_FLAX_WEIGHT_RE = re.compile(r"^(?:tf_model\.h5|flax_model\.msgpack)$") # The shard-index sidecars a sharded TF / Flax weight is enumerated through. @@ -1054,36 +1002,29 @@ def _is_default_load_weight_file(name: str) -> bool: def _pytorch_root_weight_formats_ignored(ignore_patterns: Any) -> bool: - """True when the request's ignore filter drops BOTH canonical PyTorch root weights - (``model.safetensors`` AND ``pytorch_model.bin``) -- the signature of a ``from_tf`` / ``from_flax`` - load, whose prefetch ignores ``*.safetensors`` + ``*.bin`` and keeps ``*.h5`` / ``*.msgpack``. Lets - the readable-weight check count the TF/Flax weight the load actually reads (``tf_model.h5`` / - ``flax_model.msgpack``) rather than false-reject a complete h5/msgpack download into a - ``DownloadStallError``. Never true for a normal load (which keeps at least one PyTorch format).""" + """True when the ignore filter drops BOTH canonical PyTorch root weights (``model.safetensors`` AND + ``pytorch_model.bin``) -- the ``from_tf`` / ``from_flax`` signature. Lets the readable-weight check + count the TF/Flax weight the load actually reads rather than false-reject a complete h5/msgpack + download. Never true for a normal load (which keeps a PyTorch format).""" return not _filter_paths( ["model.safetensors", "pytorch_model.bin"], None, ignore_patterns ) -# A training-checkpoint subdir (checkpoint-500/, checkpoint_7/): its weights are never read as diffusers -# pipeline COMPONENTS, so they must not mask missing unet/vae/text-encoder weights. +# A training-checkpoint subdir (checkpoint-500/): its weights are never read as diffusers pipeline +# COMPONENTS, so they must not mask missing unet/vae/text-encoder weights. _CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: - """True if a DECLARED diffusers pipeline COMPONENT weight (a loadable weight in a component SUBFOLDER - the ``model_index.json`` declares: unet/, vae/, text_encoder/, ...) that the ignore filter keeps is - present. Scoped to declared components, so a stale partial holding only an UNDECLARED leftover subtree - (a controlnet/ dir not in ``model_index.json``) does not read as proof the pipeline is warm while the - declared unet / vae weights are still missing -- which the in-process load would then fetch over - un-killable Xet. Also excludes ROOT-level weights (an adapter / merged file a ``DiffusionPipeline`` - does not read as a component) and training-checkpoint subtrees (checkpoint-N/). A malformed / empty - ``model_index.json`` fails OPEN (any component subfolder counts). Stays lenient on WHICH declared - components are required (a pipeline's components can be optional): it only tells a real component warm - from an undeclared-leftover / checkpoint-only / config-only stale snapshot. Counts only CANONICAL - (non-variant) component weights (``_CANONICAL_COMPONENT_WEIGHT_RE``): a variant weight - (``unet/diffusion_pytorch_model.fp16.safetensors`` left by a prior ``variant='fp16'`` warm) is not - what a PLAIN pipeline load reads, so a variant-only stale cache is retried over HTTP rather than - loaded (its non-variant component weight is still missing).""" + """True if a DECLARED diffusers pipeline COMPONENT weight (in a ``model_index.json``-declared + subfolder: unet/, vae/, ...) that the ignore filter keeps is present. Scoped to declared components so + a stale UNDECLARED leftover subtree (a controlnet/ dir not declared) does not read as a warm pipeline + while the declared unet / vae weights are missing (which the load would fetch over un-killable Xet). + Excludes ROOT-level weights and training-checkpoint subtrees. A malformed / empty ``model_index.json`` + fails OPEN. Lenient on WHICH declared components are required (they can be optional) -- only + distinguishes a real component warm from an undeclared-leftover / checkpoint-only / config-only stale + snapshot. Counts only CANONICAL (non-variant) weights: a variant-only stale cache is retried over HTTP + (its non-variant component weight is still missing).""" declared = _diffusers_declared_components(snapshot_dir) rels: list = [] try: @@ -1091,7 +1032,7 @@ def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any if not _is_default_load_weight_file(entry.name): continue if not _CANONICAL_COMPONENT_WEIGHT_RE.match(entry.name): - continue # a VARIANT component weight -- a plain load reads the non-variant name + continue # a VARIANT weight -- a plain load reads the non-variant name try: if not entry.is_file(): continue @@ -1100,11 +1041,11 @@ def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any continue parts = rel.split("/") if len(parts) < 2: - continue # a ROOT-level weight is not a pipeline component + continue # a ROOT-level weight is not a component if declared is not None and parts[0] not in declared: - continue # an UNDECLARED subtree the DiffusionPipeline load does not read + continue # an UNDECLARED subtree the load does not read if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): - continue # under a training-checkpoint subtree, not a component + continue # a training-checkpoint subtree, not a component rels.append(rel) except OSError: return False @@ -1113,17 +1054,12 @@ def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: """Whether an UNPATTERNED model warm holds a weight a default load reads: a CANONICAL ROOT weight - (``model.safetensors`` / ``pytorch_model.bin``, single or numbered shard), or -- for a diffusers - pipeline (root ``model_index.json``) -- a component-subfolder weight. Counting ANY subtree weight - would accept a stale checkpoint-only snapshot and then fetch the root weights over un-killable Xet; - diffusers is the one layout whose weights live in subfolders. Only the - canonical names are counted (``_CANONICAL_ROOT_MODEL_WEIGHT_RE``): a VARIANT-named root weight - (``model.fp16.safetensors``), a PEFT adapter (``adapter_model.*``), a gguf, and a NON-canonical root - weight (``consolidated.safetensors``) are excluded, since a default from_pretrained probes only the - canonical names, so a cache holding only something else is retried over HTTP rather than loaded (its - canonical weight is still missing). The request's ignore filter is applied to the ROOT weights, so an - offline-fallback partial holding only the format the load will NOT read (an ignored ``*.bin`` under a - safetensors request) does not count as a usable weight.""" + (single or numbered shard), or -- for a diffusers pipeline (root ``model_index.json``) -- a + component-subfolder weight. Counting ANY subtree weight would accept a stale checkpoint-only snapshot + then fetch the root weights over un-killable Xet; diffusers is the one layout with weights in + subfolders. Only canonical names count (a VARIANT root, PEFT adapter, gguf, or NON-canonical + consolidated.* is retried over HTTP -- its canonical weight is still missing). The ignore filter is + applied so a partial holding only the format the load will NOT read does not count.""" try: is_diffusers = (snapshot_dir / "model_index.json").is_file() except OSError: @@ -1141,21 +1077,17 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - except OSError: continue if _CANONICAL_ROOT_MODEL_WEIGHT_RE.match(name): - rels.append(name) # a canonical model.safetensors / pytorch_model.bin (single or shard) + rels.append(name) # canonical model / pytorch_model (single or shard) elif _CANONICAL_ROOT_TF_FLAX_WEIGHT_RE.match(name): - tf_flax_rels.append(name) # a TF/Flax root weight (from_tf / from_flax) + tf_flax_rels.append(name) # TF/Flax root weight (from_tf / from_flax) except OSError: return False if _filter_paths(rels, None, ignore_patterns): return True - # from_tf / from_flax: the ignore filter drops BOTH canonical PyTorch formats, so the load reads a - # TF (tf_model.h5) / Flax (flax_model.msgpack) root weight the safetensors/bin check above cannot - # see. Count a SINGLE-FILE TF/Flax weight, or a COMPLETE sharded set (its index present with every - # listed shard present), so a complete from_tf/from_flax download is not false-rejected into a - # DownloadStallError -- while an INCOMPLETE sharded set (a lone shard, or an index missing a shard) - # is NOT counted, so it is retried over HTTP rather than loaded over un-killable Xet. Gated on "both - # PyTorch formats ignored", so a normal load is unchanged and a stray leftover h5/msgpack never - # counts. + # from_tf / from_flax (both PyTorch formats ignored): count a SINGLE-FILE TF/Flax weight or a COMPLETE + # sharded set (index + every listed shard present), so a complete h5/msgpack download is not + # false-rejected, while an INCOMPLETE set is retried over HTTP. Gated so a normal load is unchanged + # and a stray leftover never counts. if _pytorch_root_weight_formats_ignored(ignore_patterns): if tf_flax_rels and _filter_paths(tf_flax_rels, None, ignore_patterns): return True @@ -1172,15 +1104,12 @@ def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) - def _root_has_variant_weight( snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None ) -> bool: - """True if a CANONICAL ROOT model weight carrying the requested *variant* token, kept by the ignore - filter, is present. transformers writes the variant on the model base then shards it, so the names it - reads are ``model..safetensors`` (single) and ``model.-00001-of-00002.safetensors`` - (a ``.-`` shard infix) -- matched by ``_ROOT_MODEL_VARIANT_WEIGHT_RE`` plus the specific - variant infix. A non-canonical base (``consolidated..safetensors``), a PEFT adapter, or a - non-``model`` variant name a default variant load never reads is excluded, so a cache holding only - those is retried over HTTP rather than loaded (its ``model..*`` weight is still missing). The - ignore filter is applied so a partial holding only the ignored format (``model.fp16.bin`` under - ``ignore=['*.bin']``) does not count.""" + """True if a CANONICAL ROOT model weight carrying the requested *variant* token (kept by the ignore + filter) is present. transformers reads ``model..safetensors`` and + ``model.-00001-of-00002.safetensors`` (``.-`` shard infix), matched by + ``_ROOT_MODEL_VARIANT_WEIGHT_RE`` plus the variant infix. A non-canonical base, PEFT adapter, or + non-``model`` variant is excluded -> a cache holding only those is retried over HTTP. The ignore + filter is applied so an ignored-format partial does not count.""" infix_dot = f".{variant}." infix_dash = f".{variant}-" rels: list = [] @@ -1190,8 +1119,7 @@ def _root_has_variant_weight( if infix_dot not in name and infix_dash not in name: continue # not the requested variant token if not _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name): - continue # only a canonical model / pytorch_model variant weight is read by a default - # variant load -- an adapter, a consolidated.* sidecar, or a gguf is not + continue # only a canonical model / pytorch_model variant weight, not adapter / gguf try: if entry.is_file(): rels.append(name) @@ -1205,17 +1133,11 @@ def _root_has_variant_weight( def _has_diffusers_component_variant_weight( snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None ) -> bool: - """Variant analog of ``_has_diffusers_component_weight``: True if a DECLARED diffusers pipeline - COMPONENT subfolder (unet/, vae/, text_encoder/, ... that ``model_index.json`` declares) holds a - weight carrying the requested *variant* token (``unet/diffusion_pytorch_model.fp16.safetensors``). A - variant pipeline warm's weights are component-scoped, not root ``model..*`` files, so a - root-only variant check would false-reject a complete diffusers variant download into a - ``DownloadStallError``. Scoped to declared components (as the plain component helper is), so a stale - partial holding only an UNDECLARED leftover variant weight (a ``controlnet/`` dir not in - ``model_index.json``) does not read as proof the pipeline is warm while the declared unet / vae - variant weights are still missing -- which ``DiffusionPipeline.from_pretrained(..., variant=...)`` - would then fetch over un-killable Xet. A malformed / empty ``model_index.json`` fails OPEN. Excludes - ROOT-level and training-checkpoint weights (as the plain component check does) and reads only + """Variant analog of ``_has_diffusers_component_weight``: True if a DECLARED component subfolder holds + a weight with the requested *variant* token. A variant pipeline's weights are component-scoped, not + root ``model..*``, so a root-only check would false-reject a complete diffusers variant + download. Scoped to declared components (an undeclared leftover does not read as warm), fails OPEN on a + malformed ``model_index.json``, excludes ROOT-level / training-checkpoint weights, reads only safetensors / bin.""" declared = _diffusers_declared_components(snapshot_dir) infix_dot = f".{variant}." @@ -1236,11 +1158,11 @@ def _has_diffusers_component_variant_weight( continue parts = rel.split("/") if len(parts) < 2: - continue # a ROOT-level variant weight is not a pipeline component + continue # a ROOT-level variant weight is not a component if declared is not None and parts[0] not in declared: - continue # an UNDECLARED subtree the DiffusionPipeline load does not read + continue # an UNDECLARED subtree the load does not read if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): - continue # under a training-checkpoint subtree, not a component + continue # a training-checkpoint subtree, not a component rels.append(rel) except OSError: return False @@ -1250,10 +1172,9 @@ def _has_diffusers_component_variant_weight( def _root_model_has_variant_weight( snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None ) -> bool: - """Whether an UNPATTERNED variant warm holds a variant weight a default load reads: a ROOT variant - weight, or -- for a diffusers pipeline (root ``model_index.json``) -- a component-subfolder variant - weight. Variant analog of ``_root_model_has_weight``: a diffusers variant's weights live in component - subfolders, not root ``model..*`` files, so the root-only check would false-reject them.""" + """Variant analog of ``_root_model_has_weight``: an UNPATTERNED variant warm holds a ROOT variant + weight, or -- for a diffusers pipeline -- a component-subfolder variant weight (its weights live in + subfolders, not root ``model..*``, so the root-only check would false-reject them).""" try: is_diffusers = (snapshot_dir / "model_index.json").is_file() except OSError: @@ -1266,20 +1187,19 @@ def _root_model_has_variant_weight( # Interchangeable exact weight names collapse to one equivalence group: the either-format pair -# ``["pytorch_model.bin", "model.safetensors"]`` is satisfied by ANY one -- and so is the variant pair -# ``["model.fp16.safetensors", "pytorch_model.fp16.bin"]`` (HF allow patterns are ALTERNATIVES over the -# repo, so a repo publishing only one format is complete). Distinct logical weights (base AND adapter, a -# different variant token) stay separate groups (each required). +# ``["pytorch_model.bin", "model.safetensors"]`` (and the variant pair) is satisfied by ANY one, since HF +# allow patterns are ALTERNATIVES over the repo. Distinct logical weights (base AND adapter, a different +# variant) stay separate groups (each required). _EITHER_FORMAT_WEIGHT_RE = re.compile( r"^(model|pytorch_model|adapter_model)(?:\.([^.]+))?\.(?:safetensors|bin)$" ) def _exact_weight_logical(base: str) -> Any: - """Equivalence key for an EXACT-named weight so the either-format alternatives share a group: - ``model.safetensors`` / ``pytorch_model.bin`` -> ``("root_model", None)``; the same variant token in - both formats shares ``("root_model", "")``; ``adapter_model.*`` -> ``("adapter_model", ...)``. - A non-weight (or sharded) name maps to itself, so each distinct file is still required.""" + """Equivalence key for an EXACT-named weight so either-format alternatives share a group: + ``model.safetensors`` / ``pytorch_model.bin`` -> ``("root_model", None)``; same variant in both + formats -> ``("root_model", "")``; ``adapter_model.*`` -> ``("adapter_model", ...)``. A + non-weight / sharded name maps to itself (still required individually).""" m = _EITHER_FORMAT_WEIGHT_RE.match(base) if m is None: return base @@ -1291,16 +1211,16 @@ def _exact_weight_logical(base: str) -> Any: def _requested_exact_files_present_grouped( snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, ) -> bool: - """True unless an EXACT-named requested file is missing. Interchangeable weights - (``["pytorch_model.bin", "model.safetensors"]``) need any one; distinct logical files (base AND - adapter, a tokenizer file) each. A glob / unpatterned request is trivially satisfied here.""" + """True unless an EXACT-named requested file is missing. Interchangeable weights need any one; distinct + logical files (base AND adapter, a tokenizer file) each. A glob / unpatterned request is trivially + satisfied here.""" allow = _as_pattern_list(allow_patterns) ignore = _as_pattern_list(ignore_patterns) if not allow or any(not isinstance(p, str) or _has_glob(p) for p in allow): return True requested = _filter_paths(allow, None, ignore) if not requested: - return True # the ignore filter dropped every named file -> nothing to require + return True # ignore filter dropped every named file -> nothing to require try: present = { entry.relative_to(snapshot_dir).as_posix() @@ -1322,9 +1242,8 @@ def _requested_exact_files_present_grouped( def _has_selected_weight( snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, ) -> bool: - """True if a loadable weight the request SELECTS is present. Applies the allow / ignore filter, so a - patterned request is not satisfied by an out-of-scope weight (a stale ``.bin``, an unrequested - checkpoint subfolder).""" + """True if a loadable weight the request SELECTS (allow / ignore filtered) is present, so a stale + out-of-scope weight does not satisfy a patterned request.""" weights: list = [] try: for entry in snapshot_dir.rglob("*"): @@ -1344,10 +1263,8 @@ def _has_selected_variant_weight( snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: str, ) -> bool: """True if a SELECTED loadable weight carrying the *variant* token is present. Combines the request's - allow / ignore scope (as ``_has_selected_weight``) with the variant infix check (as - ``_root_has_variant_weight``): a patterned variant load (e.g. ``subfolder=`` + ``variant=``) whose - offline-fallback partial kept only the canonical weight in scope is retried over HTTP rather than - loaded, else the in-process load fetches ``model..safetensors`` over un-killable Xet.""" + allow / ignore scope with the variant infix check, so a patterned variant load whose partial kept + only the canonical weight is retried over HTTP (else it fetches the variant over un-killable Xet).""" infix_dot = f".{variant}." infix_dash = f".{variant}-" weights: list = [] @@ -1369,22 +1286,21 @@ def _has_selected_variant_weight( def _patterns_are_exact_names(patterns: Any) -> bool: - """True only for a non-empty allow list of EXACT filenames (no ``None`` / glob / trailing-slash - dir). Only such a request is locally provable complete; ``None`` / a glob needs the Hub manifest.""" + """True only for a non-empty allow list of EXACT filenames (no ``None`` / glob / trailing-slash dir). + Only such a request is locally provable complete; ``None`` / a glob needs the Hub manifest.""" patterns = _as_pattern_list(patterns) if patterns is None: return False if not patterns: - return True # selects nothing -> trivially satisfied, nothing to fetch + return True # selects nothing -> nothing to fetch return all(isinstance(p, str) and not _has_glob(p) for p in patterns) def _request_selects_canonical_root_shards(allow_patterns: Any, ignore_patterns: Any) -> bool: - """Whether the request's allow / ignore filter keeps a canonical ROOT shard name. When False, an - incomplete canonical root shard set is OUT of the request's scope -- a co-resident leftover from a - prior interrupted base pull that a patterned load (adapter / gguf / subfolder) never reads -- so the - canonical-shard-completeness gate must NOT reject on it, else a genuinely complete patterned download - is failed into a DownloadStallError.""" + """Whether the allow / ignore filter keeps a canonical ROOT shard name. When False, an incomplete + canonical root shard set is OUT of scope (a leftover a patterned adapter / gguf / subfolder load never + reads), so the shard-completeness gate must NOT reject on it, else a complete patterned download is + failed into a DownloadStallError.""" probes = ["model-00001-of-00002.safetensors", "pytorch_model-00001-of-00002.bin"] return bool(_filter_paths(probes, allow_patterns, ignore_patterns)) @@ -1392,10 +1308,10 @@ def _request_selects_canonical_root_shards(allow_patterns: Any, ignore_patterns: def _request_selects_root_variant_weight( allow_patterns: Any, ignore_patterns: Any, variant: str, ) -> bool: - """Whether the request's allow / ignore filter keeps a ROOT variant weight name. When False, a stale - incomplete root variant shard set is OUT of the request's scope (e.g. a subfolder request - ``allow=['unet/*']`` whose variant weights live under ``unet/``), so the ROOT variant-shard gate must - not reject on it, else a complete in-scope variant download is failed.""" + """Whether the allow / ignore filter keeps a ROOT variant weight name. When False, a stale incomplete + root variant shard set is OUT of scope (e.g. a subfolder request whose variant weights live under + ``unet/``), so the ROOT variant-shard gate must not reject on it, else a complete in-scope variant + download is failed.""" probes = [ f"model.{variant}.safetensors", f"model.{variant}-00001-of-00002.safetensors", f"pytorch_model.{variant}.bin", f"pytorch_model.{variant}-00001-of-00002.bin", @@ -1407,32 +1323,26 @@ def _cache_can_skip_download( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str] = None, ) -> bool: - """PRE-download: whether a cached snapshot is complete enough to skip the protective child. - - STRICT for a weight-bearing model request: only the conservative canonical gate - (``snapshot_dir_is_complete``) skips; anything uncertain (diffusers, variants, patterns, - sharded-without-index) spawns the child. A false True would let the load fetch a missing weight over - un-killable Xet (the hang). A weightless model or non-model (dataset) request has no weight to hang - on, but is locally provable complete only when it names EXACT files -- an unpatterned / glob request - defers to the child rather than hand back a partial cache. An intact exact-named subset still - short-circuits (offline tokenizer-only / named-file warm).""" + """PRE-download: whether a cached snapshot is complete enough to skip the protective child. STRICT -- + a false True lets the load fetch a missing weight over un-killable Xet (the hang). + + Weight-bearing model request: only the conservative canonical gate (``snapshot_dir_is_complete``) + skips; anything uncertain (diffusers, variants, patterns, sharded-without-index) spawns the child. A + weightless / non-model request has no weight to hang on but is locally provable complete only when it + names EXACT files (an intact exact-named subset short-circuits; unpatterned / glob defers).""" if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): - # A variant load reads variant-named weights (model..safetensors) that the canonical - # gate does not check: a cache holding only the canonical weight reads as complete, so the - # in-process load would fetch the variant over un-killable Xet. Defer to the child (it warms - # the variant too). + # A variant load reads variant-named weights the canonical gate does not check, so a + # canonical-only cache would fetch the variant over un-killable Xet. Defer to the child. if variant: return False - # STRICT: a default load probes model.safetensors before pytorch_model.bin, so a bin-only cache - # for a repo that also publishes safetensors (which the local cache cannot rule out) would fetch - # the preferred safetensors in-process over Xet. prefer_safetensors defers such a cache to the - # child; a use_safetensors=False request (safetensors ignored) still fast-paths its bin cache. + # A default load probes model.safetensors before pytorch_model.bin, so a bin-only cache for a + # repo that also publishes safetensors (unprovable locally) would fetch the safetensors over Xet. + # prefer_safetensors defers it; a use_safetensors=False request still fast-paths its bin cache. return snapshot_dir_is_complete( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, prefer_safetensors = True, ) - # Weightless / non-model: skip only for an intact exact-named subset. A None / glob request cannot - # be proven complete from local files, so defer to the child for the manifest compare + resume. + # Weightless / non-model: skip only for an intact exact-named subset; None / glob defers to the child. if not _patterns_are_exact_names(allow_patterns): return False return _intact_subset( @@ -1444,16 +1354,15 @@ def _cache_can_skip_download( def _has_readable_weight( snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str], ) -> bool: - """Invariant A (presence): a weight the in-process load will READ is present on disk, with the - request's ignore filter ALWAYS applied and the scope matched to the request: + """Invariant A (presence): a weight the in-process load will READ is present, ignore filter ALWAYS + applied and scope matched to the request: - - variant + UNPATTERNED -> a ROOT variant weight (``model..*``); - - variant + PATTERNED -> a SELECTED variant weight (within the allow scope); + - variant + UNPATTERNED -> a ROOT variant weight; + - variant + PATTERNED -> a SELECTED variant weight; - plain + UNPATTERNED -> a ROOT (or diffusers-component) weight, NOT a stray subfolder checkpoint; - - plain + PATTERNED -> a SELECTED weight (within the allow scope). + - plain + PATTERNED -> a SELECTED weight. - A partial that kept only the ignored format (an ``*.bin`` under ``ignore=['*.bin']``) does not count, - so the incomplete result is retried over HTTP rather than loaded in-process.""" + A partial that kept only the ignored format does not count -> retried over HTTP.""" if variant: if allow_patterns is None: return _root_model_has_variant_weight( @@ -1473,45 +1382,37 @@ def _has_readable_weight( def _readable_shard_set_incomplete( snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str], ) -> bool: - """Invariant B (shard completeness): an IN-SCOPE shard set the load reads is incomplete (an index - present with a shard missing, or a lone numbered shard without its index) and must be retried. The - check is ALWAYS scoped to what the request selects, so a co-resident stale shard set the load never - reads (a leftover root checkpoint under a subfolder/adapter/gguf request) does not false-reject a - complete download: - - - variant: the ROOT variant-shard check applies (for a NON-diffusers snapshot) for an UNPATTERNED - request, or a PATTERNED request that selects a ROOT variant weight (a globbed ``['*.safetensors']``); - a subfolder-scoped variant request does not root-check. - - plain: the canonical-root-shard check applies (for a NON-diffusers snapshot) for an UNPATTERNED - request, or a GLOBBED request that selects canonical root shards; an exact-named subset or an - out-of-scope request does not. - - non-root: a PATTERNED request additionally checks any SELECTED shard index the root-model checks do - not cover (a sharded adapter under ``['adapter_model*']``, a component subfolder) via - ``_selected_shard_index_incomplete``; an exact-named subset defers to the exact-file presence check. - - diffusers: a pipeline (root ``model_index.json``) reads COMPONENT subfolders (unet/, vae/, ...), NOT - root model shards, so the root-model checks above are SKIPPED for it (a stale root index must not - reject a complete pipeline); an UNPATTERNED warm's component shard sets are checked via - ``_diffusers_component_shards_incomplete``, and a PATTERNED one via ``_selected_shard_index_incomplete``. - - The ignore filter is threaded through so completeness is judged for the FORMAT the load reads (a - complete safetensors set does not mask an incomplete ``.bin`` under ``ignore=['*.safetensors']``).""" + """Invariant B (shard completeness): an IN-SCOPE shard set the load reads is incomplete (index present + with a shard missing, or a lone numbered shard without its index) and must be retried. ALWAYS scoped + to what the request selects, so a co-resident stale shard set the load never reads does not + false-reject a complete download: + + - variant: ROOT variant-shard check (NON-diffusers) for UNPATTERNED, or a PATTERNED request selecting + a ROOT variant weight; a subfolder-scoped variant request does not root-check. + - plain: canonical-root-shard check (NON-diffusers) for UNPATTERNED, or a GLOBBED request selecting + canonical root shards; an exact-named subset / out-of-scope request does not. + - non-root: a PATTERNED request also checks any SELECTED shard index the root checks miss (a sharded + adapter, a component subfolder) via ``_selected_shard_index_incomplete``. + - diffusers: a pipeline reads COMPONENT subfolders, NOT root model shards, so root checks are SKIPPED; + its component shard sets go through ``_diffusers_component_shards_incomplete`` (unpatterned) / + ``_selected_shard_index_incomplete`` (patterned). + + The ignore filter is threaded through so completeness is judged for the FORMAT the load reads.""" try: is_diffusers = (snapshot_dir / "model_index.json").is_file() except OSError: is_diffusers = False if _patterns_are_exact_names(allow_patterns): - # An exact-named subset (variant or plain) defers to the exact-file presence check: the load - # reads exactly the named shard(s), so a lone exact variant shard is not judged against its - # (unrequested) index -- else a valid exact request is false-rejected into a DownloadStallError. + # An exact-named subset defers to the exact-file presence check: the load reads exactly the named + # shard(s), so a lone exact shard is not judged against its unrequested index (else false-reject). return False if variant: if not is_diffusers and ( allow_patterns is None or _request_selects_root_variant_weight(allow_patterns, ignore_patterns, variant) ): - # A diffusers pipeline reads component-subfolder variant weights, not root model. - # shards, so a stale root variant index must not reject a complete pipeline (handled below by - # the component check); only a non-diffusers root variant load runs the root-shard check. + # Only a non-diffusers root variant load runs the root-shard check (a diffusers pipeline's + # variant weights are component-scoped, handled below). if _has_incomplete_variant_root_shards( snapshot_dir, variant, ignore_patterns = ignore_patterns ): @@ -1525,29 +1426,29 @@ def _readable_shard_set_incomplete( elif is_diffusers and _diffusers_component_shards_incomplete( snapshot_dir, variant = variant, ignore_patterns = ignore_patterns ): - # an UNPATTERNED variant diffusers warm: a component subfolder's variant shard index is - # incomplete (the root variant check above only covers root model. shards). + # UNPATTERNED variant diffusers warm: a component variant shard index is incomplete (the root + # variant check above only covers root model. shards). return True return False if allow_patterns is None: if not is_diffusers and _has_incomplete_canonical_root_shards( snapshot_dir, ignore_patterns = ignore_patterns ): - # a non-diffusers root model load; a diffusers pipeline reads component subfolders, not root - # model shards, so a stale root index there is handled by the component check below. + # non-diffusers root model load (a diffusers stale root index is handled by the component + # check below). return True if is_diffusers and _diffusers_component_shards_incomplete( snapshot_dir, variant = None, ignore_patterns = ignore_patterns ): - # an UNPATTERNED plain diffusers warm reads component subfolders (unet/, vae/, ...); a - # component shard index missing a shard is not covered by the canonical ROOT-shard check. + # UNPATTERNED plain diffusers warm: a component shard index missing a shard is not covered by + # the canonical ROOT-shard check. return True return False if not is_diffusers and _request_selects_canonical_root_shards( allow_patterns, ignore_patterns ) and _has_incomplete_canonical_root_shards(snapshot_dir, ignore_patterns = ignore_patterns): # non-diffusers only: a diffusers pipeline never reads root model shards (its component sets are - # checked via _selected_shard_index_incomplete below), so a stale root index must not reject it. + # checked below), so a stale root index must not reject it. return True return _selected_shard_index_incomplete( snapshot_dir, allow_patterns = allow_patterns, @@ -1558,11 +1459,10 @@ def _readable_shard_set_incomplete( def _selected_readable_weight_complete( snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str], ) -> bool: - """Single entry point for the weight-bearing MODEL acceptance check: the weight the in-process load - will READ is present (Invariant A) AND its in-scope shard set is complete (Invariant B). Both - invariants apply the request's ignore filter and match its scope uniformly, so a co-resident - out-of-scope / ignored-format partial neither masks an incomplete readable weight (a silent Xet hang) - nor false-rejects a complete download (a spurious ``DownloadStallError``).""" + """Weight-bearing MODEL acceptance check: the weight the load will READ is present (Invariant A) AND + its in-scope shard set is complete (Invariant B). Both apply the ignore filter and match scope + uniformly, so a co-resident out-of-scope / ignored-format partial neither masks an incomplete weight + (a silent Xet hang) nor false-rejects a complete download (a spurious ``DownloadStallError``).""" if not _has_readable_weight( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, variant = variant, @@ -1580,22 +1480,17 @@ def _download_result_usable( snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str] = None, ) -> bool: - """POST-download: whether the child's result is usable, or should be retried over HTTP. - snapshot_download already did the authoritative manifest compare, so accept unless there is - POSITIVE breakage evidence; LENIENT otherwise (a finished diffusers / either-format download passes, - an optional missing file is not treated as broken) so a good download is never looped into a - ``DownloadStallError``. A transient connection error during the child's metadata call makes - ``snapshot_download`` silently return an existing (stale / partial) cache instead of fetching, so - the checks below apply the request's filters to the weight the load will actually read. Breakage: + """POST-download: whether the child's result is usable or should be retried over HTTP. + snapshot_download already did the authoritative manifest compare, so accept unless there is POSITIVE + breakage evidence; LENIENT otherwise so a good download is never looped into a ``DownloadStallError``. + A transient connection error during the child's metadata call makes ``snapshot_download`` silently + return a stale / partial cache, so the checks apply the request's filters to the weight the load + reads. Breakage: - A dangling REQUESTED symlink (a missing / still-``.incomplete`` blob). - - A missing EXACT-named requested file (grouped by weight equivalence: the either-format pair needs - one; base AND adapter, or a ``["tokenizer.json"]`` request, each). Globs stay lenient. - - A weight-bearing MODEL request whose READABLE weight is absent or incomplete. Delegated to - ``_selected_readable_weight_complete``, which applies the request's ignore filter and scope - uniformly: the weight the load reads (variant vs canonical, root vs in-scope) must be present, and - its in-scope shard set complete. A co-resident out-of-scope / ignored-format partial neither masks - an incomplete readable weight nor false-rejects a complete download.""" + - A missing EXACT-named requested file (grouped by weight equivalence; globs stay lenient). + - A weight-bearing MODEL request whose READABLE weight is absent or incomplete (delegated to + ``_selected_readable_weight_complete``).""" if snapshot_has_requested_broken_symlinks( snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, repo_type = repo_type, @@ -1620,7 +1515,7 @@ def _snapshot_payload_incomplete( ) -> bool: """True when a snapshot download returned a real directory not usable for the request (see ``_download_result_usable``). Guarded to an existing dir, so a mocked / non-path payload (tests) is - trusted rather than rejected; in production the child always returns a real snapshot dir.""" + trusted; production always returns a real snapshot dir.""" try: path = Path(payload) except (TypeError, ValueError, OSError): @@ -1657,20 +1552,17 @@ def _download_with_xet_fallback( raise RuntimeError("Cancelled") cache_dir = params.get("cache_dir") - # The Unsloth/HF knobs can force HTTP from the very first attempt. xet_force_disabled() reads - # os.environ["HF_HUB_DISABLE_XET"] live, and a CONCURRENT download briefly sets that var in the - # parent env around its spawn (under _SPAWN_ENV_LOCK) so its child inherits it. Read under the - # same lock so this download cannot observe the other's child-only value and wrongly force itself - # onto HTTP from the start. + # Read xet_force_disabled() under the lock a CONCURRENT download briefly sets HF_HUB_DISABLE_XET + # under (around its spawn), so this download cannot observe the other's child-only value and wrongly + # force itself onto HTTP from the start. with _SPAWN_ENV_LOCK: disable_xet = xet_force_disabled() for attempt in range(2): if disable_xet: - # Purge a non-HTTP partial first: an HTTP resume over a sparse Xet/hf_transfer partial - # silently corrupts the blob. Scope the purge to the partials the stalled child owned, so - # a concurrent same-repo sibling's partial is spared. An injected (Studio) hook owns its - # own cache accounting, so it keeps the plain (repo_type, repo_id) signature. + # Purge a non-HTTP partial first (an HTTP resume over a sparse Xet/hf_transfer partial + # silently corrupts the blob), scoped to the stalled child's own partials so a same-repo + # sibling is spared. An injected (Studio) hook keeps the plain (repo_type, repo_id) signature. owned_incomplete = params.pop("_owned_incomplete_blobs", None) try: if prepare_for_http_fn is None: @@ -1682,8 +1574,8 @@ def _download_with_xet_fallback( prepare_for_http_fn(repo_type, repo_id) except Exception as e: logger.debug("prepare_for_http failed for %s: %s", repo_id, e) - # An unsafe partial that could not be cleared (locked file, permission error) would - # corrupt the blob on an HTTP resume: force a clean re-download instead. + # An unsafe partial that could not be cleared (locked / permission) would corrupt the blob on + # an HTTP resume: force a clean re-download instead. if has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir): logger.warning( "Unsafe partial for '%s' could not be cleared; forcing a clean " @@ -1713,10 +1605,10 @@ def _download_with_xet_fallback( ignore_patterns = params.get("ignore_patterns"), variant = variant, ): - # HF can hand back an existing incomplete snapshot dir (offline / timed-out request) - # instead of fetching the missing files. Never load that in-process: retry over HTTP, - # then fail loudly rather than load a broken cache. (Patterned / non-model requests are - # judged by their own subset, so a valid weightless snapshot is not rejected.) + # HF can hand back an existing incomplete snapshot dir (offline / timed-out) instead of + # fetching: never load it in-process. Retry over HTTP, then fail loudly. (Patterned / + # non-model requests judge their own subset, so a valid weightless snapshot is not + # rejected.) if not disable_xet: logger.warning( "Download for '%s' returned an incomplete snapshot -- " @@ -1734,12 +1626,10 @@ def _download_with_xet_fallback( raise RuntimeError("Cancelled") if kind_result == "error": # Deterministic failure (auth / not-found / gated / disk-full): the other transport fails - # identically, so do not retry. _raise_child_error preserves the original exception type - # across the spawn boundary so callers' typed except clauses still match. + # identically. _raise_child_error preserves the original type across the spawn. _raise_child_error(payload) if kind_result == "retryable_error": - # Transient transport failure (hf_xet CAS timeout, 5xx, reset): HTTP may recover, so retry - # once before surfacing it; if HTTP also failed there is no transport left -> raise. + # Transient transport failure (hf_xet CAS timeout, 5xx, reset): retry HTTP once, else raise. if not disable_xet: logger.warning( "Download for '%s' hit a transient Xet transport error -- retrying " @@ -1765,8 +1655,7 @@ def _download_with_xet_fallback( logger.warning( "Download stalled for '%s' -- retrying with HF_HUB_DISABLE_XET=1", label ) - # _safe_status: a raising status hook (disconnected client) must not abort the retry - # before disable_xet is set, turning a recoverable stall into a failed download. + # _safe_status: a raising status hook must not abort the retry before disable_xet is set. _safe_status(on_status, f"{label}: Xet stalled, retrying over HTTP") disable_xet = True continue @@ -1797,24 +1686,20 @@ def hf_hub_download_with_xet_fallback( on_status: Optional[Callable[[str], None]] = None, prepare_for_http_fn: Optional[Callable[[str, str], None]] = None, ) -> str: - """Download a single file with Xet primary and HTTP as a stall-only fallback. + """Download a single file with Xet primary and HTTP as a stall-only fallback; return the local path. - Returns the local cache path. Raises ``RuntimeError("Cancelled")`` if *cancel_event* is set, - re-raises a deterministic child error unchanged (no fallback), and raises ``DownloadStallError`` - only if BOTH transports stall. ``force_download=True`` re-fetches even if cached; - ``local_files_only=True`` resolves from cache in-process with no child (HF offline semantics); - ``subfolder`` is forwarded to ``hf_hub_download``. + Raises ``RuntimeError("Cancelled")`` if *cancel_event* is set, re-raises a deterministic child error + unchanged, and raises ``DownloadStallError`` only if BOTH transports stall. ``local_files_only=True`` + resolves from cache in-process with no child (HF offline semantics). """ repo_type = repo_type or "model" # HF treats None as the default model repo. - # Expand ~ (and normalize Path) as huggingface_hub does, so the probe and the child resolve to - # the same on-disk location (else a warm cache is missed and we spawn a child for a cached file). + # Expand ~ as huggingface_hub does, so the probe and the child resolve to the same location. if isinstance(cache_dir, (str, os.PathLike)): cache_dir = os.path.expanduser(os.fspath(cache_dir)) - # Honor an already-set cancellation before any probe: the short-circuits below return without - # reaching _download_with_xet_fallback (which holds the only other cancel check). + # Honor cancellation before any probe (the short-circuits below bypass the fallback's cancel check). if cancel_event is not None and cancel_event.is_set(): raise RuntimeError("Cancelled") - # Offline: resolve purely from cache. HF raises LocalEntryNotFoundError if uncached; let it propagate. + # Offline: resolve from cache. HF raises LocalEntryNotFoundError if uncached; let it propagate. if local_files_only: from huggingface_hub import hf_hub_download @@ -1828,8 +1713,8 @@ def hf_hub_download_with_xet_fallback( cache_dir = cache_dir, local_files_only = True, ) - # Finalized blob already cached: return it with no child and no network (skipped under - # force_download). The cache stores a subfolder file under "/". + # Finalized blob already cached: return it with no child (skipped under force_download). A subfolder + # file is cached under "/". if not force_download: try: from huggingface_hub import try_to_load_from_cache @@ -1885,25 +1770,23 @@ def snapshot_download_with_xet_fallback( on_status: Optional[Callable[[str], None]] = None, prepare_for_http_fn: Optional[Callable[[str, str], None]] = None, ) -> str: - """Download a whole repo snapshot with Xet primary and HTTP as a stall-only fallback, returning - the local snapshot dir. + """Download a whole repo snapshot with Xet primary and HTTP as a stall-only fallback; return the + local snapshot dir. Used by Unsloth's ``from_pretrained`` to warm the cache in a killable child BEFORE the in-process - model load (which then hits a warm cache and cannot hang on a native Xet thread). A fully cached - repo short-circuits in-process via ``local_files_only`` with no child. ``force_download=True`` - re-fetches in the killable child even if cached; ``local_files_only=True`` resolves from cache - in-process with no child (HF offline semantics). ``variant`` (e.g. "fp16") forces the child even - on a warm canonical cache, since the canonical gate cannot prove the variant-named weights present. + load (which then hits a warm cache and cannot hang on a native Xet thread). A fully cached repo + short-circuits in-process with no child. ``local_files_only=True`` resolves from cache in-process + (HF offline semantics). ``variant`` forces the child even on a warm canonical cache, since the + canonical gate cannot prove the variant-named weights present. """ repo_type = repo_type or "model" # HF treats None as the default model repo. - # Expand ~ as huggingface_hub does, so the probe and the child resolve to the same cache location. + # Expand ~ as huggingface_hub does, so the probe and the child resolve to the same location. if isinstance(cache_dir, (str, os.PathLike)): cache_dir = os.path.expanduser(os.fspath(cache_dir)) - # Honor an already-set cancellation before any probe: the short-circuits below return without - # reaching _download_with_xet_fallback (which holds the only other cancel check). + # Honor cancellation before any probe (the short-circuits below bypass the fallback's cancel check). if cancel_event is not None and cancel_event.is_set(): raise RuntimeError("Cancelled") - # Offline: resolve purely from cache. HF raises if uncached; let it propagate. + # Offline: resolve from cache. HF raises if uncached; let it propagate. if local_files_only: from huggingface_hub import snapshot_download @@ -1916,8 +1799,7 @@ def snapshot_download_with_xet_fallback( ignore_patterns = ignore_patterns, local_files_only = True, ) - # Fast path: everything already on disk -> resolve in-process (no Xet, no hang). Skipped under - # force_download. + # Fast path: everything on disk -> resolve in-process (no Xet, no hang). Skipped under force_download. if not force_download: try: from huggingface_hub import snapshot_download @@ -1931,12 +1813,10 @@ def snapshot_download_with_xet_fallback( ignore_patterns = ignore_patterns, local_files_only = True, ) - # local_files_only returns a snapshot dir whenever refs/ + snapshots/ exist, - # even one left by a prior interrupted or patterned download (config-only, partial shards). - # Validate the EXACT returned revision dir: a full model warmup skips the child only when - # its canonical weights are provably complete; a patterned / non-model request only needs - # its referenced files. Scope to this snapshot, NOT the whole repo, so an unrelated - # revision mid-download elsewhere in the repo cache does not force a needless re-fetch. + # local_files_only returns a snapshot dir whenever refs/ + snapshots/ exist, even + # one left by a prior interrupted / patterned download. Validate the EXACT returned revision + # dir (scoped to this snapshot, not the whole repo, so an unrelated revision mid-download does + # not force a needless re-fetch). if _cache_can_skip_download( Path(cached_dir), repo_type = repo_type, From 94d30ff00ebe03d47c0407e69a6a1f34c1f85df6 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Fri, 3 Jul 2026 13:39:52 +0000 Subject: [PATCH 79/82] Reject a diffusers pipeline warm missing a declared component Consolidated review pass (12x reviewer.py personas + 3 forked reviewers). 8 of 12 reviewers independently reproduced a post-download false-accept: _has_diffusers_component_weight returned True as soon as ANY declared component held a weight, so a stale partial pipeline (model_index.json declaring unet + vae with only unet present, or vae present with its config but no weight) was accepted -- then the in-process DiffusionPipeline load fetched the missing component over un-killable Xet, the exact hang this fallback prevents. Diffusion is a real loader consumer (diffusion.py warms with allow_patterns=None), so it is reachable. An unpatterned diffusers warm now requires that every DECLARED ACTIVE component (a [library, class] spec with both non-null) is materialised as a non-empty subfolder AND that each model-style component (one carrying config.json) holds a readable weight of the read format. A [null, null] disabled component (e.g. safety_checker) and weightless components (scheduler / tokenizer / feature_extractor, which carry *_config.json, not config.json) are not required, so a complete or safety-checker-disabled pipeline is not false-rejected into a DownloadStallError. A variant load accepts a component's canonical weight as diffusers' per-component fallback, so a mixed variant/canonical pipeline still passes. A malformed / empty model_index.json fails OPEN to the prior lenient any-weight check. Adds a regression over component-absent, config-only, complete-with-null-safety- checker, variant-absent, and mixed-variant cases. Full suite 192 passed; safety- invariant fuzz stays at 0 false-accepts. Not taken (re-verified, no new evidence): - Injected prepare_for_http_fn dropping owned_incomplete_blobs: by design; a Studio hook owns its marker-based cache accounting (documented at the call site). - Retryable / crashed-child HTTP prep using the coarse mtime guard: an exited child has no open fds to scope owned partials; inherent, not closable. - Missing hf_hub kwargs / non-import-light default: a purpose-built prefetch helper (local_dir breaks the cache model) and an intentional opt-in (UNSLOTH_ZOO_DISABLE_GPU_INIT), not a general hf_hub_download replacement. - Patterned TF/Flax sharded-subfolder edge: documented low-reachability; PRE never accepts tf/flax, so it can only ever cause a safe retry. --- tests/test_hf_xet_fallback.py | 84 ++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 26 +++--- unsloth_zoo/hf_xet_fallback.py | 141 ++++++++++++++++++++------------- 3 files changed, 187 insertions(+), 64 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index b745c85ee..0f338a24a 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2890,6 +2890,90 @@ def test_post_download_diffusers_variant_presence_scoped_to_declared(tmp_path): variant = "fp16") is True +def test_post_download_rejects_diffusers_missing_declared_component(tmp_path): + """A declared weight-bearing component absent (or holding only its config) is retried over HTTP, not + accepted -- else the in-process pipeline load fetches the missing component over un-killable Xet. A + ``[null, null]`` (disabled) component and weightless components (scheduler / tokenizer) are not + required.""" + def _mi(): + return json.dumps({ + "_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], + "vae": ["diffusers", "AutoencoderKL"], + "text_encoder": ["transformers", "CLIPTextModel"], + "scheduler": ["diffusers", "PNDMScheduler"], + "tokenizer": ["transformers", "CLIPTokenizer"], + "safety_checker": [None, None], + }) + + def _model_comp(root, name, blob, *, weight = True, variant = None): + d = root / name + d.mkdir() + (d / "config.json").write_text("{}") + if weight: + w = "diffusion_pytorch_model.safetensors" if variant is None \ + else f"diffusion_pytorch_model.{variant}.safetensors" + (d / w).symlink_to(blob) + + def _weightless(root, name): + d = root / name + d.mkdir() + # scheduler_config.json / tokenizer_config.json -- NOT a plain config.json, so no weight required + (d / f"{name}_config.json").write_text("{}") + + # unet present, vae ABSENT (text_encoder present) -> reject. + snap, blob = _mk_snapshot(tmp_path, "diff_missing_vae") + (snap / "model_index.json").write_text(_mi()) + _model_comp(snap, "unet", blob) + _model_comp(snap, "text_encoder", blob) + _weightless(snap, "scheduler") + _weightless(snap, "tokenizer") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + # vae present with config ONLY (weight missing) -> reject. + snap2, blob2 = _mk_snapshot(tmp_path, "diff_vae_config_only") + (snap2 / "model_index.json").write_text(_mi()) + _model_comp(snap2, "unet", blob2) + _model_comp(snap2, "text_encoder", blob2) + _model_comp(snap2, "vae", blob2, weight = False) # config, no weight + _weightless(snap2, "scheduler") + _weightless(snap2, "tokenizer") + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + # Every weight-bearing component complete; safety_checker [null,null] absent -> accept (no false-reject). + snap3, blob3 = _mk_snapshot(tmp_path, "diff_complete") + (snap3 / "model_index.json").write_text(_mi()) + for c in ("unet", "vae", "text_encoder"): + _model_comp(snap3, c, blob3) + _weightless(snap3, "scheduler") + _weightless(snap3, "tokenizer") + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + # Variant: vae absent -> reject; a mixed pipeline (unet fp16, vae/text_encoder canonical fallback) -> accept. + snap4, blob4 = _mk_snapshot(tmp_path, "diff_variant_missing_vae") + (snap4 / "model_index.json").write_text(_mi()) + _model_comp(snap4, "unet", blob4, variant = "fp16") + _model_comp(snap4, "text_encoder", blob4, variant = "fp16") + _weightless(snap4, "scheduler") + _weightless(snap4, "tokenizer") + assert xf._download_result_usable( + snap4, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + snap5, blob5 = _mk_snapshot(tmp_path, "diff_variant_mixed") + (snap5 / "model_index.json").write_text(_mi()) + _model_comp(snap5, "unet", blob5, variant = "fp16") + _model_comp(snap5, "vae", blob5) # canonical fallback for this component + _model_comp(snap5, "text_encoder", blob5) # canonical fallback + _weightless(snap5, "scheduler") + _weightless(snap5, "tokenizer") + assert xf._download_result_usable( + snap5, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + def test_post_download_single_variant_beats_stale_variant_index(tmp_path): """Variant twin of single-beats-index: a complete single variant weight beside a stale variant index is usable (ST and bin); a stale index with no single weight is breakage.""" snap, blob = _mk_snapshot(tmp_path, "single_variant_beats_index") diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index e5c35b04c..548068442 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -867,12 +867,9 @@ def _selected_shard_index_incomplete( _CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") -def _diffusers_declared_components(snapshot_dir: Path) -> "Optional[set]": - """The component subfolder names a diffusers ``model_index.json`` declares (top-level keys mapping to - a ``[library, class]`` list; ``_``-prefixed metadata excluded). None when absent / unreadable / - malformed, so the caller falls back to every subfolder (fail OPEN, preserving hang protection). - Scopes the component check to what the pipeline reads, so a stale UNDECLARED subtree cannot - force-fail a complete pipeline download.""" +def _diffusers_declared_component_specs(snapshot_dir: Path) -> "Optional[dict]": + """name -> declared ``[library, class]`` spec from a diffusers ``model_index.json`` (``_``-prefixed + metadata excluded). None when absent / unreadable / malformed / empty, so the caller fails OPEN.""" import json try: @@ -882,13 +879,20 @@ def _diffusers_declared_components(snapshot_dir: Path) -> "Optional[set]": return None if not isinstance(data, dict): return None - components = { - key for key, value in data.items() + specs = { + key: value for key, value in data.items() if not key.startswith("_") and isinstance(value, (list, tuple)) } - # An empty / all-metadata model_index.json is degenerate -> fail OPEN (None) so the caller checks - # every subfolder, preserving hang protection. - return components or None + # An empty / all-metadata model_index.json is degenerate -> fail OPEN (None). + return specs or None + + +def _diffusers_declared_components(snapshot_dir: Path) -> "Optional[set]": + """The component subfolder names a diffusers ``model_index.json`` declares. None when absent / + unreadable / malformed so the caller falls back to every subfolder (fail OPEN). Scopes the component + check to what the pipeline reads, so a stale UNDECLARED subtree cannot force-fail a complete download.""" + specs = _diffusers_declared_component_specs(snapshot_dir) + return set(specs) if specs else None def _diffusers_component_shards_incomplete( diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index c13d5aadd..8decbb166 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -41,7 +41,7 @@ _ROOT_MODEL_VARIANT_WEIGHT_RE, _as_pattern_list, _diffusers_component_shards_incomplete, - _diffusers_declared_components, + _diffusers_declared_component_specs, _filter_paths, _has_glob, _has_incomplete_canonical_root_shards, @@ -1015,24 +1015,44 @@ def _pytorch_root_weight_formats_ignored(ignore_patterns: Any) -> bool: _CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") -def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: - """True if a DECLARED diffusers pipeline COMPONENT weight (in a ``model_index.json``-declared - subfolder: unet/, vae/, ...) that the ignore filter keeps is present. Scoped to declared components so - a stale UNDECLARED leftover subtree (a controlnet/ dir not declared) does not read as a warm pipeline - while the declared unet / vae weights are missing (which the load would fetch over un-killable Xet). - Excludes ROOT-level weights and training-checkpoint subtrees. A malformed / empty ``model_index.json`` - fails OPEN. Lenient on WHICH declared components are required (they can be optional) -- only - distinguishes a real component warm from an undeclared-leftover / checkpoint-only / config-only stale - snapshot. Counts only CANONICAL (non-variant) weights: a variant-only stale cache is retried over HTTP - (its non-variant component weight is still missing).""" - declared = _diffusers_declared_components(snapshot_dir) - rels: list = [] +def _diffusers_active_component_dirs(specs: dict) -> set: + """Declared components a pipeline actually loads: spec is a ``[library, class]`` pair with both + non-null. A ``[null, null]`` (a disabled / optional component such as safety_checker) is excluded -- + the load skips it, so it is not required to be present.""" + active: set = set() + for name, spec in specs.items(): + if ( + isinstance(spec, (list, tuple)) and len(spec) >= 2 + and spec[0] is not None and spec[1] is not None + ): + active.add(name) + return active + + +def _diffusers_component_weights_complete( + snapshot_dir: Path, *, variant: Optional[str], ignore_patterns: Any = None, +) -> bool: + """True when a diffusers pipeline warm holds every weight a plain / variant load reads. Beyond "some + declared component weight is present" it requires each DECLARED ACTIVE component to be materialised (a + non-empty subfolder) AND each model-style component (one carrying ``config.json``) to hold a readable + weight of the read format -- so a stale partial missing a whole component (unet present, vae absent) or + holding a component's config without its weight is retried over un-killable Xet, not loaded. Excludes + ROOT-level weights and training-checkpoint subtrees; applies the ignore filter (format the load reads). + Fails OPEN on a malformed / empty ``model_index.json`` to the lenient any-component-weight check, + preserving hang protection without false-rejecting. A variant load accepts a component's canonical + weight as diffusers' per-component fallback.""" + specs = _diffusers_declared_component_specs(snapshot_dir) + declared = set(specs) if specs else None + active = _diffusers_active_component_dirs(specs) if specs else None + infix_dot = f".{variant}." if variant else "" + infix_dash = f".{variant}-" if variant else "" + per_comp_canon: dict = {} + per_comp_variant: dict = {} try: for entry in snapshot_dir.rglob("*"): - if not _is_default_load_weight_file(entry.name): + name = entry.name + if not _is_default_load_weight_file(name): continue - if not _CANONICAL_COMPONENT_WEIGHT_RE.match(entry.name): - continue # a VARIANT weight -- a plain load reads the non-variant name try: if not entry.is_file(): continue @@ -1042,14 +1062,56 @@ def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any parts = rel.split("/") if len(parts) < 2: continue # a ROOT-level weight is not a component - if declared is not None and parts[0] not in declared: + comp = parts[0] + if declared is not None and comp not in declared: continue # an UNDECLARED subtree the load does not read if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): continue # a training-checkpoint subtree, not a component - rels.append(rel) + if variant and (infix_dot in name or infix_dash in name): + per_comp_variant.setdefault(comp, []).append(rel) + elif _CANONICAL_COMPONENT_WEIGHT_RE.match(name): + per_comp_canon.setdefault(comp, []).append(rel) except OSError: return False - return bool(_filter_paths(rels, None, ignore_patterns)) + + def _has_canon(comp: str) -> bool: + return bool(_filter_paths(per_comp_canon.get(comp, []), None, ignore_patterns)) + + def _has_variant(comp: str) -> bool: + return bool(_filter_paths(per_comp_variant.get(comp, []), None, ignore_patterns)) + + def _has_read_weight(comp: str) -> bool: + # variant load falls back to a component's canonical weight when it ships no variant file + return _has_variant(comp) or _has_canon(comp) if variant else _has_canon(comp) + + if active is not None: + for comp in active: + comp_dir = snapshot_dir / comp + try: + present = comp_dir.is_dir() and any(comp_dir.iterdir()) + except OSError: + present = False + if not present: + return False # a declared active component was never materialised + try: + has_config = (comp_dir / "config.json").is_file() + except OSError: + has_config = False + if has_config and not _has_read_weight(comp): + return False # a model-style component holds its config but no readable weight + # Floor: at least one component holds a weight of the READ format -- rejects a variant-only-for-plain, + # config-only, checkpoint-only, or undeclared-leftover-only stale snapshot. + if variant: + return any(_has_variant(c) for c in (declared or per_comp_variant)) + return any(_has_canon(c) for c in (declared or per_comp_canon)) + + +def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: + """Whether a PLAIN (non-variant) diffusers pipeline warm is complete for the weights a load reads + (see ``_diffusers_component_weights_complete``).""" + return _diffusers_component_weights_complete( + snapshot_dir, variant = None, ignore_patterns = ignore_patterns + ) def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: @@ -1133,40 +1195,13 @@ def _root_has_variant_weight( def _has_diffusers_component_variant_weight( snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None ) -> bool: - """Variant analog of ``_has_diffusers_component_weight``: True if a DECLARED component subfolder holds - a weight with the requested *variant* token. A variant pipeline's weights are component-scoped, not - root ``model..*``, so a root-only check would false-reject a complete diffusers variant - download. Scoped to declared components (an undeclared leftover does not read as warm), fails OPEN on a - malformed ``model_index.json``, excludes ROOT-level / training-checkpoint weights, reads only - safetensors / bin.""" - declared = _diffusers_declared_components(snapshot_dir) - infix_dot = f".{variant}." - infix_dash = f".{variant}-" - rels: list = [] - try: - for entry in snapshot_dir.rglob("*"): - name = entry.name - if not _is_default_load_weight_file(name): - continue - if infix_dot not in name and infix_dash not in name: - continue - try: - if not entry.is_file(): - continue - rel = entry.relative_to(snapshot_dir).as_posix() - except (OSError, ValueError): - continue - parts = rel.split("/") - if len(parts) < 2: - continue # a ROOT-level variant weight is not a component - if declared is not None and parts[0] not in declared: - continue # an UNDECLARED subtree the load does not read - if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): - continue # a training-checkpoint subtree, not a component - rels.append(rel) - except OSError: - return False - return bool(_filter_paths(rels, None, ignore_patterns)) + """Variant analog of ``_has_diffusers_component_weight``: whether a diffusers *variant* pipeline warm + is complete (a pipeline's variant weights are component-scoped, not root ``model..*``, so a + root-only check would false-reject a complete variant download). See + ``_diffusers_component_weights_complete``.""" + return _diffusers_component_weights_complete( + snapshot_dir, variant = variant, ignore_patterns = ignore_patterns + ) def _root_model_has_variant_weight( From 8accd3bb0f654b0704c70bb5e0a05cc325e416c7 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Sat, 4 Jul 2026 01:25:17 +0000 Subject: [PATCH 80/82] Reject a diffusers component dir holding neither config nor weight Follow-up review pass (12x reviewer.py personas + 3 forks) on the diffusers component-completeness gate. A residual false-accept remained: an active component dir that is non-empty but carries NEITHER a top-level config.json NOR a readable weight NOR a weightless sidecar (only a stray file or a nested training-checkpoint) passed the presence rule and skipped the config-gated weight rule, so a stale partial mid-download component was accepted and the in-process pipeline load then fetched its weight over un-killable Xet. A component now satisfies the warm only if it holds a readable weight of the read format, OR is WEIGHTLESS-shaped (a scheduler / tokenizer / feature_extractor sidecar: a *_config.json / tokenizer file and no config.json). A dir with neither is an incomplete model component and is retried over HTTP. Complete pipelines (including SDXL dual text encoders, sharded components, fp16 mixed / canonical fallback) still accept; the safety-invariant fuzz stays at 0 false-accepts and the suite is 193 passed. Not changed (re-verified this pass): the variant floor requiring at least one variant weight is correct -- diffusers itself raises on a zero-variant repo, so a canonical-only request for a variant is not a load diffusers would complete. The recurring reviewer.py "stale-branch revert" flag is a harness artifact (GitHub reports the PR mergeable / clean). --- tests/test_hf_xet_fallback.py | 52 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_xet_fallback.py | 41 +++++++++++++++++++-------- 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index 0f338a24a..a6cad0b9a 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -2974,6 +2974,58 @@ def _weightless(root, name): variant = "fp16") is True +def test_post_download_rejects_diffusers_component_dir_without_weight_or_config(tmp_path): + """A declared active component dir that is non-empty but holds NEITHER a config.json NOR a readable + weight NOR a weightless sidecar (only a stray file or a nested checkpoint) is an incomplete model + component -> retried over HTTP, not accepted. A weightless component (its *_config.json sidecar) still + passes without a weight.""" + def _mi(): + return json.dumps({ + "_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], + "vae": ["diffusers", "AutoencoderKL"], + "scheduler": ["diffusers", "PNDMScheduler"], + "tokenizer": ["transformers", "CLIPTokenizer"], + }) + + def _base(root, name, blob): + (root / "model_index.json").write_text(_mi()) + d = root / "unet" + d.mkdir() + (d / "config.json").write_text("{}") + (d / "diffusion_pytorch_model.safetensors").symlink_to(blob) + for n in ("scheduler", "tokenizer"): + wl = root / n + wl.mkdir() + (wl / f"{n}_config.json").write_text("{}") + + # vae dir holds only a nested training-checkpoint weight (no top-level config / weight) -> reject. + snap, blob = _mk_snapshot(tmp_path, "diff_comp_nested_ckpt") + _base(snap, "unet", blob) + ckpt = snap / "vae" / "checkpoint-500" + ckpt.mkdir(parents = True) + (ckpt / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + # vae dir holds only a stray non-weight file -> reject. + snap2, blob2 = _mk_snapshot(tmp_path, "diff_comp_stray") + _base(snap2, "unet", blob2) + (snap2 / "vae").mkdir() + (snap2 / "vae" / "README.md").write_text("x") + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + # Complete: vae holds config + weight -> accept (weightless scheduler/tokenizer need no weight). + snap3, blob3 = _mk_snapshot(tmp_path, "diff_comp_complete") + _base(snap3, "unet", blob3) + (snap3 / "vae").mkdir() + (snap3 / "vae" / "config.json").write_text("{}") + (snap3 / "vae" / "diffusion_pytorch_model.safetensors").symlink_to(blob3) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + def test_post_download_single_variant_beats_stale_variant_index(tmp_path): """Variant twin of single-beats-index: a complete single variant weight beside a stale variant index is usable (ST and bin); a stale index with no single weight is breakage.""" snap, blob = _mk_snapshot(tmp_path, "single_variant_beats_index") diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 8decbb166..594816119 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -1029,14 +1029,29 @@ def _diffusers_active_component_dirs(specs: dict) -> set: return active +def _is_weightless_component(names: list) -> bool: + """Whether a component dir (given its top-level file names) is a WEIGHTLESS component -- + scheduler / tokenizer / feature_extractor, which ship a ``*_config.json`` / tokenizer sidecar and no + model weight. Lets such a component satisfy the warm without a weight, while a dir carrying neither a + ``config.json`` nor a readable weight nor a recognised sidecar reads as an incomplete model component.""" + return any( + n.endswith("_config.json") + or n.startswith("tokenizer") + or n in ("preprocessor_config.json", "special_tokens_map.json", "merges.txt", "vocab.json") + for n in names + ) + + def _diffusers_component_weights_complete( snapshot_dir: Path, *, variant: Optional[str], ignore_patterns: Any = None, ) -> bool: """True when a diffusers pipeline warm holds every weight a plain / variant load reads. Beyond "some declared component weight is present" it requires each DECLARED ACTIVE component to be materialised (a - non-empty subfolder) AND each model-style component (one carrying ``config.json``) to hold a readable - weight of the read format -- so a stale partial missing a whole component (unet present, vae absent) or - holding a component's config without its weight is retried over un-killable Xet, not loaded. Excludes + non-empty subfolder) AND a model-style component to hold a readable weight of the read format -- a + component is model-style unless it is WEIGHTLESS-shaped (a scheduler / tokenizer / feature_extractor + sidecar, no config.json + no weight). So a stale partial missing a whole component (unet present, vae + absent), holding a component's config without its weight, or holding only a stray / nested-checkpoint + file in a component dir is retried over un-killable Xet, not loaded. Excludes ROOT-level weights and training-checkpoint subtrees; applies the ignore filter (format the load reads). Fails OPEN on a malformed / empty ``model_index.json`` to the lenient any-component-weight check, preserving hang protection without false-rejecting. A variant load accepts a component's canonical @@ -1088,17 +1103,19 @@ def _has_read_weight(comp: str) -> bool: for comp in active: comp_dir = snapshot_dir / comp try: - present = comp_dir.is_dir() and any(comp_dir.iterdir()) + names = [e.name for e in comp_dir.iterdir()] if comp_dir.is_dir() else [] except OSError: - present = False - if not present: + names = [] + if not names: return False # a declared active component was never materialised - try: - has_config = (comp_dir / "config.json").is_file() - except OSError: - has_config = False - if has_config and not _has_read_weight(comp): - return False # a model-style component holds its config but no readable weight + if "config.json" in names: # a model-style component MUST hold a readable weight + if not _has_read_weight(comp): + return False + elif not _has_read_weight(comp) and not _is_weightless_component(names): + # No config.json and no readable weight: OK only when the dir is weightless-SHAPED + # (scheduler / tokenizer / feature_extractor sidecars). Otherwise it is an incomplete + # model-style component whose weight the load would fetch over un-killable Xet. + return False # Floor: at least one component holds a weight of the READ format -- rejects a variant-only-for-plain, # config-only, checkpoint-only, or undeclared-leftover-only stale snapshot. if variant: From 0beffb3676dda1c9667d31aaa9e3f4133f5269f9 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Sun, 5 Jul 2026 12:47:15 +0000 Subject: [PATCH 81/82] Tighten diffusers component weights and let a subfolder single beat a stale index Two post-download completeness fixes from a re-review of the Codex feedback. 1. A declared diffusers component was treated as warm by ANY no-dot safetensors/bin basename, including a sidecar such as unet/adapter_model.safetensors. The pipeline still loads the component's canonical weight (unet/diffusion_pytorch_model.*), so a stale cache holding only the sidecar was accepted and the missing base weight was fetched in-process over un-killable Xet. The component-weight regex now matches only the canonical basenames diffusers / transformers components load (diffusion_pytorch_model / model / pytorch_model, single or numbered shard). 2. A selected subfolder holding a single canonical weight beside a stale same-format shard index (e.g. encoder/model.safetensors + a leftover encoder/model.safetensors.index.json missing shards) was rejected, then retried into a DownloadStallError even though transformers reads the single file before the index. _selected_shard_index_incomplete now records single weights per dir and lets a single of a format beat that format's index there, mirroring the root precedence. Adds regressions for both. Full suite 195 passed; safety-invariant fuzz stays at 0 false-accepts and the e2e stall/HTTP recovery stays green. Not changed (re-verified, no new evidence): the dashed-only adapter glob (adapter_model-*) and pytorch_model* root-format edges are not loader-reachable (loaders pass adapter_model* / allow=None, both correctly classified); the TF/Flax patterned-shard gap is the documented low-reachability edge; injected prepare_for_http_fn, crashed-child coarse cleanup, and import-lightness remain by-design. --- tests/test_hf_xet_fallback.py | 45 ++++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 24 +++++++++++++++--- unsloth_zoo/hf_xet_fallback.py | 2 +- 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index a6cad0b9a..c0097c546 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -3026,6 +3026,51 @@ def _base(root, name, blob): snap3, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True +def test_post_download_rejects_diffusers_component_with_only_sidecar_weight(tmp_path): + """A declared component whose only weight is a non-canonical sidecar (unet/adapter_model.safetensors) + is NOT a warm component: the pipeline still loads unet/diffusion_pytorch_model.* and would fetch it + over Xet. Only the canonical component weight names count.""" + snap, blob = _mk_snapshot(tmp_path, "comp_sidecar_only") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "UNet2DConditionModel"]})) + unet = snap / "unet" + unet.mkdir() + (unet / "config.json").write_text("{}") + (unet / "adapter_model.safetensors").symlink_to(blob) # sidecar, not the base weight + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # The canonical base weight present -> accepted. + (unet / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_subfolder_single_weight_beats_stale_index(tmp_path): + """In a selected subfolder, a single canonical weight is read before a same-format shard index + (transformers precedence), so a stale co-resident index must not false-reject the warm; a shard index + with no single is still required complete.""" + # encoder/model.safetensors present beside a stale encoder/model.safetensors.index.json -> accepted. + snap, blob = _mk_snapshot(tmp_path, "subdir_single_beats_index") + enc = snap / "encoder" + enc.mkdir() + (enc / "model.safetensors").symlink_to(blob) + (enc / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) # shards absent (stale) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["encoder/*"], ignore_patterns = None) is True + # No single, index missing shards -> still incomplete. + snap2, blob2 = _mk_snapshot(tmp_path, "subdir_index_only") + enc2 = snap2 / "encoder" + enc2.mkdir() + (enc2 / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + (enc2 / "model-00001-of-00002.safetensors").symlink_to(blob2) # one shard, second absent + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["encoder/*"], ignore_patterns = None) is False + + def test_post_download_single_variant_beats_stale_variant_index(tmp_path): """Variant twin of single-beats-index: a complete single variant weight beside a stale variant index is usable (ST and bin); a stale index with no single weight is breakage.""" snap, blob = _mk_snapshot(tmp_path, "single_variant_beats_index") diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index 548068442..da07f9858 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -786,11 +786,14 @@ def _selected_shard_index_incomplete( allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) want_variant = variant or None + _canon = r"(?:diffusion_pytorch_model|model|pytorch_model)" if want_variant is None: shard_file_re = re.compile(r"^[^.]+-\d{5}-of-\d{5}\.(?:safetensors|bin)$") + single_file_re = re.compile(rf"^{_canon}\.(?:safetensors|bin)$") else: v = re.escape(want_variant) shard_file_re = re.compile(rf"^[^.]+\.{v}-\d{{5}}-of-\d{{5}}\.(?:safetensors|bin)$") + single_file_re = re.compile(rf"^{_canon}\.{v}\.(?:safetensors|bin)$") try: entries = list(snapshot_dir.rglob("*")) except OSError: @@ -798,6 +801,7 @@ def _selected_shard_index_incomplete( per_dir: dict = {} # dir_rel -> {"safetensors": [shard_rels, ...], "bin": [...]} (from indices) index_fmts: dict = {} # dir_rel -> {fmt} an index of the read variant is present (non-root-model) shard_fmts: dict = {} # dir_rel -> {fmt} a SELECTED numbered shard file is present (non-root-model) + single_fmts: dict = {} # dir_rel -> {fmt} a SELECTED single weight is present (beats that dir's index) for entry in entries: name = entry.name if not _safe_is_file(entry): @@ -844,9 +848,23 @@ def _selected_shard_index_incomplete( continue # the load does not read this shard (out of scope / ignored format) fmt = "safetensors" if name.endswith(".safetensors") else "bin" shard_fmts.setdefault(dir_rel, set()).add(fmt) - for by_fmt in per_dir.values(): - # safetensors read before bin: require only the preferred format present in this directory. - for shard_rels in by_fmt.get("safetensors") or by_fmt.get("bin") or []: + elif single_file_re.match(name): + # a SINGLE canonical weight of the read variant: transformers/diffusers read it before a + # same-format shard index in the SAME dir, so a stale co-resident index must not reject it. + if _filter_paths([rel], allow_patterns, ignore_patterns): + fmt = "safetensors" if name.endswith(".safetensors") else "bin" + single_fmts.setdefault(dir_rel, set()).add(fmt) + for dir_rel, by_fmt in per_dir.items(): + singles = single_fmts.get(dir_rel, set()) + # safetensors read before bin, and a single weight of a format beats that format's stale index. + if "safetensors" in singles: + continue + required = by_fmt.get("safetensors") + if required is None: + if "bin" in singles: + continue + required = by_fmt.get("bin") or [] + for shard_rels in required: for shard in shard_rels: try: if not (snapshot_dir / shard).exists(): diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 594816119..73011abd1 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -989,7 +989,7 @@ def _is_default_load_weight_file(name: str) -> bool: # weight (diffusion_pytorch_model.fp16.safetensors) is EXCLUDED, so a variant='fp16' stale cache does not # read as a warm PLAIN pipeline (which would fetch the non-variant name over un-killable Xet). _CANONICAL_COMPONENT_WEIGHT_RE = re.compile( - r"^[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" + r"^(?:diffusion_pytorch_model|model|pytorch_model)(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" ) # SINGLE-FILE canonical root TF / Flax weight a from_tf / from_flax load reads instead of a PyTorch From 296fdc87864bbfd31427399d2ad87a1449796b37 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Sun, 5 Jul 2026 13:43:30 +0000 Subject: [PATCH 82/82] Close cross-path gaps in the diffusers/adapter cache-completeness checks Follow-ups from a Codex re-review of the previous tightening; several fixes applied to the plain / root path had not been mirrored onto the parallel variant / component / adapter paths. hf_xet_fallback.py: - Variant diffusers components now filter to CANONICAL component bases (diffusion_pytorch_model / model / pytorch_model + variant token), matching the plain branch. A variant-named sidecar (unet/adapter_model.fp16.safetensors) no longer counts as a warm component weight the pipeline never reads, which would have fetched the real diffusion_pytorch_model.fp16.* over un-killable Xet. - A model component holding neither config.json nor a weight is decided against its DECLARED class from model_index.json, not just sidecar filenames: a weight-bearing class (UNet2DConditionModel, AutoencoderKL, ...) always requires a weight, so a stray tokenizer_config.json under unet/ can no longer pass the component off as weightless. Falls back to the filename heuristic only when the class is unavailable. - A stall that fires before the child opens any partial (connect / metadata phase) yields an EMPTY ownership set; that would scope the HTTP-prep purge to nothing and strand a pre-existing stale blob / link for the retry to inherit. Normalise an empty set to None (unscoped) so the mtime + active-partner guards still clear genuinely-stale state while sparing a live sibling. hf_cache_state.py: - _diffusers_component_shards_incomplete now records same-dir SINGLE component weights and lets a single beat that dir's stale shard index (the single-before- index precedence already used for root and patterned weights), so a usable unet/diffusion_pytorch_model.safetensors beside a stale index is no longer false-rejected into a raised DownloadStallError. - The single-file precedence in _selected_shard_index_incomplete now includes adapter_model, so a PEFT adapter load with a single adapter_model.safetensors beside a stale adapter_model.safetensors.index.json is accepted. Adds five regression tests; full suite green and the layout safety fuzz stays at zero false-accepts. --- tests/test_hf_xet_fallback.py | 145 +++++++++++++++++++++++++++++++++ unsloth_zoo/hf_cache_state.py | 30 ++++++- unsloth_zoo/hf_xet_fallback.py | 61 ++++++++++++-- 3 files changed, 224 insertions(+), 12 deletions(-) diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py index c0097c546..d779370a0 100644 --- a/tests/test_hf_xet_fallback.py +++ b/tests/test_hf_xet_fallback.py @@ -3102,6 +3102,151 @@ def test_post_download_single_variant_beats_stale_variant_index(tmp_path): variant = "fp16") is False +def test_post_download_variant_component_sidecar_is_not_warm(tmp_path): + """A variant diffusers load must count only CANONICAL component variant weights: a variant-named + SIDECAR (unet/adapter_model.fp16.safetensors) is not the weight the pipeline reads + (unet/diffusion_pytorch_model.fp16.*), so a component holding only the sidecar is retried over Xet, + not accepted. Mirrors the plain-branch canonical-basename filter.""" + snap, blob = _mk_snapshot(tmp_path, "variant_comp_sidecar_only") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "UNet2DConditionModel"]})) + unet = snap / "unet" + unet.mkdir() + (unet / "config.json").write_text("{}") + (unet / "adapter_model.fp16.safetensors").symlink_to(blob) # variant sidecar, not the base weight + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # The canonical component variant weight present -> accepted. + (unet / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_diffusers_component_single_beats_stale_index(tmp_path): + """A diffusers pipeline component reads a single canonical weight before a same-format shard index in + its own subfolder, so a stale co-resident index must not false-reject a complete component; a component + holding only a stale index (no single) is still incomplete.""" + # unet single diffusion_pytorch_model.safetensors beside a stale same-dir index -> accepted. + snap, blob = _mk_snapshot(tmp_path, "diff_comp_single_beats_index") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "UNet2DConditionModel"]})) + unet = snap / "unet" + unet.mkdir() + (unet / "config.json").write_text("{}") + (unet / "diffusion_pytorch_model.safetensors").symlink_to(blob) + (unet / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) # shards absent + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # No single, index lists a missing shard -> still incomplete. + snap2, blob2 = _mk_snapshot(tmp_path, "diff_comp_index_only") + (snap2 / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "UNet2DConditionModel"]})) + unet2 = snap2 / "unet" + unet2.mkdir() + (unet2 / "config.json").write_text("{}") + (unet2 / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) + (unet2 / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob2) # 1 shard, 2nd absent + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_adapter_single_beats_stale_index(tmp_path): + """A PEFT adapter load (allow=['adapter_model*', 'adapter_config.json']) reads a single + adapter_model.safetensors before a shard index, so a stale co-resident adapter_model.safetensors.index.json + must not false-reject the warm; a sharded adapter missing a shard is still incomplete.""" + snap, blob = _mk_snapshot(tmp_path, "adapter_single_beats_index") + (snap / "adapter_config.json").write_text("{}") + (snap / "adapter_model.safetensors").symlink_to(blob) + (snap / "adapter_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "adapter_model-00001-of-00002.safetensors", + "b": "adapter_model-00002-of-00002.safetensors"}})) # shards absent (stale) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is True + # Sharded adapter, one shard absent, no single -> incomplete. + snap2, blob2 = _mk_snapshot(tmp_path, "adapter_sharded_incomplete") + (snap2 / "adapter_config.json").write_text("{}") + (snap2 / "adapter_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "adapter_model-00001-of-00002.safetensors", + "b": "adapter_model-00002-of-00002.safetensors"}})) + (snap2 / "adapter_model-00001-of-00002.safetensors").symlink_to(blob2) # 1 shard, 2nd absent + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is False + + +def test_post_download_model_component_stray_sidecar_rejected(tmp_path): + """A declared MODEL component (unet) holding only a stray weightless sidecar (a lone + tokenizer_config.json), no config.json and no weight, is an incomplete component -> rejected: its + declared class (UNet2DConditionModel) is weight-bearing, so a stray sidecar cannot pass it as + weightless. A genuinely weightless component (PNDMScheduler) needs no weight.""" + snap, blob = _mk_snapshot(tmp_path, "model_comp_stray_sidecar") + (snap / "model_index.json").write_text(json.dumps({ + "_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], + "vae": ["diffusers", "AutoencoderKL"], + "scheduler": ["diffusers", "PNDMScheduler"], + })) + # vae complete (satisfies the floor); scheduler weightless; unet holds ONLY a stray sidecar. + vae = snap / "vae"; vae.mkdir() + (vae / "config.json").write_text("{}") + (vae / "diffusion_pytorch_model.safetensors").symlink_to(blob) + sched = snap / "scheduler"; sched.mkdir() + (sched / "scheduler_config.json").write_text("{}") + unet = snap / "unet"; unet.mkdir() + (unet / "tokenizer_config.json").write_text("{}") # stray weightless-shaped sidecar, no weight/config + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # unet now holds its config + canonical weight -> accepted (scheduler still needs no weight). + (unet / "config.json").write_text("{}") + (unet / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_http_prep_empty_owned_set_scopes_to_nothing(tmp_path): + """An EMPTY owned-set scopes the purge to nothing (own nothing -> delete nothing). A stall that fires + before the child opens any partial must therefore pass ownership=None (unscoped), not the empty set, or + an aged pre-existing stale blob/link is stranded for the retry to inherit. Documents the semantic the + stall path relies on when it normalises an empty ownership set to None.""" + repo = "ztest/empty-owned" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + blobs = repo_dir / "blobs" + snap = repo_dir / "snapshots" / "sha" + blobs.mkdir(parents = True) + snap.mkdir(parents = True) + old = time.time() - 600 + + def _seed(): + for p in list(snap.iterdir()): + p.unlink() + stale_blob = blobs / "stalehash.incomplete" + stale_blob.write_bytes(b"x") + os.utime(stale_blob, (old, old)) + (snap / "stale.safetensors").symlink_to(blobs / "stalehash") + return stale_blob + + # Empty owned-set: scope to nothing -> the aged stale blob/link SURVIVE (why the empty set is unsafe). + stale_blob = _seed() + _REAL_DEFAULT_PREPARE("model", repo, cache_dir = str(tmp_path), active_grace = 180, + owned_incomplete_blobs = set()) + assert stale_blob.exists(), "empty owned-set must scope to nothing (delete nothing)" + assert (snap / "stale.safetensors").is_symlink() + + # None (what the stall path sends when ownership is empty): the mtime guard clears the aged stale state. + stale_blob = _seed() + _REAL_DEFAULT_PREPARE("model", repo, cache_dir = str(tmp_path), active_grace = 180, + owned_incomplete_blobs = None) + assert not stale_blob.exists(), "unscoped prep must purge the aged stale partial" + assert not (snap / "stale.safetensors").is_symlink() + + def test_post_download_diffusers_skips_root_model_shard_checks(tmp_path): """A diffusers pipeline reads component subfolders, so a stale root model shard index (canonical or variant) is accepted; component completeness is still enforced.""" # Plain: stale root model.safetensors.index.json alongside complete components. diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py index da07f9858..a8917ada8 100644 --- a/unsloth_zoo/hf_cache_state.py +++ b/unsloth_zoo/hf_cache_state.py @@ -786,7 +786,10 @@ def _selected_shard_index_incomplete( allow_patterns = _as_pattern_list(allow_patterns) ignore_patterns = _as_pattern_list(ignore_patterns) want_variant = variant or None - _canon = r"(?:diffusion_pytorch_model|model|pytorch_model)" + # Canonical single-weight bases whose presence beats a stale same-dir shard index. Includes + # adapter_model so a single adapter_model.safetensors (a PEFT adapter load, allow=['adapter_model*']) + # is not false-rejected by a co-resident stale adapter_model.safetensors.index.json. + _canon = r"(?:diffusion_pytorch_model|model|pytorch_model|adapter_model)" if want_variant is None: shard_file_re = re.compile(r"^[^.]+-\d{5}-of-\d{5}\.(?:safetensors|bin)$") single_file_re = re.compile(rf"^{_canon}\.(?:safetensors|bin)$") @@ -932,9 +935,15 @@ def _diffusers_component_shards_incomplete( declared = _diffusers_declared_components(snapshot_dir) if want_variant is None: shard_file_re = re.compile(r"^[^.]+-\d{5}-of-\d{5}\.(?:safetensors|bin)$") + single_file_re = re.compile( + r"^(?:diffusion_pytorch_model|model|pytorch_model)\.(?:safetensors|bin)$" + ) else: v = re.escape(want_variant) shard_file_re = re.compile(rf"^[^.]+\.{v}-\d{{5}}-of-\d{{5}}\.(?:safetensors|bin)$") + single_file_re = re.compile( + rf"^(?:diffusion_pytorch_model|model|pytorch_model)\.{v}\.(?:safetensors|bin)$" + ) try: entries = list(snapshot_dir.rglob("*")) except OSError: @@ -942,6 +951,7 @@ def _diffusers_component_shards_incomplete( per_dir: dict = {} index_fmts: dict = {} # component dir_rel -> {fmt} an index of the read variant is present shard_fmts: dict = {} # component dir_rel -> {fmt} a numbered shard file (ignore-kept) is present + single_fmts: dict = {} # component dir_rel -> {fmt} a SINGLE component weight (beats that dir's index) for entry in entries: name = entry.name if not _safe_is_file(entry): @@ -978,8 +988,22 @@ def _diffusers_component_shards_incomplete( elif shard_file_re.match(name) and _filter_paths([rel], None, ignore_patterns): fmt = "safetensors" if name.endswith(".safetensors") else "bin" shard_fmts.setdefault(dir_rel, set()).add(fmt) - for by_fmt in per_dir.values(): - for shard_rels in by_fmt.get("safetensors") or by_fmt.get("bin") or []: + elif single_file_re.match(name) and _filter_paths([rel], None, ignore_patterns): + # a SINGLE canonical component weight: the pipeline reads it before a same-format shard index + # in the SAME component dir, so a stale co-resident index must not reject a complete component. + fmt = "safetensors" if name.endswith(".safetensors") else "bin" + single_fmts.setdefault(dir_rel, set()).add(fmt) + for dir_rel, by_fmt in per_dir.items(): + singles = single_fmts.get(dir_rel, set()) + # safetensors is read before bin, and a single weight of a format beats that format's stale index. + if "safetensors" in singles: + continue + required = by_fmt.get("safetensors") + if required is None: + if "bin" in singles: + continue + required = by_fmt.get("bin") or [] + for shard_rels in required: for shard in shard_rels: try: if not (snapshot_dir / shard).exists(): diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py index 73011abd1..bf8530faa 100644 --- a/unsloth_zoo/hf_xet_fallback.py +++ b/unsloth_zoo/hf_xet_fallback.py @@ -899,12 +899,18 @@ def _run_download_attempt( # them. Prefer the per-pid open-fd set; else post-baseline partials; None -> coarser mtime # guard. owned = _child_open_incomplete_blobs(proc.pid) if proc.pid else None - if owned is None and baseline_partials is not None: + if not owned and baseline_partials is not None: + # None (no psutil / proc) OR an empty set (the child stalled in the connect / metadata + # phase before opening any partial): try the post-baseline diff before giving up on scope. current = set( _active_incomplete_blob_sizes(repo_type, repo_id, params.get("cache_dir")) ) owned = current - baseline_partials - params["_owned_incomplete_blobs"] = owned + # An empty ownership set would scope the HTTP-prep purge to NOTHING, leaving a pre-existing + # stale *.incomplete blob / dangling link for the retry to inherit and re-trip on. Fall back + # to None (unscoped) so the mtime + active-partner guards still clear genuinely-stale state + # while sparing a live sibling. + params["_owned_incomplete_blobs"] = owned or None _terminate_process_group(proc, grace_period) return ("stall", None) try: @@ -1042,6 +1048,32 @@ def _is_weightless_component(names: list) -> bool: ) +# Diffusers component classes that ship NO model weight (a config / tokenizer / preprocessor sidecar +# only). Matched as a SUBSTRING so *Scheduler / *Tokenizer / *TokenizerFast / *FeatureExtractor / +# *ImageProcessor / *Processor all register; no weight-bearing pipeline component class (a *Model / +# *Transformer2DModel / *AutoencoderKL / *ControlNetModel / *UNet*) contains one of these tokens. +_WEIGHTLESS_COMPONENT_CLASS_RE = re.compile( + r"Scheduler|Tokenizer|FeatureExtractor|ImageProcessor|Processor" +) + + +def _component_requires_weight(spec: Any, names: list) -> bool: + """Whether a component holding NEITHER a ``config.json`` NOR a readable weight is an INCOMPLETE model + component (True -> reject) rather than a genuinely WEIGHTLESS one (False). Prefers the DECLARED class + from ``model_index.json`` (a ``[library, class]`` pair): a scheduler / tokenizer / feature_extractor + class needs no weight, everything else does -- so a model component holding only a stray sidecar + (a lone ``tokenizer_config.json`` under ``unet/``) cannot masquerade as weightless. Falls back to the + sidecar-name heuristic only when the class is unavailable (a malformed / non-string spec).""" + cls = ( + spec[1] + if isinstance(spec, (list, tuple)) and len(spec) >= 2 and isinstance(spec[1], str) + else None + ) + if cls is not None: + return not bool(_WEIGHTLESS_COMPONENT_CLASS_RE.search(cls)) + return not _is_weightless_component(names) + + def _diffusers_component_weights_complete( snapshot_dir: Path, *, variant: Optional[str], ignore_patterns: Any = None, ) -> bool: @@ -1059,8 +1091,18 @@ def _diffusers_component_weights_complete( specs = _diffusers_declared_component_specs(snapshot_dir) declared = set(specs) if specs else None active = _diffusers_active_component_dirs(specs) if specs else None - infix_dot = f".{variant}." if variant else "" - infix_dash = f".{variant}-" if variant else "" + # A VARIANT component weight is a CANONICAL component base carrying the variant token, single or + # numbered shard: diffusion_pytorch_model.fp16.safetensors / ...fp16-00001-of-00002.safetensors. + # Filtering to canonical bases (like the plain branch below) stops a variant-named SIDECAR + # (unet/adapter_model.fp16.safetensors) from counting as a warm component weight the pipeline, which + # reads unet/diffusion_pytorch_model.fp16.*, would otherwise fetch over un-killable Xet. + variant_weight_re = ( + re.compile( + rf"^(?:diffusion_pytorch_model|model|pytorch_model)\.{re.escape(variant)}" + rf"(?:-\d{{5}}-of-\d{{5}})?\.(?:safetensors|bin)$" + ) + if variant else None + ) per_comp_canon: dict = {} per_comp_variant: dict = {} try: @@ -1082,7 +1124,7 @@ def _diffusers_component_weights_complete( continue # an UNDECLARED subtree the load does not read if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): continue # a training-checkpoint subtree, not a component - if variant and (infix_dot in name or infix_dash in name): + if variant_weight_re is not None and variant_weight_re.match(name): per_comp_variant.setdefault(comp, []).append(rel) elif _CANONICAL_COMPONENT_WEIGHT_RE.match(name): per_comp_canon.setdefault(comp, []).append(rel) @@ -1111,10 +1153,11 @@ def _has_read_weight(comp: str) -> bool: if "config.json" in names: # a model-style component MUST hold a readable weight if not _has_read_weight(comp): return False - elif not _has_read_weight(comp) and not _is_weightless_component(names): - # No config.json and no readable weight: OK only when the dir is weightless-SHAPED - # (scheduler / tokenizer / feature_extractor sidecars). Otherwise it is an incomplete - # model-style component whose weight the load would fetch over un-killable Xet. + elif not _has_read_weight(comp) and _component_requires_weight(specs.get(comp), names): + # No config.json and no readable weight: an incomplete model-style component whose weight + # the load would fetch over un-killable Xet. A genuinely weightless component (scheduler / + # tokenizer / feature_extractor) is exempt -- decided by its DECLARED class when known, so + # a model component holding only a stray sidecar cannot pass as weightless. return False # Floor: at least one component holds a weight of the READ format -- rejects a variant-only-for-plain, # config-only, checkpoint-only, or undeclared-leftover-only stale snapshot.