diff --git a/headroom/proxy/handlers/anthropic.py b/headroom/proxy/handlers/anthropic.py index 116b6a01e..1e98c0552 100644 --- a/headroom/proxy/handlers/anthropic.py +++ b/headroom/proxy/handlers/anthropic.py @@ -41,6 +41,48 @@ class AnthropicHandlerMixin: """Mixin providing Anthropic API handler methods for HeadroomProxy.""" + async def _count_tokens_offloaded(self, model, messages): # noqa: ANN001, ANN201 + """Resolve a tokenizer and count messages off the event loop. + + Tokenizer resolution can be expensive on first use (HuggingFace + backends may download vocab files) and counting a full Claude Code + conversation is CPU-bound, so both run on the compression executor + bounded by ``COMPRESSION_TIMEOUT_SECONDS`` (GH #1701: an unbounded + on-loop load froze the whole server). On timeout or error this + fails open to character-based estimation. + + Returns: + Tuple of ``(tokenizer, token_count)``. The tokenizer is fully + initialized, so later ``count_messages`` calls on it are pure + CPU work. + """ + from headroom.proxy.helpers import COMPRESSION_TIMEOUT_SECONDS + from headroom.tokenizers import EstimatingTokenCounter, get_tokenizer + + def _resolve_and_count(): # noqa: ANN202 + tokenizer = get_tokenizer(model) + return tokenizer, tokenizer.count_messages(messages) + + try: + return await self._run_compression_in_executor( + _resolve_and_count, + timeout=float(COMPRESSION_TIMEOUT_SECONDS), + ) + except Exception as e: # fail open — includes asyncio.TimeoutError + # Log the downgrade once per model, not per request. + fallback_models = getattr(self, "_token_count_fallback_models", None) + if fallback_models is None: + fallback_models = set() + self._token_count_fallback_models = fallback_models + if model not in fallback_models: + fallback_models.add(model) + logger.warning( + f"Token counting for model {model} failed or timed out " + f"({e.__class__.__name__}); falling back to estimation" + ) + estimator = EstimatingTokenCounter() + return estimator, estimator.count_messages(messages) + @staticmethod def _resolve_ccr_workspace( request: Any, @@ -469,7 +511,6 @@ async def handle_anthropic_messages( read_request_json_with_bytes, ) from headroom.proxy.modes import is_cache_mode, is_token_mode - from headroom.tokenizers import get_tokenizer from headroom.utils import extract_user_query start_time = time.time() @@ -899,9 +940,10 @@ async def _finalize_pre_upstream() -> None: media_type="application/json", ) - # Count original tokens - tokenizer = get_tokenizer(model) - original_tokens = tokenizer.count_messages(messages) + # Count original tokens off the event loop: first-use tokenizer + # resolution may hit the network (HF download) and counting a full + # conversation is CPU-bound — on-loop it froze the server (#1701). + tokenizer, original_tokens = await self._count_tokens_offloaded(model, messages) # Enterprise Security: scan request before compression _security_ctx = None @@ -1164,7 +1206,9 @@ def should_skip_ccr_request_compression( ) if skip_ccr_request_compression: optimized_messages = messages - optimized_tokens = tokenizer.count_messages(optimized_messages) + _, optimized_tokens = await self._count_tokens_offloaded( + model, optimized_messages + ) else: # Zone 1: Swap cached compressed versions into working copy working_messages = comp_cache.apply_cached(messages) @@ -2985,7 +3029,6 @@ async def handle_anthropic_batch_create( from headroom.ccr import CCRToolInjector from headroom.proxy.helpers import MAX_REQUEST_BODY_SIZE, _read_request_json from headroom.proxy.modes import is_cache_mode - from headroom.tokenizers import get_tokenizer from headroom.utils import extract_user_query start_time = time.time() @@ -3093,17 +3136,27 @@ async def handle_anthropic_batch_create( ) if is_cache_mode(self.config.mode): optimized_messages = messages - original_tokens = get_tokenizer(model).count_messages(messages) + _, original_tokens = await self._count_tokens_offloaded(model, messages) optimized_tokens = original_tokens else: - result = self.anthropic_pipeline.apply( - messages=messages, - model=model, - model_limit=context_limit, - context=extract_user_query(messages), - frozen_message_count=frozen_message_count, - request_id=request_id, - **proxy_pipeline_kwargs(self.config), + from headroom.proxy.helpers import COMPRESSION_TIMEOUT_SECONDS + + # Offload off the event loop (#1701): an inline apply() + # blocks every other request for the duration; a timeout + # here is caught below and passes the item through. + result = await self._run_compression_in_executor( + lambda messages=messages, model=model, context_limit=context_limit, frozen_message_count=frozen_message_count: ( + self.anthropic_pipeline.apply( + messages=messages, + model=model, + model_limit=context_limit, + context=extract_user_query(messages), + frozen_message_count=frozen_message_count, + request_id=request_id, + **proxy_pipeline_kwargs(self.config), + ) + ), + timeout=COMPRESSION_TIMEOUT_SECONDS, ) optimized_messages = result.messages diff --git a/headroom/proxy/handlers/batch.py b/headroom/proxy/handlers/batch.py index febd8cba9..1888a2efa 100644 --- a/headroom/proxy/handlers/batch.py +++ b/headroom/proxy/handlers/batch.py @@ -15,7 +15,7 @@ from fastapi.responses import Response from headroom.proxy.auth_mode import classify_client -from headroom.proxy.helpers import extract_tags +from headroom.proxy.helpers import COMPRESSION_TIMEOUT_SECONDS, extract_tags from headroom.proxy.outcome import RequestOutcome logger = logging.getLogger("headroom.proxy") @@ -161,11 +161,18 @@ async def handle_google_batch_create( ) # Use OpenAI pipeline (similar message format after conversion) - result = self.openai_pipeline.apply( - messages=messages, - model=model, - model_limit=context_limit, - context=extract_user_query(messages), + # Offload off the event loop (#1701): inline apply() blocks + # every other request; timeouts fall to the except below. + result = await self._run_compression_in_executor( + lambda messages=messages, model=model, context_limit=context_limit: ( + self.openai_pipeline.apply( + messages=messages, + model=model, + model_limit=context_limit, + context=extract_user_query(messages), + ) + ), + timeout=COMPRESSION_TIMEOUT_SECONDS, ) optimized_messages = result.messages @@ -1078,11 +1085,18 @@ async def _compress_batch_jsonl(self, content: str, request_id: str) -> tuple[li if self.config.optimize: try: context_limit = self.openai_provider.get_context_limit(model) - result = self.openai_pipeline.apply( - messages=messages, - model=model, - model_limit=context_limit, - context=extract_user_query(messages), + # Offload off the event loop (#1701); timeouts fall to + # the except below and pass the line through. + result = await self._run_compression_in_executor( + lambda messages=messages, model=model, context_limit=context_limit: ( + self.openai_pipeline.apply( + messages=messages, + model=model, + model_limit=context_limit, + context=extract_user_query(messages), + ) + ), + timeout=COMPRESSION_TIMEOUT_SECONDS, ) compressed_messages = result.messages # Use pipeline's token counts for consistency with pipeline logs diff --git a/headroom/tokenizers/huggingface.py b/headroom/tokenizers/huggingface.py index 62e6e9b56..718ee940a 100644 --- a/headroom/tokenizers/huggingface.py +++ b/headroom/tokenizers/huggingface.py @@ -7,6 +7,8 @@ from __future__ import annotations import logging +import os +import threading from functools import lru_cache from typing import Any @@ -103,10 +105,32 @@ } +# Bound the first (network) load of a HuggingFace tokenizer. Without a bound, +# huggingface_hub download retries can block for many minutes (GH #1701: 610s +# on a restricted Windows network). 0 disables network loads entirely. +_LOAD_TIMEOUT_ENV = "HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS" +_LOAD_TIMEOUT_DEFAULT = 10.0 + + +def _load_timeout_secs() -> float: + try: + return float(os.environ.get(_LOAD_TIMEOUT_ENV, _LOAD_TIMEOUT_DEFAULT)) + except (TypeError, ValueError): + return _LOAD_TIMEOUT_DEFAULT + + @lru_cache(maxsize=16) def _load_tokenizer(tokenizer_name: str): """Load and cache HuggingFace tokenizer. + The first attempt is cache-only (``local_files_only=True``) so a warm + HF cache never touches the network. A cache miss falls through to a + network download bounded by ``HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS`` + (default 10s) on a daemon thread — the download itself cannot be + cancelled, but the caller unblocks and falls back to estimation. + Failures are cached by ``lru_cache`` (returns ``None``), so a slow or + offline hub is probed at most once per process per tokenizer. + Args: tokenizer_name: HuggingFace model/tokenizer name. @@ -119,10 +143,50 @@ def _load_tokenizer(tokenizer_name: str): return AutoTokenizer.from_pretrained( tokenizer_name, trust_remote_code=True, + local_files_only=True, ) - except Exception as e: - logger.warning(f"Failed to load tokenizer {tokenizer_name}: {e}") + except Exception: + pass # Not in the local cache — try the network below, bounded. + + timeout = _load_timeout_secs() + if timeout <= 0: + logger.warning( + f"Tokenizer {tokenizer_name} not in local HF cache and network " + f"loading is disabled ({_LOAD_TIMEOUT_ENV}=0); using estimation" + ) + return None + + result: list[Any] = [] + error: list[BaseException] = [] + + def _download() -> None: + try: + result.append( + AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=True, + ) + ) + except BaseException as e: # noqa: BLE001 — report any failure to the waiter + error.append(e) + + thread = threading.Thread( + target=_download, + name=f"headroom-hf-tokenizer-load-{tokenizer_name}", + daemon=True, + ) + thread.start() + thread.join(timeout) + if thread.is_alive(): + logger.warning( + f"Timed out loading tokenizer {tokenizer_name} after {timeout}s " + f"(set {_LOAD_TIMEOUT_ENV} to adjust); using estimation" + ) + return None + if error: + logger.warning(f"Failed to load tokenizer {tokenizer_name}: {error[0]}") return None + return result[0] if result else None def get_tokenizer_name(model: str) -> str: diff --git a/tests/test_huggingface_tokenizer_timeout.py b/tests/test_huggingface_tokenizer_timeout.py new file mode 100644 index 000000000..8e9243ea7 --- /dev/null +++ b/tests/test_huggingface_tokenizer_timeout.py @@ -0,0 +1,114 @@ +"""HF tokenizer loading must be bounded (GH #1701): AutoTokenizer.from_pretrained +performs unbounded network downloads/retries; called lazily from the proxy's request +path it blocked the event loop for ~10 minutes and zombified the server. The fix +tries the local HF cache first (local_files_only=True), bounds the network attempt +with HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS on a daemon thread, and fails open to +estimation — caching the failure so the hub is probed at most once per process. +""" + +from __future__ import annotations + +import sys +import time +import types +from typing import Any + +import pytest + +from headroom.tokenizers import huggingface as hf_mod +from headroom.tokenizers.huggingface import HuggingFaceTokenizer, _load_tokenizer + + +@pytest.fixture(autouse=True) +def _fresh_cache(): + _load_tokenizer.cache_clear() + yield + _load_tokenizer.cache_clear() + + +def _install_fake_transformers(monkeypatch: pytest.MonkeyPatch, from_pretrained) -> None: + fake = types.ModuleType("transformers") + fake.AutoTokenizer = type( + "AutoTokenizer", (), {"from_pretrained": staticmethod(from_pretrained)} + ) + monkeypatch.setitem(sys.modules, "transformers", fake) + + +def test_local_cache_tried_before_network(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[dict[str, Any]] = [] + + def fake_from_pretrained(name: str, **kwargs: Any): + calls.append(kwargs) + if kwargs.get("local_files_only"): + raise OSError("not in cache") + return "network-tokenizer" + + _install_fake_transformers(monkeypatch, fake_from_pretrained) + monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "5") + + assert _load_tokenizer("some/model") == "network-tokenizer" + assert calls[0].get("local_files_only") is True, "first attempt must be cache-only" + assert not calls[1].get("local_files_only") + + +def test_cache_hit_never_touches_network(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[dict[str, Any]] = [] + + def fake_from_pretrained(name: str, **kwargs: Any): + calls.append(kwargs) + return "cached-tokenizer" + + _install_fake_transformers(monkeypatch, fake_from_pretrained) + + assert _load_tokenizer("some/model") == "cached-tokenizer" + assert len(calls) == 1 + assert calls[0].get("local_files_only") is True + + +def test_slow_network_load_times_out_and_fails_open(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_from_pretrained(name: str, **kwargs: Any): + if kwargs.get("local_files_only"): + raise OSError("not in cache") + time.sleep(60) # simulates hung huggingface_hub download + return "never" + + _install_fake_transformers(monkeypatch, fake_from_pretrained) + monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "0.2") + + start = time.monotonic() + assert _load_tokenizer("slow/model") is None + assert time.monotonic() - start < 5, "load must unblock at the timeout, not the download" + + # Failure is cached (lru_cache) — the second call must not re-probe the hub. + start = time.monotonic() + assert _load_tokenizer("slow/model") is None + assert time.monotonic() - start < 0.05 + + +def test_timeout_zero_disables_network_loading(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_from_pretrained(name: str, **kwargs: Any): + if kwargs.get("local_files_only"): + raise OSError("not in cache") + raise AssertionError("network load attempted despite timeout=0") + + _install_fake_transformers(monkeypatch, fake_from_pretrained) + monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "0") + + assert _load_tokenizer("offline/model") is None + + +def test_count_messages_fails_open_to_estimation(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_from_pretrained(name: str, **kwargs: Any): + raise OSError("unavailable") + + _install_fake_transformers(monkeypatch, fake_from_pretrained) + monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "0.2") + + counter = HuggingFaceTokenizer("deepseek-chat") + tokens = counter.count_messages([{"role": "user", "content": "hello world" * 50}]) + assert tokens > 0 # estimation fallback, no exception, no hang + + +def test_invalid_timeout_env_falls_back_to_default(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "not-a-number") + assert hf_mod._load_timeout_secs() == hf_mod._LOAD_TIMEOUT_DEFAULT diff --git a/tests/test_proxy_handlers_batch.py b/tests/test_proxy_handlers_batch.py index bfe6788d6..43812848d 100644 --- a/tests/test_proxy_handlers_batch.py +++ b/tests/test_proxy_handlers_batch.py @@ -113,6 +113,12 @@ def _extract_tags(self, headers: dict) -> dict[str, str]: async def handle_passthrough(self, request, base_url): # noqa: ANN001, ANN201 return {"request": request, "base_url": base_url} + async def _run_compression_in_executor(self, fn, *, timeout): # noqa: ANN001, ANN201 + # Mirror of HeadroomProxy._run_compression_in_executor: batch handlers + # offload pipeline.apply() off the event loop (#1701). Inline is fine + # for tests — only the call contract matters here. + return fn() + async def _retry_request(self, method, url, headers, body, **kwargs): # noqa: ANN001, ANN201 return self._retry_response diff --git a/tests/test_tokenizer_count_offload.py b/tests/test_tokenizer_count_offload.py new file mode 100644 index 000000000..239d5e1ef --- /dev/null +++ b/tests/test_tokenizer_count_offload.py @@ -0,0 +1,126 @@ +"""Token counting must run off the event loop (GH #1701): the Anthropic messages +handler resolved the tokenizer and counted the conversation inline in the async +handler. For HF-backed models (e.g. deepseek-*) first use triggers an unbounded +network download, freezing the whole server (610s request, then /livez, /readyz +and /health hang until kill). The fix routes resolution + counting through +HeadroomProxy._count_tokens_offloaded (compression executor, bounded by +COMPRESSION_TIMEOUT_SECONDS, fail-open to estimation), and offloads the inline +batch pipeline.apply() calls the same way. +""" + +from __future__ import annotations + +import asyncio +import inspect +import threading +import time + +from headroom.proxy.handlers.anthropic import AnthropicHandlerMixin +from headroom.proxy.handlers.batch import BatchHandlerMixin +from headroom.proxy.server import ProxyConfig, create_app +from headroom.tokenizers import EstimatingTokenCounter + + +def _make_proxy(): # noqa: ANN202 — returns the internal HeadroomProxy + app = create_app( + ProxyConfig( + optimize=True, + cache_enabled=False, + rate_limit_enabled=False, + cost_tracking_enabled=False, + ) + ) + return app.state.proxy + + +def test_handlers_offload_token_counting_and_batch_apply() -> None: + """Wiring guard: the request paths must use the offloaded helpers, not inline + get_tokenizer/count_messages or pipeline.apply on the event loop.""" + fn = AnthropicHandlerMixin.handle_anthropic_messages + assert inspect.iscoroutinefunction(fn) + src = inspect.getsource(fn) + assert "_count_tokens_offloaded(" in src, "token counting not offloaded" + assert "tokenizer = get_tokenizer(" not in src, "tokenizer resolved inline on the loop" + + for mixin, method in ( + (AnthropicHandlerMixin, "handle_anthropic_batch_create"), + (BatchHandlerMixin, "handle_google_batch_create"), + (BatchHandlerMixin, "_compress_batch_jsonl"), + ): + fn = getattr(mixin, method) + assert inspect.iscoroutinefunction(fn), f"{method} must be async" + src = inspect.getsource(fn) + if "pipeline.apply(" in src: + assert "_run_compression_in_executor(" in src, f"{method}: apply() not offloaded" + assert "COMPRESSION_TIMEOUT_SECONDS" in src, f"{method}: offload missing timeout" + + helper_src = inspect.getsource(AnthropicHandlerMixin._count_tokens_offloaded) + assert "COMPRESSION_TIMEOUT_SECONDS" in helper_src + assert "EstimatingTokenCounter" in helper_src, "helper must fail open to estimation" + + +async def test_count_tokens_offloaded_runs_on_worker_thread(monkeypatch) -> None: # noqa: ANN001 + proxy = _make_proxy() + loop_thread = threading.current_thread().name + seen: dict[str, str] = {} + + class _SpyTokenizer(EstimatingTokenCounter): + def count_messages(self, messages): # noqa: ANN001, ANN201 + seen["thread"] = threading.current_thread().name + return super().count_messages(messages) + + monkeypatch.setattr("headroom.tokenizers.get_tokenizer", lambda *a, **k: _SpyTokenizer()) + + _, tokens = await proxy._count_tokens_offloaded("gpt-4", [{"role": "user", "content": "hi"}]) + + assert tokens > 0 + assert seen["thread"].startswith("headroom-compress") + assert seen["thread"] != loop_thread + + +async def test_count_tokens_offloaded_keeps_loop_responsive(monkeypatch) -> None: # noqa: ANN001 + """A slow tokenizer (stand-in for an HF network load) must not starve the loop — + the pre-fix inline call yielded ~0 ticks here.""" + proxy = _make_proxy() + ticks = 0 + + async def _ticker() -> None: + nonlocal ticks + while True: + await asyncio.sleep(0.01) + ticks += 1 + + class _SlowTokenizer(EstimatingTokenCounter): + def count_messages(self, messages): # noqa: ANN001, ANN201 + time.sleep(0.3) + return super().count_messages(messages) + + monkeypatch.setattr("headroom.tokenizers.get_tokenizer", lambda *a, **k: _SlowTokenizer()) + + tick_task = asyncio.create_task(_ticker()) + try: + _, tokens = await proxy._count_tokens_offloaded("m", [{"role": "user", "content": "hi"}]) + finally: + tick_task.cancel() + + assert tokens > 0 + assert ticks >= 5 + + +async def test_count_tokens_offloaded_fails_open(monkeypatch) -> None: # noqa: ANN001 + """Resolution errors and timeouts downgrade to estimation instead of raising.""" + proxy = _make_proxy() + + def _boom(*a, **k): # noqa: ANN002, ANN003, ANN202 + raise RuntimeError("tokenizer backend exploded") + + monkeypatch.setattr("headroom.tokenizers.get_tokenizer", _boom) + + tokenizer, tokens = await proxy._count_tokens_offloaded( + "deepseek-chat", [{"role": "user", "content": "hello world"}] + ) + + assert isinstance(tokenizer, EstimatingTokenCounter) + assert tokens > 0 + # Logged-once bookkeeping records the downgraded model. + assert "deepseek-chat" in proxy._token_count_fallback_models