diff --git a/.github/workflows/consolidated-tests-ci.yml b/.github/workflows/consolidated-tests-ci.yml index 1c70ce6f2..cfdf6c619 100644 --- a/.github/workflows/consolidated-tests-ci.yml +++ b/.github/workflows/consolidated-tests-ci.yml @@ -99,11 +99,14 @@ jobs: # invisible until #739 because no job executed it (collect-only plus # the macOS shim smoke). test_training_utils_use_cache pins the # use_cache disable/restore contract for gradient checkpointing - # (#715). Executing both here keeps fixture/source drift visible. + # (#715). test_hf_xet_fallback exercises the Xet->HTTP stall fallback + # + cache-completeness gate (CPU-pure; collect-only would let it + # regress silently). Executing them here keeps fixture/source drift visible. run: | python -m pytest \ tests/test_training_utils_use_cache.py \ tests/test_mlx_finetune_last_n_layers.py \ + tests/test_hf_xet_fallback.py \ -v - name: pytest tests/test_mlx_module_exports + zoo-specific CPU tests diff --git a/tests/test_hf_xet_fallback.py b/tests/test_hf_xet_fallback.py new file mode 100644 index 000000000..a6cad0b9a --- /dev/null +++ b/tests/test_hf_xet_fallback.py @@ -0,0 +1,3643 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. + +"""Tests for unsloth_zoo.hf_xet_fallback: watchdog, Xet->HTTP transport policy, +file/snapshot entrypoints, the UNSLOTH_DISABLE_XET knob, and HF_HUB_DISABLE_XET. + +CPU-only, no network, no real subprocess (the download seam is monkeypatched). +The modules under test are loaded via importlib to avoid importing the full +``unsloth_zoo`` package (torch + GPU init). +""" + +from __future__ import annotations + +import errno +import importlib.util +import json +import os +import subprocess +import sys +import threading +import time +import types as _types +from pathlib import Path + +import pytest + +import huggingface_hub +from huggingface_hub import constants as hf_constants + +_ZOO_DIR = Path(__file__).resolve().parents[1] / "unsloth_zoo" + + +def _load(name: str, filename: str): + spec = importlib.util.spec_from_file_location(name, _ZOO_DIR / filename) + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + spec.loader.exec_module(module) + return module + + +# Package placeholder so intra-package imports in hf_xet_fallback resolve to the +# files loaded below. Restored afterwards: a leftover would shadow the real +# unsloth_zoo; the loaded modules keep their own references and work regardless. +_saved_modules = { + name: sys.modules.get(name) + for name in ("unsloth_zoo", "unsloth_zoo.hf_cache_state", "unsloth_zoo.hf_xet_fallback") +} +if "unsloth_zoo" not in sys.modules: + _pkg = _types.ModuleType("unsloth_zoo") + _pkg.__path__ = [str(_ZOO_DIR)] + sys.modules["unsloth_zoo"] = _pkg + +hcs = _load("unsloth_zoo.hf_cache_state", "hf_cache_state.py") +xf = _load("unsloth_zoo.hf_xet_fallback", "hf_xet_fallback.py") + +for _name, _mod in _saved_modules.items(): + if _mod is None: + sys.modules.pop(_name, None) + else: + sys.modules[_name] = _mod + +# Real prep impl, captured before the autouse fixture stubs the module attribute. +_REAL_DEFAULT_PREPARE = xf._default_prepare_for_http + + +# Watchdog: fires only on a constant-size .incomplete, sparse-aware byte total. +REPO = "ztest/xet-watchdog" + + +@pytest.fixture +def hf_cache(tmp_path, monkeypatch): + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path)) + return tmp_path + + +def _blobs_dir(root: Path, repo_id: str = REPO) -> Path: + d = root / f"models--{repo_id.replace('/', '--')}" / "blobs" + d.mkdir(parents = True, exist_ok = True) + return d + + +def _wait(predicate, timeout: float = 2.0, step: float = 0.02) -> bool: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if predicate(): + return True + time.sleep(step) + return predicate() + + +def test_constant_incomplete_fires_stall(hf_cache): + blobs = _blobs_dir(hf_cache) + (blobs / "deadbeef.incomplete").write_bytes(b"\0" * 1024) # never grows + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3 + ) + try: + assert _wait( + lambda: len(calls) >= 1, timeout = 3.0 + ), "watchdog never fired on a constant-size .incomplete" + finally: + stop.set() + assert "stalled" in calls[0].lower() + + +def test_growing_incomplete_never_stalls(hf_cache): + blobs = _blobs_dir(hf_cache) + part = blobs / "growing.incomplete" + part.write_bytes(b"\0" * 1024) + + grow_stop = threading.Event() + + def _grow(): + size = 1024 + while not grow_stop.wait(0.05): + size += 4096 + part.write_bytes(b"\0" * size) + + grower = threading.Thread(target = _grow, daemon = True) + grower.start() + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5 + ) + try: + time.sleep(1.0) # well past stall_timeout, but bytes keep growing + assert calls == [], "watchdog fired despite continuous progress" + finally: + stop.set() + grow_stop.set() + + +def test_no_incomplete_never_stalls(hf_cache): + blobs = _blobs_dir(hf_cache) + (blobs / "finalized_blob").write_bytes(b"\0" * 4096) # no .incomplete + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5 + ) + try: + time.sleep(0.8) + assert calls == [], "watchdog fired with no active .incomplete" + finally: + stop.set() + + +def test_transient_unmeasurable_tick_is_progress(hf_cache, monkeypatch): + """An unmeasurable tick (state -> None) counts as progress, but a later frozen state still stalls.""" + seq = {"n": 0} + frozen = (2048, True) # constant size + active .incomplete + + def fake_state(*args, **kwargs): + seq["n"] += 1 + return None if seq["n"] <= 8 else frozen # first ~8 ticks unmeasurable, then frozen + + monkeypatch.setattr(xf, "get_hf_download_state", fake_state) + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + ) + try: + time.sleep(0.3) # within the unmeasurable window + assert calls == [], "watchdog fired during a transient-unmeasurable window" + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), "stall never fired after state recovered" + finally: + stop.set() + + +def test_stall_fires_at_most_once(hf_cache): + blobs = _blobs_dir(hf_cache) + (blobs / "frozen.incomplete").write_bytes(b"\0" * 2048) + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.2 + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0) + time.sleep(0.6) # keep ticking; must not fire again + assert len(calls) == 1, f"on_stall fired {len(calls)} times, expected exactly 1" + finally: + stop.set() + + +def test_file_watchdog_scopes_to_child_partial(hf_cache): + """A single-file download follows only its own child's partials, so a growing baseline sibling does not mask its stalled child.""" + blobs = _blobs_dir(hf_cache) + sibling = blobs / "sibling.incomplete" # in flight -> in baseline + sibling.write_bytes(b"\0" * 1024) + baseline = {"sibling.incomplete"} + + grow_stop = threading.Event() + + def _grow(): + size = 1024 + while not grow_stop.wait(0.05): + size += 4096 + sibling.write_bytes(b"\0" * size) # healthy sibling progresses + + grower = threading.Thread(target = _grow, daemon = True) + grower.start() + + # This download's child writes its own constant (stalled) partial, not in baseline. + (blobs / "child.incomplete").write_bytes(b"\0" * 2048) + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + watch_new_partials_only = True, baseline_incomplete_blobs = baseline, + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), ( + "file watchdog never fired: a growing sibling partial masked the stalled child" + ) + finally: + stop.set() + grow_stop.set() + + +def test_repo_wide_watchdog_is_masked_by_sibling(hf_cache): + """Contrast for file-scoping: the default repo-wide measure sums every blob, so a growing sibling resets the timer.""" + blobs = _blobs_dir(hf_cache) + sibling = blobs / "sibling.incomplete" + sibling.write_bytes(b"\0" * 1024) + (blobs / "child.incomplete").write_bytes(b"\0" * 2048) # constant + + grow_stop = threading.Event() + + def _grow(): + size = 1024 + while not grow_stop.wait(0.05): + size += 4096 + sibling.write_bytes(b"\0" * size) + + grower = threading.Thread(target = _grow, daemon = True) + grower.start() + + calls: list[str] = [] + stop = xf.start_watchdog( # default: repo-wide (watch_new_partials_only = False) + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5, + ) + try: + time.sleep(1.0) # past stall_timeout, but repo-wide bytes keep growing + assert calls == [], "repo-wide watchdog should be reset by the growing sibling" + finally: + stop.set() + grow_stop.set() + + +def test_file_watchdog_ignores_baseline_only_partials(hf_cache): + """When the only partial is a baseline sibling's, the file watchdog owns nothing and must not fire.""" + blobs = _blobs_dir(hf_cache) + (blobs / "sibling.incomplete").write_bytes(b"\0" * 4096) # constant baseline sibling + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5, + watch_new_partials_only = True, baseline_incomplete_blobs = {"sibling.incomplete"}, + ) + try: + time.sleep(0.8) + assert calls == [], "file watchdog fired on a baseline sibling partial it must ignore" + finally: + stop.set() + + +def _spawn_holding_open(path: Path) -> "subprocess.Popen": + """Child process that opens *path* and holds it open without writing (a hung download); prints 'ok' when open.""" + code = ( + "import sys, time\n" + "f = open(sys.argv[1], 'r+b')\n" + "sys.stdout.write('ok'); sys.stdout.flush()\n" + "time.sleep(30)\n" + ) + proc = subprocess.Popen( + [sys.executable, "-c", code, str(path)], stdout = subprocess.PIPE + ) + assert proc.stdout.read(2) == b"ok" # wait until the child holds the file open + return proc + + +def test_file_watchdog_detects_resumed_baseline_partial(hf_cache): + """A resumed download reuses a baseline .incomplete; pid-scoping (not name exclusion) still detects the hang.""" + blobs = _blobs_dir(hf_cache) + partial = blobs / "resumed.incomplete" + partial.write_bytes(b"\0" * 4096) # leftover from a prior interrupted download + baseline = {"resumed.incomplete"} # present before the resuming child starts + + child = _spawn_holding_open(partial) # hung resume: holds it open, never grows it + try: + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + watch_new_partials_only = True, baseline_incomplete_blobs = baseline, + child_pid = child.pid, + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), ( + "watchdog did not fire on a hung resume of a baseline partial" + ) + finally: + stop.set() + finally: + child.terminate() + child.wait(timeout = 5) + + +def test_file_watchdog_resumed_partial_fires_without_pid_ownership(hf_cache, monkeypatch): + """When ownership is unknowable (no psutil/proc -> None), the repo-wide fallback still stalls a frozen resumed baseline partial.""" + blobs = _blobs_dir(hf_cache) + (blobs / "resumed.incomplete").write_bytes(b"\0" * 4096) # leftover resume, constant (hung) + monkeypatch.setattr(xf, "_child_open_incomplete_blobs", lambda pid: None) # no psutil / no /proc + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + watch_new_partials_only = True, baseline_incomplete_blobs = {"resumed.incomplete"}, + child_pid = 4242, # non-None, but open-file inspection -> None -> repo-wide fallback + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), ( + "watchdog never fired on a resumed baseline partial when pid ownership is unavailable" + ) + finally: + stop.set() + + +def test_file_watchdog_pid_scope_ignores_unowned_sibling(hf_cache): + """With pid scoping, a growing sibling the child does not hold open is ignored; the child's own constant partial still stalls.""" + blobs = _blobs_dir(hf_cache) + owned_partial = blobs / "owned.incomplete" + owned_partial.write_bytes(b"\0" * 2048) # child holds this open, constant (hung) + sibling = blobs / "sibling.incomplete" + sibling.write_bytes(b"\0" * 1024) + + grow_stop = threading.Event() + + def _grow(): + size = 1024 + while not grow_stop.wait(0.05): + size += 4096 + sibling.write_bytes(b"\0" * size) # unrelated sibling progressing + + grower = threading.Thread(target = _grow, daemon = True) + grower.start() + + child = _spawn_holding_open(owned_partial) + try: + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.3, + watch_new_partials_only = True, baseline_incomplete_blobs = set(), + child_pid = child.pid, + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), ( + "pid-scoped watchdog never fired: an unowned growing sibling masked the stall" + ) + finally: + stop.set() + finally: + grow_stop.set() + child.terminate() + child.wait(timeout = 5) + + +def test_file_watchdog_empty_open_set_ignores_sibling(hf_cache, monkeypatch): + """An EMPTY child open-set means the child owns no partial yet (connect/metadata phase), so a stalled sibling must not fire.""" + blobs = _blobs_dir(hf_cache) + # Sibling partial created after baseline (not name-excluded), constant (stalled). + (blobs / "sibling.incomplete").write_bytes(b"\0" * 4096) + monkeypatch.setattr(xf, "_child_open_incomplete_blobs", lambda pid: set()) # child owns none + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, interval = 0.05, stall_timeout = 0.5, + watch_new_partials_only = True, baseline_incomplete_blobs = set(), + child_pid = 4242, # non-None so the precise child-open path is taken + ) + try: + time.sleep(0.8) + assert calls == [], "watchdog falsely fired on a sibling partial the child does not own" + finally: + stop.set() + + +def test_get_state_empty_cache(hf_cache): + assert xf.get_hf_download_state([REPO]) == (0, False) + + +def test_get_state_absent_cache_root(tmp_path, monkeypatch): + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path / "no-such-cache")) + assert xf.get_hf_download_state([REPO]) == (0, False) + + +def test_get_state_skips_local_paths(hf_cache): + # Filesystem paths are not repo IDs and must be ignored without error. + assert xf.get_hf_download_state( + ["/abs/path", "./rel", "~user", "c:\\x", "c:/x"] + ) == (0, False) + + +def test_get_state_sparse_aware(hf_cache): + blobs = _blobs_dir(hf_cache) + sparse = blobs / "sparse.incomplete" + with open(sparse, "wb") as f: + f.truncate(64 * 1024 * 1024) # large apparent size, few allocated blocks + st = sparse.stat() + if getattr(st, "st_blocks", 0) == 0: + pytest.skip("filesystem does not report st_blocks; sparse accounting unavailable") + total, has_incomplete = xf.get_hf_download_state([REPO]) + assert has_incomplete is True + assert total < st.st_size, "sparse partial counted at apparent size, not allocated blocks" + + +def test_blob_bytes_present_zero_blocks_is_zero(tmp_path): + """A fully-sparse .incomplete (st_size > 0, 0 allocated blocks) counts as 0 bytes present, not full size.""" + p = tmp_path / "sparse.incomplete" + with open(p, "wb") as f: + f.truncate(8 * 1024 * 1024) # apparent 8 MiB, nothing actually written + st = p.stat() + if getattr(st, "st_blocks", None) is None: + pytest.skip("st_blocks not reported on this platform") + if st.st_blocks != 0: + pytest.skip("filesystem pre-allocated blocks for the sparse file") + assert hcs.blob_bytes_present(p) == 0 + + +def test_custom_cache_dir_is_watched_and_cleaned(tmp_path, monkeypatch): + """A stall under a caller-supplied cache_dir (not HF_HUB_CACHE) is seen by the probe, watchdog, and HTTP-prep purge.""" + default_cache = tmp_path / "default" + custom_cache = tmp_path / "custom" + default_cache.mkdir() + custom_cache.mkdir() + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(default_cache)) + + blobs = custom_cache / f"models--{REPO.replace('/', '--')}" / "blobs" + blobs.mkdir(parents = True) + partial = blobs / "stalled.incomplete" + partial.write_bytes(b"partial-bytes") + + # Default cache sees nothing; the custom cache sees the active partial. + assert xf.get_hf_download_state([REPO]) == (0, False) + total, has_incomplete = xf.get_hf_download_state([REPO], cache_dir = str(custom_cache)) + assert has_incomplete is True and total > 0 + + # The watchdog fires for the custom cache, not the (empty) default one. + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, cache_dir = str(custom_cache), + interval = 0.05, stall_timeout = 0.3, + ) + try: + assert _wait(lambda: len(calls) >= 1, timeout = 3.0), "watchdog ignored the custom cache_dir" + finally: + stop.set() + + # HTTP-prep purge removes the unsafe partial from the custom cache (real impl; + # the autouse fixture stubs the attribute). Age past the grace so it reads stalled. + old = time.time() - 600 + os.utime(partial, (old, old)) + _REAL_DEFAULT_PREPARE("model", REPO, cache_dir = str(custom_cache)) + assert not partial.exists() + + +def test_prepare_for_http_clears_broken_snapshot_symlink(tmp_path): + """A broken snapshot symlink reads as incomplete, so HTTP prep must clear it too or the retry re-trips the watchdog.""" + repo = "ztest/broken-symlink" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + snap = repo_dir / "snapshots" / "abc123" + snap.mkdir(parents = True) + link = snap / "model.safetensors" + link.symlink_to(repo_dir / "blobs" / "missing-blob") # dangling + assert link.is_symlink() and not link.exists() + + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, True) + + _REAL_DEFAULT_PREPARE("model", repo, cache_dir = str(tmp_path)) + + assert not link.is_symlink(), "broken snapshot symlink not cleared by HTTP prep" + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, False) + + +def test_prepare_for_http_spares_concurrent_sibling_active_symlink(tmp_path): + """HTTP prep spares a sibling's dangling link while its .incomplete partner is being written, but clears our own stale link.""" + repo = "ztest/concurrent" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + blobs = repo_dir / "blobs" + snap = repo_dir / "snapshots" / "sha" + blobs.mkdir(parents = True) + snap.mkdir(parents = True) + + # Sibling mid-download: dangling link to a blob with an active .incomplete partner. + active_partner = blobs / "activehash.incomplete" + active_partner.write_bytes(b"active") + sibling_link = snap / "active.safetensors" + sibling_link.symlink_to(blobs / "activehash") + + # Our own stale interrupted link: target blob has no .incomplete partner. + stale_link = snap / "stale.safetensors" + stale_link.symlink_to(blobs / "stalehash") + + _REAL_DEFAULT_PREPARE("model", repo, cache_dir = str(tmp_path), active_grace = 180) + + assert sibling_link.is_symlink(), "active sibling's dangling symlink must be preserved" + assert not stale_link.is_symlink(), "our own stale dangling symlink must still be cleared" + + +def test_snapshot_dir_has_broken_symlinks_unit(tmp_path): + """The new per-snapshot primitive flags a dangling link and is clean otherwise.""" + snap = tmp_path / "snapshots" / "sha" + snap.mkdir(parents = True) + good = snap / "config.json" + good.write_text("{}") + assert hcs.snapshot_dir_has_broken_symlinks(snap) is False + (snap / "model.safetensors").symlink_to(tmp_path / "blobs" / "missing") + assert hcs.snapshot_dir_has_broken_symlinks(snap) is True + + +def test_broken_older_snapshot_detected_when_newer_is_clean(tmp_path): + """The detector inspects every snapshot: an older broken revision reads incomplete even when a newer one is clean.""" + repo = "ztest/two-snaps" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + old = repo_dir / "snapshots" / "oldsha" + new = repo_dir / "snapshots" / "newsha" + old.mkdir(parents = True) + new.mkdir(parents = True) + (old / "model.safetensors").symlink_to(repo_dir / "blobs" / "missing") # broken older + (new / "config.json").write_text("{}") # clean newer + # Make the clean snapshot newest by mtime so a latest-only check would pass. + os.utime(new, (time.time() + 10, time.time() + 10)) + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, True) + + +def test_snapshot_fast_path_rejects_broken_requested_revision(tmp_path, monkeypatch): + """The fast path validates the EXACT returned dir, so a broken requested revision defers to the killable child.""" + snap = tmp_path / "snapshots" / "oldsha" + snap.mkdir(parents = True) + (snap / "model.safetensors").symlink_to(tmp_path / "blobs" / "missing") # dangling + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + # Repo-wide scan sees nothing (empty cache root); only the per-revision check can catch it. + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path / "empty-cache")) + fake = _install(monkeypatch, [("ok", "/cache/snap-fresh")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/snap-fresh", "fast path returned a broken requested revision" + assert len(fake.calls) == 1 + + +def test_prepare_for_http_clears_broken_symlink_in_older_snapshot(tmp_path): + """HTTP prep clears dangling links across all snapshots, not just the newest.""" + repo = "ztest/old-broken" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + old = repo_dir / "snapshots" / "oldsha" + new = repo_dir / "snapshots" / "newsha" + old.mkdir(parents = True) + new.mkdir(parents = True) + link = old / "model.safetensors" + link.symlink_to(repo_dir / "blobs" / "missing") # dangling, older snapshot + (new / "config.json").write_text("{}") + os.utime(new, (time.time() + 10, time.time() + 10)) # newer snapshot is clean + + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, True) + _REAL_DEFAULT_PREPARE("model", repo, cache_dir = str(tmp_path)) + assert not link.is_symlink(), "broken symlink in older snapshot not cleared" + assert xf.get_hf_download_state([repo], cache_dir = str(tmp_path)) == (0, False) + + +def test_prepare_for_http_preserves_case_colliding_repo(tmp_path): + """On a case-sensitive FS, HTTP prep for Org/Repo purges only its exact-case dir, never a case-colliding org/repo.""" + upper = tmp_path / "models--Org--Repo" / "blobs" + lower = tmp_path / "models--org--repo" / "blobs" + upper.mkdir(parents = True) + lower.mkdir(parents = True) + if upper.parent.resolve() == lower.parent.resolve(): + pytest.skip("case-insensitive filesystem; cannot collide cache dirs") + upper_partial = upper / "a.incomplete" + lower_partial = lower / "b.incomplete" + upper_partial.write_bytes(b"x") + lower_partial.write_bytes(b"y") + # Age both past the grace; lower is spared by repo attribution, not mtime. + old = time.time() - 600 + os.utime(upper_partial, (old, old)) + os.utime(lower_partial, (old, old)) + + _REAL_DEFAULT_PREPARE("model", "Org/Repo", cache_dir = str(tmp_path)) + + assert not upper_partial.exists(), "exact-case partial should be purged" + assert lower_partial.exists(), "case-colliding repo's partial must be preserved" + + +def test_repo_type_none_resolves_model_cache(hf_cache): + """repo_type=None (HF default model) resolves the models-- dir, not a bogus Nones-- dir.""" + blobs = _blobs_dir(hf_cache) + (blobs / "x.incomplete").write_bytes(b"abc") + + model_state = xf.get_hf_download_state([REPO], repo_type = "model") + none_state = xf.get_hf_download_state([REPO], repo_type = None) + assert model_state == none_state + assert none_state[1] is True and none_state[0] > 0 + + +def test_state_ignores_case_colliding_repo_partial(tmp_path, monkeypatch): + """The read/watchdog path attributes a partial only to an exact-case repo dir.""" + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path)) + exact = tmp_path / "models--Org--Repo" / "blobs" + other = tmp_path / "models--org--repo" / "blobs" + exact.mkdir(parents = True) + other.mkdir(parents = True) + if exact.parent.resolve() == other.parent.resolve(): + pytest.skip("case-insensitive filesystem; cannot collide cache dirs") + (other / "stale.incomplete").write_bytes(b"x") # only the lowercase repo + + assert xf.get_hf_download_state(["Org/Repo"]) == (0, False) + + +def test_single_folded_match_rejected_on_case_sensitive_fs(tmp_path, monkeypatch): + """On a case-sensitive FS a folded-but-not-exact dir is a different repo, so its partial is not attributed here.""" + monkeypatch.setattr(hf_constants, "HF_HUB_CACHE", str(tmp_path)) + lower = tmp_path / "models--org--repo" / "blobs" + lower.mkdir(parents = True) + if (tmp_path / "models--Org--Repo").exists(): + pytest.skip("case-insensitive filesystem; the folded dir is the same entry") + (lower / "stale.incomplete").write_bytes(b"x") # only the lowercase repo exists + assert xf.get_hf_download_state(["Org/Repo"]) == (0, False) + + +def test_cache_dir_is_expanded(tmp_path, monkeypatch): + """A cache_dir with ~ is expanded (as HF does on write), else the probe scans the literal '~/...' path.""" + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) # Windows home var + blobs = tmp_path / "hfcache" / f"models--{REPO.replace('/', '--')}" / "blobs" + blobs.mkdir(parents = True) + (blobs / "p.incomplete").write_bytes(b"abc") + + total, has_incomplete = xf.get_hf_download_state([REPO], cache_dir = "~/hfcache") + assert has_incomplete is True and total > 0 + + +def test_status_callback_failure_does_not_kill_watchdog(hf_cache): + """A raising on_heartbeat must not stop the watchdog from detecting a stall and firing on_stall.""" + blobs = _blobs_dir(hf_cache) + (blobs / "x.incomplete").write_bytes(b"\0" * 1024) # constant size -> stalls + + def boom(_message): + raise RuntimeError("client disconnected") + + calls: list[str] = [] + stop = xf.start_watchdog( + repo_ids = [REPO], on_stall = calls.append, on_heartbeat = boom, + interval = 0.05, stall_timeout = 0.3, + ) + try: + assert _wait( + lambda: len(calls) >= 1, timeout = 3.0 + ), "a raising on_heartbeat killed stall detection" + finally: + stop.set() + + +# Transport policy: cached short-circuit, cancel, error propagation, the single +# Xet->HTTP fallback, injected prepare seam, and UNSLOTH_DISABLE_XET. The download +# seam (_run_download_attempt) is faked, so no real spawn. +DL_REPO, FILE = "ztest/xet-dl", "model-Q4_K_XL.gguf" + + +@pytest.fixture(autouse = True) +def _no_real_cache_hit(monkeypatch): + """Default: cache probe and snapshot fast path miss, so tests exercise the download seam.""" + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: None) + + def _snap_miss(*a, **k): + raise FileNotFoundError("not fully cached") + + monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap_miss) + # Neutralize the generic cache purge; tests that care record it. + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a, **k: None) + # No env knob unless a test sets it. + monkeypatch.delenv("UNSLOTH_DISABLE_XET", raising = False) + monkeypatch.delenv("UNSLOTH_STABLE_DOWNLOADS", raising = False) + monkeypatch.delenv("HF_HUB_DISABLE_XET", raising = False) + + +class _FakeAttempt: + """Records download-seam calls and returns scripted results (matches _run_download_attempt's signature).""" + + def __init__(self, results): + self._results = list(results) + self.calls = [] + + def __call__( + self, + repo_id, + *, + kind, + params, + token, + repo_type, + disable_xet, + cancel_event, + stall_timeout, + interval, + grace_period, + on_status, + ): + self.calls.append( + _types.SimpleNamespace( + repo_id = repo_id, + kind = kind, + target = params.get("filename", repo_id), + cache_dir = params.get("cache_dir"), + subfolder = params.get("subfolder"), + force_download = params.get("force_download"), + disable_xet = disable_xet, + repo_type = repo_type, + ) + ) + return self._results[len(self.calls) - 1] + + +def _install(monkeypatch, results): + fake = _FakeAttempt(results) + monkeypatch.setattr(xf, "_run_download_attempt", fake) + return fake + + +def test_cached_file_short_circuits(monkeypatch, tmp_path): + cached = tmp_path / "cached.gguf" + cached.write_bytes(b"\0" * 8) + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: str(cached)) + fake = _install(monkeypatch, []) + + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == str(cached) + assert fake.calls == [], "spawned a download for an already-cached file" + + +def test_cancel_before_start_raises_no_attempt(monkeypatch): + fake = _install(monkeypatch, []) + ev = threading.Event() + ev.set() + with pytest.raises(RuntimeError, match = "Cancelled"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cancel_event = ev) + assert fake.calls == [] + + +def test_cancel_honored_even_when_file_cached(monkeypatch, tmp_path): + """A pre-set cancel_event raises even when the file is already cached (the warm-cache short-circuit honors cancel first).""" + cached = tmp_path / "cached.gguf" + cached.write_bytes(b"\0" * 8) + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: str(cached)) + fake = _install(monkeypatch, []) + ev = threading.Event() + ev.set() + with pytest.raises(RuntimeError, match = "Cancelled"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cancel_event = ev) + assert fake.calls == [] + + +def test_snapshot_cancel_honored_even_when_cached(monkeypatch, tmp_path): + """The snapshot wrapper honors a pre-set cancel before its warm-cache short-circuit.""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "model.safetensors").write_bytes(b"x") + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + fake = _install(monkeypatch, []) + ev = threading.Event() + ev.set() + with pytest.raises(RuntimeError, match = "Cancelled"): + xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, cancel_event = ev) + assert fake.calls == [] + + +def test_nonstall_error_propagates_without_fallback(monkeypatch): + fake = _install(monkeypatch, [("error", "RepositoryNotFoundError: 404 not found")]) + # Deterministic Hub error re-raised with its original type across the spawn boundary, + # reconstructed from the child's ": ..." prefix. + expected_cls = xf._resolve_exception_class("RepositoryNotFoundError") + assert expected_cls is not None and expected_cls is not RuntimeError + with pytest.raises(expected_cls, match = "RepositoryNotFoundError"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1, "deterministic error must not trigger an HTTP fallback" + assert fake.calls[0].disable_xet is False + + +def test_crashed_child_retries_over_http(monkeypatch): + """A silent process-level crash (child exits without a result) retries over HTTP; a clean second result is accepted.""" + fake = _install(monkeypatch, [("crashed", "exited without a result"), ("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/cache/x" + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_crashed_child_on_both_transports_raises(monkeypatch): + """If the child crashes on Xet AND on HTTP, surface a hard error after both attempts.""" + fake = _install(monkeypatch, [("crashed", "boom"), ("crashed", "boom")]) + with pytest.raises(RuntimeError, match = "boom"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_retryable_xet_error_retries_over_http(monkeypatch): + """A transient Xet failure (CAS timeout / 5xx) retries over HTTP; a clean HTTP result is accepted.""" + fake = _install(monkeypatch, [("retryable_error", "CasClientError: request timed out"), ("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/cache/x" + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_retryable_xet_error_on_both_transports_raises(monkeypatch): + """A transient error on both transports surfaces after both attempts rather than looping.""" + fake = _install(monkeypatch, [("retryable_error", "503 Server Error"), ("retryable_error", "503 Server Error")]) + with pytest.raises(RuntimeError, match = "503"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_is_retryable_download_error_classification(): + """Transient transport failures (hf_xet/CAS, timeout, reset, 5xx/429) are retryable; deterministic Hub/OS and unknown errors are not.""" + f = xf._is_retryable_download_error + + # Transient transport failures -> retryable. + assert f(Exception("hf_xet download failed: data processing error")) is True + assert f(TimeoutError("connection reset by peer")) is True + assert f(Exception("CasClientError: deadline exceeded")) is True + + class _Resp503(Exception): + response = type("R", (), {"status_code": 503})() + + assert f(_Resp503("server error")) is True + + class _Status429(Exception): + status_code = 429 + + assert f(_Status429("Too Many Requests")) is True + + class _Status408(Exception): + status_code = 408 + + assert f(_Status408("Request Timeout")) is True # 408 is transient + + # Deterministic Hub errors -> not retryable (class name or 4xx status). + class _Status416(Exception): + status_code = 416 + + assert f(_Status416("Range Not Satisfiable")) is False + class RepositoryNotFoundError(Exception): + pass + + assert f(RepositoryNotFoundError("404 Client Error")) is False + + class _Resp404(Exception): + response = type("R", (), {"status_code": 404})() + + assert f(_Resp404("not found")) is False + assert f(OSError(errno.ENOSPC, "No space left on device")) is False + assert f(ValueError("unexpected response payload")) is False # unknown -> deterministic + + +def test_local_entry_not_found_transient_is_retryable(): + """A transient LocalEntryNotFoundError (HEAD connection error/timeout) is retryable; a genuine offline miss stays deterministic and type-preserved.""" + f = xf._is_retryable_download_error + + class LocalEntryNotFoundError(Exception): + pass + + # Transient connection wrapper -> retryable. + transient = LocalEntryNotFoundError( + "An error happened while trying to locate the file on the Hub and we cannot find the " + "requested files in the local cache. Please check your connection and try again." + ) + assert f(transient) is True + timed_out = LocalEntryNotFoundError("Read timed out while fetching metadata") + assert f(timed_out) is True + # Genuine offline miss (no transient hint) -> deterministic, type-preserved. + offline = LocalEntryNotFoundError( + "Cannot find the requested files in the disk cache and outgoing traffic has been disabled." + ) + assert f(offline) is False + assert "LocalEntryNotFoundError" in xf._DETERMINISTIC_ERROR_NAMES + cls = xf._resolve_exception_class("LocalEntryNotFoundError") + assert cls is not None and issubclass(cls, BaseException) + + +def test_immediate_success_uses_xet_only(monkeypatch): + prepared = [] + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda *a, **k: prepared.append(a)) + fake = _install(monkeypatch, [("ok", "/cache/model.gguf")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/cache/model.gguf" + assert len(fake.calls) == 1 and fake.calls[0].disable_xet is False + assert prepared == [], "no cache prep should run when Xet succeeds first try" + + +def test_stall_then_http_fallback_succeeds(monkeypatch): + prepared = [] + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda repo_type, repo_id, cache_dir = None, **k: prepared.append((repo_type, repo_id))) + fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/model.gguf")]) + + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/cache/model.gguf" + assert len(fake.calls) == 2 + assert fake.calls[0].disable_xet is False # Xet first + assert fake.calls[1].disable_xet is True # HTTP fallback + assert prepared == [("model", DL_REPO)], "must prep cache for HTTP before the retry" + + +def test_injected_prepare_for_http_used(monkeypatch): + """An injected prepare_for_http_fn is used; the generic default must not run.""" + monkeypatch.setattr( + xf, "_default_prepare_for_http", lambda *a, **k: pytest.fail("generic prepare ran") + ) + injected = [] + _install(monkeypatch, [("stall", None), ("ok", "/cache/model.gguf")]) + out = xf.hf_hub_download_with_xet_fallback( + DL_REPO, FILE, None, prepare_for_http_fn = lambda rt, rid: injected.append((rt, rid)) + ) + assert out == "/cache/model.gguf" + assert injected == [("model", DL_REPO)] + + +def test_second_stall_raises_download_stall_error(monkeypatch): + fake = _install(monkeypatch, [("stall", None), ("stall", None)]) + with pytest.raises(xf.DownloadStallError): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 2 + + +def test_cancelled_midattempt_raises_no_fallback(monkeypatch): + fake = _install(monkeypatch, [("cancelled", None)]) + with pytest.raises(RuntimeError, match = "Cancelled"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1 + + +def test_per_file_independent_fallback(monkeypatch): + """A stalled shard falls back; a sibling shard that succeeds does not.""" + fake = _install(monkeypatch, [("ok", "/a"), ("stall", None), ("ok", "/b")]) + assert xf.hf_hub_download_with_xet_fallback(DL_REPO, "shardA.gguf", None) == "/a" + assert xf.hf_hub_download_with_xet_fallback(DL_REPO, "shardB.gguf", None) == "/b" + assert [c.disable_xet for c in fake.calls] == [False, False, True] + + +def test_unsloth_disable_xet_forces_http_first(monkeypatch): + """UNSLOTH_DISABLE_XET=1 skips Xet: the first (and only) attempt is HTTP.""" + monkeypatch.setenv("UNSLOTH_DISABLE_XET", "1") + fake = _install(monkeypatch, [("ok", "/http/model.gguf")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert out == "/http/model.gguf" + assert len(fake.calls) == 1 and fake.calls[0].disable_xet is True + + +def test_unsloth_disable_xet_stall_raises_no_retry(monkeypatch): + """With the knob set, a stall on the (already HTTP) attempt does not retry.""" + monkeypatch.setenv("UNSLOTH_DISABLE_XET", "1") + fake = _install(monkeypatch, [("stall", None)]) + with pytest.raises(xf.DownloadStallError): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1 + + +def test_file_path_accepts_cache_dir(monkeypatch): + """The single-file wrapper accepts cache_dir and threads it through.""" + fake = _install(monkeypatch, [("ok", "/cache/model.gguf")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cache_dir = "/custom/cache") + assert out == "/cache/model.gguf" + assert fake.calls[0].cache_dir == "/custom/cache" + + +# Spawn env-timing: the parent sets HF_HUB_DISABLE_XET before the child starts, so +# the child inherits it before re-importing huggingface_hub (constants cache it at +# import). Uses a fake spawn context -- no real subprocess. +class _FakeProc: + def __init__(self, recorder): + self._rec = recorder + self.pid = 4242 + self.exitcode = 0 + + def start(self): + self._rec["disable_xet"] = os.environ.get("HF_HUB_DISABLE_XET") + self._rec["hf_transfer"] = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") + self._rec["skip_gpu_init"] = os.environ.get("UNSLOTH_ZOO_DISABLE_GPU_INIT") + self._rec["main_file"] = getattr(sys.modules.get("__main__"), "__file__", None) + + def is_alive(self): + return False + + def join(self, timeout = None): + pass + + +class _FakeQueue: + def __init__(self, result): + self._result = result + + def get(self, timeout = None): + return self._result + + def get_nowait(self): + return self._result + + def put(self, item): + pass + + +class _FakeCtx: + def __init__(self, recorder, result): + self._rec = recorder + self._result = result + + def Process(self, *, target = None, kwargs = None, daemon = None): + return _FakeProc(self._rec) + + def Queue(self): + return _FakeQueue(self._result) + + +def test_http_retry_sets_disable_xet_before_spawn(monkeypatch): + monkeypatch.delenv("HF_HUB_DISABLE_XET", raising = False) + monkeypatch.delenv("HF_HUB_ENABLE_HF_TRANSFER", raising = False) + rec: dict = {} + monkeypatch.setattr(xf, "_CTX", _FakeCtx(rec, {"ok": True, "path": "/cache/x"})) + + kind_result, payload = xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = True, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + assert (kind_result, payload) == ("ok", "/cache/x") + # Child inherited HTTP transport env at spawn time. + assert rec["disable_xet"] == "1" + assert rec["hf_transfer"] == "0" + # Parent env restored afterwards (was unset). + assert "HF_HUB_DISABLE_XET" not in os.environ + + +def test_xet_attempt_does_not_force_disable_before_spawn(monkeypatch): + monkeypatch.delenv("HF_HUB_DISABLE_XET", raising = False) + rec: dict = {} + monkeypatch.setattr(xf, "_CTX", _FakeCtx(rec, {"ok": True, "path": "/cache/x"})) + xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + # On the Xet-first attempt, do not force-disable Xet for the child. + assert rec["disable_xet"] is None + + +class _EmptyQueue: + def get(self, timeout = None): + import queue as _queue + raise _queue.Empty + + def get_nowait(self): + import queue as _queue + raise _queue.Empty + + def put(self, item): + pass + + +def test_run_attempt_no_result_is_crashed(monkeypatch): + """A child that exits without enqueuing a result maps to 'crashed' (HTTP may recover), not a deterministic 'error'.""" + rec: dict = {} + + class _Ctx: + def Process(self, *, target = None, kwargs = None, daemon = None): + return _FakeProc(rec) + + def Queue(self): + return _EmptyQueue() + + monkeypatch.setattr(xf, "_CTX", _Ctx()) + kind_result, _ = xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + assert kind_result == "crashed" + + +def test_child_skips_gpu_init_env_set_before_spawn_and_restored(monkeypatch): + """The child inherits UNSLOTH_ZOO_DISABLE_GPU_INIT=1 at spawn (skips heavy torch init); the parent env is restored after.""" + monkeypatch.delenv("UNSLOTH_ZOO_DISABLE_GPU_INIT", raising = False) + rec: dict = {} + monkeypatch.setattr(xf, "_CTX", _FakeCtx(rec, {"ok": True, "path": "/cache/x"})) + + xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + assert rec["skip_gpu_init"] == "1" # set in the parent before proc.start() + assert "UNSLOTH_ZOO_DISABLE_GPU_INIT" not in os.environ # restored after + + +def test_spawn_repoints_main_file_and_restores(monkeypatch): + """The parent repoints __main__.__file__ to this module at spawn (so an unguarded caller is not re-executed) and restores it.""" + main_mod = sys.modules["__main__"] + monkeypatch.setattr(main_mod, "__file__", "/fake/user_script.py", raising = False) + rec: dict = {} + monkeypatch.setattr(xf, "_CTX", _FakeCtx(rec, {"ok": True, "path": "/cache/x"})) + + xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.2, on_status = None, + ) + assert rec["main_file"] == xf.__file__ # child imports the helper, not the script + assert main_mod.__file__ == "/fake/user_script.py" # restored in the parent + + +def test_scrub_secrets_handles_boolean_token(): + """token=True must not crash the child error scrubber.""" + out = xf._default_scrub_secrets("auth failed for hf_abcdefghij", hf_token = True) + assert "hf_abcdefghij" not in out and "***" in out + + +def test_scrub_redacts_presigned_url(): + """The scrubber redacts a presigned S3/CAS URL's credential query string, leaving non-signed URLs intact.""" + url = ( + "https://cas-bridge.xethub.hf.co/xet-bridge-us/abc/def" + "?X-Amz-Signature=deadbeefcafe&X-Amz-Credential=AKIAEXAMPLE123" + ) + out = xf._default_scrub_secrets(f"403 Client Error for url: {url}") + assert "X-Amz-Signature" not in out + assert "deadbeefcafe" not in out and "AKIAEXAMPLE123" not in out + assert "cas-bridge.xethub.hf.co/xet-bridge-us/abc/def?***" in out + # A non-signed URL keeps its (harmless) query string. + plain = xf._default_scrub_secrets("see https://huggingface.co/org/repo?download=true now") + assert "download=true" in plain + + +def test_scrub_redaction_preserves_surrounding_delimiters(): + """Query redaction stops at the closing delimiter of an embedded signed URL (does not swallow the ``"}``).""" + embedded = ( + '{"error": "403", "url": ' + '"https://cas-bridge.xethub.hf.co/x/y?X-Amz-Signature=deadbeef&X-Amz-Expires=3600"}' + ) + out = xf._default_scrub_secrets(embedded) + assert "deadbeef" not in out # signed query redacted + assert "cas-bridge.xethub.hf.co/x/y?***" in out + assert out.endswith('"}') # closing delimiters preserved + # A signed URL wrapped in single quotes / parens keeps those delimiters too. + wrapped = "(https://s3.amazonaws.com/b/k?X-Amz-Signature=abc123) tail" + out2 = xf._default_scrub_secrets(wrapped) + assert "abc123" not in out2 and "?***)" in out2 and out2.endswith(") tail") + + +def test_local_files_only_file_resolves_in_process(monkeypatch): + """local_files_only resolves the file from cache in-process and never spawns a network child.""" + seen = {} + + def _dl(*a, **k): + seen.update(k) + return "/cache/file.gguf" + + monkeypatch.setattr(huggingface_hub, "hf_hub_download", _dl) + fake = _install(monkeypatch, []) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, local_files_only = True) + assert out == "/cache/file.gguf" + assert seen.get("local_files_only") is True + assert fake.calls == [], "local_files_only must not spawn a download child" + + +def test_local_files_only_snapshot_resolves_in_process(monkeypatch): + seen = {} + + def _snap(*a, **k): + seen.update(k) + return "/cache/snap" + + monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap) + fake = _install(monkeypatch, []) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, local_files_only = True) + assert out == "/cache/snap" + assert seen.get("local_files_only") is True + assert fake.calls == [], "local_files_only must not spawn a download child" + + +def test_file_probe_uses_expanded_cache_dir(monkeypatch, tmp_path): + """The single-file probe uses the expanded cache_dir, else a finalized file under ~/hf-cache is missed.""" + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) # Windows home var + seen = {} + + def _probe(repo_id, filename, *, repo_type, revision, cache_dir): + seen["cache_dir"] = cache_dir + return None # not cached -> falls through to the faked seam + + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", _probe) + fake = _install(monkeypatch, [("ok", "/cache/x")]) + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, cache_dir = "~/hfcache") + assert seen["cache_dir"] == str(tmp_path / "hfcache") + assert fake.calls[0].cache_dir == str(tmp_path / "hfcache") + + +def test_pathlib_cache_dir_is_expanded(monkeypatch, tmp_path): + """A pathlib.Path cache_dir with ~ is normalized too, else the child writes under '~/...' and the stall is undetected.""" + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) + fake = _install(monkeypatch, [("ok", "/cache/snap")]) + xf.snapshot_download_with_xet_fallback( + DL_REPO, token = None, cache_dir = Path("~/hfcache") + ) + assert fake.calls[0].cache_dir == str(tmp_path / "hfcache") + + +def test_subfolder_forwarded_to_file_download(monkeypatch): + """subfolder is forwarded into the download params and the probe uses the combined '/' path.""" + probed = {} + + def _probe(repo_id, filename, *, repo_type, revision, cache_dir): + probed["filename"] = filename + return None # not cached -> falls through to the faked attempt + + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", _probe) + fake = _install(monkeypatch, [("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback( + DL_REPO, FILE, None, subfolder = "checkpoint-10" + ) + assert out == "/cache/x" + assert probed["filename"] == f"checkpoint-10/{FILE}" # probe uses combined path + assert fake.calls[0].subfolder == "checkpoint-10" # forwarded to the child + + +def test_file_download_defaults_token_to_none(monkeypatch): + """The single-file helper accepts no token (parity with hf_hub_download).""" + fake = _install(monkeypatch, [("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE) # no token arg + assert out == "/cache/x" and len(fake.calls) == 1 + + +def test_unrelated_partial_does_not_block_clean_cached_snapshot(hf_cache, monkeypatch): + """A clean requested snapshot short-circuits in-process even with a stale unrelated .incomplete: the fast path validates only the returned dir.""" + blobs = _blobs_dir(hf_cache, DL_REPO) + repo_dir = blobs.parent + snap = repo_dir / "snapshots" / "goodsha" + snap.mkdir(parents = True) + good = blobs / "good" + good.write_bytes(b"weights") + (snap / "model.safetensors").symlink_to(good) # resolves -> snapshot clean + (blobs / "other.incomplete").write_bytes(b"abc") # unrelated stale partial + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + fake = _install(monkeypatch, []) # must NOT spawn a child + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == str(snap), "clean cached snapshot rejected by an unrelated partial" + assert fake.calls == [], "spawned a download despite a clean requested snapshot" + + +def test_retry_status_failure_does_not_abort_fallback(monkeypatch): + """A raising on_status during the Xet->HTTP retry must not abort the fallback.""" + fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/x")]) + + def boom(_message): + raise RuntimeError("client gone") + + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, on_status = boom) + assert out == "/cache/x" + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_unclearable_partial_forces_clean_redownload(hf_cache, monkeypatch): + """When prep cannot clear an unsafe partial, the HTTP attempt forces a clean re-download rather than resume over it.""" + # The autouse fixture makes _default_prepare_for_http a no-op (partial left in place). + (_blobs_dir(hf_cache, DL_REPO) / "x.incomplete").write_bytes(b"abc") + fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/x")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/x" + assert fake.calls[0].force_download is False # Xet attempt: not forced + assert fake.calls[1].force_download is True # HTTP attempt: forced clean + + +# Snapshot variant: in-process fast path on a warm cache, else watched download. +def test_snapshot_fast_path_no_child(hf_cache, monkeypatch): + """A fully cached repo (weights present) resolves in-process via local_files_only, no child attempt.""" + blobs = _blobs_dir(hf_cache, DL_REPO) + snap = blobs.parent / "snapshots" / "sha" + snap.mkdir(parents = True) + weight = blobs / "w" + weight.write_bytes(b"\0" * 16) + (snap / "model.safetensors").symlink_to(weight) # weight present -> complete + (snap / "config.json").write_text("{}") + seen = {} + + def _snap(*a, **k): + seen["local_files_only"] = k.get("local_files_only") + return str(snap) + + monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap) + fake = _install(monkeypatch, []) # must not be called + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == str(snap) + assert seen["local_files_only"] is True + assert fake.calls == [], "spawned a download for an already-cached snapshot" + + +def test_snapshot_dir_is_complete_unit(tmp_path): + """Weight presence drives completeness: config-only is incomplete, a resolvable weight is complete.""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "config.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap) is False # no weights + blob = tmp_path / "blob" + blob.write_bytes(b"weights") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_snapshot_dir_is_complete_broken_symlink(tmp_path): + """A dangling weight symlink reads as incomplete.""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "model.safetensors").symlink_to(tmp_path / "missing") + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_snapshot_dir_is_complete_missing_shard(tmp_path): + """A shard index whose shards are not all on disk reads as incomplete until they are.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps( + { + "weight_map": { + "a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors", + } + } + ) + ) + assert hcs.snapshot_dir_is_complete(snap) is False # shard 2 missing + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_snapshot_dir_is_complete_missing_shard_without_index(tmp_path): + """A multi-shard pull is incomplete until the index sidecar is present, even with every shard on disk (transformers needs the index to load).""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model-00001-of-00003.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is False # shards 2 and 3 missing + (snap / "model-00002-of-00003.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is False # shard 3 still missing + (snap / "model-00003-of-00003.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is False # all shards present, no index + (snap / "model.safetensors.index.json").write_text( + json.dumps( + { + "weight_map": { + "a": "model-00001-of-00003.safetensors", + "b": "model-00002-of-00003.safetensors", + "c": "model-00003-of-00003.safetensors", + } + } + ) + ) + assert hcs.snapshot_dir_is_complete(snap) is True # index present -> loadable + + +def test_snapshot_dir_is_complete_ignores_trainer_artifacts(tmp_path): + """Trainer/optimizer state files (.bin/.pt/.pth) are not loadable weights, so a cache holding only those reads incomplete.""" + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + for junk in ( + "training_args.bin", "optimizer.pt", "scheduler.pt", + "rng_state.pth", "rng_state_0.pth", "scaler.pt", + ): + (snap / junk).symlink_to(blob) + (snap / "config.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap) is False # only trainer state, no weights + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True # real weight present + + +def test_snapshot_dir_is_complete_checkpoint_index_does_not_gate_root(tmp_path): + """A per-checkpoint shard index with missing shards does not fail an unpatterned root warm (the root weights are what loads).""" + snap = tmp_path / "snap" + (snap / "checkpoint-7").mkdir(parents = True) + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(blob) # complete root weight + # Incomplete checkpoint shard index (shard 2 missing) under checkpoint-7/. + (snap / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "checkpoint-7" / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}}) + ) + # Root warm is complete: the checkpoint index is skipped, the root weight suffices. + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_fast_path_rejects_config_only_snapshot(hf_cache, monkeypatch): + """The fast path rejects a config-only snapshot HF's local_files_only may return, deferring to the killable child.""" + blobs = _blobs_dir(hf_cache, DL_REPO) + snap = blobs.parent / "snapshots" / "sha" + snap.mkdir(parents = True) + cfg_blob = blobs / "cfg" + cfg_blob.write_text("{}") + (snap / "config.json").symlink_to(cfg_blob) # only config, no weights + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + fake = _install(monkeypatch, [("ok", "/cache/snap-fresh")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/snap-fresh" and len(fake.calls) == 1 + + +def test_fast_path_requires_each_named_weight(hf_cache, monkeypatch): + """The pre-download short-circuit rejects a cache holding only one of several explicitly named weights (base+adapter, only base cached).""" + blobs = _blobs_dir(hf_cache, DL_REPO) + snap = blobs.parent / "snapshots" / "sha" + snap.mkdir(parents = True) + base_blob = blobs / "w" + base_blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(base_blob) # base only; adapter missing + monkeypatch.setattr(huggingface_hub, "snapshot_download", lambda *a, **k: str(snap)) + fake = _install(monkeypatch, [("ok", "/cache/snap-fresh")]) + out = xf.snapshot_download_with_xet_fallback( + DL_REPO, token = None, + allow_patterns = ["model.safetensors", "adapter_model.safetensors"], + ) + assert out == "/cache/snap-fresh" and len(fake.calls) == 1 + + +def test_child_broken_snapshot_retries_over_http(monkeypatch, tmp_path): + """A broken child snapshot (dangling symlinks) is rejected on Xet and retried over HTTP; a clean second result is accepted.""" + broken = tmp_path / "broken" + broken.mkdir() + (broken / "model.safetensors").symlink_to(tmp_path / "missing") # dangling + clean = tmp_path / "clean" + clean.mkdir() + blob = tmp_path / "b" + blob.write_bytes(b"x") + (clean / "model.safetensors").symlink_to(blob) + fake = _install(monkeypatch, [("ok", str(broken)), ("ok", str(clean))]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == str(clean) + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_child_broken_snapshot_after_http_raises(monkeypatch, tmp_path): + """If even the HTTP attempt returns a broken snapshot, fail loudly rather than hand missing files to the load.""" + broken = tmp_path / "broken" + broken.mkdir() + (broken / "model.safetensors").symlink_to(tmp_path / "missing") + fake = _install(monkeypatch, [("ok", str(broken)), ("ok", str(broken))]) + with pytest.raises(xf.DownloadStallError, match = "incomplete snapshot"): + xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_child_weight_incomplete_snapshot_retries_over_http(monkeypatch, tmp_path): + """A weight-less child result (stale config-only snapshot) is rejected on Xet and retried; a complete second result is accepted.""" + cfg_only = tmp_path / "cfg" + cfg_only.mkdir() + (cfg_only / "config.json").write_text("{}") # no weights + complete = tmp_path / "complete" + complete.mkdir() + blob = tmp_path / "b" + blob.write_bytes(b"x") + (complete / "model.safetensors").symlink_to(blob) + fake = _install(monkeypatch, [("ok", str(cfg_only)), ("ok", str(complete))]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == str(complete) + assert [c.disable_xet for c in fake.calls] == [False, True] + + +def test_patterned_snapshot_without_weights_is_accepted(monkeypatch, tmp_path): + """A patterned download that returns only the requested (weightless) files is accepted as-is, not retried for lacking weights.""" + cfg_only = tmp_path / "cfg" + cfg_only.mkdir() + (cfg_only / "config.json").write_text("{}") # exactly what was requested + fake = _install(monkeypatch, [("ok", str(cfg_only))]) + out = xf.snapshot_download_with_xet_fallback( + DL_REPO, token = None, allow_patterns = ["config.json"] + ) + assert out == str(cfg_only) and len(fake.calls) == 1 + + +def test_dataset_snapshot_without_weights_is_accepted(monkeypatch, tmp_path): + """A dataset snapshot has no weights by nature; its child result is accepted, not retried as 'incomplete'.""" + files = tmp_path / "ds" + files.mkdir() + (files / "data.json").write_text("[]") + fake = _install(monkeypatch, [("ok", str(files))]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, repo_type = "dataset") + assert out == str(files) and len(fake.calls) == 1 + + +def test_model_snapshot_with_weights_excluded_is_accepted(monkeypatch, tmp_path): + """A model repo whose ignore_patterns drop every weight format yields a weightless snapshot that is accepted, not retried.""" + cfg_only = tmp_path / "cfg" + cfg_only.mkdir() + (cfg_only / "config.json").write_text("{}") + fake = _install(monkeypatch, [("ok", str(cfg_only))]) + out = xf.snapshot_download_with_xet_fallback( + DL_REPO, + token = None, + ignore_patterns = [ + "*.safetensors", "*.bin", "*.h5", "*.msgpack", "*.gguf", + "*.pt", "*.pth", "*.ckpt", "*.onnx", "*.pdparams", "*.index.json", + ], + ) + assert out == str(cfg_only) and len(fake.calls) == 1 + + +def test_request_can_include_weights_unit(): + """Default prefetch ignores (onnx/h5/msgpack/gguf) still count as weight-including; excluding every weight format does not.""" + assert hcs.request_can_include_weights(None, None) is True + assert hcs.request_can_include_weights(None, ["*.onnx", "*.h5", "*.msgpack", "*.gguf"]) is True + assert hcs.request_can_include_weights(["config.json"], None) is False + assert hcs.request_can_include_weights( + None, ["*.safetensors", "*.bin", "*.h5", "*.msgpack", "*.gguf", + "*.pt", "*.pth", "*.ckpt", "*.onnx", "*.pdparams", "*.index.json"] + ) is False + + +def test_request_can_include_weights_index_json_only(): + """A request matching only shard *index* sidecars reads weightless (the index is JSON, not weights); a real weight pattern does not.""" + assert hcs.request_can_include_weights(["*.json"], None) is False + assert hcs.request_can_include_weights(["*.index.json"], None) is False + assert hcs.request_can_include_weights( + ["model.safetensors.index.json", "pytorch_model.bin.index.json"], None + ) is False + assert hcs.request_can_include_weights(["*.safetensors"], None) is True + + +def test_request_can_include_weights_path_qualified(): + """Path-qualified allow_patterns resolve inside their directory, so a subfolder/checkpoint/shard weight request is not misread as weightless.""" + # Concrete subfolder globs: weights live under the directory. + assert hcs.request_can_include_weights(["checkpoint-10/*"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/*.safetensors"], None) is True + assert hcs.request_can_include_weights(["models/*.bin"], None) is True + # A specific (non-first) shard named verbatim. + assert hcs.request_can_include_weights(["model-00002-of-00005.safetensors"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/pytorch_model.bin"], None) is True + # Globbed parent dir, weight-targeting basename. + assert hcs.request_can_include_weights(["checkpoint-*/*.safetensors"], None) is True + # Subfolder requests targeting only non-weight files stay weightless. + assert hcs.request_can_include_weights(["checkpoint-10/config.json"], None) is False + assert hcs.request_can_include_weights(["checkpoint-10/*.json"], None) is False + assert hcs.request_can_include_weights(["checkpoint-*/tokenizer.json"], None) is False + # The unsloth subfolder warmup shape: "/*" plus root aux files. + assert hcs.request_can_include_weights( + ["checkpoint-10/*", "config.json", "tokenizer.json"], None + ) is True + + +def test_request_can_include_weights_path_qualified_custom_globs(): + """A path-qualified custom weight glob (checkpoint-10/lora_*.safetensors) reads weight-including via a concretized self-probe.""" + assert hcs.request_can_include_weights(["checkpoint-10/lora_*.safetensors"], None) is True + assert hcs.request_can_include_weights(["checkpoint-*/lora_*.bin"], None) is True + assert hcs.request_can_include_weights(["models/custom_*.pt"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/model-[0-9].safetensors"], None) is True + # A non-weight basename under a subfolder stays weightless. + assert hcs.request_can_include_weights(["checkpoint-10/tokenizer.json"], None) is False + + +def test_request_can_include_weights_empty_allow_list(tmp_path): + """allow_patterns=[] selects nothing (weightless, distinct from None); ignore_patterns=[] ignores nothing (weight-including).""" + assert hcs.request_can_include_weights([], None) is False + assert hcs.request_can_include_weights(None, None) is True + assert hcs.request_can_include_weights(None, []) is True + assert hcs.request_can_include_weights([], []) is False + # snapshot_dir_is_complete agrees: allow=[] is a select-nothing request, so an unrelated weight is not complete for it. + snap = tmp_path / "snap" + snap.mkdir() + blob = tmp_path / "blob" + blob.write_bytes(b"x") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = []) is False + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = None) is True + + +def test_request_can_include_weights_string_form(): + """A bare-string allow/ignore pattern is treated as one pattern, not iterated char by char.""" + assert hcs.request_can_include_weights("checkpoint-10/*", None) is True + assert hcs.request_can_include_weights("*.safetensors", None) is True + assert hcs.request_can_include_weights("config.json", None) is False + assert hcs.request_can_include_weights(None, "*.safetensors") is True # ignore as str + # A string ignore that drops every weight format leaves the request weightless. + assert hcs.request_can_include_weights( + "config.json", ["*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", + "*.h5", "*.msgpack", "*.ckpt", "*.onnx", "*.pdparams"] + ) is False + + +def test_prepare_for_http_spares_active_sibling_partial(hf_cache): + """The HTTP-prep purge removes only stale (old-mtime) partials, so a concurrent sibling's active .incomplete keeps writing.""" + blobs = _blobs_dir(hf_cache, DL_REPO) + stale = blobs / "stalled.incomplete" + stale.write_bytes(b"\0" * 16) + active = blobs / "sibling.incomplete" + active.write_bytes(b"\0" * 16) + # Age the stalled partial past the grace; leave the sibling current. + old = time.time() - 600 + os.utime(stale, (old, old)) + _REAL_DEFAULT_PREPARE("model", DL_REPO, cache_dir = str(hf_cache)) + assert not stale.exists(), "stale partial should be purged for the HTTP resume" + assert active.exists(), "an actively-written sibling partial must be preserved" + + +def test_snapshot_stall_then_http(monkeypatch): + prepared = [] + monkeypatch.setattr(xf, "_default_prepare_for_http", lambda rt, rid, cache_dir = None, **k: prepared.append((rt, rid))) + fake = _install(monkeypatch, [("stall", None), ("ok", "/cache/snap-dir")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None) + assert out == "/cache/snap-dir" + assert [c.kind for c in fake.calls] == ["snapshot", "snapshot"] + assert [c.disable_xet for c in fake.calls] == [False, True] + assert prepared == [("model", DL_REPO)] + + +def test_force_download_skips_fast_path_and_threads(monkeypatch): + """force_download=True bypasses the warm-cache short-circuit and threads force_download into the download params.""" + def _snap(*a, **k): + pytest.fail("force_download must not take the local_files_only fast path") + + monkeypatch.setattr(huggingface_hub, "snapshot_download", _snap) + fake = _install(monkeypatch, [("ok", "/cache/snap-dir")]) + out = xf.snapshot_download_with_xet_fallback(DL_REPO, token = None, force_download = True) + assert out == "/cache/snap-dir" + assert len(fake.calls) == 1 and fake.calls[0].force_download is True + + +def test_force_download_file_skips_cache_probe(monkeypatch, tmp_path): + """The single-file path also skips the cached-blob short-circuit and threads force_download through.""" + cached = tmp_path / "cached.gguf" + cached.write_bytes(b"\0" * 8) + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: str(cached)) + fake = _install(monkeypatch, [("ok", "/cache/x")]) + out = xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None, force_download = True) + assert out == "/cache/x" + assert len(fake.calls) == 1 and fake.calls[0].force_download is True + + +# Precondition: HF_HUB_DISABLE_XET is read at import time, so assert its effect in a +# FRESH interpreter (huggingface/huggingface_hub#3266 once ignored it). +def _safe_path() -> str: + import os + + return os.environ.get("PATH", "") + + +def test_disable_xet_constant_set_in_fresh_interpreter(): + code = ( + "from huggingface_hub import constants as c; " + "import sys; sys.exit(0 if c.HF_HUB_DISABLE_XET is True else 17)" + ) + proc = subprocess.run( + [sys.executable, "-c", code], + env = {"HF_HUB_DISABLE_XET": "1", "PATH": _safe_path()}, + capture_output = True, + text = True, + ) + assert proc.returncode == 0, ( + f"HF_HUB_DISABLE_XET=1 did not set constants.HF_HUB_DISABLE_XET=True " + f"(rc={proc.returncode}): {proc.stderr}" + ) + + +def test_default_leaves_xet_enabled(): + code = ( + "from huggingface_hub import constants as c; " + "import sys; sys.exit(0 if c.HF_HUB_DISABLE_XET is False else 17)" + ) + proc = subprocess.run( + [sys.executable, "-c", code], + env = {"PATH": _safe_path()}, # no HF_HUB_DISABLE_XET + capture_output = True, + text = True, + ) + assert proc.returncode == 0, ( + f"without the env var, constants.HF_HUB_DISABLE_XET was not False " + f"(rc={proc.returncode}): {proc.stderr}" + ) + + +# Exported Xet knobs + child-leak safety + malformed-index resilience. +def test_xet_availability_and_disable_helpers(monkeypatch): + """child_should_disable_xet reads the per-worker flag; xet_force_disabled honors every env knob; is_hf_xet_available probes importability.""" + assert xf.child_should_disable_xet({"disable_xet": True}) is True + assert xf.child_should_disable_xet({"disable_xet": False}) is False + assert xf.child_should_disable_xet({}) is False + + for knob in ("UNSLOTH_DISABLE_XET", "UNSLOTH_STABLE_DOWNLOADS", "HF_HUB_DISABLE_XET"): + for k in ("UNSLOTH_DISABLE_XET", "UNSLOTH_STABLE_DOWNLOADS", "HF_HUB_DISABLE_XET"): + monkeypatch.delenv(k, raising = False) + assert xf.xet_force_disabled() is False + monkeypatch.setenv(knob, "1") + assert xf.xet_force_disabled() is True, knob + + monkeypatch.setattr(xf.importlib.util, "find_spec", lambda name: object()) + assert xf.is_hf_xet_available() is True + monkeypatch.setattr(xf.importlib.util, "find_spec", lambda name: None) + assert xf.is_hf_xet_available() is False + + def _raise(name): + raise ImportError("boom") + + monkeypatch.setattr(xf.importlib.util, "find_spec", _raise) + assert xf.is_hf_xet_available() is False # a probe exception -> unavailable + + +def test_run_attempt_terminates_child_if_watchdog_start_raises(monkeypatch): + """If start_watchdog raises after the child spawned, the child is still reaped (no leak) and the error propagates.""" + rec = {"terminated": False} + + class _AliveProc: + def __init__(self): + self.pid = None # None -> uses terminate(), not killpg + self.exitcode = None + self._alive = True + + def start(self): + pass + + def is_alive(self): + return self._alive + + def terminate(self): + rec["terminated"] = True + self._alive = False + + def kill(self): + rec["terminated"] = True + self._alive = False + + def join(self, timeout = None): + pass + + class _Ctx: + def Process(self, *, target = None, kwargs = None, daemon = None): + return _AliveProc() + + def Queue(self): + return _FakeQueue({"ok": True, "path": "/cache/x"}) + + monkeypatch.setattr(xf, "_CTX", _Ctx()) + + def _boom(*a, **k): + raise RuntimeError("can't start new thread") + + monkeypatch.setattr(xf, "start_watchdog", _boom) + + with pytest.raises(RuntimeError, match = "can't start new thread"): + xf._run_download_attempt( + DL_REPO, kind = "snapshot", params = {"repo_id": DL_REPO}, token = None, + repo_type = "model", disable_xet = False, cancel_event = None, + stall_timeout = 0.2, interval = 0.05, grace_period = 0.05, on_status = None, + ) + assert rec["terminated"] is True # child reaped despite the watchdog-start failure + + +# Codex review round: scoped completeness, weightless named files, type preservation. +def test_requested_named_files_present_exact_request(tmp_path): + """An EXACT-named weightless request requires its named file on disk; a glob list or no allow_patterns is best-effort.""" + snap = tmp_path / "snap" + snap.mkdir() + (snap / "config.json").write_text("{}") + assert hcs.requested_named_files_present(snap, allow_patterns = ["tokenizer.json"]) is False + (snap / "tokenizer.json").write_text("{}") + assert hcs.requested_named_files_present(snap, allow_patterns = ["tokenizer.json"]) is True + # A glob list is best-effort: a missing optional file does not fail it. + assert hcs.requested_named_files_present(snap, allow_patterns = ["tokenizer*", "vocab.txt"]) is True + # No allow_patterns -> trivially satisfied. + assert hcs.requested_named_files_present(snap) is True + # An ignore-filtered name is not requested, so its absence does not fail. + assert hcs.requested_named_files_present( + snap, allow_patterns = ["tokenizer.json", "absent.json"], ignore_patterns = ["absent.json"] + ) is True + + +def test_deterministic_oserror_type_preserved(monkeypatch): + """A deterministic disk-full OSError is re-raised as OSError across the spawn boundary, not flattened to RuntimeError.""" + fake = _install(monkeypatch, [("error", "OSError: [Errno 28] No space left on device")]) + with pytest.raises(OSError, match = "No space left"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" + + +def test_unknown_error_falls_back_to_runtimeerror(monkeypatch): + """An unrecognized error class name surfaces as RuntimeError without a fallback; only known Hub/OS types are reconstructed.""" + fake = _install(monkeypatch, [("error", "SomeWeirdError: kaboom")]) + with pytest.raises(RuntimeError, match = "kaboom"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1 + + +def test_resolve_exception_class_maps_known_names(): + """The map resolves known deterministic Hub error names + OSError, and returns None for an unknown name.""" + assert xf._resolve_exception_class("OSError") is OSError + cls = xf._resolve_exception_class("RepositoryNotFoundError") + assert cls is not None and issubclass(cls, BaseException) + assert xf._resolve_exception_class("NotARealErrorType") is None + + +def test_error_type_preserved_when_constructor_needs_kwarg(monkeypatch): + """A Hub error whose constructor rejects a lone positional string is still re-raised with its type via an __init__-bypassing reconstruction.""" + class PickyHubError(Exception): + def __init__(self, message, *, response): # response required + keyword-only + super().__init__(message) + self.response = response + + monkeypatch.setattr( + xf, "_resolve_exception_class", + lambda name: PickyHubError if name == "PickyHubError" else None, + ) + fake = _install(monkeypatch, [("error", "PickyHubError: kaboom")]) + with pytest.raises(PickyHubError, match = "kaboom"): + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" + + +def test_instantiate_preserving_type_paths(): + """Layered reconstruction: a normal constructor is used when it accepts a string, else the __new__ bypass; both carry the message.""" + class Plain(Exception): + pass + + class Picky(Exception): + def __init__(self, message, *, response): + super().__init__(message) + + for cls in (Plain, Picky): + exc = xf._instantiate_preserving_type(cls, "the message") + assert isinstance(exc, cls) + assert "the message" in str(exc) + + +# Codex round: dir/ wildcard, logical-weight grouping post-download, errno preservation. +def test_parse_errno(): + assert xf._parse_errno("OSError: [Errno 28] No space left on device") == 28 + assert xf._parse_errno("OSError: [Errno 122] Disk quota exceeded") == 122 + assert xf._parse_errno("OSError: some message with no errno") is None + + +def test_oserror_errno_preserved(monkeypatch): + """A disk-full child OSError keeps its errno (ENOSPC) across the spawn boundary, not errno=None.""" + import errno as _errno + + fake = _install(monkeypatch, [("error", "OSError: [Errno 28] No space left on device")]) + with pytest.raises(OSError) as excinfo: + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert excinfo.value.errno == _errno.ENOSPC + assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" + + +def test_oserror_subclass_errno_preserved(monkeypatch): + """An OSError subclass (PermissionError) keeps both its type and errno across the spawn boundary; the message is not double-prefixed.""" + import errno as _errno + + fake = _install(monkeypatch, [("error", "PermissionError: [Errno 13] Permission denied")]) + with pytest.raises(PermissionError) as excinfo: + xf.hf_hub_download_with_xet_fallback(DL_REPO, FILE, None) + assert excinfo.value.errno == _errno.EACCES + assert "[Errno 13] [Errno 13]" not in str(excinfo.value) + assert len(fake.calls) == 1, "a deterministic error must not trigger an HTTP fallback" + + +def test_raise_child_error_errno_only_for_builtin_oserror(): + """errno is preserved only for a BUILTIN OSError type; a non-builtin OSError subclass (e.g. HfHubHTTPError) with a bracketed [Errno N] gets no spurious errno.""" + # Builtin OSError subclass -> errno preserved. + with pytest.raises(FileNotFoundError) as excinfo: + xf._raise_child_error("FileNotFoundError: [Errno 2] No such file or directory") + assert excinfo.value.errno == 2 + + # Non-builtin OSError subclass whose message merely contains a bracketed [Errno N]. + class _FakeHubHTTPError(OSError): + def __init__(self, message): # single-arg, like hf_hub's error types + super().__init__(message) + + orig = xf._resolve_exception_class + try: + xf._resolve_exception_class = ( + lambda name: _FakeHubHTTPError if name == "HfHubHTTPError" else orig(name) + ) + with pytest.raises(_FakeHubHTTPError) as excinfo2: + xf._raise_child_error("HfHubHTTPError: 500 Server Error [Errno 111] for url https://x") + assert excinfo2.value.errno is None + finally: + xf._resolve_exception_class = orig + + +# Spawn-safety regressions: failed-spawn queue cleanup + disable-Xet env-race lock. +def test_failed_spawn_closes_result_queue(monkeypatch): + """R2-2: if proc.start() raises, the result_queue's pipe fds (allocated before the spawn) are closed, not leaked.""" + closed = {"cancel_join": False, "close": False} + + class _FakeQueue: + def cancel_join_thread(self): + closed["cancel_join"] = True + + def close(self): + closed["close"] = True + + class _FakeProc: + def __init__(self, *a, **k): + self.pid = None + + def start(self): + raise OSError(errno.EAGAIN, "Resource temporarily unavailable") + + class _FakeCtx: + def Queue(self): + return _FakeQueue() + + def Process(self, *a, **k): + return _FakeProc() + + monkeypatch.setattr(xf, "_CTX", _FakeCtx()) + with pytest.raises(OSError): + xf._run_download_attempt( + "owner/repo", + kind = "snapshot", + params = {}, + token = None, + repo_type = "model", + disable_xet = False, + cancel_event = None, + stall_timeout = 1.0, + interval = 0.1, + grace_period = 0.1, + on_status = None, + ) + assert closed["close"] is True + assert closed["cancel_join"] is True + + +def test_disable_xet_read_under_spawn_lock(monkeypatch): + """R2-1: xet_force_disabled() is read while holding _SPAWN_ENV_LOCK, so a concurrent spawn's transient HF_HUB_DISABLE_XET=1 is not observed.""" + seen = {} + real = xf.xet_force_disabled + + def _spy(): + # A non-reentrant Lock's non-blocking acquire fails iff the read is inside `with _SPAWN_ENV_LOCK:`. + got = xf._SPAWN_ENV_LOCK.acquire(blocking = False) + if got: + xf._SPAWN_ENV_LOCK.release() + seen["held"] = not got + return real() + + monkeypatch.setattr(xf, "xet_force_disabled", _spy) + monkeypatch.setattr(xf, "_run_download_attempt", lambda *a, **k: ("ok", "/tmp/warm")) + out = xf._download_with_xet_fallback( + repo_id = "owner/repo", + label = "test", + kind = "snapshot", + params = {}, + token = None, + repo_type = "model", + cancel_event = None, + stall_timeout = 1.0, + interval = 0.1, + grace_period = 0.1, + on_status = None, + prepare_for_http_fn = None, + ) + assert out == "/tmp/warm" + assert seen.get("held") is True + + +# Conservative fast-path gate + pre/post-download acceptance split. The gate fast-paths +# ONLY the unambiguous canonical model cache, deferring everything else to the watched +# child. Pre-download (skip the child?) and post-download (accept the result?) are +# deliberately asymmetric: strict pre, lenient post. +def _mk_snapshot(tmp_path, name): + blob = tmp_path / "_blob" + if not blob.exists(): + blob.write_bytes(b"w") + snap = tmp_path / name + snap.mkdir() + return snap, blob + + +def test_gate_fast_paths_canonical_single_file(tmp_path): + """A complete, unpatterned single-file model cache is fast-path eligible.""" + snap, blob = _mk_snapshot(tmp_path, "single") + (snap / "model.safetensors").symlink_to(blob) + (snap / "config.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_gate_fast_paths_canonical_sharded_with_index(tmp_path): + """A complete sharded model with its index is fast-path eligible; without the index or with a listed shard missing, it is not.""" + snap, blob = _mk_snapshot(tmp_path, "shard") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is False # numbered shards, no index + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap) is True + snap2, _ = _mk_snapshot(tmp_path, "shard2") + (snap2 / "model-00001-of-00002.safetensors").symlink_to(blob) # shard 2 absent + (snap2 / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap2) is False + + +def test_shard_index_with_non_string_value_is_incomplete(tmp_path): + """A shard index mapping a tensor to a non-string value (null) is incomplete: fail closed and defer to the child.""" + snap, blob = _mk_snapshot(tmp_path, "badindex") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", "b": None}})) + assert hcs._weight_shard_index_complete(snap / "model.safetensors.index.json") is False + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_gate_defers_incomplete_preferred_index_masked_by_complete_bin(tmp_path): + """An incomplete safetensors index (probed before the bin) is not masked by a complete bin; the gate defers unless safetensors is ignored.""" + snap, blob = _mk_snapshot(tmp_path, "prefidx") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) # ST shard 2 absent -> incomplete index + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + (snap / "pytorch_model.bin").symlink_to(blob) # complete bin co-resident + assert hcs.snapshot_dir_is_complete(snap) is False # load prefers the incomplete safetensors + # safetensors ignored -> the load reads the complete bin -> eligible. + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.safetensors"]) is True + # A complete safetensors index alongside the bin is eligible. + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_gate_rejects_sharded_adapter_only_root_cache(tmp_path): + """A complete sharded adapter at root is not a canonical base model, so an adapter-only cache defers to the child.""" + assert hcs._is_canonical_weight_shard_index("adapter_model.safetensors.index.json") is False + assert hcs._is_canonical_weight_shard_index("model.safetensors.index.json") is True + assert hcs._is_canonical_weight_shard_index("pytorch_model.bin.index.json") is True + snap, blob = _mk_snapshot(tmp_path, "adapteronly") + (snap / "config.json").write_text("{}") + (snap / "adapter_model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "adapter_model-00002-of-00002.safetensors").symlink_to(blob) + (snap / "adapter_model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "adapter_model-00001-of-00002.safetensors", + "b": "adapter_model-00002-of-00002.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap) is False + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_gate_rejects_config_only(tmp_path): + snap, _ = _mk_snapshot(tmp_path, "cfg") + (snap / "config.json").write_text("{}") + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_gate_rejects_diffusers_marker(tmp_path): + """A diffusers pipeline (root model_index.json) is never fast-pathed, even with a root-level weight present.""" + snap, blob = _mk_snapshot(tmp_path, "diff") + (snap / "model_index.json").write_text("{}") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_gate_rejects_any_allow_pattern(tmp_path): + """Any allow_patterns makes the request non-trivial, so no fast-path.""" + snap, blob = _mk_snapshot(tmp_path, "pat") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["*.safetensors"]) is False + assert hcs.snapshot_dir_is_complete(snap, allow_patterns = ["model.safetensors"]) is False + + +def test_gate_eligible_under_ignore_patterns(tmp_path): + """allow=None with any ignore patterns stays eligible: ignores that drop other formats cannot make an incomplete cache read complete.""" + snap, blob = _mk_snapshot(tmp_path, "ign") + (snap / "model.safetensors").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*/*.safetensors", "*/*.bin"]) is True + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.onnx"]) is True + # The real unsloth bare-from_pretrained ignore list; the warm root model.safetensors must still fast-path. + unsloth_ignore = [ + "*.onnx", "*.h5", "*.msgpack", "*.tflite", "*.mlmodel", "*.gguf", "*.pt", "*.pth", + "*.ckpt", "optimizer.*", "scheduler.*", "rng_state*", "trainer_state.json", + "events.out.tfevents*", "*.bin", + "*/*.safetensors", "*/*.bin", "*/*.h5", "*/*.msgpack", "*/*.pt", "*/*.pth", + ] + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = unsloth_ignore) is True + + +def test_gate_rejects_broken_symlink(tmp_path): + snap, _ = _mk_snapshot(tmp_path, "broken") + (snap / "model.safetensors").symlink_to(tmp_path / "_missing") + assert hcs.snapshot_dir_is_complete(snap) is False + + +def test_request_can_include_weights_trim_semantics(): + r = hcs.request_can_include_weights + assert r(None, None) is True # bare unpatterned + assert r(None, ["*/*.safetensors", "*/*.bin"]) is True # subdir ignore (common bare) + assert r(["*.safetensors"], None) is True # weight glob + assert r(["model.fp16.safetensors"], None) is True # variant exact weight + assert r(["unet/*"], None) is True # subfolder weight glob + assert r(["model.gguf"], None) is True # gguf is a weight + assert r(["config.json", "tokenizer.json", "*.py"], None) is False # tokenizer-only + assert r(["adapter_model*", "adapter_config.json"], None) is True # adapter + assert r([], None) is False # empty allow selects nothing + + +def test_request_can_include_weights_partial_ignore_strip_is_weight_bearing(): + """An ignore-only request is weightless only when it strips EVERY weight format; a partial strip stays weight-bearing.""" + r = hcs.request_can_include_weights + assert r(None, ["model.safetensors", "pytorch_model.bin"]) is True # variant / other-format survives + assert r(None, ["*.safetensors", "*.bin"]) is True # .pt / .gguf / .pth / ... survive + # Only stripping EVERY weight format is weightless. + all_formats = ["*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", "*.ckpt", + "*.onnx", "*.msgpack", "*.h5", "*.pdparams"] + assert r(None, all_formats) is False + + +def test_pre_download_skips_complete_model(tmp_path): + snap, blob = _mk_snapshot(tmp_path, "m") + (snap / "model.safetensors").symlink_to(blob) + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_pre_download_defers_variant_on_canonical_cache(tmp_path): + """A variant load reads model..safetensors, so a canonical-only cache does not fast-path a variant request (but does with no variant).""" + snap, blob = _mk_snapshot(tmp_path, "var") + (snap / "model.safetensors").symlink_to(blob) + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + +def test_pre_download_defers_bin_only_when_safetensors_preferred(tmp_path): + """A default load probes model.safetensors before the bin, so the strict pre-gate defers a bin-only cache; the lenient post path accepts a finished bin-only download.""" + snap, blob = _mk_snapshot(tmp_path, "binonly") + (snap / "config.json").write_text("{}") + (snap / "pytorch_model.bin").symlink_to(blob) + # PRE: safetensors preferred (not ignored) + bin-only -> defer. + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # PRE: safetensors ignored -> the bin cache fast-paths. + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, + ignore_patterns = ["*.safetensors", "*.safetensors.index.json"]) is True + # PRE: safetensors present -> fast-path. + (snap / "model.safetensors").symlink_to(blob) + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # POST is lenient: a finished bin-only download is accepted. + snap2, blob2 = _mk_snapshot(tmp_path, "binonly_post") + (snap2 / "config.json").write_text("{}") + (snap2 / "pytorch_model.bin").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # POST: a sharded bin-only repo is likewise accepted. + snap3, blob3 = _mk_snapshot(tmp_path, "binonly_sharded_post") + (snap3 / "pytorch_model-00001-of-00002.bin").symlink_to(blob3) + (snap3 / "pytorch_model-00002-of-00002.bin").symlink_to(blob3) + (snap3 / "pytorch_model.bin.index.json").write_text(json.dumps( + {"weight_map": {"a": "pytorch_model-00001-of-00002.bin", + "b": "pytorch_model-00002-of-00002.bin"}})) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_pre_download_does_not_skip_diffusers_but_post_accepts(tmp_path): + """Pre/post asymmetry: a diffusers warm is not fast-pathed, but the same complete result is accepted post-download.""" + snap, blob = _mk_snapshot(tmp_path, "diff") + (snap / "model_index.json").write_text("{}") + comp = snap / "unet" + comp.mkdir() + (comp / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_rejects_config_only_model(tmp_path): + """A model warm returning no weight (stale config-only snapshot) is rejected post-download and retried.""" + snap, _ = _mk_snapshot(tmp_path, "cfg") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_rejects_ignored_only_format(tmp_path): + """A safetensors load (ignore=['*.bin']) whose result kept only the ignored .bin is rejected (the weight check applies the ignore filter).""" + snap, blob = _mk_snapshot(tmp_path, "ignfmt") + (snap / "pytorch_model.bin").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is False + # The requested safetensors present -> accepted. + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is True + + +def test_post_download_rejects_canonical_only_for_variant(tmp_path): + """A variant load whose result kept only the canonical (non-variant) weight is rejected; a present variant weight is accepted.""" + snap, blob = _mk_snapshot(tmp_path, "varpost") + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + (snap / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # A complete sharded variant set (shards + index) is accepted. + snap2, blob2 = _mk_snapshot(tmp_path, "varshard") + (snap2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model.fp16-00002-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_rejects_patterned_canonical_only_for_variant(tmp_path): + """The variant check applies to the patterned branch too: a subfolder variant request kept only the canonical weight is rejected.""" + snap, blob = _mk_snapshot(tmp_path, "subvar") + sub = snap / "weights" + sub.mkdir() + (sub / "model.safetensors").symlink_to(blob) # canonical only, no variant + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is False + # The in-scope variant weight present -> accepted. + (sub / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is True + # A complete sharded in-scope variant weight (dash infix + variant index) is accepted. + snap2, blob2 = _mk_snapshot(tmp_path, "subvarshard") + sub2 = snap2 / "weights" + sub2.mkdir() + (sub2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) + (sub2 / "model.fp16-00002-of-00002.safetensors").symlink_to(blob2) + (sub2 / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is True + # A lone variant shard with no index is an incomplete set -> rejected. + snap2b, blob2b = _mk_snapshot(tmp_path, "subvarshard_lone") + (snap2b / "weights").mkdir() + (snap2b / "weights" / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2b) + assert xf._download_result_usable( + snap2b, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is False + # An out-of-scope variant weight does not satisfy an in-scope variant request. + snap3, blob3 = _mk_snapshot(tmp_path, "subvaroos") + (snap3 / "model.fp16.safetensors").symlink_to(blob3) # at root, but request scopes to weights/ + (snap3 / "weights").mkdir() + (snap3 / "weights" / "model.safetensors").symlink_to(blob3) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = ["weights/*"], ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_rejects_variant_only_diffusers_for_plain_load(tmp_path): + """A plain diffusers warm (variant=None) whose result kept only variant component weights is rejected; complete plain / both-format pipelines are accepted.""" + def _mi(**comps): + data = {"_class_name": "StableDiffusionPipeline", "_diffusers_version": "0.21.0"} + data.update(comps) + return json.dumps(data) + + snap, blob = _mk_snapshot(tmp_path, "plainvaronly") + (snap / "model_index.json").write_text( + _mi(unet = ["diffusers", "UNet2DConditionModel"], vae = ["diffusers", "AutoencoderKL"])) + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) + # plain load: variant-only components do not satisfy it -> retry. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = None) is False + # the same cache is a complete fp16 warm -> the variant load accepts it. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # a complete plain pipeline (non-variant component weights) is accepted. + snap2, blob2 = _mk_snapshot(tmp_path, "plaincomplete") + (snap2 / "model_index.json").write_text( + _mi(unet = ["diffusers", "UNet2DConditionModel"], vae = ["diffusers", "AutoencoderKL"])) + for comp in ("unet", "vae"): + (snap2 / comp).mkdir() + (snap2 / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = None) is True + # a pipeline shipping both plain + fp16 in a component is accepted for a plain load. + (snap2 / "unet" / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, variant = None) is True + + +def test_post_download_rejects_incomplete_sharded_glob(tmp_path): + """A globbed weight request (allow=['*.safetensors']) with a shard index missing a shard is rejected (globs get the same shard-completeness check).""" + snap, blob = _mk_snapshot(tmp_path, "shardglob") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None) is False + # The missing shard present -> complete -> accepted. + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None) is True + + +def test_post_download_accepts_patterned_with_coresident_partial_canonical_shards(tmp_path): + """A complete patterned download is accepted even with an unrelated partial canonical base shard set co-resident at root.""" + def _partial_base_shards(snap, blob): + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) # shard 1 present + (snap / "model-00002-of-00002.safetensors").symlink_to(snap / "MISSING") # dangling shard 2 + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + # Adapter request completes; co-resident partial base shards must not reject it. + snap, blob = _mk_snapshot(tmp_path, "adapter_coresident") + (snap / "adapter_model.safetensors").symlink_to(blob) + (snap / "adapter_config.json").write_text("{}") + _partial_base_shards(snap, blob) + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["adapter_model.safetensors", "adapter_config.json", "*.json"], + ignore_patterns = None) is True + # gguf request completes; same co-resident partial base shards. + snap2, blob2 = _mk_snapshot(tmp_path, "gguf_coresident") + (snap2 / "model.Q4_K_M.gguf").symlink_to(blob2) + (snap2 / "config.json").write_text("{}") + _partial_base_shards(snap2, blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", + allow_patterns = ["model.Q4_K_M.gguf", "config.json", "*.json"], + ignore_patterns = None) is True + # A globbed weight request that DOES select canonical root shards still gets the gate. + snap3, blob3 = _mk_snapshot(tmp_path, "glob_still_gated") + _partial_base_shards(snap3, blob3) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None) is False + + +def test_post_download_rejects_incomplete_ignored_format_shards(tmp_path): + """A load ignoring safetensors (reads .bin) with a complete ST set but incomplete .bin set is rejected (the shard gate applies the ignore filter).""" + snap, blob = _mk_snapshot(tmp_path, "ignored_format_shards") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + (snap / "pytorch_model-00001-of-00002.bin").symlink_to(blob) # bin shard 1 only, no index/shard 2 + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, + ignore_patterns = ["*.safetensors"]) is False + # Ignoring the .bin instead (load reads the complete safetensors) -> accepted. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is True + + +def test_post_download_rejects_incomplete_variant_shards(tmp_path): + """A variant load with a variant shard index missing a listed shard is rejected; a complete set and a single-file variant are accepted.""" + snap, blob = _mk_snapshot(tmp_path, "variant_incomplete") + (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) # shard 1; shard 2 absent + (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # A lone variant shard with no index is also incomplete. + snap_noidx, blob_ni = _mk_snapshot(tmp_path, "variant_no_index") + (snap_noidx / "model.fp16-00001-of-00002.safetensors").symlink_to(blob_ni) + assert xf._download_result_usable( + snap_noidx, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # The missing variant shard present -> complete set -> accepted. + (snap / "model.fp16-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # A single-file variant (no index) is accepted. + snap2, blob2 = _mk_snapshot(tmp_path, "variant_single") + (snap2 / "model.fp16.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_accepts_exact_named_shard_subset(tmp_path): + """An exact-named shard request is accepted once that file is present (the whole-checkpoint gate applies only to globbed warms); an absent named shard is rejected.""" + snap, blob = _mk_snapshot(tmp_path, "exact_shard_present") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["model-00001-of-00002.safetensors"], ignore_patterns = None) is True + # The exact-named shard absent -> rejected. + snap2, _ = _mk_snapshot(tmp_path, "exact_shard_absent") + (snap2 / "config.json").write_text("{}") + assert xf._download_result_usable( + snap2, repo_type = "model", + allow_patterns = ["model-00001-of-00002.safetensors"], ignore_patterns = None) is False + + +def test_post_download_accepts_from_tf_flax_weights(tmp_path): + """A from_tf/from_flax load (both PyTorch formats ignored) reading tf_model.h5 / flax_model.msgpack is accepted when complete.""" + ig = ["*.safetensors", "*.safetensors.index.json", "*.bin", "*.bin.index.json"] + for wt in ("tf_model.h5", "flax_model.msgpack"): + snap, blob = _mk_snapshot(tmp_path, f"tf_{wt}") + (snap / wt).symlink_to(blob) + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is True + # Both PyTorch formats ignored but no h5/msgpack present -> still rejected. + snap, _ = _mk_snapshot(tmp_path, "tf_none") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is False + # A normal load (PyTorch format not ignored): a stray leftover h5 does not count, so an h5-only repo is rejected. + snap, blob = _mk_snapshot(tmp_path, "stray_h5") + (snap / "tf_model.h5").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_checks_sharded_tf_flax_completeness(tmp_path): + """Sharded TF/Flax weights: a complete set (index + all shards) is accepted, an incomplete one (missing shard or lone shard, no index) rejected.""" + ig = ["*.safetensors", "*.safetensors.index.json", "*.bin", "*.bin.index.json"] + for base, ext in (("tf_model", "h5"), ("flax_model", "msgpack")): + idx = json.dumps({"weight_map": {"a": f"{base}-00001-of-00002.{ext}", + "b": f"{base}-00002-of-00002.{ext}"}}) + # Complete sharded set -> accepted. + snap, blob = _mk_snapshot(tmp_path, f"tfshard_ok_{base}") + (snap / f"{base}-00001-of-00002.{ext}").symlink_to(blob) + (snap / f"{base}-00002-of-00002.{ext}").symlink_to(blob) + (snap / f"{base}.{ext}.index.json").write_text(idx) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is True + # A shard listed by the index is missing -> rejected. + snap2, blob2 = _mk_snapshot(tmp_path, f"tfshard_missing_{base}") + (snap2 / f"{base}-00001-of-00002.{ext}").symlink_to(blob2) + (snap2 / f"{base}.{ext}.index.json").write_text(idx) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is False + # A lone shard with no index -> rejected. + snap3, blob3 = _mk_snapshot(tmp_path, f"tfshard_lone_{base}") + (snap3 / f"{base}-00001-of-00002.{ext}").symlink_to(blob3) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = ig) is False + + +def test_post_download_checks_explicit_checkpoint_shard_completeness(tmp_path): + """An explicit checkpoint load (allow=['checkpoint-N/*']) reads its weights, so a lone shard there with no index is rejected; a complete set / untargeted leftover is fine.""" + # Lone checkpoint shard, no index, explicitly requested -> rejected. + snap, blob = _mk_snapshot(tmp_path, "ckpt_lone") + (snap / "checkpoint-7").mkdir() + (snap / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["checkpoint-7/*"], ignore_patterns = None) is False + # Complete checkpoint shard set (index + all shards) -> accepted. + snap2, blob2 = _mk_snapshot(tmp_path, "ckpt_complete") + (snap2 / "checkpoint-7").mkdir() + (snap2 / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "checkpoint-7" / "model-00002-of-00002.safetensors").symlink_to(blob2) + (snap2 / "checkpoint-7" / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["checkpoint-7/*"], ignore_patterns = None) is True + # A leftover checkpoint the request does not target (subfolder=unet) must not reject a complete in-scope download. + snap3, blob3 = _mk_snapshot(tmp_path, "ckpt_leftover") + (snap3 / "unet").mkdir() + (snap3 / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob3) + (snap3 / "checkpoint-7").mkdir() + (snap3 / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob3) # lone, but not read + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is True + # A nested checkpoint the request explicitly targets (allow=['foo/checkpoint-7/*']) still rejects its lone shard + # (the scope check matches a checkpoint dir at any leading segment). + snap4, blob4 = _mk_snapshot(tmp_path, "ckpt_nested") + (snap4 / "foo" / "checkpoint-7").mkdir(parents = True) + (snap4 / "foo" / "checkpoint-7" / "model-00001-of-00002.safetensors").symlink_to(blob4) + assert xf._download_result_usable( + snap4, repo_type = "model", + allow_patterns = ["foo/checkpoint-7/*"], ignore_patterns = None) is False + + +def test_post_download_accepts_exact_named_variant_shard_subset(tmp_path): + """An exact variant shard request is accepted once present (index/sibling absent); the exact-name escape applies to the variant branch too.""" + snap, blob = _mk_snapshot(tmp_path, "exact_var_shard") + (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["model.fp16-00001-of-00002.safetensors"], ignore_patterns = None, + variant = "fp16") is True + # The exact-named variant shard absent -> still rejected. + snap2, _ = _mk_snapshot(tmp_path, "exact_var_absent") + (snap2 / "config.json").write_text("{}") + assert xf._download_result_usable( + snap2, repo_type = "model", + allow_patterns = ["model.fp16-00001-of-00002.safetensors"], ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_rejects_patterned_incomplete_variant_shards(tmp_path): + """A globbed variant request with a lone root variant shard (no index) is rejected too; a complete in-scope set is accepted.""" + snap, blob = _mk_snapshot(tmp_path, "pat_var_incomplete") + (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) # lone shard, no index + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None, + variant = "fp16") is False + # Complete variant shard set -> accepted. + (snap / "model.fp16-00002-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_applies_ignore_to_diffusers_components(tmp_path): + """A diffusers warm ignoring a format is not satisfied by a component weight in that ignored format (only unet/*.bin under ignore=['*.bin'] is rejected).""" + snap, blob = _mk_snapshot(tmp_path, "diff_ignore") + (snap / "model_index.json").write_text("{}") + (snap / "unet").mkdir() + (snap / "unet" / "diffusion_pytorch_model.bin").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is False + # The safetensors component present -> usable. + (snap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is True + + +def test_post_download_rejects_index_only_sharded_masked_by_bin(tmp_path): + """A safetensors index with none of its shards (index-only), co-resident with a complete bin, is rejected (the index is probed before the bin).""" + snap, blob = _mk_snapshot(tmp_path, "index_only") + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + (snap / "pytorch_model.bin").symlink_to(blob) # complete bin, no ST shards at all + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # safetensors explicitly ignored -> load reads the complete bin -> usable. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.safetensors"]) is True + + +def test_post_download_patterned_shard_check_honors_ignore(tmp_path): + """A patterned request ignoring safetensors selects the complete .bin; a co-resident incomplete ST set does not reject it (the check applies the ignore filter).""" + snap, blob = _mk_snapshot(tmp_path, "pat_ignore") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) # incomplete ST (shard 2 absent) + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + (snap / "pytorch_model.bin").symlink_to(blob) # complete bin, the selected format + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*"], ignore_patterns = ["*.safetensors"]) is True + + +def test_post_download_variant_root_shard_check_scoped_to_selection(tmp_path): + """A subfolder variant request with a complete selected weight is accepted even with a stale out-of-scope root variant shard co-resident.""" + snap, blob = _mk_snapshot(tmp_path, "var_scope") + (snap / "unet").mkdir() + (snap / "unet" / "model.fp16.safetensors").symlink_to(blob) # complete in-scope variant + (snap / "model.fp16-00001-of-00002.safetensors").symlink_to(blob) # stale ROOT variant shard, oos + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None, + variant = "fp16") is True + # A GLOBBED root variant request DOES get the root variant-shard check. + snap2, blob2 = _mk_snapshot(tmp_path, "var_scope_glob") + (snap2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) # lone root variant shard + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["*.safetensors"], ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_root_variant_weight_honors_ignore(tmp_path): + """A variant load ignoring .bin is not satisfied by a variant .bin (only model.fp16.bin under ignore=['*.bin'] is rejected).""" + snap, blob = _mk_snapshot(tmp_path, "var_ignore") + (snap / "model.fp16.bin").symlink_to(blob) # only the ignored-format variant + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"], + variant = "fp16") is False + # The safetensors variant present -> usable. + (snap / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"], + variant = "fp16") is True + + +def test_post_download_variant_shard_check_honors_ignore(tmp_path): + """A variant load ignoring .bin judges the variant shard set for the read format only: a complete ST variant beside a stale ignored .bin shard is accepted.""" + snap, blob = _mk_snapshot(tmp_path, "var_shard_ignore") + (snap / "model.fp16.safetensors").symlink_to(blob) # complete, the read format + (snap / "model.fp16-00001-of-00002.bin").symlink_to(blob) # stale ignored .bin shard, no index + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"], + variant = "fp16") is True + # Without the ignore, the complete safetensors variant is preferred and read. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True # the complete safetensors variant is preferred and read + + +def test_post_download_rejects_variant_index_only_masked_by_bin(tmp_path): + """The variant analog of index-only: a variant ST index with none of its shards, beside a complete variant bin, is rejected.""" + snap, blob = _mk_snapshot(tmp_path, "var_index_only") + (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + (snap / "pytorch_model.fp16.bin").symlink_to(blob) # complete bin, no ST variant shards at all + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # The variant safetensors ignored -> load reads the complete variant bin -> usable. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.safetensors"], + variant = "fp16") is True + + +def test_post_download_rejects_incomplete_sharded_adapter(tmp_path): + """A PEFT adapter load whose partial has a sharded adapter index missing a shard is rejected (the selected-index check covers the non-model adapter index).""" + snap, blob = _mk_snapshot(tmp_path, "adapter_incomplete") + (snap / "adapter_config.json").write_text("{}") + (snap / "adapter_model-00001-of-00002.safetensors").symlink_to(blob) # shard 1; shard 2 absent + (snap / "adapter_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "adapter_model-00001-of-00002.safetensors", + "b": "adapter_model-00002-of-00002.safetensors"}})) + allow = ["adapter_config.json", "adapter_model*"] + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = allow, ignore_patterns = None) is False + # The missing adapter shard present -> complete set -> accepted. + (snap / "adapter_model-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = allow, ignore_patterns = None) is True + + +def test_post_download_rejects_incomplete_component_subfolder_shards(tmp_path): + """A subfolder request (allow=['unet/*']) whose component has a shard index missing a shard is rejected (the selected-index check covers component subfolders).""" + snap, blob = _mk_snapshot(tmp_path, "component_incomplete") + (snap / "unet").mkdir() + (snap / "unet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "unet" / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is False + (snap / "unet" / "diffusion_pytorch_model-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is True + + +def test_post_download_rejects_gguf_only_default_load(tmp_path): + """A default warm reads safetensors/bin, not GGUF, so a snapshot holding only a .gguf is rejected.""" + snap, blob = _mk_snapshot(tmp_path, "gguf_only") + (snap / "model.Q4_K_M.gguf").symlink_to(blob) + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # The safetensors weight present -> the default warm accepts, even beside the gguf. + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_rejects_adapter_variant_for_default_variant_load(tmp_path): + """A variant warm reads the root model variant, not an adapter variant, so an adapter-variant-only snapshot is rejected; the base variant present is accepted.""" + snap, blob = _mk_snapshot(tmp_path, "adapter_variant_only") + (snap / "adapter_model.fp16.safetensors").symlink_to(blob) + (snap / "adapter_config.json").write_text("{}") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + (snap / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_accepts_complete_diffusers_variant(tmp_path): + """A diffusers variant warm's weights are component-scoped, so a complete diffusers variant download is accepted; a non-variant-only pipeline does not satisfy a variant load.""" + snap, blob = _mk_snapshot(tmp_path, "diffusers_variant") + (snap / "model_index.json").write_text("{}") + (snap / "unet").mkdir() + (snap / "unet" / "config.json").write_text("{}") + (snap / "unet" / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + snap2, blob2 = _mk_snapshot(tmp_path, "diffusers_variant_missing") + (snap2 / "model_index.json").write_text("{}") + (snap2 / "unet").mkdir() + (snap2 / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob2) # non-variant only + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_rejects_incomplete_diffusers_component_shards_unpatterned(tmp_path): + """An unpatterned diffusers warm rejects a component shard index listing an absent shard; plain and variant indexes are covered, a complete set accepted.""" + snap, blob = _mk_snapshot(tmp_path, "diffusers_comp_incomplete") + (snap / "model_index.json").write_text("{}") + (snap / "unet").mkdir() + (snap / "unet" / "config.json").write_text("{}") + (snap / "unet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "unet" / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + (snap / "unet" / "diffusion_pytorch_model-00002-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # Same for a variant component index (variant='fp16', unpatterned). + snapv, blobv = _mk_snapshot(tmp_path, "diffusers_comp_variant_incomplete") + (snapv / "model_index.json").write_text("{}") + (snapv / "unet").mkdir() + (snapv / "unet" / "diffusion_pytorch_model.fp16-00001-of-00002.safetensors").symlink_to(blobv) + (snapv / "unet" / "diffusion_pytorch_model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model.fp16-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snapv, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + (snapv / "unet" / "diffusion_pytorch_model.fp16-00002-of-00002.safetensors").symlink_to(blobv) + assert xf._download_result_usable( + snapv, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_single_safetensors_beats_stale_index(tmp_path): + """A complete single model.safetensors (probed before the index) beside a stale incomplete index is usable; a stale index with no single weight is breakage.""" + snap, blob = _mk_snapshot(tmp_path, "single_beats_index") + (snap / "config.json").write_text("{}") + (snap / "model.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) # shards absent + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + assert hcs.snapshot_dir_is_complete(snap) is True # the PRE gate agrees + # No single weight, only the stale index -> incomplete. + snap2, _ = _mk_snapshot(tmp_path, "stale_index_only") + (snap2 / "config.json").write_text("{}") + (snap2 / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_rejects_noncanonical_root_weight_for_default_load(tmp_path): + """A default load probes only canonical names, so a cache holding only a non-canonical root weight (consolidated.safetensors) is rejected; the canonical weight present is accepted.""" + snap, blob = _mk_snapshot(tmp_path, "noncanonical") + (snap / "config.json").write_text("{}") + (snap / "consolidated.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_diffusers_component_check_scoped_to_declared_components(tmp_path): + """The component shard check is scoped to declared components: a stale undeclared subtree does not reject a complete pipeline; an incomplete declared component still does.""" + snap, blob = _mk_snapshot(tmp_path, "declared_scope") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], "vae": ["diffusers", "AutoencoderKL"]})) + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "config.json").write_text("{}") + (snap / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob) + (snap / "controlnet").mkdir() # UNDECLARED leftover + (snap / "controlnet" / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) + (snap / "controlnet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # An incomplete declared component (unet index missing a shard) is still caught. + snap2, blob2 = _mk_snapshot(tmp_path, "declared_incomplete") + (snap2 / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "UNet2DConditionModel"]})) + (snap2 / "unet").mkdir() + (snap2 / "unet" / "diffusion_pytorch_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model-00002-of-00002.safetensors"}})) + (snap2 / "unet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_variant_presence_requires_canonical_name(tmp_path): + """The variant presence check counts only a canonical variant name; a non-canonical sidecar or dot-infix shard does not satisfy the request.""" + snap, blob = _mk_snapshot(tmp_path, "var_noncanonical") + (snap / "config.json").write_text("{}") + (snap / "consolidated.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + snap_dot, blob_dot = _mk_snapshot(tmp_path, "var_dotinfix") + (snap_dot / "config.json").write_text("{}") + (snap_dot / "model-00001-of-00001.fp16.safetensors").symlink_to(blob_dot) # not a name tf reads + assert xf._download_result_usable( + snap_dot, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # The canonical single variant weight -> accepted. + (snap / "model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_rejects_selected_shard_without_index(tmp_path): + """A selected non-root numbered shard with no index is an incomplete set the load cannot enumerate, so it is rejected; a complete indexed set is accepted.""" + # A sharded adapter with a lone shard and no index. + snap, blob = _mk_snapshot(tmp_path, "adapter_lone_shard") + (snap / "config.json").write_text("{}") + (snap / "adapter_config.json").write_text("{}") + (snap / "adapter_model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is False + # Complete it with the second shard + index -> accepted. + (snap / "adapter_model-00002-of-00002.safetensors").symlink_to(blob) + (snap / "adapter_model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "adapter_model-00001-of-00002.safetensors", + "b": "adapter_model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is True + # A component subfolder lone shard (allow=['unet/*']) is likewise rejected. + snap2, blob2 = _mk_snapshot(tmp_path, "unet_lone_shard") + (snap2 / "unet").mkdir() + (snap2 / "unet" / "config.json").write_text("{}") + (snap2 / "unet" / "diffusion_pytorch_model-00001-of-00002.safetensors").symlink_to(blob2) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is False + + +def test_post_download_diffusers_presence_scoped_to_declared(tmp_path): + """A diffusers warm counts a component weight only for a declared component, so an undeclared-leftover-only cache is rejected; declared components present are accepted.""" + snap, blob = _mk_snapshot(tmp_path, "diffusers_undeclared_only") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + (snap / "controlnet").mkdir() # UNDECLARED + (snap / "controlnet" / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # The declared components present -> accepted. + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_diffusers_variant_presence_scoped_to_declared(tmp_path): + """Variant twin of the declared-scope check: a diffusers variant warm counts a component variant weight only for a declared component.""" + snap, blob = _mk_snapshot(tmp_path, "diffusers_variant_undeclared_only") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + (snap / "controlnet").mkdir() # UNDECLARED + (snap / "controlnet" / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + # The declared component variant weights present -> accepted. + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_rejects_diffusers_missing_declared_component(tmp_path): + """A declared weight-bearing component absent (or holding only its config) is retried over HTTP, not + accepted -- else the in-process pipeline load fetches the missing component over un-killable Xet. A + ``[null, null]`` (disabled) component and weightless components (scheduler / tokenizer) are not + required.""" + def _mi(): + return json.dumps({ + "_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], + "vae": ["diffusers", "AutoencoderKL"], + "text_encoder": ["transformers", "CLIPTextModel"], + "scheduler": ["diffusers", "PNDMScheduler"], + "tokenizer": ["transformers", "CLIPTokenizer"], + "safety_checker": [None, None], + }) + + def _model_comp(root, name, blob, *, weight = True, variant = None): + d = root / name + d.mkdir() + (d / "config.json").write_text("{}") + if weight: + w = "diffusion_pytorch_model.safetensors" if variant is None \ + else f"diffusion_pytorch_model.{variant}.safetensors" + (d / w).symlink_to(blob) + + def _weightless(root, name): + d = root / name + d.mkdir() + # scheduler_config.json / tokenizer_config.json -- NOT a plain config.json, so no weight required + (d / f"{name}_config.json").write_text("{}") + + # unet present, vae ABSENT (text_encoder present) -> reject. + snap, blob = _mk_snapshot(tmp_path, "diff_missing_vae") + (snap / "model_index.json").write_text(_mi()) + _model_comp(snap, "unet", blob) + _model_comp(snap, "text_encoder", blob) + _weightless(snap, "scheduler") + _weightless(snap, "tokenizer") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + # vae present with config ONLY (weight missing) -> reject. + snap2, blob2 = _mk_snapshot(tmp_path, "diff_vae_config_only") + (snap2 / "model_index.json").write_text(_mi()) + _model_comp(snap2, "unet", blob2) + _model_comp(snap2, "text_encoder", blob2) + _model_comp(snap2, "vae", blob2, weight = False) # config, no weight + _weightless(snap2, "scheduler") + _weightless(snap2, "tokenizer") + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + # Every weight-bearing component complete; safety_checker [null,null] absent -> accept (no false-reject). + snap3, blob3 = _mk_snapshot(tmp_path, "diff_complete") + (snap3 / "model_index.json").write_text(_mi()) + for c in ("unet", "vae", "text_encoder"): + _model_comp(snap3, c, blob3) + _weightless(snap3, "scheduler") + _weightless(snap3, "tokenizer") + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + # Variant: vae absent -> reject; a mixed pipeline (unet fp16, vae/text_encoder canonical fallback) -> accept. + snap4, blob4 = _mk_snapshot(tmp_path, "diff_variant_missing_vae") + (snap4 / "model_index.json").write_text(_mi()) + _model_comp(snap4, "unet", blob4, variant = "fp16") + _model_comp(snap4, "text_encoder", blob4, variant = "fp16") + _weightless(snap4, "scheduler") + _weightless(snap4, "tokenizer") + assert xf._download_result_usable( + snap4, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + snap5, blob5 = _mk_snapshot(tmp_path, "diff_variant_mixed") + (snap5 / "model_index.json").write_text(_mi()) + _model_comp(snap5, "unet", blob5, variant = "fp16") + _model_comp(snap5, "vae", blob5) # canonical fallback for this component + _model_comp(snap5, "text_encoder", blob5) # canonical fallback + _weightless(snap5, "scheduler") + _weightless(snap5, "tokenizer") + assert xf._download_result_usable( + snap5, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_rejects_diffusers_component_dir_without_weight_or_config(tmp_path): + """A declared active component dir that is non-empty but holds NEITHER a config.json NOR a readable + weight NOR a weightless sidecar (only a stray file or a nested checkpoint) is an incomplete model + component -> retried over HTTP, not accepted. A weightless component (its *_config.json sidecar) still + passes without a weight.""" + def _mi(): + return json.dumps({ + "_class_name": "StableDiffusionPipeline", + "unet": ["diffusers", "UNet2DConditionModel"], + "vae": ["diffusers", "AutoencoderKL"], + "scheduler": ["diffusers", "PNDMScheduler"], + "tokenizer": ["transformers", "CLIPTokenizer"], + }) + + def _base(root, name, blob): + (root / "model_index.json").write_text(_mi()) + d = root / "unet" + d.mkdir() + (d / "config.json").write_text("{}") + (d / "diffusion_pytorch_model.safetensors").symlink_to(blob) + for n in ("scheduler", "tokenizer"): + wl = root / n + wl.mkdir() + (wl / f"{n}_config.json").write_text("{}") + + # vae dir holds only a nested training-checkpoint weight (no top-level config / weight) -> reject. + snap, blob = _mk_snapshot(tmp_path, "diff_comp_nested_ckpt") + _base(snap, "unet", blob) + ckpt = snap / "vae" / "checkpoint-500" + ckpt.mkdir(parents = True) + (ckpt / "diffusion_pytorch_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + # vae dir holds only a stray non-weight file -> reject. + snap2, blob2 = _mk_snapshot(tmp_path, "diff_comp_stray") + _base(snap2, "unet", blob2) + (snap2 / "vae").mkdir() + (snap2 / "vae" / "README.md").write_text("x") + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + # Complete: vae holds config + weight -> accept (weightless scheduler/tokenizer need no weight). + snap3, blob3 = _mk_snapshot(tmp_path, "diff_comp_complete") + _base(snap3, "unet", blob3) + (snap3 / "vae").mkdir() + (snap3 / "vae" / "config.json").write_text("{}") + (snap3 / "vae" / "diffusion_pytorch_model.safetensors").symlink_to(blob3) + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_single_variant_beats_stale_variant_index(tmp_path): + """Variant twin of single-beats-index: a complete single variant weight beside a stale variant index is usable (ST and bin); a stale index with no single weight is breakage.""" + snap, blob = _mk_snapshot(tmp_path, "single_variant_beats_index") + (snap / "config.json").write_text("{}") + (snap / "model.fp16.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) # shards absent + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # A single .bin variant beside a stale .bin variant index (no ST) -> usable. + snapb, blobb = _mk_snapshot(tmp_path, "single_bin_variant_beats_index") + (snapb / "config.json").write_text("{}") + (snapb / "pytorch_model.fp16.bin").symlink_to(blobb) + (snapb / "pytorch_model.bin.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "pytorch_model.fp16-00001-of-00002.bin"}})) + assert xf._download_result_usable( + snapb, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # A stale variant index with no single variant weight -> incomplete. + snap2, _ = _mk_snapshot(tmp_path, "stale_variant_index_only") + (snap2 / "config.json").write_text("{}") + (snap2 / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_diffusers_skips_root_model_shard_checks(tmp_path): + """A diffusers pipeline reads component subfolders, so a stale root model shard index (canonical or variant) is accepted; component completeness is still enforced.""" + # Plain: stale root model.safetensors.index.json alongside complete components. + snap, blob = _mk_snapshot(tmp_path, "diffusers_stale_root_index_plain") + (snap / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + for comp in ("unet", "vae"): + (snap / comp).mkdir() + (snap / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) # shards absent (stale) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # Variant: stale root model.safetensors.index.fp16.json alongside complete variant components. + snapv, blobv = _mk_snapshot(tmp_path, "diffusers_stale_root_index_variant") + (snapv / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + for comp in ("unet", "vae"): + (snapv / comp).mkdir() + (snapv / comp / "diffusion_pytorch_model.fp16.safetensors").symlink_to(blobv) + (snapv / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snapv, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # An incomplete declared component is still rejected. + (snapv / "unet" / "diffusion_pytorch_model.fp16.safetensors").unlink() + (snapv / "unet" / "diffusion_pytorch_model.fp16-00001-of-00002.safetensors").symlink_to(blobv) + (snapv / "unet" / "diffusion_pytorch_model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "diffusion_pytorch_model.fp16-00001-of-00002.safetensors", + "b": "diffusion_pytorch_model.fp16-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snapv, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_out_of_scope_malformed_index_not_rejected(tmp_path): + """A malformed shard index the request does not select is not read, so it does not reject a complete in-scope download; an in-scope malformed index is breakage.""" + snap, blob = _mk_snapshot(tmp_path, "malformed_out_of_scope") + (snap / "config.json").write_text("{}") + (snap / "model.safetensors").symlink_to(blob) + (snap / "adapter_model.safetensors.index.json").write_text("{ not valid json") # malformed, unselected + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["model*"], ignore_patterns = None) is True + # An in-scope malformed index (adapter warm selects adapter_model*) is still rejected. + snap2, blob2 = _mk_snapshot(tmp_path, "malformed_in_scope") + (snap2 / "config.json").write_text("{}") + (snap2 / "adapter_config.json").write_text("{}") + (snap2 / "adapter_model-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "adapter_model.safetensors.index.json").write_text("{ not valid json") + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is False + + +def test_selected_readable_weight_complete_entry_point(tmp_path): + """The acceptance helper enforces (A) a readable weight present, (B) its shard set complete: exercise present+complete, absent, and incomplete-shard cases.""" + # Present + complete single weight -> True. + snap, blob = _mk_snapshot(tmp_path, "srwc_ok") + (snap / "model.safetensors").symlink_to(blob) + assert xf._selected_readable_weight_complete( + snap, allow_patterns = None, ignore_patterns = None, variant = None) is True + # Invariant A fails: only an ignored-format weight present -> False. + snap2, blob2 = _mk_snapshot(tmp_path, "srwc_ignored") + (snap2 / "pytorch_model.bin").symlink_to(blob2) + assert xf._selected_readable_weight_complete( + snap2, allow_patterns = None, ignore_patterns = ["*.bin"], variant = None) is False + # Invariant B fails: readable weight present but its shard set incomplete -> False. + snap3, blob3 = _mk_snapshot(tmp_path, "srwc_incomplete") + (snap3 / "model-00001-of-00002.safetensors").symlink_to(blob3) + (snap3 / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._selected_readable_weight_complete( + snap3, allow_patterns = None, ignore_patterns = None, variant = None) is False + + +def test_post_download_accepts_dataset_without_weight(tmp_path): + snap, blob = _mk_snapshot(tmp_path, "ds") + (snap / "data.parquet").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "dataset", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_accepts_either_format_single_present(tmp_path): + """An either-format named request against a safetensors-only repo is accepted (not retried for the .bin the repo does not publish).""" + snap, blob = _mk_snapshot(tmp_path, "either") + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["pytorch_model.bin", "model.safetensors"], ignore_patterns = None) is True + + +def test_pre_download_skips_intact_tokenizer_only(tmp_path): + """A tokenizer-only (weightless) warm short-circuits offline: an intact requested subset is enough, no weight required.""" + snap, _ = _mk_snapshot(tmp_path, "tok") + (snap / "tokenizer.json").write_text("{}") + (snap / "config.json").write_text("{}") + assert xf._cache_can_skip_download( + snap, repo_type = "model", + allow_patterns = ["tokenizer.json", "config.json"], ignore_patterns = None) is True + + +def test_pre_download_partial_ignore_does_not_skip_config_only(tmp_path): + """An ignore-only request stripping only some weight formats on a config-only cache must not skip the child (a surviving weight format would hang over Xet).""" + snap, _ = _mk_snapshot(tmp_path, "cfgign") + (snap / "config.json").write_text("{}") + assert xf._cache_can_skip_download( + snap, repo_type = "model", allow_patterns = None, + ignore_patterns = ["*.safetensors", "*.bin"]) is False + + +# Review-round regression guards (10-reviewer findings) +def test_gate_rejects_malformed_shard_index(tmp_path): + """A truncated / non-dict / empty weight-shard index does not read as complete (_weight_shard_index_complete is fail-closed).""" + snap, blob = _mk_snapshot(tmp_path, "malidx") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text("{not valid json") + assert hcs.snapshot_dir_is_complete(snap) is False + # Empty weight_map proves nothing. + snap2, blob2 = _mk_snapshot(tmp_path, "emptyidx") + (snap2 / "model-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model.safetensors.index.json").write_text(json.dumps({"weight_map": {}})) + assert hcs.snapshot_dir_is_complete(snap2) is False + # weight_map not a dict. + snap3, blob3 = _mk_snapshot(tmp_path, "listidx") + (snap3 / "model.safetensors.index.json").write_text(json.dumps({"weight_map": ["a", "b"]})) + assert hcs._weight_shard_index_complete(snap3 / "model.safetensors.index.json") is False + + +def test_shard_index_rejects_unsafe_path_refs(tmp_path): + """An attacker-influenced shard value (absolute, drive-letter, UNC, parent-escaping) is rejected so ``base / shard`` cannot probe outside the snapshot.""" + # Unit: the helper flags every escape variant and keeps legit relative names. + for bad in ["/etc/passwd", r"C:\evil.safetensors", "C:evil.safetensors", r"\\srv\share\x", + "../../x.safetensors", r"..\x.safetensors", "a/../../b"]: + assert hcs._is_unsafe_shard_ref(bad) is True, bad + for ok in ["model-00001-of-00002.safetensors", "unet/diffusion_pytorch_model.safetensors", + "model.fp16.safetensors"]: + assert hcs._is_unsafe_shard_ref(ok) is False, ok + # A crafted index listing a drive-letter shard is not "complete". + snap, blob = _mk_snapshot(tmp_path, "unsafe_shard_idx") + (snap / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": r"C:\Windows\System32\x.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert hcs._weight_shard_index_complete(snap / "model.safetensors.index.json") is False + # The enumerator returns None (defer) rather than a path escaping the snapshot. + assert hcs._index_shard_rel_paths(snap / "model.safetensors.index.json", "") is None + # A well-formed relative index still enumerates + validates. + snap2, blob2 = _mk_snapshot(tmp_path, "safe_shard_idx") + (snap2 / "model-00001-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model-00002-of-00002.safetensors").symlink_to(blob2) + (snap2 / "model.safetensors.index.json").write_text(json.dumps( + {"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert hcs._weight_shard_index_complete(snap2 / "model.safetensors.index.json") is True + assert set(hcs._index_shard_rel_paths(snap2 / "model.safetensors.index.json", "")) == { + "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"} + + +def test_malformed_index_scope_honors_ignored_format(tmp_path): + """A malformed index is judged by the weight the load reads, so a malformed index for an ignored format is skipped; a malformed index of the read format is breakage.""" + # Patterned subfolder warm reading safetensors: a co-resident malformed bin index is ignored. + snap, blob = _mk_snapshot(tmp_path, "malformed_ignored_bin_idx") + (snap / "unet").mkdir() + (snap / "unet" / "config.json").write_text("{}") + (snap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(blob) + (snap / "unet" / "diffusion_pytorch_model.bin.index.json").write_text("{ truncated") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = ["*.bin"]) is True + # The malformed index of the READ format (safetensors, not ignored) is still breakage. + snap2, _ = _mk_snapshot(tmp_path, "malformed_read_st_idx") + (snap2 / "unet").mkdir() + (snap2 / "unet" / "config.json").write_text("{}") + (snap2 / "unet" / "diffusion_pytorch_model.safetensors.index.json").write_text("{ truncated") + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = ["unet/*"], ignore_patterns = None) is False + # Diffusers: a malformed component bin index under ignore=['*.bin'] does not reject a complete ST pipeline. + snap3, blob3 = _mk_snapshot(tmp_path, "malformed_diffusers_bin_idx") + (snap3 / "model_index.json").write_text(json.dumps( + {"_class_name": "P", "unet": ["diffusers", "U"], "vae": ["diffusers", "V"]})) + for comp in ("unet", "vae"): + (snap3 / comp).mkdir() + (snap3 / comp / "diffusion_pytorch_model.safetensors").symlink_to(blob3) + (snap3 / "unet" / "diffusion_pytorch_model.bin.index.json").write_text("{ truncated") + assert xf._download_result_usable( + snap3, repo_type = "model", allow_patterns = None, ignore_patterns = ["*.bin"]) is True + + +def test_gate_ignored_canonical_weight_does_not_prove_complete(tmp_path): + """A canonical weight whose format the request ignores does not prove completeness (a bin-only cache under ignore=['*.bin'] defers to the child).""" + snap, blob = _mk_snapshot(tmp_path, "ignbin") + (snap / "config.json").write_text("{}") + (snap / "pytorch_model.bin").symlink_to(blob) + assert hcs.snapshot_dir_is_complete(snap, ignore_patterns = ["*.bin"]) is False + # Without the ignore, the present .bin is what a default load reads -> complete. + assert hcs.snapshot_dir_is_complete(snap) is True + # A .bin shard index is also discarded when *.bin is ignored (the format probe catches its .json sidecar). + snap2, blob2 = _mk_snapshot(tmp_path, "ignbinshard") + (snap2 / "pytorch_model-00001-of-00001.bin").symlink_to(blob2) + (snap2 / "pytorch_model.bin.index.json").write_text( + json.dumps({"weight_map": {"a": "pytorch_model-00001-of-00001.bin"}})) + assert hcs.snapshot_dir_is_complete(snap2, ignore_patterns = ["*.bin"]) is False + assert hcs.snapshot_dir_is_complete(snap2) is True + # A safetensors warm survives an *.bin ignore. + snap3, blob3 = _mk_snapshot(tmp_path, "stignbin") + (snap3 / "config.json").write_text("{}") + (snap3 / "model.safetensors").symlink_to(blob3) + assert hcs.snapshot_dir_is_complete(snap3, ignore_patterns = ["*.bin"]) is True + + +def test_post_download_accepts_weightless_patterned_result(tmp_path): + """A genuinely weightless patterned result (allow=['tokenizer*']) is accepted post-download; the no-weight rejection stays for an unpatterned model warm.""" + snap, _ = _mk_snapshot(tmp_path, "tokglob") + (snap / "tokenizer.json").write_text("{}") + (snap / "tokenizer_config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer*"], ignore_patterns = None) is True + # An unpatterned model warm with no weight is still rejected (stale config-only snapshot). + cfg, _ = _mk_snapshot(tmp_path, "cfgonly") + (cfg / "config.json").write_text("{}") + assert xf._download_result_usable( + cfg, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_gate_rejects_variant_only_shard_index(tmp_path): + """A variant-only shard index does not satisfy the canonical allow=None fast path (snapshot_dir_is_complete is variant-blind); a canonical index does.""" + snap, blob = _mk_snapshot(tmp_path, "variant") + (snap / "config.json").write_text("{}") + (snap / "model-00001-of-00001.fp16.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.fp16.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00001.fp16.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap) is False + # The canonical index for the same model still fast-paths. + (snap / "model-00001-of-00001.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00001.safetensors"}})) + assert hcs.snapshot_dir_is_complete(snap) is True + + +def test_generic_hub_http_error_type_preserved_but_status_drives_retry(): + """A bare HfHubHTTPError keeps its type across the spawn boundary while retry stays status-driven: a 5xx retries, a 4xx does not.""" + assert "HfHubHTTPError" not in xf._DETERMINISTIC_ERROR_NAMES # status-driven, not name-driven + cls = xf._resolve_exception_class("HfHubHTTPError") + assert cls is not None and issubclass(cls, BaseException) + + class _Resp: + def __init__(self, code): self.status_code = code + e503 = xf._instantiate_preserving_type(cls, "HfHubHTTPError: 503 service unavailable") + e503.response = _Resp(503) + assert xf._is_retryable_download_error(e503) is True # 5xx still retryable + e403 = xf._instantiate_preserving_type(cls, "HfHubHTTPError: 403 forbidden") + e403.response = _Resp(403) + assert xf._is_retryable_download_error(e403) is False # 4xx deterministic + + +def test_hfvalidationerror_type_preserved_across_spawn(): + """A malformed repo id fails identically over either transport, so HFValidationError is deterministic and its type is reconstructed across the spawn boundary.""" + assert "HFValidationError" in xf._DETERMINISTIC_ERROR_NAMES + cls = xf._resolve_exception_class("HFValidationError") + assert cls is not None and issubclass(cls, BaseException) + inst = xf._instantiate_preserving_type(cls, "HFValidationError: bad repo id") + assert type(inst).__name__ == "HFValidationError" + assert xf._is_retryable_download_error(inst) is False + + +def test_oserror_subclass_type_preserved_across_spawn(): + """A builtin OSError subclass keeps its type across the spawn boundary; a non-OSError builtin is not spuriously resolved.""" + for name in ("PermissionError", "FileNotFoundError", "IsADirectoryError", "NotADirectoryError"): + cls = xf._resolve_exception_class(name) + assert cls is not None and issubclass(cls, OSError) and cls.__name__ == name + # A deterministic PermissionError is reconstructed as a real PermissionError and not retried. + perm = xf._instantiate_preserving_type(xf._resolve_exception_class("PermissionError"), "denied") + assert isinstance(perm, PermissionError) + assert xf._is_retryable_download_error(perm) is False + # An unrelated builtin (not OSError, not a Hub error name) is not resolved. + assert xf._resolve_exception_class("ValueError") is None + + +def test_weight_pattern_selector_handles_globs(tmp_path): + """The selector reads tokenizer/config/json globs as weightless but standard weight names and ?/[] globs as weight-bearing.""" + weightless = ["tokenizer*", "*.json", "config.json", "tokenizer.model", "*.txt"] + weighty = [ + "model.safetensors", "*.safetensors", "model.?afetensors", "model.[sp]afetensors", + "checkpoint-*/model.?afetensors", "unet/*", "adapter_model*", "consolidated*", "model.gguf", + ] + for pat in weightless: + assert hcs.request_can_include_weights([pat], None) is False, pat + for pat in weighty: + assert hcs.request_can_include_weights([pat], None) is True, pat + + +def test_post_download_rejects_config_only_for_explicit_weight_pattern(tmp_path): + """An explicit weight request returning only config.json is a stale config-only snapshot: reject and retry.""" + snap, _ = _mk_snapshot(tmp_path, "patcfg") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["model.safetensors"], ignore_patterns = None) is False + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["adapter_model.safetensors"], + ignore_patterns = None) is False + + +def test_post_download_rejects_incomplete_canonical_root_shards(tmp_path): + """An interrupted canonical sharded warm (loose shard, no index) is rejected; a complete set is accepted; a variant-only layout does not satisfy a default load.""" + snap, blob = _mk_snapshot(tmp_path, "incshard") + (snap / "config.json").write_text("{}") + (snap / "model-00001-of-00002.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # Complete the set with its index -> accepted. + (snap / "model-00002-of-00002.safetensors").symlink_to(blob) + (snap / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": {"a": "model-00001-of-00002.safetensors", + "b": "model-00002-of-00002.safetensors"}})) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # A variant-named shard is not a canonical weight a default load reads, so a variant-only cache is rejected. + vsnap, vblob = _mk_snapshot(tmp_path, "vshard") + (vsnap / "config.json").write_text("{}") + (vsnap / "model-00001-of-00001.fp16.safetensors").symlink_to(vblob) + assert xf._download_result_usable( + vsnap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_local_token_not_found_error_type_preserved(): + """A missing required token fails identically over either transport, so LocalTokenNotFoundError is deterministic and type-preserved.""" + assert "LocalTokenNotFoundError" in xf._DETERMINISTIC_ERROR_NAMES + cls = xf._resolve_exception_class("LocalTokenNotFoundError") + assert cls is not None and issubclass(cls, BaseException) + assert xf._is_retryable_download_error( + xf._instantiate_preserving_type(cls, "LocalTokenNotFoundError: token required")) is False + + +def test_metadata_directory_pattern_is_weightless(tmp_path): + """A trailing-slash metadata dir pattern (allow=['tokenizer/']) reads weightless; component/checkpoint dir patterns stay weight-bearing.""" + assert hcs.request_can_include_weights(["tokenizer/"], None) is False + assert hcs.request_can_include_weights(["processor/"], None) is False + assert hcs.request_can_include_weights(["unet/"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/"], None) is True + snap, _ = _mk_snapshot(tmp_path, "tokdir") + (snap / "tokenizer").mkdir() + (snap / "tokenizer" / "tokenizer.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer/"], ignore_patterns = None) is True + + +def test_metadata_directory_glob_is_weightless(tmp_path): + """A metadata-dir glob (allow=['tokenizer/*']) reads weightless like its trailing-slash form; a component/checkpoint dir glob stays weight-bearing.""" + assert hcs.request_can_include_weights(["tokenizer/*"], None) is False + assert hcs.request_can_include_weights(["tokenizer/*.json"], None) is False + assert hcs.request_can_include_weights(["processor/*"], None) is False + assert hcs.request_can_include_weights(["unet/*"], None) is True + assert hcs.request_can_include_weights(["checkpoint-10/*"], None) is True + snap, _ = _mk_snapshot(tmp_path, "tokglob") + (snap / "tokenizer").mkdir() + (snap / "tokenizer" / "tokenizer.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer/*"], ignore_patterns = None) is True + + +def test_allow_star_with_all_weights_ignored_is_weightless(tmp_path): + """An allow the ignore filter strips of every weight reads weightless (root allow=['*'] and subdir allow=['unet/*']); surviving weight suffixes stay weight-bearing.""" + all_weight_ignores = [ + "*.safetensors", "*.bin", "*.pt", "*.pth", "*.gguf", + "*.ckpt", "*.onnx", "*.msgpack", "*.h5", "*.pdparams", + ] + assert hcs.request_can_include_weights(["*"], all_weight_ignores) is False + assert hcs.request_can_include_weights(["*"], None) is True + # A subdir allow that ignores every weight suffix is weightless too... + assert hcs.request_can_include_weights(["unet/*"], all_weight_ignores) is False + # ...but one whose weight suffixes survive stays weight-bearing. + assert hcs.request_can_include_weights(["unet/*"], ["*.bin"]) is True + assert hcs.request_can_include_weights(["*.safetensors"], ["*.bin"]) is True + snap, _ = _mk_snapshot(tmp_path, "cfgonly") + (snap / "config.json").write_text("{}") + (snap / "tokenizer.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*"], ignore_patterns = all_weight_ignores) is True + + +def test_post_download_rejects_checkpoint_only_root_model(tmp_path): + """A snapshot whose only weight is under checkpoint-7/ is rejected for an unpatterned root warm, but accepted when explicitly scoped.""" + snap, blob = _mk_snapshot(tmp_path, "ckonly") + (snap / "config.json").write_text("{}") + (snap / "checkpoint-7").mkdir() + (snap / "checkpoint-7" / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["checkpoint-7/*"], ignore_patterns = None) is True + # A diffusers pipeline's subfolder weights still count (model_index.json gates that). + dsnap, dblob = _mk_snapshot(tmp_path, "diff") + (dsnap / "model_index.json").write_text("{}") + (dsnap / "unet").mkdir() + (dsnap / "unet" / "diffusion_pytorch_model.safetensors").symlink_to(dblob) + assert xf._download_result_usable( + dsnap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # A diffusers snapshot whose only weight is under checkpoint-N/ (not a pipeline component) is rejected. + dck, dckb = _mk_snapshot(tmp_path, "diff_ckpt") + (dck / "model_index.json").write_text("{}") + (dck / "checkpoint-7").mkdir() + (dck / "checkpoint-7" / "diffusion_pytorch_model.safetensors").symlink_to(dckb) + assert xf._download_result_usable( + dck, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + + +def test_post_download_rejects_adapter_only_for_default_load(tmp_path): + """A default warm reads the base weight, not an adapter, so an adapter-only snapshot is rejected; an adapter-scoped request still accepts it.""" + snap, blob = _mk_snapshot(tmp_path, "adapter_only_default") + (snap / "adapter_model.safetensors").symlink_to(blob) + (snap / "adapter_config.json").write_text("{}") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # An adapter load (patterned) reads the adapter and accepts it. + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["adapter_model*", "adapter_config.json"], + ignore_patterns = None) is True + # The base weight present -> the default warm accepts, even beside the adapter. + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + + +def test_post_download_variant_root_check_ignores_adapter_index(tmp_path): + """A variant load reads the root model variant, so a complete model.fp16 beside a stale adapter variant index is accepted; an incomplete root variant index is still rejected.""" + snap, blob = _mk_snapshot(tmp_path, "var_adapter_idx") + (snap / "model.fp16.safetensors").symlink_to(blob) # complete root model variant (the read weight) + (snap / "adapter_model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "adapter_model.fp16-00001-of-00002.safetensors", + "b": "adapter_model.fp16-00002-of-00002.safetensors"}})) # stale, shards absent + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + # An incomplete root model variant index is still caught. + snap2, blob2 = _mk_snapshot(tmp_path, "var_root_idx_incomplete") + (snap2 / "model.safetensors.index.fp16.json").write_text(json.dumps( + {"weight_map": {"a": "model.fp16-00001-of-00002.safetensors", + "b": "model.fp16-00002-of-00002.safetensors"}})) + (snap2 / "model.fp16-00001-of-00002.safetensors").symlink_to(blob2) # shard 2 absent + assert xf._download_result_usable( + snap2, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False + + +def test_post_download_rejects_variant_only_root_for_default_load(tmp_path): + """A default (no-variant) load reads the canonical name, not a variant name, so a variant-only snapshot is rejected; canonical names still pass.""" + snap, blob = _mk_snapshot(tmp_path, "var_only") + (snap / "model.fp16.safetensors").symlink_to(blob) # variant-named only + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # A sharded variant name is likewise not a default weight. + snap_sh, blob_sh = _mk_snapshot(tmp_path, "var_only_sharded") + (snap_sh / "model.fp16-00001-of-00002.safetensors").symlink_to(blob_sh) + (snap_sh / "config.json").write_text("{}") + assert xf._download_result_usable( + snap_sh, repo_type = "model", allow_patterns = None, ignore_patterns = None) is False + # The canonical weight present -> accepted, even beside the variant. + (snap / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = None, ignore_patterns = None) is True + # A variant load DOES read the variant weight. + assert xf._download_result_usable( + snap_sh, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is False # sharded variant with no index is still incomplete + snap_v, blob_v = _mk_snapshot(tmp_path, "var_single") + (snap_v / "model.fp16.safetensors").symlink_to(blob_v) + (snap_v / "config.json").write_text("{}") + assert xf._download_result_usable( + snap_v, repo_type = "model", allow_patterns = None, ignore_patterns = None, + variant = "fp16") is True + + +def test_post_download_variant_either_format_exact_alternatives(tmp_path): + """An exact request listing both variant formats is an alternative (like the canonical either-format pair): a safetensors-variant-only repo is complete; distinct/base+adapter still requires each.""" + snap, blob = _mk_snapshot(tmp_path, "var_either") + (snap / "model.fp16.safetensors").symlink_to(blob) # only the safetensors variant present + assert xf._download_result_usable( + snap, repo_type = "model", + allow_patterns = ["model.fp16.safetensors", "pytorch_model.fp16.bin"], + ignore_patterns = None, variant = "fp16") is True + # The canonical either-format pair keeps working. + snap_c, blob_c = _mk_snapshot(tmp_path, "canon_either") + (snap_c / "pytorch_model.bin").symlink_to(blob_c) + assert xf._download_result_usable( + snap_c, repo_type = "model", + allow_patterns = ["model.safetensors", "pytorch_model.bin"], ignore_patterns = None) is True + # Base and adapter are distinct groups: adapter present, base absent -> rejected. + snap_d, blob_d = _mk_snapshot(tmp_path, "base_and_adapter") + (snap_d / "adapter_model.safetensors").symlink_to(blob_d) + assert xf._download_result_usable( + snap_d, repo_type = "model", + allow_patterns = ["model.safetensors", "adapter_model.safetensors"], + ignore_patterns = None) is False + + +def test_post_download_validates_weightless_named_subset(tmp_path): + """An exact weightless request missing its named file is rejected and retried; a glob allow list stays lenient.""" + snap, _ = _mk_snapshot(tmp_path, "noname") + (snap / "config.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer.json"], ignore_patterns = None) is False + assert xf._download_result_usable( + snap, repo_type = "dataset", allow_patterns = ["data.parquet"], ignore_patterns = None) is False + # Present named file -> accepted; a glob stays lenient. + (snap / "tokenizer.json").write_text("{}") + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["tokenizer.json"], ignore_patterns = None) is True + assert xf._download_result_usable( + snap, repo_type = "model", allow_patterns = ["*.json"], ignore_patterns = None) is True + + +def test_post_download_rejects_missing_exact_weight_request(tmp_path): + """An exact weight request whose file is missing is rejected even with a different weight present; the either-format pair stays satisfied by one.""" + base, blob = _mk_snapshot(tmp_path, "baseonly") + (base / "model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + base, repo_type = "model", allow_patterns = ["adapter_model.safetensors"], + ignore_patterns = None) is False + assert xf._download_result_usable( + base, repo_type = "model", + allow_patterns = ["model.safetensors", "adapter_model.safetensors"], ignore_patterns = None) is False + assert xf._download_result_usable( + base, repo_type = "model", + allow_patterns = ["model.safetensors", "pytorch_model.bin"], ignore_patterns = None) is True + # Both present -> accepted. + (base / "adapter_model.safetensors").symlink_to(blob) + assert xf._download_result_usable( + base, repo_type = "model", + allow_patterns = ["model.safetensors", "adapter_model.safetensors"], ignore_patterns = None) is True + + +def test_dataset_unpatterned_or_glob_partial_does_not_skip_child(tmp_path): + """A dataset/space snapshot whose completeness cannot be proven locally (allow=None or a glob) defers to the child; an intact exact-named subset short-circuits.""" + snap, _ = _mk_snapshot(tmp_path, "dspart") + (snap / "README.md").write_text("partial") + assert xf._cache_can_skip_download( + snap, repo_type = "dataset", allow_patterns = None, ignore_patterns = None) is False + assert xf._cache_can_skip_download( + snap, repo_type = "dataset", allow_patterns = ["*.parquet"], ignore_patterns = None) is False + assert xf._cache_can_skip_download( + snap, repo_type = "dataset", allow_patterns = ["README.md"], ignore_patterns = None) is True + + +def test_http_prep_scopes_blob_cleanup_to_owned_partials(tmp_path): + """HTTP prep purges only the child's own partials: with ownership known a sibling's aged partial/link is spared; with ownership None the mtime guard purges both.""" + repo = "ztest/concurrent-blobs" + repo_dir = tmp_path / f"models--{repo.replace('/', '--')}" + blobs = repo_dir / "blobs" + snap = repo_dir / "snapshots" / "sha" + blobs.mkdir(parents = True) + snap.mkdir(parents = True) + old = time.time() - 600 + + def _seed(): + owned = blobs / "ownedhash.incomplete" + sibling = blobs / "siblinghash.incomplete" + owned.write_bytes(b"o") + sibling.write_bytes(b"s") + os.utime(owned, (old, old)) + os.utime(sibling, (old, old)) + for name in list(snap.iterdir()): + name.unlink() + (snap / "our.safetensors").symlink_to(blobs / "ownedhash") + (snap / "sib.safetensors").symlink_to(blobs / "siblinghash") + return owned, sibling + + owned, sibling = _seed() + _REAL_DEFAULT_PREPARE( + "model", repo, cache_dir = str(tmp_path), active_grace = 180, + owned_incomplete_blobs = {"ownedhash.incomplete"}) + assert not owned.exists(), "our own stalled partial must be purged" + assert sibling.exists(), "a concurrent sibling's partial must be spared" + assert not (snap / "our.safetensors").is_symlink() + assert (snap / "sib.safetensors").is_symlink(), "sibling's dangling link must be spared" + + # No ownership info -> coarse mtime guard purges both aged partials. + owned, sibling = _seed() + _REAL_DEFAULT_PREPARE( + "model", repo, cache_dir = str(tmp_path), active_grace = 180, owned_incomplete_blobs = None) + assert not owned.exists() and not sibling.exists() diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 821173deb..7f9b317ec 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -134,12 +134,17 @@ def has_429_exact_full_read(log_dir: str | Path) -> str: # Everything below this point is GPU-only. Use a flag to gate it. _SKIP_GPU_INIT = True else: - _SKIP_GPU_INIT = False + # Opt-in: a download-only helper child (e.g. hf_xet_fallback) sets + # UNSLOTH_ZOO_DISABLE_GPU_INIT=1 to skip the heavy torch/transformers/device + # init it never uses. Off by default, so normal CUDA/CPU runs are unchanged. + # The HF cache redirect above still runs, so the child shares the parent's cache. + _SKIP_GPU_INIT = os.environ.get("UNSLOTH_ZOO_DISABLE_GPU_INIT", "0") == "1" del _is_mlx_only, is_mlx_available -# Inject triton & bitsandbytes stubs on Apple Silicon with MLX so unsloth's -# CUDA-only imports don't error at startup. _SKIP_GPU_INIT is True only on -# Darwin/arm64 with mlx installed (the exact case stubs are needed). +# Inject triton & bitsandbytes stubs whenever GPU init is skipped (MLX host or the +# opt-in download child), so unsloth's CUDA-only imports resolve to a loud no-op stub +# instead of a hard ImportError. Inert in the download child, which never touches them. +# On a normal CUDA/CPU run _SKIP_GPU_INIT is False and the real modules are untouched. if _SKIP_GPU_INIT: from .stubs.triton_stub import inject_into_sys_modules as _inject_triton _inject_triton() diff --git a/unsloth_zoo/hf_cache_state.py b/unsloth_zoo/hf_cache_state.py new file mode 100644 index 000000000..548068442 --- /dev/null +++ b/unsloth_zoo/hf_cache_state.py @@ -0,0 +1,1086 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# This file is licensed under the GNU Affero General Public License v3.0 only +# (AGPL-3.0-only), unlike the rest of unsloth_zoo which is LGPL-3.0-or-later. It +# is the single shared home for the sparse-aware Hugging Face cache primitives +# used by the Xet -> HTTP stall fallback (unsloth_zoo.hf_xet_fallback) and by +# Unsloth Studio's download manager, which re-exports the names below. +# See . + +"""Sparse-aware introspection of the active Hugging Face hub cache. + +Reports on-disk bytes (sparse-aware, so a partial Xet / ``hf_transfer`` ``.incomplete`` is not read as +full progress) and whether an ``.incomplete`` partial is present -- the signals the no-progress +watchdog runs on. + +``snapshot_dir_is_complete`` is a CONSERVATIVE fast-path gate, not an authoritative verifier: it +returns "complete" only for unambiguous canonical model layouts and defers everything else to the +watched ``snapshot_download`` child. A false "complete" is the only dangerous error (an in-process +load could then fetch a missing weight over un-killable Xet); a false "not complete" only spawns the +cheap child. Only the active ``HF_HUB_CACHE`` root is scanned. +""" + +from __future__ import annotations + +import fnmatch +import re +import sys +from pathlib import Path, PurePosixPath, PureWindowsPath +from typing import Iterator, Optional + + +__all__ = [ + "INCOMPLETE_SUFFIX", + "hf_cache_root", + "target_dir_name", + "repo_cache_dir_name", + "blob_bytes_present", + "latest_snapshot_dir", + "snapshot_dir_has_broken_symlinks", + "snapshot_dir_is_complete", + "request_can_include_weights", + "iter_active_repo_cache_dirs", + "repo_cache_dir_has_incomplete_blobs", + "has_active_incomplete_blobs", +] + + +INCOMPLETE_SUFFIX = ".incomplete" + + +def _safe_is_dir(path: Path) -> bool: + """``Path.is_dir()`` returning False instead of raising.""" + try: + return path.is_dir() + except OSError: + return False + + +def _safe_is_file(path: Path) -> bool: + """``Path.is_file()`` returning False instead of raising.""" + try: + return path.is_file() + except OSError: + return False + + +def hf_cache_root(*, create: bool = False, cache_dir: "Optional[str | Path]" = None) -> Optional[Path]: + """The hub cache root to scan, or None if unavailable. *cache_dir* is used verbatim; otherwise + ``HF_HUB_CACHE`` is read lazily so an import-time redirect is honored.""" + if cache_dir is not None: + # Match huggingface_hub's ~ expansion; expanduser() raises with no resolvable HOME (container), + # so fall back to the literal path rather than crash. + try: + root = Path(cache_dir).expanduser() + except (RuntimeError, OSError): + root = Path(cache_dir) + else: + try: + from huggingface_hub import constants as hf_constants + except ImportError: + return None + root = Path(hf_constants.HF_HUB_CACHE) + if create: + try: + root.mkdir(parents = True, exist_ok = True) + except OSError: + return None + return root + return root if _safe_is_dir(root) else None + + +def target_dir_name(repo_type: Optional[str], repo_id: str) -> str: + return repo_cache_dir_name(repo_type, repo_id).lower() + + +def repo_cache_dir_name(repo_type: Optional[str], repo_id: str) -> str: + # repo_type=None is HF's default "model", so None resolves models-- not Nones--. + repo_type = repo_type or "model" + return f"{repo_type}s--{repo_id.replace('/', '--')}" + + +def _blob_dir_is_partial(blobs_dir: Path) -> bool: + try: + for blob in blobs_dir.iterdir(): + if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): + return True + except OSError: + return False + return False + + +def blob_bytes_present(path: Path) -> int: + """Sparse-aware on-disk size: a sparse ``.incomplete`` reports full ``st_size`` while only some + blocks are allocated, so prefer ``st_blocks``, falling back to ``st_size`` where it is unreported + (Windows, some network filesystems).""" + st = path.stat() + blocks = getattr(st, "st_blocks", None) + if blocks is not None: + # Trust st_blocks even at 0: a truncated sparse .incomplete reports full st_size but 0 blocks + # and must read as 0 bytes (a > 0 guard would fall through to st_size). + return min(blocks * 512, st.st_size) + if sys.platform == "win32": + allocated = _windows_allocated_size(path) + if allocated is not None: + return min(allocated, st.st_size) + return st.st_size + + +def _windows_allocated_size(path: Path) -> Optional[int]: + """Best-effort allocated-byte count for sparse files on Windows.""" + if sys.platform != "win32": + return None + try: + import ctypes + from ctypes import wintypes + + kernel32 = ctypes.WinDLL("kernel32", use_last_error = True) + get_compressed_file_size = kernel32.GetCompressedFileSizeW + get_compressed_file_size.argtypes = [ + wintypes.LPCWSTR, + ctypes.POINTER(wintypes.DWORD), + ] + get_compressed_file_size.restype = wintypes.DWORD + + high = wintypes.DWORD(0) + ctypes.set_last_error(0) + low = get_compressed_file_size(str(path), ctypes.byref(high)) + if low == 0xFFFFFFFF and ctypes.get_last_error() != 0: + return None + return (int(high.value) << 32) + int(low) + except Exception: + return None + + +def latest_snapshot_dir(repo_dir: Path) -> Optional[Path]: + """Newest child of ``repo_dir/snapshots`` by mtime (what from_pretrained resolves to), or None.""" + snapshots_dir = repo_dir / "snapshots" + try: + if not snapshots_dir.is_dir(): + return None + snapshots = [entry for entry in snapshots_dir.iterdir() if entry.is_dir()] + if not snapshots: + return None + return max(snapshots, key = lambda entry: entry.stat().st_mtime) + except OSError: + return None + + +def snapshot_dir_has_broken_symlinks(snapshot_dir: Path) -> bool: + """True if *snapshot_dir* holds a dangling symlink (a referenced blob missing or still + ``.incomplete``) -- an interrupted download. Validates one requested revision, not just the newest.""" + try: + for entry in snapshot_dir.rglob("*"): + if entry.is_symlink() and not entry.exists(): + return True + except OSError: + return False + return False + + +# --------------------------------------------------------------------------- +# Weight-file recognition +# --------------------------------------------------------------------------- + +_WEIGHT_FILE_SUFFIXES = ( + ".safetensors", + ".bin", + ".pt", + ".pth", + ".gguf", + ".ckpt", + ".onnx", + ".msgpack", + ".h5", + ".pdparams", +) + +# Trainer / optimizer state carries weight suffixes but is NOT a loadable weight, so a cache holding +# only these is not a warm model cache. +_NON_WEIGHT_BASENAMES = frozenset({ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.bin", + "scheduler.pt", + "scaler.pt", + "rng_state.pt", + "rng_state.pth", +}) +# Distributed runs shard the RNG state as rng_state_0.pth, rng_state_1.pth, ... +_NON_WEIGHT_BASENAME_PREFIXES = ("rng_state_",) + + +def _is_loadable_weight_file(name: str) -> bool: + """True if *name* is a loadable weight: a weight suffix that is not trainer / optimizer state.""" + if not name.endswith(_WEIGHT_FILE_SUFFIXES): + return False + lowered = name.lower() + if lowered in _NON_WEIGHT_BASENAMES: + return False + if any(lowered.startswith(prefix) for prefix in _NON_WEIGHT_BASENAME_PREFIXES): + return False + return True + + +def _is_weight_shard_index(name: str) -> bool: + """True for a weight-shard index sidecar, canonical or variant; a plain suffix test would miss the + variant form (``model.safetensors.index.fp16.json``).""" + return name.endswith(".json") and (".safetensors.index." in name or ".bin.index." in name) + + +def _is_canonical_weight_shard_index(name: str) -> bool: + """True only for the CANONICAL (non-variant) index a default load probes. Exact names only, so a + sharded-adapter-only / variant-only cache does not satisfy the canonical fast path (its base + canonical weights are still missing -> the load would fetch them over un-killable Xet).""" + return name in ("model.safetensors.index.json", "pytorch_model.bin.index.json") + + +def _is_unsafe_shard_ref(shard: str) -> bool: + """True if a ``weight_map`` value is NOT a safe relative path inside the snapshot (absolute, Windows + drive-letter, UNC, or ``..``-escaping). Judged under BOTH POSIX and Windows semantics so a crafted + index is rejected regardless of the running OS (on Windows ``base / "C:\\x"`` resolves OUTSIDE the + snapshot; ``startswith(("/", "\\"))`` alone misses a drive letter). A well-formed relative basename + is never rejected.""" + if not shard or shard.startswith(("/", "\\")): + return True + win = PureWindowsPath(shard) + if win.is_absolute() or win.drive or ".." in win.parts: + return True + posix = PurePosixPath(shard) + if posix.is_absolute() or ".." in posix.parts: + return True + return False + + +def _weight_shard_index_complete(index_path: Path) -> bool: + """True only if every shard a HF weight index lists is present next to it. + + Fail-CLOSED: an unreadable / truncated index, a non-dict payload or ``weight_map``, or an empty + shard set return False, deferring a malformed index to the watched child rather than letting the + load skip it and then fail (or fetch over Xet).""" + import json + + try: + with open(index_path, "r", encoding = "utf-8") as f: + data = json.load(f) + except (OSError, ValueError): + return False + weight_map = data.get("weight_map") if isinstance(data, dict) else None + if not isinstance(weight_map, dict): + return False + values = list(weight_map.values()) + # A non-string shard value is malformed; fail CLOSED rather than drop it and read the rest as complete. + if not values or not all(isinstance(s, str) for s in values): + return False + shards = set(values) + base = index_path.parent + for shard in shards: + # Reject an unsafe ref rather than let ``base / shard`` resolve to a file OUTSIDE the snapshot. + if _is_unsafe_shard_ref(shard): + return False + try: + if not (base / shard).exists(): + return False + except OSError: + return False + return True + + +# --------------------------------------------------------------------------- +# Pattern helpers (normalization + glob detection + HF filtering) +# --------------------------------------------------------------------------- + +_GLOB_CHARS = ("*", "?", "[") + + +def _has_glob(text: str) -> bool: + # A trailing-slash dir pattern ("unet/") is a wildcard: HF expands it like "unet/*". + return text.endswith("/") or any(ch in text for ch in _GLOB_CHARS) + + +def _as_pattern_list(patterns: "Optional[object]") -> "Optional[list]": + """Normalize an allow / ignore argument to a list. HF accepts a bare ``str``, which would otherwise + be iterated character by character.""" + if patterns is None: + return None + if isinstance(patterns, str): + return [patterns] + return list(patterns) + + +def _filter_paths( + paths: list, + allow_patterns: "Optional[list]" = None, + ignore_patterns: "Optional[list]" = None, +) -> list: + """Filter repo-relative *paths* by HF allow / ignore patterns (as ``snapshot_download`` selects). + Fails OPEN (returns all paths) so a snapshot holding weights is never rejected on an unevaluable + filter.""" + try: + from huggingface_hub.utils import filter_repo_objects + + return list( + filter_repo_objects( + paths, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) + ) + except Exception: + return list(paths) + + +def _broken_symlink_rel_paths(snapshot_dir: Path) -> list: + """Repo-relative posix paths of every dangling symlink in *snapshot_dir*, so the interrupted-download + signal can be scoped to the files a request selects.""" + out: list = [] + try: + for entry in snapshot_dir.rglob("*"): + try: + if entry.is_symlink() and not entry.exists(): + try: + out.append(entry.relative_to(snapshot_dir).as_posix()) + except ValueError: + out.append(entry.name) + except OSError: + continue + except OSError: + return out + return out + + +def snapshot_has_requested_broken_symlinks( + snapshot_dir: Path, + *, + allow_patterns: "Optional[object]" = None, + ignore_patterns: "Optional[object]" = None, + repo_type: "Optional[str]" = "model", +) -> bool: + """True iff a dangling symlink in *snapshot_dir* is for a file the request SELECTS, so a dangling root + ``model.safetensors`` does not fail a weightless ``allow=["config.json"]`` request whose config is on + disk. (*repo_type* kept for signature compatibility.)""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + broken = _broken_symlink_rel_paths(snapshot_dir) + if not broken: + return False + return bool(_filter_paths(broken, allow_patterns, ignore_patterns)) + + +# --------------------------------------------------------------------------- +# The conservative fast-path completeness gate +# --------------------------------------------------------------------------- + +# Canonical root weight filenames a default load reads (the single file or its shard index proves warm). +_CANONICAL_SINGLE_WEIGHTS = ("model.safetensors", "pytorch_model.bin") + + +def _ignore_strips_all_weights(ignore_patterns: "list") -> bool: + """True iff the ignore set provably excludes EVERY weight format. A partial strip is NOT weightless + (a surviving weight could still be pulled), so the request stays weight-bearing (conservative).""" + for suffix in _WEIGHT_FILE_SUFFIXES: + probe = "weight" + suffix + if not any(isinstance(p, str) and fnmatch.fnmatchcase(probe, p) for p in ignore_patterns): + return False + return True + + +# Representative weight names a glob allow pattern is probed against: a glob matching one can select a +# weight; one matching none (``tokenizer*``, ``*.json``) is weightless. Covers canonical / variant / +# sharded / adapter / diffusers / consolidated and the non-safetensors formats. +_WEIGHT_PATTERN_PROBES = ( + "model.safetensors", + "model.fp16.safetensors", + "model-00001-of-00002.safetensors", + "pytorch_model.bin", + "pytorch_model-00001-of-00002.bin", + "adapter_model.safetensors", + "adapter_model.bin", + "consolidated.safetensors", + "consolidated.00.pth", + "diffusion_pytorch_model.safetensors", + "model.gguf", + "model.pt", + "model.pth", + "model.h5", + "model.msgpack", + "tf_model.h5", + "flax_model.msgpack", +) + +# Subdirs that hold only metadata / config, so a ``dir/`` pattern scoped to one is weightless. Any +# other dir pattern stays weight-bearing (a component dir or a checkpoint dir can hold a weight). +_NON_WEIGHT_DIRS = frozenset({ + "tokenizer", + "processor", + "preprocessor", + "feature_extractor", + "image_processor", + "video_processor", + "scheduler", +}) + + +def _pattern_can_select_weight(pattern: "object") -> bool: + """Whether a single allow pattern could select a weight. A weight-suffix basename or a non-metadata + directory pattern is weight-bearing; a glob basename is weight-bearing only if it matches a + ``_WEIGHT_PATTERN_PROBES`` name; a concrete non-weight name is weightless. A false weight-bearing only + spawns the cheap child; the probe set avoids a false weightless on real weights.""" + if not isinstance(pattern, str): + return True + if pattern.endswith("/"): + dir_name = pattern.rstrip("/").rsplit("/", 1)[-1].lower() + return dir_name not in _NON_WEIGHT_DIRS + # A pattern scoped under a metadata dir ("tokenizer/*") is weightless like the "tokenizer/" form. + if "/" in pattern: + parent = pattern.rsplit("/", 1)[0].rstrip("/").rsplit("/", 1)[-1].lower() + if parent in _NON_WEIGHT_DIRS: + return False + base = pattern.rsplit("/", 1)[-1] + if base.endswith(_WEIGHT_FILE_SUFFIXES): + return True + if any(ch in base for ch in _GLOB_CHARS): + return any(fnmatch.fnmatchcase(probe, base) for probe in _WEIGHT_PATTERN_PROBES) + return False + + +def request_can_include_weights( + allow_patterns: "Optional[object]" = None, ignore_patterns: "Optional[object]" = None +) -> bool: + """Whether a request restricted by *allow_patterns* / *ignore_patterns* can still include a weight. + Conservative: True when uncertain; False only for a clearly weightless request (a tokenizer / config + allow list, an ignore list dropping every weight format, or an allow + ignore pair that strips them + all), preserving the tokenizer-only short-circuit.""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + if allow_patterns is None and ignore_patterns is None: + return True + if allow_patterns is None: + return not _ignore_strips_all_weights(ignore_patterns or []) + if not allow_patterns: + return False # allow=[] selects nothing + if not any(_pattern_can_select_weight(pat) for pat in allow_patterns): + return False + # An allow that reaches a weight can still be left weightless by the ignore filter (allow=["*"] + + # ignore=[every weight suffix], or a subdir allow=["unet/*"] ignoring every weight suffix). Apply HF's + # allow-then-ignore semantics to weight probes at the ROOT and under each subdir-scoped allow, so a + # genuinely weightless request is not required to hold a weight (a subdir allow keeping its weight + # suffixes still matches a subdir probe and stays weight-bearing). + if ignore_patterns: + probes = list(_WEIGHT_PATTERN_PROBES) + for pat in allow_patterns: + if isinstance(pat, str) and "/" in pat: + head = pat.rsplit("/", 1)[0] + probes.extend(f"{head}/{name}" for name in _WEIGHT_PATTERN_PROBES) + if not _filter_paths(probes, allow_patterns, ignore_patterns): + return False + return True + + +def _canonical_root_weights_complete( + snapshot_dir: Path, entries: list, ignore_patterns: "Optional[list]" = None, + *, prefer_safetensors: bool = False, +) -> bool: + """True iff the snapshot holds a complete canonical ROOT weight set: a root + ``model.safetensors`` / ``pytorch_model.bin``, OR a root shard index whose every shard is present. + Numbered shards without an index, or subfolder-only weights, do NOT count. + + A weight whose FORMAT the ignore filter drops does not count (a stale ``pytorch_model.bin`` under + ``ignore=['*.bin']`` is no proof the requested safetensors are on disk). The format probe also + discards a ``pytorch_model.bin.index.json`` whose ``.json`` name would slip the raw filter. + + *prefer_safetensors* is set by the STRICT pre-download gate: a default load probes safetensors + BEFORE bin, so when safetensors is read (not ignored) a bin-only cache cannot be proven complete + (the local cache cannot show the preferred safetensors is absent remotely) and skipping the child + would fetch it over un-killable Xet. So ``.bin`` satisfies the gate only when safetensors is IGNORED. + The lenient POST path leaves this False: a finished bin-only download is a genuinely bin-only repo + and must not be false-rejected into a ``DownloadStallError``.""" + root_files: set = set() + root_indices: list = [] + for entry in entries: + try: + rel = entry.relative_to(snapshot_dir).as_posix() + except ValueError: + rel = entry.name + if "/" in rel: + continue # ROOT files only + if _is_canonical_weight_shard_index(entry.name): + if _safe_is_file(entry): + root_indices.append(entry) + elif _safe_is_file(entry): + root_files.add(entry.name) + + def _format_kept(weight_name: str) -> bool: + # The read format must survive the ignore filter, else the file is a stale excluded-format artifact. + if not ignore_patterns: + return True + return bool(_filter_paths([weight_name], None, ignore_patterns)) + + st_index = next((e for e in root_indices if ".safetensors.index." in e.name), None) + bin_index = next((e for e in root_indices if ".bin.index." in e.name), None) + # transformers' local precedence, mirrored: single safetensors before the safetensors index, + # safetensors before the .bin single, .bin single before the .bin index. So a complete single weight + # is never masked by a stale index, and an incomplete PREFERRED safetensors index is breakage a + # complete .bin must not mask. An ignore-dropped format is skipped so the next read format is judged. + if "model.safetensors" in root_files and _format_kept("model.safetensors"): + return True + if st_index is not None and _format_kept("model.safetensors"): + return _weight_shard_index_complete(st_index) + if prefer_safetensors and _format_kept("model.safetensors"): + # STRICT gate: safetensors is preferred but absent, and a bin-only cache cannot prove it absent + # remotely, so a default load would fetch it over Xet -> defer to the child. + return False + if "pytorch_model.bin" in root_files and _format_kept("pytorch_model.bin"): + return True + if bin_index is not None and _format_kept("pytorch_model.bin"): + return _weight_shard_index_complete(bin_index) + return False + + +def snapshot_dir_is_complete( + snapshot_dir: Path, + *, + allow_patterns: "Optional[object]" = None, + ignore_patterns: "Optional[object]" = None, + require_named_weights: bool = False, + prefer_safetensors: bool = False, +) -> bool: + """Conservative fast-path gate: True only for an unambiguously complete canonical ROOT model cache, + so an in-process load will not fetch a weight. True requires: an UNPATTERNED request, not a diffusers + pipeline (no root ``model_index.json``), no dangling symlink, and canonical root weights present. + Everything else defers to the watched child. A false True risks a silent Xet fetch; a false False + only spawns the cheap child. *require_named_weights* is accepted for signature compatibility (a + named-weight request is patterned, so never fast-pathed). + + *prefer_safetensors* (set by the strict pre-download gate) rejects a bin-only cache when a default + load prefers safetensors: the cache cannot prove the preferred file absent remotely, so fast-pathing + would fetch it over Xet. The POST caller leaves it False so a genuinely bin-only download is accepted. + + *ignore_patterns* need no eligibility gate: the canonical-weight check below is what the load reads, + so an ignore dropping some format cannot make an incomplete cache read complete.""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + if allow_patterns is not None: + return False # any allow list scopes the on-disk set unprovably -> defer + try: + entries = list(snapshot_dir.rglob("*")) + except OSError: + return False + if _safe_is_file(snapshot_dir / "model_index.json"): + return False # diffusers needs component reasoning we do not fast-path + if snapshot_dir_has_broken_symlinks(snapshot_dir): + return False # interrupted blob + return _canonical_root_weights_complete( + snapshot_dir, entries, ignore_patterns, prefer_safetensors = prefer_safetensors + ) + + +# A canonical numbered root shard (no variant token): ``model-00001-of-00002.safetensors`` matches but +# ``model-00001-of-00002.fp16.safetensors`` does not. +_CANONICAL_ROOT_SHARD_RE = re.compile( + r"^(?:model|pytorch_model)-\d{5}-of-\d{5}\.(?:safetensors|bin)$" +) + + +def _has_incomplete_canonical_root_shards( + snapshot_dir: Path, *, ignore_patterns: "Optional[object]" = None +) -> bool: + """True when the root holds canonical numbered shards but is NOT a complete canonical model (index + missing or a shard absent) for the format the request READS -- a stale interrupted download the post + check rejects and retries over HTTP. The request's ignore filter is applied, so a complete + safetensors set does not mask an incomplete ``.bin`` set read under ``ignore=['*.safetensors']``. + Variant shards (``.-`` infix) are excluded, so a variant-only repo is not force-failed.""" + try: + names = [entry.name for entry in snapshot_dir.iterdir()] + except OSError: + return False + # Shard evidence = a numbered shard FILE or a canonical shard INDEX. An index-only partial (index + # present, no shards yet) is still incomplete and must be caught before any shard file exists. + has_shard_evidence = ( + any(_CANONICAL_ROOT_SHARD_RE.match(name) for name in names) + or any(_is_canonical_weight_shard_index(name) for name in names) + ) + if not has_shard_evidence: + return False + return not snapshot_dir_is_complete(snapshot_dir, ignore_patterns = ignore_patterns) + + +def _has_incomplete_variant_root_shards( + snapshot_dir: Path, variant: str, *, ignore_patterns: "Optional[object]" = None +) -> bool: + """True when the ROOT variant weight the load READS is an incomplete sharded set. Incomplete means: + a present variant shard INDEX whose listed shards are not all present (an index-only partial counts), + OR variant shard FILES with no complete index. + + The request's ignore filter is applied, and safetensors is read BEFORE bin (transformers' probe + order): an incomplete variant safetensors index is breakage even with a complete variant bin. + Positive-evidence: a single-file variant or a complete variant shard set returns False. Only the ROOT + ``model`` / ``pytorch_model`` variant weight is considered, so a stale ``adapter_model`` variant set + (which the default variant model load does not read) must not force-fail a complete model variant.""" + dot_infix = f".{variant}." # variant index or single file + dash_infix = f".{variant}-" # a sharded variant weight (model.-00001-of-00002.safetensors) + ignore_patterns = _as_pattern_list(ignore_patterns) + + def _format_kept(weight_name: str) -> bool: + # The read format must survive the ignore filter, else the file is a stale excluded-format artifact. + if not ignore_patterns: + return True + return bool(_filter_paths([weight_name], None, ignore_patterns)) + + try: + entries = list(snapshot_dir.iterdir()) + except OSError: + return False + st_index_incomplete = None # tri-state: None absent, else present & (in)complete + bin_index_incomplete = None + has_st_shard = has_bin_shard = False + has_single_st = has_single_bin = False + for entry in entries: + name = entry.name + # Restrict to the ROOT model index; an adapter_model / other non-model variant index the default + # load does not read is skipped so its incompleteness cannot force-fail the model variant. + if dot_infix in name and _ROOT_MODEL_SHARD_INDEX_RE.match(name): + is_safetensors = ".safetensors.index." in name + fmt_probe = ( + f"model.{variant}.safetensors" if is_safetensors else f"pytorch_model.{variant}.bin" + ) + if not _format_kept(fmt_probe): + continue # this format is ignored -> the load does not read it + incomplete = not (_safe_is_file(entry) and _weight_shard_index_complete(entry)) + if is_safetensors: + st_index_incomplete = incomplete + else: + bin_index_incomplete = incomplete + elif dash_infix in name and _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name): + if _safe_is_file(entry) and _format_kept(name): + if name.endswith(".safetensors"): + has_st_shard = True + else: + has_bin_shard = True + elif dot_infix in name and _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name): + # a single-file ROOT model variant weight (model..safetensors / .bin). + if _safe_is_file(entry) and _format_kept(name): + if name.endswith(".safetensors"): + has_single_st = True + else: + has_single_bin = True + # transformers' local precedence, mirrored: single safetensors before the safetensors index, + # safetensors before .bin, single .bin before the .bin index. So a complete single-file variant is + # never masked by a stale index (which would force a spurious DownloadStallError), and an incomplete + # PREFERRED safetensors index is still breakage a complete .bin must not mask. + if has_single_st: + return False # a complete single-file safetensors variant, probed before the index + if st_index_incomplete is not None: + return st_index_incomplete + if has_st_shard: + return True # variant safetensors shard files with no index -> incomplete + if has_single_bin: + return False # a complete single-file bin variant, probed before the .bin index + if bin_index_incomplete is not None: + return bin_index_incomplete + if has_bin_shard: + return True # variant bin shard files with no index -> incomplete + return False + + +_VARIANT_SHARD_INDEX_RE = re.compile(r"\.(?:safetensors|bin)\.index\.([^.]+)\.json$") + +# The ROOT canonical / variant MODEL shard index (owned by the canonical / variant root checks): +# model.safetensors.index.json, pytorch_model.bin.index.json, and their variant forms. +_ROOT_MODEL_SHARD_INDEX_RE = re.compile( + r"^(?:model\.safetensors|pytorch_model\.bin)\.index(?:\.[^.]+)?\.json$" +) + +# A ROOT model VARIANT weight, single or sharded (model.fp16.safetensors, +# pytorch_model.fp16-00001-of-00002.bin). Excludes a PEFT adapter the variant model load does not read. +_ROOT_MODEL_VARIANT_WEIGHT_RE = re.compile( + r"^(?:model|pytorch_model)\.[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" +) + + +def _index_variant_token(name: str) -> "Optional[str]": + """The variant token of a weight-shard INDEX basename, or None for the canonical form + (``model.safetensors.index.fp16.json`` -> ``"fp16"``). Lets the selected-index check read only the + indices a load reads (variant load -> variant indices, plain load -> canonical).""" + if name.endswith(".safetensors.index.json") or name.endswith(".bin.index.json"): + return None + m = _VARIANT_SHARD_INDEX_RE.search(name) + return m.group(1) if m else None + + +def _index_shard_rel_paths(index_path: Path, dir_rel: str) -> "Optional[list]": + """Snapshot-relative posix paths of the shards a weight index lists, or None if the index is + unreadable / malformed (fail-CLOSED, mirroring ``_weight_shard_index_complete``). *dir_rel* is the + index's snapshot-relative dir ("" at root), so a listed basename is joined back to a full + repo-relative path for the request filter.""" + import json + + try: + with open(index_path, "r", encoding = "utf-8") as f: + data = json.load(f) + except (OSError, ValueError): + return None + weight_map = data.get("weight_map") if isinstance(data, dict) else None + if not isinstance(weight_map, dict): + return None + values = list(weight_map.values()) + if not values or not all(isinstance(s, str) for s in values): + return None + prefix = f"{dir_rel}/" if dir_rel else "" + out: list = [] + for shard in set(values): + if _is_unsafe_shard_ref(shard): + return None + out.append(f"{prefix}{shard}") + return out + + +def _index_shard_probe(index_name: str, dir_rel: str) -> "Optional[str]": + """A representative numbered-shard path for a malformed weight-shard INDEX (index base + format as a + first shard, under *dir_rel*), so the scope check judges the request filter on the WEIGHT the load + reads, not the ``.json`` filename -- ``ignore=['*.bin']`` misses ``pytorch_model.bin.index.json`` but + the load never reads that ignored-format index. None when the name is not a recognizable shard index.""" + for marker, ext in ((".safetensors.index.", "safetensors"), (".bin.index.", "bin")): + if marker in index_name: + base = index_name.split(marker, 1)[0] + if not base: + return None + prefix = f"{dir_rel}/" if dir_rel else "" + return f"{prefix}{base}-00001-of-00002.{ext}" + return None + + +def _request_scopes_into_dir(allow_patterns: "Optional[list]", dir_name: str) -> bool: + """True when an allow pattern names *dir_name* among its LITERAL leading path segments + (``allow=['checkpoint-7/*']`` -> True for ``checkpoint-7``), i.e. the load reads INTO that directory. + Lets the shard check skip a leftover checkpoint subtree the request does not target while still + validating one it explicitly loads from. Segments are read only up to the first glob.""" + for p in allow_patterns or (): + if not isinstance(p, str) or "/" not in p: + continue + for seg in p.split("/"): + if _has_glob(seg): + break # a wildcard segment is not a literal directory target + if seg == dir_name: + return True + return False + + +def _selected_shard_index_incomplete( + snapshot_dir: Path, *, allow_patterns: "Optional[object]", ignore_patterns: "Optional[object]", + variant: "Optional[str]", +) -> bool: + """True when a weight-shard INDEX the load READS -- a sharded ADAPTER or a component SUBFOLDER set not + covered by the canonical / variant ROOT-model checks -- lists an absent shard (or is malformed). + Scoped to the request so a complete download is never false-rejected: + + - variant: a variant load reads only variant indices; a plain load only canonical ones. + - allow / ignore: an index is read only when its listed shards survive the request filter. + - precedence: safetensors read before bin, so when both are selected only the safetensors set's + completeness is required. + + Also rejects a SELECTED numbered shard FILE whose directory has NO index of the read format: the load + enumerates a sharded weight through its index, so a shard set without one is incomplete and would + fetch the index + remaining shards over Xet. The ROOT canonical / variant MODEL shard set is skipped + (the root-shard checks own it).""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + want_variant = variant or None + if want_variant is None: + shard_file_re = re.compile(r"^[^.]+-\d{5}-of-\d{5}\.(?:safetensors|bin)$") + else: + v = re.escape(want_variant) + shard_file_re = re.compile(rf"^[^.]+\.{v}-\d{{5}}-of-\d{{5}}\.(?:safetensors|bin)$") + try: + entries = list(snapshot_dir.rglob("*")) + except OSError: + return False + per_dir: dict = {} # dir_rel -> {"safetensors": [shard_rels, ...], "bin": [...]} (from indices) + index_fmts: dict = {} # dir_rel -> {fmt} an index of the read variant is present (non-root-model) + shard_fmts: dict = {} # dir_rel -> {fmt} a SELECTED numbered shard file is present (non-root-model) + for entry in entries: + name = entry.name + if not _safe_is_file(entry): + continue + try: + rel = entry.relative_to(snapshot_dir).as_posix() + except ValueError: + continue + dir_rel = rel.rsplit("/", 1)[0] if "/" in rel else "" + if _is_weight_shard_index(name): + if _index_variant_token(name) != want_variant: + continue # a wrong-variant index the load does not read + if dir_rel == "" and _ROOT_MODEL_SHARD_INDEX_RE.match(name): + continue # the ROOT model index -- owned by the canonical / variant root checks + fmt = "safetensors" if ".safetensors.index." in name else "bin" + index_fmts.setdefault(dir_rel, set()).add(fmt) + shard_rels = _index_shard_rel_paths(entry, dir_rel) + if shard_rels is None: + # A malformed index. Defer only when the REQUEST reads its weight set, judged on a + # representative shard of the index's OWN base + format (not the .json filename). So a + # stale malformed index the request does NOT select, or one for an IGNORED format, must + # not force a spurious retry; an unrecognizable index defers to the child. + probe = _index_shard_probe(name, dir_rel) + if probe is None or _filter_paths([probe], allow_patterns, ignore_patterns): + return True + continue + if not _filter_paths(shard_rels, allow_patterns, ignore_patterns): + continue # the load does not read this set (out of scope / ignored format) + per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) + elif shard_file_re.match(name): + # a numbered shard FILE of the read variant. Skip the ROOT model shard set (root checks own + # it) and any training-checkpoint subtree. + if dir_rel == "" and ( + (want_variant is None and _CANONICAL_ROOT_SHARD_RE.match(name)) + or (want_variant is not None and _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name)) + ): + continue + ckpt_dirs = [p for p in rel.split("/")[:-1] if _CHECKPOINT_DIR_RE.match(p)] + if ckpt_dirs and not _request_scopes_into_dir(allow_patterns, ckpt_dirs[0]): + # a leftover checkpoint subtree the request does not target. An EXPLICIT checkpoint load + # (allow=['checkpoint-N/*']) DOES read it, so that set is checked rather than skipped. + continue + if not _filter_paths([rel], allow_patterns, ignore_patterns): + continue # the load does not read this shard (out of scope / ignored format) + fmt = "safetensors" if name.endswith(".safetensors") else "bin" + shard_fmts.setdefault(dir_rel, set()).add(fmt) + for by_fmt in per_dir.values(): + # safetensors read before bin: require only the preferred format present in this directory. + for shard_rels in by_fmt.get("safetensors") or by_fmt.get("bin") or []: + for shard in shard_rels: + try: + if not (snapshot_dir / shard).exists(): + return True + except OSError: + return True + for dir_rel, fmts in shard_fmts.items(): + # a numbered shard of the read format with NO index in its dir: the load cannot enumerate the set + # and would fetch the index + remaining shards over Xet. + preferred = "safetensors" if "safetensors" in fmts else "bin" + if preferred not in index_fmts.get(dir_rel, set()): + return True + return False + + +# A training-checkpoint subdir (checkpoint-500/): never read as a diffusers pipeline COMPONENT, so an +# incomplete shard index under it must not force-fail a complete pipeline. +_CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") + + +def _diffusers_declared_component_specs(snapshot_dir: Path) -> "Optional[dict]": + """name -> declared ``[library, class]`` spec from a diffusers ``model_index.json`` (``_``-prefixed + metadata excluded). None when absent / unreadable / malformed / empty, so the caller fails OPEN.""" + import json + + try: + with open(snapshot_dir / "model_index.json", "r", encoding = "utf-8") as f: + data = json.load(f) + except (OSError, ValueError): + return None + if not isinstance(data, dict): + return None + specs = { + key: value for key, value in data.items() + if not key.startswith("_") and isinstance(value, (list, tuple)) + } + # An empty / all-metadata model_index.json is degenerate -> fail OPEN (None). + return specs or None + + +def _diffusers_declared_components(snapshot_dir: Path) -> "Optional[set]": + """The component subfolder names a diffusers ``model_index.json`` declares. None when absent / + unreadable / malformed so the caller falls back to every subfolder (fail OPEN). Scopes the component + check to what the pipeline reads, so a stale UNDECLARED subtree cannot force-fail a complete download.""" + specs = _diffusers_declared_component_specs(snapshot_dir) + return set(specs) if specs else None + + +def _diffusers_component_shards_incomplete( + snapshot_dir: Path, *, variant: "Optional[str]" = None, + ignore_patterns: "Optional[object]" = None, +) -> bool: + """True when a diffusers pipeline COMPONENT subfolder (unet/, vae/, ...) holds a weight-shard INDEX of + the read variant listing an absent shard (or a malformed index) -- an interrupted component pull the + pipeline load would finish over un-killable Xet. + + Scoped so a complete pipeline is never false-rejected: limited to declared components (a stale + UNDECLARED subtree is skipped), ROOT indices (root checks own them) and checkpoint subtrees are + skipped, and the ignore filter selects the read format. Per directory safetensors is read before bin. + A plain load reads canonical component indices, a variant load variant ones. Also rejects a component + numbered shard FILE with NO index of the read format. Positive-evidence: a single-file or complete + component set passes.""" + want_variant = variant or None + ignore_patterns = _as_pattern_list(ignore_patterns) + declared = _diffusers_declared_components(snapshot_dir) + if want_variant is None: + shard_file_re = re.compile(r"^[^.]+-\d{5}-of-\d{5}\.(?:safetensors|bin)$") + else: + v = re.escape(want_variant) + shard_file_re = re.compile(rf"^[^.]+\.{v}-\d{{5}}-of-\d{{5}}\.(?:safetensors|bin)$") + try: + entries = list(snapshot_dir.rglob("*")) + except OSError: + return False + per_dir: dict = {} + index_fmts: dict = {} # component dir_rel -> {fmt} an index of the read variant is present + shard_fmts: dict = {} # component dir_rel -> {fmt} a numbered shard file (ignore-kept) is present + for entry in entries: + name = entry.name + if not _safe_is_file(entry): + continue + try: + rel = entry.relative_to(snapshot_dir).as_posix() + except ValueError: + continue + parts = rel.split("/") + if len(parts) < 2: + continue # a ROOT file -- owned by the canonical / variant root-model checks + if declared is not None and parts[0] not in declared: + continue # an UNDECLARED subtree the DiffusionPipeline load does not read + if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): + continue # a training-checkpoint subtree, not a pipeline component + dir_rel = rel.rsplit("/", 1)[0] + if _is_weight_shard_index(name): + if _index_variant_token(name) != want_variant: + continue # a wrong-variant index the load does not read + fmt = "safetensors" if ".safetensors.index." in name else "bin" + index_fmts.setdefault(dir_rel, set()).add(fmt) + shard_rels = _index_shard_rel_paths(entry, dir_rel) + if shard_rels is None: + # A malformed index. Defer only when its FORMAT is read (a representative shard survives + # the ignore filter); a stale malformed index for an IGNORED format is not read and must + # not force a spurious retry of a complete other-format pipeline. + probe = _index_shard_probe(name, dir_rel) + if probe is None or _filter_paths([probe], None, ignore_patterns): + return True + continue + if not _filter_paths(shard_rels, None, ignore_patterns): + continue # the load does not read this set (ignored format) + per_dir.setdefault(dir_rel, {}).setdefault(fmt, []).append(shard_rels) + elif shard_file_re.match(name) and _filter_paths([rel], None, ignore_patterns): + fmt = "safetensors" if name.endswith(".safetensors") else "bin" + shard_fmts.setdefault(dir_rel, set()).add(fmt) + for by_fmt in per_dir.values(): + for shard_rels in by_fmt.get("safetensors") or by_fmt.get("bin") or []: + for shard in shard_rels: + try: + if not (snapshot_dir / shard).exists(): + return True + except OSError: + return True + for dir_rel, fmts in shard_fmts.items(): + preferred = "safetensors" if "safetensors" in fmts else "bin" + if preferred not in index_fmts.get(dir_rel, set()): + return True # a component numbered shard with no index of the read format + return False + + +def requested_named_files_present( + snapshot_dir: Path, + *, + allow_patterns: "Optional[object]" = None, + ignore_patterns: "Optional[object]" = None, +) -> bool: + """For a request naming EXACT files (every entry glob-free), True only when each named file the ignore + filter keeps is on disk -- ``local_files_only`` returns the revision dir even when config-only, so a + ``["tokenizer.json"]`` request needs its file present. A request with ANY glob, or no allow list, is + trivially satisfied.""" + allow_patterns = _as_pattern_list(allow_patterns) + ignore_patterns = _as_pattern_list(ignore_patterns) + if not allow_patterns or any(_has_glob(p) for p in allow_patterns): + return True + try: + entries = list(snapshot_dir.rglob("*")) + except OSError: + return True + present = set() + for entry in entries: + if _safe_is_file(entry): + try: + present.add(entry.relative_to(snapshot_dir).as_posix()) + except ValueError: + present.add(entry.name) + for pat in allow_patterns: + # A named file the ignore filter drops is not actually requested. + if ignore_patterns and not _filter_paths([pat], None, ignore_patterns): + continue + if pat not in present: + return False + return True + + +# --------------------------------------------------------------------------- +# Active-cache enumeration primitives (download-manager / watchdog support) +# --------------------------------------------------------------------------- + +def _iter_snapshot_dirs(repo_dir: Path) -> Iterator[Path]: + snapshots_dir = repo_dir / "snapshots" + try: + if not snapshots_dir.is_dir(): + return + children = [entry for entry in snapshots_dir.iterdir() if entry.is_dir()] + except OSError: + return + yield from children + + +def _repo_dir_has_broken_snapshot_symlinks(repo_dir: Path) -> bool: + # Check every snapshot, not just the newest: an older revision may be broken while a newer is clean. + return any( + snapshot_dir_has_broken_symlinks(snapshot) + for snapshot in _iter_snapshot_dirs(repo_dir) + ) + + +def _case_safe_repo_cache_dirs(root: Path, repo_type: Optional[str], repo_id: str) -> list: + """Cache dirs safely attributable to this exact repo id. + + The Hub case-folds the dir name, so a case-insensitive match is needed, but on a case-sensitive fs + ``models--Org--Repo`` and ``models--org--repo`` are distinct repos. Prefer exact case; else accept a + single folded match ONLY on a case-insensitive fs; on a 2+ way collision attribute to neither, so a + stale partial in one repo cannot make the watchdog kill / purge the other.""" + target = repo_cache_dir_name(repo_type, repo_id) + folded_target = target.lower() + try: + entries = [entry for entry in root.iterdir() if entry.name.lower() == folded_target] + except OSError: + return [] + exact = [entry for entry in entries if entry.name == target] + if exact: + return exact + if len(entries) == 1: + # A single folded-but-not-exact match is the same dir only on a case-insensitive fs; on a + # case-sensitive fs it is a DIFFERENT repo. + try: + if (root / target).exists(): + return entries + except OSError: + return [] + return [] + + +def iter_active_repo_cache_dirs( + repo_type: Optional[str], repo_id: str, *, cache_dir: "Optional[str | Path]" = None +) -> Iterator[Path]: + """Yield the repo's cache dir(s) under *cache_dir* (or the active ``HF_HUB_CACHE``). Case-collision + safe, so the read / watchdog path and the destructive HTTP-prep path share one rule.""" + root = hf_cache_root(cache_dir = cache_dir) + if root is None: + return + yield from _case_safe_repo_cache_dirs(root, repo_type, repo_id) + + +def repo_cache_dir_has_incomplete_blobs(repo_dir: Path) -> bool: + blobs_dir = repo_dir / "blobs" + return (blobs_dir.is_dir() and _blob_dir_is_partial(blobs_dir)) or ( + _repo_dir_has_broken_snapshot_symlinks(repo_dir) + ) + + +def has_active_incomplete_blobs( + repo_type: "Optional[str]", repo_id: str, *, cache_dir: "Optional[str | Path]" = None +) -> bool: + for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): + if repo_cache_dir_has_incomplete_blobs(entry): + return True + return False diff --git a/unsloth_zoo/hf_xet_fallback.py b/unsloth_zoo/hf_xet_fallback.py new file mode 100644 index 000000000..594816119 --- /dev/null +++ b/unsloth_zoo/hf_xet_fallback.py @@ -0,0 +1,1905 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# This file is licensed under the GNU Affero General Public License v3.0 only +# (AGPL-3.0-only), unlike the rest of unsloth_zoo which is LGPL-3.0-or-later. It +# is the single shared home for the Xet -> HTTP stall fallback used by both +# Unsloth (FastModel.from_pretrained) and Unsloth Studio, which imports it. +# See . + +"""Xet-primary HF downloads with an automatic HTTP fallback on a no-progress stall. + +Xet (``hf_xet``) is fast but can hang with no progress, no exception, and an un-killable native thread. +``HF_HUB_DISABLE_XET`` is read at import time, so the fallback runs in a fresh ``spawn`` child (not a +thread) that sets the env before importing ``huggingface_hub``. Cached files short-circuit with no +child; deterministic errors (401/403/404/disk-full) and cancellation propagate without a fallback. +``snapshot_download_with_xet_fallback`` warms a whole repo in a killable child before Unsloth's +in-process load; ``hf_hub_download_with_xet_fallback`` does a single file. Studio cache / secret / +process helpers are used best-effort (imported only if present) or injected. The child sets +``UNSLOTH_ZOO_DISABLE_GPU_INIT=1`` for unsloth_zoo's lightweight import path (no torch / transformers). +""" + +from __future__ import annotations + +import builtins +import errno +import importlib.util +import multiprocessing as mp +import logging +import os +import queue +import re +import signal +import sys +import threading +import time +from pathlib import Path +from typing import Any, Callable, Optional + +from unsloth_zoo.hf_cache_state import ( + INCOMPLETE_SUFFIX, + _ROOT_MODEL_VARIANT_WEIGHT_RE, + _as_pattern_list, + _diffusers_component_shards_incomplete, + _diffusers_declared_component_specs, + _filter_paths, + _has_glob, + _has_incomplete_canonical_root_shards, + _has_incomplete_variant_root_shards, + _is_loadable_weight_file, + _selected_shard_index_incomplete, + _weight_shard_index_complete, + blob_bytes_present, + has_active_incomplete_blobs, + hf_cache_root, + iter_active_repo_cache_dirs, + request_can_include_weights, + requested_named_files_present, + snapshot_dir_is_complete, + snapshot_has_requested_broken_symlinks, +) + +logger = logging.getLogger(__name__) + +# Explicit list keeps stdlib imports out of Studio's `import *` re-export shim. +__all__ = [ + "DownloadStallError", + "hf_hub_download_with_xet_fallback", + "snapshot_download_with_xet_fallback", + "start_watchdog", + "get_hf_download_state", + "is_hf_xet_available", + "xet_force_disabled", + "child_should_disable_xet", +] + +_CTX = mp.get_context("spawn") + +# Defaults match the existing Studio inference watchdog and hub shutdown deadline. +DEFAULT_HEARTBEAT_INTERVAL = 30.0 +DEFAULT_STALL_TIMEOUT = 180.0 +DEFAULT_GRACE_PERIOD = 10.0 +_POLL_INTERVAL = 0.5 + +# Serializes the parent-env (and __main__.__file__) mutation around a child spawn so +# concurrent downloads cannot observe each other's transport env. +_SPAWN_ENV_LOCK = threading.Lock() + +# Sentinel: "__main__.__file__ untouched for this spawn" (distinct from a saved None). +_UNSET = object() + +# HF boolean env convention, case-insensitive. +_TRUTHY = {"1", "true", "yes", "on"} + + +def _is_true(value: Optional[str]) -> bool: + return value is not None and str(value).strip().lower() in _TRUTHY + + +def _safe_status(callback: Optional[Callable[[str], None]], message: str) -> None: + """Invoke a status callback; swallow its exceptions so a disconnected client cannot kill the + daemon watchdog thread (stopping stall detection).""" + if callback is None: + return + try: + callback(message) + except Exception as e: + logger.debug("watchdog status callback raised (ignored): %s", e) + + +class DownloadStallError(RuntimeError): + """Raised when no download progress is observed for too long. Studio re-imports this canonical type.""" + + +def is_hf_xet_available() -> bool: + """True iff the ``hf_xet`` extra is importable (Hub uses it automatically).""" + try: + return importlib.util.find_spec("hf_xet") is not None + except Exception: + return False + + +def xet_force_disabled() -> bool: + """Whether the user asked to force HTTP up front via ``UNSLOTH_DISABLE_XET`` / + ``UNSLOTH_STABLE_DOWNLOADS`` / ``HF_HUB_DISABLE_XET``.""" + return ( + _is_true(os.environ.get("UNSLOTH_DISABLE_XET")) + or _is_true(os.environ.get("UNSLOTH_STABLE_DOWNLOADS")) + or _is_true(os.environ.get("HF_HUB_DISABLE_XET")) + ) + + +def child_should_disable_xet(config: dict) -> bool: + """Single source of truth for the per-worker Xet env flip.""" + return bool(config.get("disable_xet")) + + +def _default_scrub_secrets(text: str, hf_token: Optional[str] = None) -> str: + """Best-effort redaction of a token / bearer credential from an error string.""" + if not text: + return text + out = text + # token=True ("use cached token") is common; only a real string token can be substring-redacted. + if isinstance(hf_token, str) and hf_token: + out = out.replace(hf_token, "***") + out = re.sub(r"hf_[A-Za-z0-9]{8,}", "***", out) + out = re.sub(r"([Bb]earer\s+)[A-Za-z0-9._\-]+", r"\1***", out) + # Redact the query of a presigned S3/CAS blob URL (temporary creds in the query string); keep + # non-signed URLs (e.g. ...?download=true). + def _redact_signed_query(match: "re.Match") -> str: + base, query = match.group(1), match.group(2) + if re.search( + r"(X-Amz-|[Ss]ignature|(?:^|&)(?:sig|token|key|Expires|Policy|Key-Pair-Id)=)", + query, + ): + return f"{base}?***" + return match.group(0) + + # Stop the query at whitespace OR a structural delimiter (quote/bracket/brace/paren/angle/pipe): a + # URL embedded in JSON / a dict repr has no trailing whitespace, so a greedy [^\s]* would swallow the + # closing "} and corrupt the log line. Signed-query values percent-encode these chars, so a genuine + # presigned URL is still fully redacted. + out = re.sub( + r"(https?://[^\s?]+)\?([^\s\"'()<>{}|[\]]*)", _redact_signed_query, out + ) + return out + + +def _broken_link_has_active_partner(link: Path, *, active_grace: float) -> bool: + """SPARE a dangling snapshot symlink iff a sibling is still writing its target blob. Discriminator + is a FRESH ``.incomplete`` partner of the target, NOT the link mtime: our killed child's partner was + static for the full stall timeout and is purged first (no partner -> link cleared), while a sibling + mid-download still has a growing partner (link spared).""" + try: + target = Path(os.readlink(link)) + if not target.is_absolute(): + target = link.parent / target + incomplete_partner = target.with_name(target.name + INCOMPLETE_SUFFIX) + if incomplete_partner.is_file(): + return time.time() - incomplete_partner.stat().st_mtime < active_grace + except OSError: + return False + return False + + +def _link_incomplete_partner_name(link: Path) -> Optional[str]: + """The ``.incomplete`` basename for a dangling symlink's target blob, or None.""" + try: + target = Path(os.readlink(link)) + return target.name + INCOMPLETE_SUFFIX + except OSError: + return None + + +def _default_prepare_for_http( + repo_type: str, + repo_id: str, + *, + cache_dir: Optional[str] = None, + active_grace: float = DEFAULT_STALL_TIMEOUT, + owned_incomplete_blobs: Optional[set] = None, +) -> None: + """Make the partial safe for an HTTP resume: delete the repo's active ``*.incomplete`` blobs (an + HTTP resume over a sparse Xet / hf_transfer partial silently corrupts the blob) and the broken + snapshot symlinks the detector counts as active (else the retry inherits stale state and re-trips). + ``iter_active_repo_cache_dirs`` is case-collision safe, so this destructive purge only touches an + unambiguous repo cache dir. Studio injects its marker-aware version instead. + + *owned_incomplete_blobs* (basenames the stalled child held open, captured before the kill) SCOPES the + purge so a same-repo sibling writing a DIFFERENT blob is spared even if aged past *active_grace*; + None -> coarser mtime guard only. + """ + try: + for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): + blobs_dir = entry / "blobs" + if blobs_dir.is_dir(): + for blob in blobs_dir.iterdir(): + if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): + # Scope to our own partials when known: never delete a sibling's blob. + if owned_incomplete_blobs is not None and blob.name not in owned_incomplete_blobs: + continue + try: + # Spare a partial touched within active_grace (a slow sibling, not stalled); + # our killed partial has been static for the full stall timeout so it purges. + if time.time() - blob.stat().st_mtime < active_grace: + continue + blob.unlink() + except OSError: + continue # a locked / denied blob must not abort the rest + # Clear broken snapshot symlinks (also read as active incomplete state). Sweep EVERY snapshot, + # else a dangling link under an older revision keeps the repo incomplete and re-trips. + snapshots_dir = entry / "snapshots" + try: + snapshot_dirs = [s for s in snapshots_dir.iterdir() if s.is_dir()] + except OSError: + snapshot_dirs = [] + for snapshot in snapshot_dirs: + try: + for link in snapshot.rglob("*"): + if link.is_symlink() and not link.exists(): + # Scope to our own partials when known; a link to a sibling's blob is theirs. + if owned_incomplete_blobs is not None and ( + _link_incomplete_partner_name(link) not in owned_incomplete_blobs + ): + continue + # Spare a sibling's active link (target still has a fresh .incomplete). + if _broken_link_has_active_partner(link, active_grace = active_grace): + continue + try: + link.unlink() + except OSError: + continue + except OSError: + continue + except Exception as e: + logger.debug("default prepare_for_http failed for %s: %s", repo_id, e) + + +def _active_incomplete_blob_sizes( + repo_type: Optional[str], repo_id: str, cache_dir: Optional[str] = None +) -> dict[str, int]: + """Map ``{blob_filename: bytes_present}`` (sparse-aware) for the repo's ``*.incomplete`` partials, so + the single-file watchdog follows only its own child's partials and a sibling download of a different + file cannot mask this file's stall.""" + sizes: dict[str, int] = {} + try: + for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): + blobs_dir = entry / "blobs" + if not blobs_dir.is_dir(): + continue + for blob in blobs_dir.iterdir(): + try: + if blob.is_file() and blob.name.endswith(INCOMPLETE_SUFFIX): + sizes[blob.name] = blob_bytes_present(blob) + except OSError: + pass + except Exception: + pass + return sizes + + +def _child_open_incomplete_blobs(pid: int) -> Optional[set]: + """Basenames of the ``*.incomplete`` blobs child *pid* has open -- exactly the partial THIS child is + writing (incl. a resumed partial reusing a prior blob-hash name), not a sibling's (different pid). + ``None`` when undeterminable (no ``psutil`` / ``/proc``, or process gone) -> caller uses a coarser + measure; empty set -> child not yet writing (connect / metadata phase).""" + # psutil is cross-platform (Linux / macOS / Windows). + try: + import psutil # type: ignore + except ImportError: + psutil = None # type: ignore + if psutil is not None: + try: + files = psutil.Process(pid).open_files() + except Exception: + return None + return {os.path.basename(f.path) for f in files if f.path.endswith(INCOMPLETE_SUFFIX)} + # Linux fallback: read open fds from /proc. + fd_dir = f"/proc/{pid}/fd" + try: + entries = os.listdir(fd_dir) + except OSError: + return None # no /proc (non-Linux) or process gone + open_blobs: set = set() + for fd in entries: + try: + target = os.readlink(os.path.join(fd_dir, fd)) + except OSError: + continue + if target.endswith(INCOMPLETE_SUFFIX): + open_blobs.add(os.path.basename(target)) + return open_blobs + + +def get_hf_download_state( + repo_ids: Optional[list[str]] = None, + *, + repo_type: Optional[str] = "model", + cache_dir: Optional[str] = None, +) -> Optional[tuple[int, bool]]: + """Return ``(total_on_disk_bytes, has_incomplete)`` for the HF cache being written (sparse-aware, so + a partial Xet / ``hf_transfer`` blob is not read as full progress). Scans *cache_dir* or the active + ``HF_HUB_CACHE``. Missing / empty cache -> ``(0, False)``; ``None`` only on a probe exception + (unmeasurable -> callers skip stall logic this tick).""" + try: + if hf_cache_root(cache_dir = cache_dir) is None: + return (0, False) + + total = 0 + has_incomplete = False + for repo_id in repo_ids or []: + # Skip local paths: HF IDs never start with / . ~, contain "\", or a drive-letter ":". + if ( + not repo_id + or repo_id.startswith(("/", ".", "~")) + or "\\" in repo_id + or ":" in repo_id + ): + continue + for entry in iter_active_repo_cache_dirs(repo_type, repo_id, cache_dir = cache_dir): + blobs_dir = entry / "blobs" + if not blobs_dir.is_dir(): + continue + for blob in blobs_dir.iterdir(): + try: + if blob.is_file(): + total += blob_bytes_present(blob) + except OSError: + pass + if has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir): + has_incomplete = True + return (total, has_incomplete) + except Exception as e: + logger.debug("Failed to determine HF download state: %s", e) + return None + + +def start_watchdog( + *, + repo_ids: list[str], + on_stall: Callable[[str], None], + repo_type: Optional[str] = "model", + cache_dir: Optional[str] = None, + interval: float = DEFAULT_HEARTBEAT_INTERVAL, + stall_timeout: float = DEFAULT_STALL_TIMEOUT, + xet_disabled: bool = False, + on_heartbeat: Optional[Callable[[str], None]] = None, + watch_new_partials_only: bool = False, + baseline_incomplete_blobs: Optional[set] = None, + child_pid: Optional[int] = None, +) -> threading.Event: + """Start a daemon thread firing ``on_stall(message)`` once iff a ``*.incomplete`` is present AND the + on-disk size is unchanged for *stall_timeout* seconds. The timer resets while no ``*.incomplete`` + exists, so post-download init is not misread as a stall. Returns a stop event the caller sets when + the download phase ends. + + *watch_new_partials_only* (single-file) measures progress only over the child's own partial, so a + sibling pull of a different file cannot keep a hung child alive. That partial is identified by the + blobs *child_pid* has open (precise across a resume), else the partials not in + *baseline_incomplete_blobs* (captured pre-spawn). Snapshots keep the repo-wide measurement.""" + stop = threading.Event() + transport = "https" if xet_disabled else "xet" + fired = False + baseline = set(baseline_incomplete_blobs or ()) + single_repo_id = repo_ids[0] if repo_ids else "" + + def _measure() -> Optional[tuple[int, bool]]: + if watch_new_partials_only: + sizes = _active_incomplete_blob_sizes(repo_type, single_repo_id, cache_dir) + open_names = _child_open_incomplete_blobs(child_pid) if child_pid else None + if open_names is not None: + # Only partials this child holds open (handles a resume reusing a baseline name, excludes + # siblings). hf_xet holds the .incomplete fd continuously, so an EMPTY set means the child + # owns no partial YET (connect / metadata phase), not a sibling's. + owned = {name: n for name, n in sizes.items() if name in open_names} + return (sum(owned.values()), len(owned) > 0) + if child_pid: + # pid given but open files uninspectable (no psutil AND no /proc: native Windows / macOS + # without psutil). Post-baseline name filtering would forever EXCLUDE a resumed partial + # reusing a baseline name, so a frozen Xet resume never trips -- defeating the fallback. + # Fall back to the repo-wide measure (as snapshots do): the resume is watched, at the cost + # that a same-repo sibling's progress may mask this child's stall (accepted tradeoff). + return get_hf_download_state( + [single_repo_id], repo_type = repo_type, cache_dir = cache_dir + ) + # No child pid at all: follow only newly-created (post-baseline) partials. + owned = {name: n for name, n in sizes.items() if name not in baseline} + return (sum(owned.values()), len(owned) > 0) + return get_hf_download_state(repo_ids, repo_type = repo_type, cache_dir = cache_dir) + + def _beat() -> None: + nonlocal fired + state = _measure() + last_size = state[0] if state is not None else 0 + last_change = time.monotonic() + + while not stop.wait(interval): + state = _measure() + now = time.monotonic() + + if state is None: + # Unmeasurable this tick (transient FS error): treat as progress so the gap cannot + # trip a false stall once readable again. + last_change = now + _safe_status(on_heartbeat, f"Downloading ({transport} transport)...") + continue + + current_size, has_incomplete = state + if current_size != last_size: + last_size = current_size + last_change = now + + # Reset unless .incomplete confirms an active download, so model init and lock waits + # are not counted as a stall. + if not has_incomplete: + last_change = now + elif now - last_change >= stall_timeout: + if not fired: + fired = True + on_stall( + f"Download appears stalled ({transport} transport) " + f"-- no progress for {int(now - last_change)}s" + ) + return + + _safe_status(on_heartbeat, f"Downloading ({transport} transport)...") + + threading.Thread(target = _beat, daemon = True, name = "hf-xet-watchdog").start() + return stop + + +def _scrub_in_child(text: str, token: Optional[str]) -> str: + """Redact secrets from a child error string, preferring Studio's patterns if present.""" + try: + from hub.utils.download_registry import scrub_secrets # type: ignore + + return scrub_secrets(text, hf_token = token) + except Exception: + return _default_scrub_secrets(text, hf_token = token) + + +# Deterministic Hub failures that recur identically over either transport, so retrying HTTP is +# pointless: surface them. Matched by class NAME so the parent need not import HF's error classes. +_DETERMINISTIC_ERROR_NAMES = frozenset({ + "RepositoryNotFoundError", + "RevisionNotFoundError", + "EntryNotFoundError", + "GatedRepoError", + "DisabledRepoError", + "LocalEntryNotFoundError", + "LocalTokenNotFoundError", # a missing required token fails identically either way + "BadRequestError", + "HFValidationError", # a malformed repo id / revision never reaches the network +}) +# TYPE reconstructed across the spawn but NOT retry-deterministic: ``HfHubHTTPError`` bases both +# deterministic 4xx and transient 5xx / 429, so its retry stays status-code driven while the parent +# still re-raises the real type (not RuntimeError) so ``except HfHubHTTPError`` keeps working. +_TYPE_PRESERVE_ONLY_NAMES = frozenset({ + "HfHubHTTPError", +}) +# Substrings marking a transient transport failure (hf_xet / CAS error, timeout, reset, 5xx / 429) +# that an HTTP retry may recover. +_TRANSIENT_ERROR_HINTS = ( + "xet", "casclient", "cas_", "timeout", "timed out", "connection", "reset by peer", + "temporarily", "try again", "incompleteread", "protocolerror", "remotedisconnected", + "broken pipe", "ssl", "eof occurred", "502", "503", "504", "500 server", "429", + "too many requests", "service unavailable", "bad gateway", "gateway time", + "connection aborted", +) + + +def _resolve_exception_class(type_name: str) -> "Optional[type]": + """Map a deterministic Hub / OS error class NAME back to its class so the parent re-raises the + original type, not RuntimeError. Best-effort (unknown -> None); local imports keep it import-light + and independent of the huggingface_hub layout.""" + if type_name == "OSError": + return OSError + # Preserve builtin OSError subclasses (PermissionError, FileNotFoundError, ...): deterministic FS + # failures the child cannot retry away, so a caller's `except OSError` / `except PermissionError` + # must still fire rather than see RuntimeError. + builtin_cls = getattr(builtins, type_name, None) + if isinstance(builtin_cls, type) and issubclass(builtin_cls, OSError): + return builtin_cls + if type_name not in _DETERMINISTIC_ERROR_NAMES and type_name not in _TYPE_PRESERVE_ONLY_NAMES: + return None + for module_name in ("huggingface_hub.errors", "huggingface_hub.utils"): + try: + module = importlib.import_module(module_name) + except Exception: + continue + cls = getattr(module, type_name, None) + if isinstance(cls, type) and issubclass(cls, BaseException): + return cls + return None + + +def _instantiate_preserving_type(exc_cls: type, message: str) -> "Optional[BaseException]": + """Build an *exc_cls* instance carrying *message*, robust to a finicky constructor: Hub errors + subclass ``HfHubHTTPError`` whose ``response`` arg can be required, so ``exc_cls(message)`` may raise + ``TypeError``. Try normal constructors, then BYPASS ``__init__`` via ``__new__`` so type + message + survive. None only if ``__new__`` fails.""" + for build in ( + lambda: exc_cls(message), + lambda: exc_cls(message, response = None), + ): + try: + return build() + except Exception: + continue + try: + exc = exc_cls.__new__(exc_cls) + BaseException.__init__(exc, message) + return exc + except Exception: + return None + + +def _parse_errno(message: str) -> "Optional[int]": + """Pull the errno out of a stringified OSError (CPython formats it ``[Errno 28] ...``), so a + disk-full / quota error keeps its code across the spawn boundary for ``exc.errno`` branching.""" + match = re.search(r"\[Errno (\d+)\]", message) + if match is None: + return None + try: + return int(match.group(1)) + except ValueError: + return None + + +def _is_builtin_oserror(exc: BaseException) -> bool: + """True iff *exc* is a BUILTIN ``OSError`` (or subclass) with a real errno. Excludes HF/requests HTTP + errors (``OSError`` via ``requests -> IOError`` but no OS errno), so a bracketed ``[Errno N]`` in + their message is not mistaken for one.""" + if not isinstance(exc, OSError): + return False + builtin = getattr(builtins, type(exc).__name__, None) + return isinstance(builtin, type) and issubclass(builtin, OSError) and isinstance(exc, builtin) + + +def _raise_child_error(message: str) -> None: + """Re-raise a deterministic child error preserving its original TYPE for a known Hub / OS error, so + callers' ``except RepositoryNotFoundError`` / ``GatedRepoError`` / ``OSError`` still match across the + spawn. Child reports ``": "``; an unrecognized / uninstantiable class -> + ``RuntimeError``.""" + type_name = message.split(":", 1)[0].strip() if ":" in message else "" + exc_cls = _resolve_exception_class(type_name) + if exc_cls is None: + raise RuntimeError(message) + exc = _instantiate_preserving_type(exc_cls, message) + if exc is None: + raise RuntimeError(message) + if _is_builtin_oserror(exc) and getattr(exc, "errno", None) is None: + # Preserve errno (ENOSPC / EDQUOT ...) across the spawn for `except OSError` cleanup, for EVERY + # builtin OSError subclass. Restricted to BUILTIN types (an HF HTTP error is also an OSError via + # requests -> IOError, but its "[Errno N]" is not a real errno). Set as an attribute, not via the + # two-arg constructor: a single-arg __init__ subclass (LocalEntryNotFoundError) rejects it, and + # this avoids a doubled "[Errno N]" prefix. + errno_val = _parse_errno(message) + if errno_val is not None: + try: + exc.errno = errno_val + except Exception: + pass + raise exc + + +def _is_retryable_download_error(exc: BaseException) -> bool: + """True when a captured exception looks like a transient transport failure (``hf_xet`` / CAS error, + reset, timeout, 5xx / 429) the OTHER transport may recover, vs a deterministic Hub error (auth, + not-found, gated, disk-full). Unknown errors count as deterministic, so a real repeatable failure is + surfaced rather than looped between transports.""" + name = type(exc).__name__ + # LocalEntryNotFoundError wraps BOTH a genuine offline / uncached miss (deterministic) AND a + # TRANSIENT HEAD connection error / timeout for an uncached file. Retry the transient sub-case; a + # true offline miss (no transient hint) falls through to the deterministic set below. + if name == "LocalEntryNotFoundError" and any( + hint in f"{name}: {exc}".lower() for hint in _TRANSIENT_ERROR_HINTS + ): + return True + if name in _DETERMINISTIC_ERROR_NAMES: + return False + # Disk full / quota: another transport cannot help. + if isinstance(exc, OSError) and getattr(exc, "errno", None) in (errno.ENOSPC, errno.EDQUOT): + return False + # HTTP status (HfHubHTTPError carries a requests / httpx response): 5xx / 429 / 408 transient, + # other 4xx (401 / 403 / 404 / 416) deterministic. + status = getattr(getattr(exc, "response", None), "status_code", None) + if not isinstance(status, int): + status = getattr(exc, "status_code", None) + if isinstance(status, int): + return status >= 500 or status in (408, 429) + text = f"{name}: {exc}".lower() + return any(hint in text for hint in _TRANSIENT_ERROR_HINTS) + + +def _child_download(*, kind: str, params: dict, token: Optional[str], repo_type: str) -> str: + """Run the actual HF download for one attempt inside the spawn child.""" + if kind == "snapshot": + from huggingface_hub import snapshot_download + + return snapshot_download( + repo_id = params["repo_id"], + repo_type = repo_type, + token = token, + revision = params.get("revision"), + cache_dir = params.get("cache_dir"), + allow_patterns = params.get("allow_patterns"), + ignore_patterns = params.get("ignore_patterns"), + force_download = params.get("force_download", False), + ) + + from huggingface_hub import hf_hub_download + + return hf_hub_download( + repo_id = params["repo_id"], + filename = params["filename"], + subfolder = params.get("subfolder"), + repo_type = repo_type, + token = token, + revision = params.get("revision"), + cache_dir = params.get("cache_dir"), + force_download = params.get("force_download", False), + ) + + +def _download_child_entry( + *, + kind: str, + params: dict, + token: Optional[str], + repo_type: str, + disable_xet: bool, + result_queue: Any, +) -> None: + """Spawn-child entrypoint (top-level + picklable): set the Xet env BEFORE importing huggingface_hub, + form its own process group so the parent can kill the whole transfer, never log token / signed + URLs.""" + # Die with the parent on Linux under Studio (best-effort; module absent standalone). + try: + from utils.process_lifetime import bind_current_process_to_parent_lifetime # type: ignore + + bind_current_process_to_parent_lifetime() + except Exception: + pass + + if hasattr(os, "setsid"): + try: + os.setsid() + except OSError: + pass + + if disable_xet: + os.environ["HF_HUB_DISABLE_XET"] = "1" + # Keep the HTTP writer sequential and resumable (hf_transfer's sparse partials cannot). + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" + os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") + + repo_id = params["repo_id"] + + # Test-only fault injection (never set in production): stall the Xet attempt to exercise the + # watchdog + HTTP fallback against a real repo. + if not disable_xet and os.environ.get("UNSLOTH_HF_XET_FORCE_STALL") == "1": + _stall_fh = None + try: + from huggingface_hub.constants import HF_HUB_CACHE + + # Write the fake partial under the repo_type-correct dir the watchdog scans. + cache_root = params.get("cache_dir") or HF_HUB_CACHE + repo_dir_name = f"{repo_type or 'model'}s--" + repo_id.replace("/", "--") + blobs = os.path.join(cache_root, repo_dir_name, "blobs") + os.makedirs(blobs, exist_ok = True) + # Hold the partial OPEN for the whole stall (the single-file watchdog counts only partials + # this PID holds open); bound to a local so it stays open across the sleep. + _stall_fh = open(os.path.join(blobs, "xet-force-stall.incomplete"), "wb") + _stall_fh.write(b"\0" * 4096) + _stall_fh.flush() + except OSError: + pass + while True: + time.sleep(3600) + + try: + path = _child_download(kind = kind, params = params, token = token, repo_type = repo_type) + result_queue.put({"ok": True, "path": path}) + except BaseException as e: # noqa: BLE001 - report every failure to the parent + # Classify here where the exception (status, errno, type) is intact, so the parent retries a + # transient failure over HTTP yet surfaces a deterministic one without a second attempt. + result_queue.put({ + "ok": False, + "error": _scrub_in_child(f"{type(e).__name__}: {e}", token), + "retryable": _is_retryable_download_error(e), + }) + + +def _terminate_process_group(proc: "mp.process.BaseProcess", grace_period: float) -> None: + """Kill *proc* and its whole process group (Xet may spawn helpers). The child ``os.setsid()``s so its + pgid == pid; the group is signalled via ``os.killpg`` only once ``os.getpgid(pid) == pid`` confirms + it. SIGTERM, then SIGKILL after *grace_period*.""" + pid = proc.pid + + def _signal_group(sig: int) -> None: + # Signal the GROUP only once the child is its own leader (pgid == pid after setsid). Before setsid + # it is still in OUR group and its pid could collide with a recycled group, so ``getpgid != pid`` + # guards ``killpg`` from the WRONG group (a reaped child raises here). Otherwise (also Windows: no + # killpg / getpgid) signal the single process. + if pid is not None and hasattr(os, "killpg") and hasattr(os, "getpgid"): + try: + if os.getpgid(pid) == pid: + os.killpg(pid, sig) + return + except (ProcessLookupError, PermissionError, OSError): + pass + try: + proc.terminate() if sig != getattr(signal, "SIGKILL", -9) else proc.kill() + except Exception: + pass + + _signal_group(getattr(signal, "SIGTERM", signal.SIGINT)) + proc.join(timeout = grace_period) + # SIGKILL only while alive, so the pid (== pgid) is a live target: once join() reaps a leader, a busy + # host could recycle its pid into an unrelated group and killpg would hit the WRONG one. hf_xet 1.5.x + # spawns no helpers, so a reaped leader leaves nothing to clean up. + if proc.is_alive(): + # -9 sentinel takes the force-kill branch on Windows (signal.SIGKILL undefined; moot there since + # proc.kill() == proc.terminate()), keeping the call site consistent. + _signal_group(getattr(signal, "SIGKILL", -9)) + proc.join(timeout = 5.0) + + +def _run_download_attempt( + repo_id: str, + *, + kind: str, + params: dict, + token: Optional[str], + repo_type: str, + disable_xet: bool, + cancel_event: Optional[threading.Event], + stall_timeout: float, + interval: float, + grace_period: float, + on_status: Optional[Callable[[str], None]], +) -> tuple[str, Optional[str]]: + """Run one download in a spawn child under the no-progress watchdog. Returns ``("ok", path)``, + ``("stall", None)``, ``("cancelled", None)``, ``("crashed", message)`` (crash, no captured + exception), ``("retryable_error", message)`` (transient, worth an HTTP retry), or ``("error", + message)`` (deterministic Hub error). Tests monkeypatch this seam to avoid spawning.""" + # Single-file: snapshot the on-disk partials BEFORE spawning so the watchdog follows only the blob(s) + # this child writes, not a sibling's. Snapshots stay repo-wide. + baseline_partials: Optional[set] = None + if kind == "file": + baseline_partials = set( + _active_incomplete_blob_sizes(repo_type, repo_id, params.get("cache_dir")) + ) + result_queue: Any = _CTX.Queue() + proc = _CTX.Process( + target = _download_child_entry, + kwargs = dict( + kind = kind, + params = params, + token = token, + repo_type = repo_type, + disable_xet = disable_xet, + result_queue = result_queue, + ), + daemon = True, + ) + # Set the transport env in THIS process around the spawn so the child inherits it from creation: HF + # caches HF_HUB_DISABLE_XET at import time and the child re-imports huggingface_hub before its body, + # so a child-side assignment lands too late. The child still sets it defensively. + child_env = { + "HF_HUB_DISABLE_PROGRESS_BARS": "1", + # Skip unsloth_zoo's heavy torch / transformers / device init in the child. + "UNSLOTH_ZOO_DISABLE_GPU_INIT": "1", + } + if disable_xet: + child_env["HF_HUB_DISABLE_XET"] = "1" + child_env["HF_HUB_ENABLE_HF_TRANSFER"] = "0" + with _SPAWN_ENV_LOCK: + # Cache Hub's transport constants in the PARENT from the REAL env NOW, before the child-only + # HF_HUB_DISABLE_XET=1 is briefly set below: else a concurrent thread's FIRST `import + # huggingface_hub` in the spawn window caches the disabled value and routes later in-process + # downloads over HTTP. No-op once imported. + try: + import huggingface_hub.constants # noqa: F401 + except Exception: + pass + saved_env = {k: os.environ.get(k) for k in child_env} + # 'spawn' reconstructs __main__ from __main__.__file__: a pseudo-path ('', a notebook) + # fails to start, and a real UNGUARDED caller script re-imports as __mp_main__, re-running + # top-level from_pretrained -> "start a process before bootstrapping" -> the child exits without + # a result. Point __main__ at THIS side-effect-free module for the spawn. + main_module = sys.modules.get("__main__") + saved_main_file = _UNSET + saved_main_spec = _UNSET + if main_module is not None: + saved_main_file = getattr(main_module, "__file__", _UNSET) + main_module.__file__ = __file__ + # `python -m pkg`: spawn prefers __spec__.name and re-imports BY NAME. Clearing __spec__ + # forces the __file__ path branch. + saved_main_spec = getattr(main_module, "__spec__", _UNSET) + main_module.__spec__ = None + try: + os.environ.update(child_env) + proc.start() + except BaseException: + # proc.start() can raise (OSError under fd / thread exhaustion). The lifecycle try/finally + # that closes the queue's pipe fds runs only after a successful start, so close it here to + # avoid an fd leak. + try: + result_queue.cancel_join_thread() + result_queue.close() + except Exception: + pass + raise + finally: + for k, v in saved_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + if main_module is not None: + if saved_main_file is _UNSET: + # __file__ was absent before; remove the one we added. + try: + delattr(main_module, "__file__") + except AttributeError: + pass + else: + main_module.__file__ = saved_main_file + if saved_main_spec is _UNSET: + try: + delattr(main_module, "__spec__") + except AttributeError: + pass + else: + main_module.__spec__ = saved_main_spec + + # Bind the child to the parent lifetime when running under Studio (best-effort). + try: + from utils.process_lifetime import adopt_pid # type: ignore + + adopt_pid(proc.pid) + except Exception: + pass + + stalled = threading.Event() + # If start_watchdog raises ("can't start new thread"), the already-started child must STILL be + # reaped, so it runs inside the try whose finally reaps it; stop_watchdog stays None until it works. + stop_watchdog = None + result: Optional[dict] = None + try: + stop_watchdog = start_watchdog( + repo_ids = [repo_id], + on_stall = lambda msg: stalled.set(), + repo_type = repo_type, + cache_dir = params.get("cache_dir"), + interval = interval, + stall_timeout = stall_timeout, + xet_disabled = disable_xet, + on_heartbeat = on_status, + watch_new_partials_only = (kind == "file"), + baseline_incomplete_blobs = baseline_partials, + child_pid = proc.pid, + ) + while proc.is_alive(): + if cancel_event is not None and cancel_event.is_set(): + _terminate_process_group(proc, grace_period) + return ("cancelled", None) + if stalled.is_set(): + # Prefer a result the child enqueued in the watchdog's fire window, so a just-succeeded + # download is not killed. Its feeder may not have flushed a microseconds-earlier put, so + # use a short timeout, not get_nowait(). + try: + result = result_queue.get(timeout = 1.0) + break + except queue.Empty: + pass + # Capture the partials THIS child owns BEFORE killing it, so HTTP prep scopes its purge to + # them. Prefer the per-pid open-fd set; else post-baseline partials; None -> coarser mtime + # guard. + owned = _child_open_incomplete_blobs(proc.pid) if proc.pid else None + if owned is None and baseline_partials is not None: + current = set( + _active_incomplete_blob_sizes(repo_type, repo_id, params.get("cache_dir")) + ) + owned = current - baseline_partials + params["_owned_incomplete_blobs"] = owned + _terminate_process_group(proc, grace_period) + return ("stall", None) + try: + result = result_queue.get(timeout = _POLL_INTERVAL) + break + except queue.Empty: + continue + else: + # Process exited; drain any enqueued result. Short timeout, not get_nowait(): the child can + # exit just before its feeder flushes the pipe, which would spuriously look resultless. + try: + result = result_queue.get(timeout = 1.0) + except queue.Empty: + result = None + finally: + if stop_watchdog is not None: + stop_watchdog.set() + proc.join(timeout = grace_period) + # No loop exit may leak the child; _terminate_process_group is idempotent so a redundant call + # is a no-op. + if proc.is_alive(): + _terminate_process_group(proc, grace_period) + # Release the queue's pipe fds now rather than at GC. The result is extracted and a killed child + # has nothing to flush, so cancel the feeder before close. + try: + result_queue.cancel_join_thread() + result_queue.close() + except Exception: + pass + + if result is None: + # Child exited resultless: a process-level crash (native hf_xet abort / segfault), not a captured + # exception, so the other transport may still succeed -- report "crashed". + return ( + "crashed", + f"download process for '{repo_id}' exited " + f"(code={proc.exitcode}) without a result", + ) + if result.get("ok"): + return ("ok", result["path"]) + message = result.get("error") or "unknown download error" + if result.get("retryable"): + return ("retryable_error", message) + return ("error", message) + + +def _intact_subset( + snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, +) -> bool: + """No interrupted-download evidence for the SELECTED files: no dangling requested symlink and every + EXACT-named requested file present. A dangling EXCLUDED weight does not reject the subset.""" + return ( + not snapshot_has_requested_broken_symlinks( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + repo_type = repo_type, + ) + and requested_named_files_present( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) + ) + + +def _is_default_load_weight_file(name: str) -> bool: + """A weight a DEFAULT ``from_pretrained`` reads: safetensors or bin only. Excludes gguf / pt / onnx / + msgpack / ... -- a stale cache holding only e.g. ``model.Q4_K_M.gguf`` must not satisfy the load, else + it fetches the missing ``model.safetensors`` over un-killable Xet. Optimizer state is already excluded + by ``_is_loadable_weight_file``.""" + return _is_loadable_weight_file(name) and name.endswith((".safetensors", ".bin")) + + +# CANONICAL root model weight a DEFAULT (no-variant) load reads: model.safetensors / pytorch_model.bin, +# single or numbered shard (dash infix, not a dotted variant token). A PEFT adapter, a variant +# (model.fp16.safetensors), a gguf, and a non-canonical root (consolidated.safetensors, tf_model.h5) are +# NOT matched -- a default from_pretrained probes only these names, so a cache holding only something +# else would fetch the missing canonical weight over un-killable Xet. +_CANONICAL_ROOT_MODEL_WEIGHT_RE = re.compile( + r"^(?:model|pytorch_model)(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" +) + +# CANONICAL (non-variant) diffusers component weight a PLAIN pipeline load reads in a component subfolder: +# a base with no intermediate dotted token, single or numbered shard, safetensors or bin. A VARIANT +# weight (diffusion_pytorch_model.fp16.safetensors) is EXCLUDED, so a variant='fp16' stale cache does not +# read as a warm PLAIN pipeline (which would fetch the non-variant name over un-killable Xet). +_CANONICAL_COMPONENT_WEIGHT_RE = re.compile( + r"^[^.]+(?:-\d{5}-of-\d{5})?\.(?:safetensors|bin)$" +) + +# SINGLE-FILE canonical root TF / Flax weight a from_tf / from_flax load reads instead of a PyTorch +# format. A SHARDED TF/Flax weight is judged through its index instead -- a lone shard here must NOT read +# as present, else an incomplete sharded set is loaded over Xet. +_CANONICAL_ROOT_TF_FLAX_WEIGHT_RE = re.compile(r"^(?:tf_model\.h5|flax_model\.msgpack)$") + +# The shard-index sidecars a sharded TF / Flax weight is enumerated through. +_TF_FLAX_WEIGHT_INDEX_NAMES = ("tf_model.h5.index.json", "flax_model.msgpack.index.json") + + +def _pytorch_root_weight_formats_ignored(ignore_patterns: Any) -> bool: + """True when the ignore filter drops BOTH canonical PyTorch root weights (``model.safetensors`` AND + ``pytorch_model.bin``) -- the ``from_tf`` / ``from_flax`` signature. Lets the readable-weight check + count the TF/Flax weight the load actually reads rather than false-reject a complete h5/msgpack + download. Never true for a normal load (which keeps a PyTorch format).""" + return not _filter_paths( + ["model.safetensors", "pytorch_model.bin"], None, ignore_patterns + ) + +# A training-checkpoint subdir (checkpoint-500/): its weights are never read as diffusers pipeline +# COMPONENTS, so they must not mask missing unet/vae/text-encoder weights. +_CHECKPOINT_DIR_RE = re.compile(r"^checkpoint[-_]\d+$") + + +def _diffusers_active_component_dirs(specs: dict) -> set: + """Declared components a pipeline actually loads: spec is a ``[library, class]`` pair with both + non-null. A ``[null, null]`` (a disabled / optional component such as safety_checker) is excluded -- + the load skips it, so it is not required to be present.""" + active: set = set() + for name, spec in specs.items(): + if ( + isinstance(spec, (list, tuple)) and len(spec) >= 2 + and spec[0] is not None and spec[1] is not None + ): + active.add(name) + return active + + +def _is_weightless_component(names: list) -> bool: + """Whether a component dir (given its top-level file names) is a WEIGHTLESS component -- + scheduler / tokenizer / feature_extractor, which ship a ``*_config.json`` / tokenizer sidecar and no + model weight. Lets such a component satisfy the warm without a weight, while a dir carrying neither a + ``config.json`` nor a readable weight nor a recognised sidecar reads as an incomplete model component.""" + return any( + n.endswith("_config.json") + or n.startswith("tokenizer") + or n in ("preprocessor_config.json", "special_tokens_map.json", "merges.txt", "vocab.json") + for n in names + ) + + +def _diffusers_component_weights_complete( + snapshot_dir: Path, *, variant: Optional[str], ignore_patterns: Any = None, +) -> bool: + """True when a diffusers pipeline warm holds every weight a plain / variant load reads. Beyond "some + declared component weight is present" it requires each DECLARED ACTIVE component to be materialised (a + non-empty subfolder) AND a model-style component to hold a readable weight of the read format -- a + component is model-style unless it is WEIGHTLESS-shaped (a scheduler / tokenizer / feature_extractor + sidecar, no config.json + no weight). So a stale partial missing a whole component (unet present, vae + absent), holding a component's config without its weight, or holding only a stray / nested-checkpoint + file in a component dir is retried over un-killable Xet, not loaded. Excludes + ROOT-level weights and training-checkpoint subtrees; applies the ignore filter (format the load reads). + Fails OPEN on a malformed / empty ``model_index.json`` to the lenient any-component-weight check, + preserving hang protection without false-rejecting. A variant load accepts a component's canonical + weight as diffusers' per-component fallback.""" + specs = _diffusers_declared_component_specs(snapshot_dir) + declared = set(specs) if specs else None + active = _diffusers_active_component_dirs(specs) if specs else None + infix_dot = f".{variant}." if variant else "" + infix_dash = f".{variant}-" if variant else "" + per_comp_canon: dict = {} + per_comp_variant: dict = {} + try: + for entry in snapshot_dir.rglob("*"): + name = entry.name + if not _is_default_load_weight_file(name): + continue + try: + if not entry.is_file(): + continue + rel = entry.relative_to(snapshot_dir).as_posix() + except (OSError, ValueError): + continue + parts = rel.split("/") + if len(parts) < 2: + continue # a ROOT-level weight is not a component + comp = parts[0] + if declared is not None and comp not in declared: + continue # an UNDECLARED subtree the load does not read + if any(_CHECKPOINT_DIR_RE.match(p) for p in parts[:-1]): + continue # a training-checkpoint subtree, not a component + if variant and (infix_dot in name or infix_dash in name): + per_comp_variant.setdefault(comp, []).append(rel) + elif _CANONICAL_COMPONENT_WEIGHT_RE.match(name): + per_comp_canon.setdefault(comp, []).append(rel) + except OSError: + return False + + def _has_canon(comp: str) -> bool: + return bool(_filter_paths(per_comp_canon.get(comp, []), None, ignore_patterns)) + + def _has_variant(comp: str) -> bool: + return bool(_filter_paths(per_comp_variant.get(comp, []), None, ignore_patterns)) + + def _has_read_weight(comp: str) -> bool: + # variant load falls back to a component's canonical weight when it ships no variant file + return _has_variant(comp) or _has_canon(comp) if variant else _has_canon(comp) + + if active is not None: + for comp in active: + comp_dir = snapshot_dir / comp + try: + names = [e.name for e in comp_dir.iterdir()] if comp_dir.is_dir() else [] + except OSError: + names = [] + if not names: + return False # a declared active component was never materialised + if "config.json" in names: # a model-style component MUST hold a readable weight + if not _has_read_weight(comp): + return False + elif not _has_read_weight(comp) and not _is_weightless_component(names): + # No config.json and no readable weight: OK only when the dir is weightless-SHAPED + # (scheduler / tokenizer / feature_extractor sidecars). Otherwise it is an incomplete + # model-style component whose weight the load would fetch over un-killable Xet. + return False + # Floor: at least one component holds a weight of the READ format -- rejects a variant-only-for-plain, + # config-only, checkpoint-only, or undeclared-leftover-only stale snapshot. + if variant: + return any(_has_variant(c) for c in (declared or per_comp_variant)) + return any(_has_canon(c) for c in (declared or per_comp_canon)) + + +def _has_diffusers_component_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: + """Whether a PLAIN (non-variant) diffusers pipeline warm is complete for the weights a load reads + (see ``_diffusers_component_weights_complete``).""" + return _diffusers_component_weights_complete( + snapshot_dir, variant = None, ignore_patterns = ignore_patterns + ) + + +def _root_model_has_weight(snapshot_dir: Path, *, ignore_patterns: Any = None) -> bool: + """Whether an UNPATTERNED model warm holds a weight a default load reads: a CANONICAL ROOT weight + (single or numbered shard), or -- for a diffusers pipeline (root ``model_index.json``) -- a + component-subfolder weight. Counting ANY subtree weight would accept a stale checkpoint-only snapshot + then fetch the root weights over un-killable Xet; diffusers is the one layout with weights in + subfolders. Only canonical names count (a VARIANT root, PEFT adapter, gguf, or NON-canonical + consolidated.* is retried over HTTP -- its canonical weight is still missing). The ignore filter is + applied so a partial holding only the format the load will NOT read does not count.""" + try: + is_diffusers = (snapshot_dir / "model_index.json").is_file() + except OSError: + is_diffusers = False + if is_diffusers: + return _has_diffusers_component_weight(snapshot_dir, ignore_patterns = ignore_patterns) + rels: list = [] + tf_flax_rels: list = [] + try: + for entry in snapshot_dir.iterdir(): + name = entry.name + try: + if not entry.is_file(): + continue + except OSError: + continue + if _CANONICAL_ROOT_MODEL_WEIGHT_RE.match(name): + rels.append(name) # canonical model / pytorch_model (single or shard) + elif _CANONICAL_ROOT_TF_FLAX_WEIGHT_RE.match(name): + tf_flax_rels.append(name) # TF/Flax root weight (from_tf / from_flax) + except OSError: + return False + if _filter_paths(rels, None, ignore_patterns): + return True + # from_tf / from_flax (both PyTorch formats ignored): count a SINGLE-FILE TF/Flax weight or a COMPLETE + # sharded set (index + every listed shard present), so a complete h5/msgpack download is not + # false-rejected, while an INCOMPLETE set is retried over HTTP. Gated so a normal load is unchanged + # and a stray leftover never counts. + if _pytorch_root_weight_formats_ignored(ignore_patterns): + if tf_flax_rels and _filter_paths(tf_flax_rels, None, ignore_patterns): + return True + for index_name in _TF_FLAX_WEIGHT_INDEX_NAMES: + index_path = snapshot_dir / index_name + try: + if index_path.is_file() and _weight_shard_index_complete(index_path): + return True + except OSError: + continue + return False + + +def _root_has_variant_weight( + snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None +) -> bool: + """True if a CANONICAL ROOT model weight carrying the requested *variant* token (kept by the ignore + filter) is present. transformers reads ``model..safetensors`` and + ``model.-00001-of-00002.safetensors`` (``.-`` shard infix), matched by + ``_ROOT_MODEL_VARIANT_WEIGHT_RE`` plus the variant infix. A non-canonical base, PEFT adapter, or + non-``model`` variant is excluded -> a cache holding only those is retried over HTTP. The ignore + filter is applied so an ignored-format partial does not count.""" + infix_dot = f".{variant}." + infix_dash = f".{variant}-" + rels: list = [] + try: + for entry in snapshot_dir.iterdir(): + name = entry.name + if infix_dot not in name and infix_dash not in name: + continue # not the requested variant token + if not _ROOT_MODEL_VARIANT_WEIGHT_RE.match(name): + continue # only a canonical model / pytorch_model variant weight, not adapter / gguf + try: + if entry.is_file(): + rels.append(name) + except OSError: + continue + except OSError: + return False + return bool(_filter_paths(rels, None, ignore_patterns)) + + +def _has_diffusers_component_variant_weight( + snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None +) -> bool: + """Variant analog of ``_has_diffusers_component_weight``: whether a diffusers *variant* pipeline warm + is complete (a pipeline's variant weights are component-scoped, not root ``model..*``, so a + root-only check would false-reject a complete variant download). See + ``_diffusers_component_weights_complete``.""" + return _diffusers_component_weights_complete( + snapshot_dir, variant = variant, ignore_patterns = ignore_patterns + ) + + +def _root_model_has_variant_weight( + snapshot_dir: Path, variant: str, *, ignore_patterns: Any = None +) -> bool: + """Variant analog of ``_root_model_has_weight``: an UNPATTERNED variant warm holds a ROOT variant + weight, or -- for a diffusers pipeline -- a component-subfolder variant weight (its weights live in + subfolders, not root ``model..*``, so the root-only check would false-reject them).""" + try: + is_diffusers = (snapshot_dir / "model_index.json").is_file() + except OSError: + is_diffusers = False + if is_diffusers: + return _has_diffusers_component_variant_weight( + snapshot_dir, variant, ignore_patterns = ignore_patterns + ) + return _root_has_variant_weight(snapshot_dir, variant, ignore_patterns = ignore_patterns) + + +# Interchangeable exact weight names collapse to one equivalence group: the either-format pair +# ``["pytorch_model.bin", "model.safetensors"]`` (and the variant pair) is satisfied by ANY one, since HF +# allow patterns are ALTERNATIVES over the repo. Distinct logical weights (base AND adapter, a different +# variant) stay separate groups (each required). +_EITHER_FORMAT_WEIGHT_RE = re.compile( + r"^(model|pytorch_model|adapter_model)(?:\.([^.]+))?\.(?:safetensors|bin)$" +) + + +def _exact_weight_logical(base: str) -> Any: + """Equivalence key for an EXACT-named weight so either-format alternatives share a group: + ``model.safetensors`` / ``pytorch_model.bin`` -> ``("root_model", None)``; same variant in both + formats -> ``("root_model", "")``; ``adapter_model.*`` -> ``("adapter_model", ...)``. A + non-weight / sharded name maps to itself (still required individually).""" + m = _EITHER_FORMAT_WEIGHT_RE.match(base) + if m is None: + return base + stem, variant = m.group(1), m.group(2) + logical = "adapter_model" if stem == "adapter_model" else "root_model" + return (logical, variant) + + +def _requested_exact_files_present_grouped( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, +) -> bool: + """True unless an EXACT-named requested file is missing. Interchangeable weights need any one; distinct + logical files (base AND adapter, a tokenizer file) each. A glob / unpatterned request is trivially + satisfied here.""" + allow = _as_pattern_list(allow_patterns) + ignore = _as_pattern_list(ignore_patterns) + if not allow or any(not isinstance(p, str) or _has_glob(p) for p in allow): + return True + requested = _filter_paths(allow, None, ignore) + if not requested: + return True # ignore filter dropped every named file -> nothing to require + try: + present = { + entry.relative_to(snapshot_dir).as_posix() + for entry in snapshot_dir.rglob("*") + if entry.is_file() + } + except OSError: + return True # cannot enumerate -> do not reject on an unreadable dir + groups: "dict[tuple[str, Any], list[str]]" = {} + for rel in requested: + parent, base = rel.rsplit("/", 1) if "/" in rel else ("", rel) + logical = _exact_weight_logical(base) + groups.setdefault((parent, logical), []).append(rel) + return all( + any(candidate in present for candidate in candidates) for candidates in groups.values() + ) + + +def _has_selected_weight( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, +) -> bool: + """True if a loadable weight the request SELECTS (allow / ignore filtered) is present, so a stale + out-of-scope weight does not satisfy a patterned request.""" + weights: list = [] + try: + for entry in snapshot_dir.rglob("*"): + if not _is_loadable_weight_file(entry.name): + continue + try: + if entry.is_file(): + weights.append(entry.relative_to(snapshot_dir).as_posix()) + except (OSError, ValueError): + continue + except OSError: + return False + return bool(_filter_paths(weights, allow_patterns, ignore_patterns)) + + +def _has_selected_variant_weight( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: str, +) -> bool: + """True if a SELECTED loadable weight carrying the *variant* token is present. Combines the request's + allow / ignore scope with the variant infix check, so a patterned variant load whose partial kept + only the canonical weight is retried over HTTP (else it fetches the variant over un-killable Xet).""" + infix_dot = f".{variant}." + infix_dash = f".{variant}-" + weights: list = [] + try: + for entry in snapshot_dir.rglob("*"): + name = entry.name + if not _is_loadable_weight_file(name): + continue + if infix_dot not in name and infix_dash not in name: + continue + try: + if entry.is_file(): + weights.append(entry.relative_to(snapshot_dir).as_posix()) + except (OSError, ValueError): + continue + except OSError: + return False + return bool(_filter_paths(weights, allow_patterns, ignore_patterns)) + + +def _patterns_are_exact_names(patterns: Any) -> bool: + """True only for a non-empty allow list of EXACT filenames (no ``None`` / glob / trailing-slash dir). + Only such a request is locally provable complete; ``None`` / a glob needs the Hub manifest.""" + patterns = _as_pattern_list(patterns) + if patterns is None: + return False + if not patterns: + return True # selects nothing -> nothing to fetch + return all(isinstance(p, str) and not _has_glob(p) for p in patterns) + + +def _request_selects_canonical_root_shards(allow_patterns: Any, ignore_patterns: Any) -> bool: + """Whether the allow / ignore filter keeps a canonical ROOT shard name. When False, an incomplete + canonical root shard set is OUT of scope (a leftover a patterned adapter / gguf / subfolder load never + reads), so the shard-completeness gate must NOT reject on it, else a complete patterned download is + failed into a DownloadStallError.""" + probes = ["model-00001-of-00002.safetensors", "pytorch_model-00001-of-00002.bin"] + return bool(_filter_paths(probes, allow_patterns, ignore_patterns)) + + +def _request_selects_root_variant_weight( + allow_patterns: Any, ignore_patterns: Any, variant: str, +) -> bool: + """Whether the allow / ignore filter keeps a ROOT variant weight name. When False, a stale incomplete + root variant shard set is OUT of scope (e.g. a subfolder request whose variant weights live under + ``unet/``), so the ROOT variant-shard gate must not reject on it, else a complete in-scope variant + download is failed.""" + probes = [ + f"model.{variant}.safetensors", f"model.{variant}-00001-of-00002.safetensors", + f"pytorch_model.{variant}.bin", f"pytorch_model.{variant}-00001-of-00002.bin", + ] + return bool(_filter_paths(probes, allow_patterns, ignore_patterns)) + + +def _cache_can_skip_download( + snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, + variant: Optional[str] = None, +) -> bool: + """PRE-download: whether a cached snapshot is complete enough to skip the protective child. STRICT -- + a false True lets the load fetch a missing weight over un-killable Xet (the hang). + + Weight-bearing model request: only the conservative canonical gate (``snapshot_dir_is_complete``) + skips; anything uncertain (diffusers, variants, patterns, sharded-without-index) spawns the child. A + weightless / non-model request has no weight to hang on but is locally provable complete only when it + names EXACT files (an intact exact-named subset short-circuits; unpatterned / glob defers).""" + if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): + # A variant load reads variant-named weights the canonical gate does not check, so a + # canonical-only cache would fetch the variant over un-killable Xet. Defer to the child. + if variant: + return False + # A default load probes model.safetensors before pytorch_model.bin, so a bin-only cache for a + # repo that also publishes safetensors (unprovable locally) would fetch the safetensors over Xet. + # prefer_safetensors defers it; a use_safetensors=False request still fast-paths its bin cache. + return snapshot_dir_is_complete( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + prefer_safetensors = True, + ) + # Weightless / non-model: skip only for an intact exact-named subset; None / glob defers to the child. + if not _patterns_are_exact_names(allow_patterns): + return False + return _intact_subset( + snapshot_dir, repo_type = repo_type, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, + ) + + +def _has_readable_weight( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str], +) -> bool: + """Invariant A (presence): a weight the in-process load will READ is present, ignore filter ALWAYS + applied and scope matched to the request: + + - variant + UNPATTERNED -> a ROOT variant weight; + - variant + PATTERNED -> a SELECTED variant weight; + - plain + UNPATTERNED -> a ROOT (or diffusers-component) weight, NOT a stray subfolder checkpoint; + - plain + PATTERNED -> a SELECTED weight. + + A partial that kept only the ignored format does not count -> retried over HTTP.""" + if variant: + if allow_patterns is None: + return _root_model_has_variant_weight( + snapshot_dir, variant, ignore_patterns = ignore_patterns + ) + return _has_selected_variant_weight( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ) + if allow_patterns is None: + return _root_model_has_weight(snapshot_dir, ignore_patterns = ignore_patterns) + return _has_selected_weight( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ) + + +def _readable_shard_set_incomplete( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str], +) -> bool: + """Invariant B (shard completeness): an IN-SCOPE shard set the load reads is incomplete (index present + with a shard missing, or a lone numbered shard without its index) and must be retried. ALWAYS scoped + to what the request selects, so a co-resident stale shard set the load never reads does not + false-reject a complete download: + + - variant: ROOT variant-shard check (NON-diffusers) for UNPATTERNED, or a PATTERNED request selecting + a ROOT variant weight; a subfolder-scoped variant request does not root-check. + - plain: canonical-root-shard check (NON-diffusers) for UNPATTERNED, or a GLOBBED request selecting + canonical root shards; an exact-named subset / out-of-scope request does not. + - non-root: a PATTERNED request also checks any SELECTED shard index the root checks miss (a sharded + adapter, a component subfolder) via ``_selected_shard_index_incomplete``. + - diffusers: a pipeline reads COMPONENT subfolders, NOT root model shards, so root checks are SKIPPED; + its component shard sets go through ``_diffusers_component_shards_incomplete`` (unpatterned) / + ``_selected_shard_index_incomplete`` (patterned). + + The ignore filter is threaded through so completeness is judged for the FORMAT the load reads.""" + try: + is_diffusers = (snapshot_dir / "model_index.json").is_file() + except OSError: + is_diffusers = False + if _patterns_are_exact_names(allow_patterns): + # An exact-named subset defers to the exact-file presence check: the load reads exactly the named + # shard(s), so a lone exact shard is not judged against its unrequested index (else false-reject). + return False + if variant: + if not is_diffusers and ( + allow_patterns is None + or _request_selects_root_variant_weight(allow_patterns, ignore_patterns, variant) + ): + # Only a non-diffusers root variant load runs the root-shard check (a diffusers pipeline's + # variant weights are component-scoped, handled below). + if _has_incomplete_variant_root_shards( + snapshot_dir, variant, ignore_patterns = ignore_patterns + ): + return True + if allow_patterns is not None: + if _selected_shard_index_incomplete( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return True + elif is_diffusers and _diffusers_component_shards_incomplete( + snapshot_dir, variant = variant, ignore_patterns = ignore_patterns + ): + # UNPATTERNED variant diffusers warm: a component variant shard index is incomplete (the root + # variant check above only covers root model. shards). + return True + return False + if allow_patterns is None: + if not is_diffusers and _has_incomplete_canonical_root_shards( + snapshot_dir, ignore_patterns = ignore_patterns + ): + # non-diffusers root model load (a diffusers stale root index is handled by the component + # check below). + return True + if is_diffusers and _diffusers_component_shards_incomplete( + snapshot_dir, variant = None, ignore_patterns = ignore_patterns + ): + # UNPATTERNED plain diffusers warm: a component shard index missing a shard is not covered by + # the canonical ROOT-shard check. + return True + return False + if not is_diffusers and _request_selects_canonical_root_shards( + allow_patterns, ignore_patterns + ) and _has_incomplete_canonical_root_shards(snapshot_dir, ignore_patterns = ignore_patterns): + # non-diffusers only: a diffusers pipeline never reads root model shards (its component sets are + # checked below), so a stale root index must not reject it. + return True + return _selected_shard_index_incomplete( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = None, + ) + + +def _selected_readable_weight_complete( + snapshot_dir: Path, *, allow_patterns: Any, ignore_patterns: Any, variant: Optional[str], +) -> bool: + """Weight-bearing MODEL acceptance check: the weight the load will READ is present (Invariant A) AND + its in-scope shard set is complete (Invariant B). Both apply the ignore filter and match scope + uniformly, so a co-resident out-of-scope / ignored-format partial neither masks an incomplete weight + (a silent Xet hang) nor false-rejects a complete download (a spurious ``DownloadStallError``).""" + if not _has_readable_weight( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return False + if _readable_shard_set_incomplete( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return False + return True + + +def _download_result_usable( + snapshot_dir: Path, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, + variant: Optional[str] = None, +) -> bool: + """POST-download: whether the child's result is usable or should be retried over HTTP. + snapshot_download already did the authoritative manifest compare, so accept unless there is POSITIVE + breakage evidence; LENIENT otherwise so a good download is never looped into a ``DownloadStallError``. + A transient connection error during the child's metadata call makes ``snapshot_download`` silently + return a stale / partial cache, so the checks apply the request's filters to the weight the load + reads. Breakage: + + - A dangling REQUESTED symlink (a missing / still-``.incomplete`` blob). + - A missing EXACT-named requested file (grouped by weight equivalence; globs stay lenient). + - A weight-bearing MODEL request whose READABLE weight is absent or incomplete (delegated to + ``_selected_readable_weight_complete``).""" + if snapshot_has_requested_broken_symlinks( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns, + repo_type = repo_type, + ): + return False + if not _requested_exact_files_present_grouped( + snapshot_dir, allow_patterns = allow_patterns, ignore_patterns = ignore_patterns + ): + return False + if repo_type in (None, "model") and request_can_include_weights(allow_patterns, ignore_patterns): + if not _selected_readable_weight_complete( + snapshot_dir, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ): + return False + return True + + +def _snapshot_payload_incomplete( + payload: Any, *, repo_type: str, allow_patterns: Any, ignore_patterns: Any, + variant: Optional[str] = None, +) -> bool: + """True when a snapshot download returned a real directory not usable for the request (see + ``_download_result_usable``). Guarded to an existing dir, so a mocked / non-path payload (tests) is + trusted; production always returns a real snapshot dir.""" + try: + path = Path(payload) + except (TypeError, ValueError, OSError): + return False # non-path payload (test sentinel) or invalid path -> trust it + try: + if not path.is_dir(): + return False + except OSError: + return False + return not _download_result_usable( + path, repo_type = repo_type, allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, variant = variant, + ) + + +def _download_with_xet_fallback( + *, + repo_id: str, + label: str, + kind: str, + params: dict, + token: Optional[str], + repo_type: str, + cancel_event: Optional[threading.Event], + stall_timeout: float, + interval: float, + grace_period: float, + on_status: Optional[Callable[[str], None]], + prepare_for_http_fn: Optional[Callable[[str, str], None]], + variant: Optional[str] = None, +) -> str: + """Shared 2-attempt loop: Xet primary, HTTP on a stall. Returns the local path.""" + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + + cache_dir = params.get("cache_dir") + # Read xet_force_disabled() under the lock a CONCURRENT download briefly sets HF_HUB_DISABLE_XET + # under (around its spawn), so this download cannot observe the other's child-only value and wrongly + # force itself onto HTTP from the start. + with _SPAWN_ENV_LOCK: + disable_xet = xet_force_disabled() + + for attempt in range(2): + if disable_xet: + # Purge a non-HTTP partial first (an HTTP resume over a sparse Xet/hf_transfer partial + # silently corrupts the blob), scoped to the stalled child's own partials so a same-repo + # sibling is spared. An injected (Studio) hook keeps the plain (repo_type, repo_id) signature. + owned_incomplete = params.pop("_owned_incomplete_blobs", None) + try: + if prepare_for_http_fn is None: + _default_prepare_for_http( + repo_type, repo_id, cache_dir = cache_dir, active_grace = stall_timeout, + owned_incomplete_blobs = owned_incomplete, + ) + else: + prepare_for_http_fn(repo_type, repo_id) + except Exception as e: + logger.debug("prepare_for_http failed for %s: %s", repo_id, e) + # An unsafe partial that could not be cleared (locked / permission) would corrupt the blob on + # an HTTP resume: force a clean re-download instead. + if has_active_incomplete_blobs(repo_type, repo_id, cache_dir = cache_dir): + logger.warning( + "Unsafe partial for '%s' could not be cleared; forcing a clean " + "HTTP re-download instead of an unsafe resume.", label + ) + params = {**params, "force_download": True} + + kind_result, payload = _run_download_attempt( + repo_id, + kind = kind, + params = params, + token = token, + repo_type = repo_type, + disable_xet = disable_xet, + cancel_event = cancel_event, + stall_timeout = stall_timeout, + interval = interval, + grace_period = grace_period, + on_status = on_status, + ) + + if kind_result == "ok": + if kind == "snapshot" and _snapshot_payload_incomplete( + payload, + repo_type = repo_type, + allow_patterns = params.get("allow_patterns"), + ignore_patterns = params.get("ignore_patterns"), + variant = variant, + ): + # HF can hand back an existing incomplete snapshot dir (offline / timed-out) instead of + # fetching: never load it in-process. Retry over HTTP, then fail loudly. (Patterned / + # non-model requests judge their own subset, so a valid weightless snapshot is not + # rejected.) + if not disable_xet: + logger.warning( + "Download for '%s' returned an incomplete snapshot -- " + "retrying with HF_HUB_DISABLE_XET=1", label + ) + _safe_status(on_status, f"{label}: incomplete snapshot, retrying over HTTP") + disable_xet = True + continue + raise DownloadStallError( + f"Download for '{label}' returned an incomplete snapshot even with " + f"HF_HUB_DISABLE_XET=1 -- missing files, check your network connection" + ) + return payload # type: ignore[return-value] + if kind_result == "cancelled": + raise RuntimeError("Cancelled") + if kind_result == "error": + # Deterministic failure (auth / not-found / gated / disk-full): the other transport fails + # identically. _raise_child_error preserves the original type across the spawn. + _raise_child_error(payload) + if kind_result == "retryable_error": + # Transient transport failure (hf_xet CAS timeout, 5xx, reset): retry HTTP once, else raise. + if not disable_xet: + logger.warning( + "Download for '%s' hit a transient Xet transport error -- retrying " + "with HF_HUB_DISABLE_XET=1: %s", label, payload + ) + _safe_status(on_status, f"{label}: transient Xet error, retrying over HTTP") + disable_xet = True + continue + raise RuntimeError(payload) + if kind_result == "crashed": + # Process-level crash with no captured exception: HTTP may still succeed, so retry once. + if not disable_xet: + logger.warning( + "Download process for '%s' crashed without a result -- " + "retrying with HF_HUB_DISABLE_XET=1", label + ) + _safe_status(on_status, f"{label}: download crashed, retrying over HTTP") + disable_xet = True + continue + raise RuntimeError(payload) + # kind_result == "stall" + if not disable_xet: + logger.warning( + "Download stalled for '%s' -- retrying with HF_HUB_DISABLE_XET=1", label + ) + # _safe_status: a raising status hook must not abort the retry before disable_xet is set. + _safe_status(on_status, f"{label}: Xet stalled, retrying over HTTP") + disable_xet = True + continue + raise DownloadStallError( + f"Download stalled for '{label}' even with HF_HUB_DISABLE_XET=1 " + f"-- check your network connection" + ) + + # Unreachable: the loop either returns or raises on each attempt. + raise DownloadStallError(f"Download failed for '{label}'") + + +def hf_hub_download_with_xet_fallback( + repo_id: str, + filename: str, + token: Optional[str] = None, + *, + cancel_event: Optional[threading.Event] = None, + repo_type: Optional[str] = "model", + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + subfolder: Optional[str] = None, + force_download: bool = False, + local_files_only: bool = False, + stall_timeout: float = DEFAULT_STALL_TIMEOUT, + interval: float = DEFAULT_HEARTBEAT_INTERVAL, + grace_period: float = DEFAULT_GRACE_PERIOD, + on_status: Optional[Callable[[str], None]] = None, + prepare_for_http_fn: Optional[Callable[[str, str], None]] = None, +) -> str: + """Download a single file with Xet primary and HTTP as a stall-only fallback; return the local path. + + Raises ``RuntimeError("Cancelled")`` if *cancel_event* is set, re-raises a deterministic child error + unchanged, and raises ``DownloadStallError`` only if BOTH transports stall. ``local_files_only=True`` + resolves from cache in-process with no child (HF offline semantics). + """ + repo_type = repo_type or "model" # HF treats None as the default model repo. + # Expand ~ as huggingface_hub does, so the probe and the child resolve to the same location. + if isinstance(cache_dir, (str, os.PathLike)): + cache_dir = os.path.expanduser(os.fspath(cache_dir)) + # Honor cancellation before any probe (the short-circuits below bypass the fallback's cancel check). + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + # Offline: resolve from cache. HF raises LocalEntryNotFoundError if uncached; let it propagate. + if local_files_only: + from huggingface_hub import hf_hub_download + + return hf_hub_download( + repo_id = repo_id, + filename = filename, + subfolder = subfolder, + token = token, + repo_type = repo_type, + revision = revision, + cache_dir = cache_dir, + local_files_only = True, + ) + # Finalized blob already cached: return it with no child (skipped under force_download). A subfolder + # file is cached under "/". + if not force_download: + try: + from huggingface_hub import try_to_load_from_cache + + probe_filename = f"{subfolder}/{filename}" if subfolder else filename + cached = try_to_load_from_cache( + repo_id, probe_filename, repo_type = repo_type, revision = revision, cache_dir = cache_dir + ) + if isinstance(cached, str) and os.path.exists(cached): + return cached + except Exception as e: + logger.debug("Cached probe failed for %s/%s: %s", repo_id, filename, e) + + return _download_with_xet_fallback( + repo_id = repo_id, + label = f"{repo_id}/{filename}", + kind = "file", + params = { + "repo_id": repo_id, + "filename": filename, + "subfolder": subfolder, + "revision": revision, + "cache_dir": cache_dir, + "force_download": force_download, + }, + token = token, + repo_type = repo_type, + cancel_event = cancel_event, + stall_timeout = stall_timeout, + interval = interval, + grace_period = grace_period, + on_status = on_status, + prepare_for_http_fn = prepare_for_http_fn, + ) + + +def snapshot_download_with_xet_fallback( + repo_id: str, + *, + revision: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = "model", + cache_dir: Optional[str] = None, + allow_patterns: Optional[Any] = None, + ignore_patterns: Optional[Any] = None, + force_download: bool = False, + local_files_only: bool = False, + variant: Optional[str] = None, + cancel_event: Optional[threading.Event] = None, + stall_timeout: float = DEFAULT_STALL_TIMEOUT, + interval: float = DEFAULT_HEARTBEAT_INTERVAL, + grace_period: float = DEFAULT_GRACE_PERIOD, + on_status: Optional[Callable[[str], None]] = None, + prepare_for_http_fn: Optional[Callable[[str, str], None]] = None, +) -> str: + """Download a whole repo snapshot with Xet primary and HTTP as a stall-only fallback; return the + local snapshot dir. + + Used by Unsloth's ``from_pretrained`` to warm the cache in a killable child BEFORE the in-process + load (which then hits a warm cache and cannot hang on a native Xet thread). A fully cached repo + short-circuits in-process with no child. ``local_files_only=True`` resolves from cache in-process + (HF offline semantics). ``variant`` forces the child even on a warm canonical cache, since the + canonical gate cannot prove the variant-named weights present. + """ + repo_type = repo_type or "model" # HF treats None as the default model repo. + # Expand ~ as huggingface_hub does, so the probe and the child resolve to the same location. + if isinstance(cache_dir, (str, os.PathLike)): + cache_dir = os.path.expanduser(os.fspath(cache_dir)) + # Honor cancellation before any probe (the short-circuits below bypass the fallback's cancel check). + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + # Offline: resolve from cache. HF raises if uncached; let it propagate. + if local_files_only: + from huggingface_hub import snapshot_download + + return snapshot_download( + repo_id = repo_id, + repo_type = repo_type, + revision = revision, + cache_dir = cache_dir, + allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, + local_files_only = True, + ) + # Fast path: everything on disk -> resolve in-process (no Xet, no hang). Skipped under force_download. + if not force_download: + try: + from huggingface_hub import snapshot_download + + cached_dir = snapshot_download( + repo_id = repo_id, + repo_type = repo_type, + revision = revision, + cache_dir = cache_dir, + allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, + local_files_only = True, + ) + # local_files_only returns a snapshot dir whenever refs/ + snapshots/ exist, even + # one left by a prior interrupted / patterned download. Validate the EXACT returned revision + # dir (scoped to this snapshot, not the whole repo, so an unrelated revision mid-download does + # not force a needless re-fetch). + if _cache_can_skip_download( + Path(cached_dir), + repo_type = repo_type, + allow_patterns = allow_patterns, + ignore_patterns = ignore_patterns, + variant = variant, + ): + return cached_dir + logger.debug("Cached snapshot for %s is incomplete; downloading.", repo_id) + except Exception as e: + logger.debug("Snapshot not fully cached for %s (%s); downloading.", repo_id, e) + + return _download_with_xet_fallback( + repo_id = repo_id, + label = repo_id, + kind = "snapshot", + params = { + "repo_id": repo_id, + "revision": revision, + "cache_dir": cache_dir, + "allow_patterns": allow_patterns, + "ignore_patterns": ignore_patterns, + "force_download": force_download, + }, + token = token, + repo_type = repo_type, + cancel_event = cancel_event, + stall_timeout = stall_timeout, + interval = interval, + grace_period = grace_period, + on_status = on_status, + prepare_for_http_fn = prepare_for_http_fn, + variant = variant, + )