Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 68 additions & 15 deletions headroom/proxy/handlers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,48 @@
class AnthropicHandlerMixin:
"""Mixin providing Anthropic API handler methods for HeadroomProxy."""

async def _count_tokens_offloaded(self, model, messages): # noqa: ANN001, ANN201
"""Resolve a tokenizer and count messages off the event loop.

Tokenizer resolution can be expensive on first use (HuggingFace
backends may download vocab files) and counting a full Claude Code
conversation is CPU-bound, so both run on the compression executor
bounded by ``COMPRESSION_TIMEOUT_SECONDS`` (GH #1701: an unbounded
on-loop load froze the whole server). On timeout or error this
fails open to character-based estimation.

Returns:
Tuple of ``(tokenizer, token_count)``. The tokenizer is fully
initialized, so later ``count_messages`` calls on it are pure
CPU work.
"""
from headroom.proxy.helpers import COMPRESSION_TIMEOUT_SECONDS
from headroom.tokenizers import EstimatingTokenCounter, get_tokenizer

def _resolve_and_count(): # noqa: ANN202
tokenizer = get_tokenizer(model)
return tokenizer, tokenizer.count_messages(messages)

try:
return await self._run_compression_in_executor(
_resolve_and_count,
timeout=float(COMPRESSION_TIMEOUT_SECONDS),
)
except Exception as e: # fail open — includes asyncio.TimeoutError
# Log the downgrade once per model, not per request.
fallback_models = getattr(self, "_token_count_fallback_models", None)
if fallback_models is None:
fallback_models = set()
self._token_count_fallback_models = fallback_models
if model not in fallback_models:
fallback_models.add(model)
logger.warning(
f"Token counting for model {model} failed or timed out "
f"({e.__class__.__name__}); falling back to estimation"
)
estimator = EstimatingTokenCounter()
return estimator, estimator.count_messages(messages)

@staticmethod
def _resolve_ccr_workspace(
request: Any,
Expand Down Expand Up @@ -469,7 +511,6 @@ async def handle_anthropic_messages(
read_request_json_with_bytes,
)
from headroom.proxy.modes import is_cache_mode, is_token_mode
from headroom.tokenizers import get_tokenizer
from headroom.utils import extract_user_query

start_time = time.time()
Expand Down Expand Up @@ -899,9 +940,10 @@ async def _finalize_pre_upstream() -> None:
media_type="application/json",
)

# Count original tokens
tokenizer = get_tokenizer(model)
original_tokens = tokenizer.count_messages(messages)
# Count original tokens off the event loop: first-use tokenizer
# resolution may hit the network (HF download) and counting a full
# conversation is CPU-bound — on-loop it froze the server (#1701).
tokenizer, original_tokens = await self._count_tokens_offloaded(model, messages)

# Enterprise Security: scan request before compression
_security_ctx = None
Expand Down Expand Up @@ -1164,7 +1206,9 @@ def should_skip_ccr_request_compression(
)
if skip_ccr_request_compression:
optimized_messages = messages
optimized_tokens = tokenizer.count_messages(optimized_messages)
_, optimized_tokens = await self._count_tokens_offloaded(
model, optimized_messages
)
else:
# Zone 1: Swap cached compressed versions into working copy
working_messages = comp_cache.apply_cached(messages)
Expand Down Expand Up @@ -2985,7 +3029,6 @@ async def handle_anthropic_batch_create(
from headroom.ccr import CCRToolInjector
from headroom.proxy.helpers import MAX_REQUEST_BODY_SIZE, _read_request_json
from headroom.proxy.modes import is_cache_mode
from headroom.tokenizers import get_tokenizer
from headroom.utils import extract_user_query

start_time = time.time()
Expand Down Expand Up @@ -3093,17 +3136,27 @@ async def handle_anthropic_batch_create(
)
if is_cache_mode(self.config.mode):
optimized_messages = messages
original_tokens = get_tokenizer(model).count_messages(messages)
_, original_tokens = await self._count_tokens_offloaded(model, messages)
optimized_tokens = original_tokens
else:
result = self.anthropic_pipeline.apply(
messages=messages,
model=model,
model_limit=context_limit,
context=extract_user_query(messages),
frozen_message_count=frozen_message_count,
request_id=request_id,
**proxy_pipeline_kwargs(self.config),
from headroom.proxy.helpers import COMPRESSION_TIMEOUT_SECONDS

# Offload off the event loop (#1701): an inline apply()
# blocks every other request for the duration; a timeout
# here is caught below and passes the item through.
result = await self._run_compression_in_executor(
lambda messages=messages, model=model, context_limit=context_limit, frozen_message_count=frozen_message_count: (
self.anthropic_pipeline.apply(
messages=messages,
model=model,
model_limit=context_limit,
context=extract_user_query(messages),
frozen_message_count=frozen_message_count,
request_id=request_id,
**proxy_pipeline_kwargs(self.config),
)
),
timeout=COMPRESSION_TIMEOUT_SECONDS,
)

optimized_messages = result.messages
Expand Down
36 changes: 25 additions & 11 deletions headroom/proxy/handlers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fastapi.responses import Response

from headroom.proxy.auth_mode import classify_client
from headroom.proxy.helpers import extract_tags
from headroom.proxy.helpers import COMPRESSION_TIMEOUT_SECONDS, extract_tags
from headroom.proxy.outcome import RequestOutcome

logger = logging.getLogger("headroom.proxy")
Expand Down Expand Up @@ -161,11 +161,18 @@ async def handle_google_batch_create(
)

# Use OpenAI pipeline (similar message format after conversion)
result = self.openai_pipeline.apply(
messages=messages,
model=model,
model_limit=context_limit,
context=extract_user_query(messages),
# Offload off the event loop (#1701): inline apply() blocks
# every other request; timeouts fall to the except below.
result = await self._run_compression_in_executor(
lambda messages=messages, model=model, context_limit=context_limit: (
self.openai_pipeline.apply(
messages=messages,
model=model,
model_limit=context_limit,
context=extract_user_query(messages),
)
),
timeout=COMPRESSION_TIMEOUT_SECONDS,
)

optimized_messages = result.messages
Expand Down Expand Up @@ -1078,11 +1085,18 @@ async def _compress_batch_jsonl(self, content: str, request_id: str) -> tuple[li
if self.config.optimize:
try:
context_limit = self.openai_provider.get_context_limit(model)
result = self.openai_pipeline.apply(
messages=messages,
model=model,
model_limit=context_limit,
context=extract_user_query(messages),
# Offload off the event loop (#1701); timeouts fall to
# the except below and pass the line through.
result = await self._run_compression_in_executor(
lambda messages=messages, model=model, context_limit=context_limit: (
self.openai_pipeline.apply(
messages=messages,
model=model,
model_limit=context_limit,
context=extract_user_query(messages),
)
),
timeout=COMPRESSION_TIMEOUT_SECONDS,
)
compressed_messages = result.messages
# Use pipeline's token counts for consistency with pipeline logs
Expand Down
68 changes: 66 additions & 2 deletions headroom/tokenizers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from __future__ import annotations

import logging
import os
import threading
from functools import lru_cache
from typing import Any

Expand Down Expand Up @@ -103,10 +105,32 @@
}


# Bound the first (network) load of a HuggingFace tokenizer. Without a bound,
# huggingface_hub download retries can block for many minutes (GH #1701: 610s
# on a restricted Windows network). 0 disables network loads entirely.
_LOAD_TIMEOUT_ENV = "HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS"
_LOAD_TIMEOUT_DEFAULT = 10.0


def _load_timeout_secs() -> float:
try:
return float(os.environ.get(_LOAD_TIMEOUT_ENV, _LOAD_TIMEOUT_DEFAULT))
except (TypeError, ValueError):
return _LOAD_TIMEOUT_DEFAULT


@lru_cache(maxsize=16)
def _load_tokenizer(tokenizer_name: str):
"""Load and cache HuggingFace tokenizer.

The first attempt is cache-only (``local_files_only=True``) so a warm
HF cache never touches the network. A cache miss falls through to a
network download bounded by ``HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS``
(default 10s) on a daemon thread — the download itself cannot be
cancelled, but the caller unblocks and falls back to estimation.
Failures are cached by ``lru_cache`` (returns ``None``), so a slow or
offline hub is probed at most once per process per tokenizer.

Args:
tokenizer_name: HuggingFace model/tokenizer name.

Expand All @@ -119,10 +143,50 @@ def _load_tokenizer(tokenizer_name: str):
return AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=True,
local_files_only=True,
)
except Exception as e:
logger.warning(f"Failed to load tokenizer {tokenizer_name}: {e}")
except Exception:
pass # Not in the local cache — try the network below, bounded.

timeout = _load_timeout_secs()
if timeout <= 0:
logger.warning(
f"Tokenizer {tokenizer_name} not in local HF cache and network "
f"loading is disabled ({_LOAD_TIMEOUT_ENV}=0); using estimation"
)
return None

result: list[Any] = []
error: list[BaseException] = []

def _download() -> None:
try:
result.append(
AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=True,
)
)
except BaseException as e: # noqa: BLE001 — report any failure to the waiter
error.append(e)

thread = threading.Thread(
target=_download,
name=f"headroom-hf-tokenizer-load-{tokenizer_name}",
daemon=True,
)
thread.start()
thread.join(timeout)
if thread.is_alive():
logger.warning(
f"Timed out loading tokenizer {tokenizer_name} after {timeout}s "
f"(set {_LOAD_TIMEOUT_ENV} to adjust); using estimation"
)
return None
if error:
logger.warning(f"Failed to load tokenizer {tokenizer_name}: {error[0]}")
return None
return result[0] if result else None


def get_tokenizer_name(model: str) -> str:
Expand Down
114 changes: 114 additions & 0 deletions tests/test_huggingface_tokenizer_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""HF tokenizer loading must be bounded (GH #1701): AutoTokenizer.from_pretrained
performs unbounded network downloads/retries; called lazily from the proxy's request
path it blocked the event loop for ~10 minutes and zombified the server. The fix
tries the local HF cache first (local_files_only=True), bounds the network attempt
with HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS on a daemon thread, and fails open to
estimation — caching the failure so the hub is probed at most once per process.
"""

from __future__ import annotations

import sys
import time
import types
from typing import Any

import pytest

from headroom.tokenizers import huggingface as hf_mod
from headroom.tokenizers.huggingface import HuggingFaceTokenizer, _load_tokenizer


@pytest.fixture(autouse=True)
def _fresh_cache():
_load_tokenizer.cache_clear()
yield
_load_tokenizer.cache_clear()


def _install_fake_transformers(monkeypatch: pytest.MonkeyPatch, from_pretrained) -> None:
fake = types.ModuleType("transformers")
fake.AutoTokenizer = type(
"AutoTokenizer", (), {"from_pretrained": staticmethod(from_pretrained)}
)
monkeypatch.setitem(sys.modules, "transformers", fake)


def test_local_cache_tried_before_network(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[dict[str, Any]] = []

def fake_from_pretrained(name: str, **kwargs: Any):
calls.append(kwargs)
if kwargs.get("local_files_only"):
raise OSError("not in cache")
return "network-tokenizer"

_install_fake_transformers(monkeypatch, fake_from_pretrained)
monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "5")

assert _load_tokenizer("some/model") == "network-tokenizer"
assert calls[0].get("local_files_only") is True, "first attempt must be cache-only"
assert not calls[1].get("local_files_only")


def test_cache_hit_never_touches_network(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[dict[str, Any]] = []

def fake_from_pretrained(name: str, **kwargs: Any):
calls.append(kwargs)
return "cached-tokenizer"

_install_fake_transformers(monkeypatch, fake_from_pretrained)

assert _load_tokenizer("some/model") == "cached-tokenizer"
assert len(calls) == 1
assert calls[0].get("local_files_only") is True


def test_slow_network_load_times_out_and_fails_open(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_from_pretrained(name: str, **kwargs: Any):
if kwargs.get("local_files_only"):
raise OSError("not in cache")
time.sleep(60) # simulates hung huggingface_hub download
return "never"

_install_fake_transformers(monkeypatch, fake_from_pretrained)
monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "0.2")

start = time.monotonic()
assert _load_tokenizer("slow/model") is None
assert time.monotonic() - start < 5, "load must unblock at the timeout, not the download"

# Failure is cached (lru_cache) — the second call must not re-probe the hub.
start = time.monotonic()
assert _load_tokenizer("slow/model") is None
assert time.monotonic() - start < 0.05


def test_timeout_zero_disables_network_loading(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_from_pretrained(name: str, **kwargs: Any):
if kwargs.get("local_files_only"):
raise OSError("not in cache")
raise AssertionError("network load attempted despite timeout=0")

_install_fake_transformers(monkeypatch, fake_from_pretrained)
monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "0")

assert _load_tokenizer("offline/model") is None


def test_count_messages_fails_open_to_estimation(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_from_pretrained(name: str, **kwargs: Any):
raise OSError("unavailable")

_install_fake_transformers(monkeypatch, fake_from_pretrained)
monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "0.2")

counter = HuggingFaceTokenizer("deepseek-chat")
tokens = counter.count_messages([{"role": "user", "content": "hello world" * 50}])
assert tokens > 0 # estimation fallback, no exception, no hang


def test_invalid_timeout_env_falls_back_to_default(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("HEADROOM_HF_TOKENIZER_LOAD_TIMEOUT_SECS", "not-a-number")
assert hf_mod._load_timeout_secs() == hf_mod._LOAD_TIMEOUT_DEFAULT
Loading
Loading