Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions headroom/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
import sys
import threading
import time
from collections.abc import Callable
from dataclasses import fields, is_dataclass, replace
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast

if TYPE_CHECKING:
from ..backends.base import Backend
Expand Down Expand Up @@ -446,6 +447,20 @@ def _agent_row(agent_key: str, label: str, source: str) -> dict[str, Any]:
)
logger = logging.getLogger("headroom.proxy")

LoopExceptionHandler = Callable[[asyncio.AbstractEventLoop, dict[str, Any]], object]


class LoopFailureDetails(TypedDict):
message: Any | None
exception: str | None


class LoopHealthState(TypedDict):
status: str
known_failures: int
last_known_failure: LoopFailureDetails | None


_MULTI_WORKER_CONFIG_ENV = "HEADROOM_PROXY_CONFIG_JSON"

# Env var that opts out of the Rust core deployment smoke test (Hotfix-A0).
Expand Down Expand Up @@ -1922,6 +1937,20 @@ def _request_is_loopback(request: Request) -> bool:
return is_loopback_host(client_host) and is_loopback_host_header(host_header)


def _is_known_websocket_callback_failure(context: dict[str, Any]) -> bool:
"""Return True iff this exact websockets callback failure shape is observed."""
if (
context.get("message")
!= "Exception in callback Connection.connection_lost(ConnectionResetError())"
):
return False
exception = context.get("exception")
return (
isinstance(exception, AttributeError)
and str(exception) == "'ClientConnection' object has no attribute 'recv_messages'"
)


def create_app(config: ProxyConfig | None = None) -> FastAPI:
"""Create FastAPI application."""
if not FASTAPI_AVAILABLE:
Expand Down Expand Up @@ -2051,6 +2080,7 @@ async def lifespan(app: FastAPI): # type: ignore[no-untyped-def]

try:
try:
previous_handler = _install_loop_exception_handler()
# Startup
await proxy.startup()
if config.periodic_toin_stats_enabled:
Expand Down Expand Up @@ -2079,6 +2109,17 @@ async def lifespan(app: FastAPI): # type: ignore[no-untyped-def]
app.state.startup_error = str(exc)
raise
finally:
loop: asyncio.AbstractEventLoop | None
previous: LoopExceptionHandler | None
try:
loop = asyncio.get_running_loop()
previous = previous_handler
except RuntimeError:
loop = None
previous = app.state.previous_loop_exception_handler
if loop is not None:
loop.set_exception_handler(previous)

app.state.ready = False
# Shutdown
if _cc_reconciler is not None:
Expand All @@ -2104,10 +2145,18 @@ async def lifespan(app: FastAPI): # type: ignore[no-untyped-def]
version=__version__,
lifespan=lifespan,
)
loop_health_state: LoopHealthState = {
"status": "healthy",
"known_failures": 0,
"last_known_failure": None,
}
app.state.proxy = proxy
app.state.started_at = None
app.state.ready = False
app.state.startup_error = None
app.state.loop_callback_health = loop_health_state
app.state.loop_exception_handler = None
app.state.previous_loop_exception_handler = None
# Set by the lifespan startup smoke test (`_check_rust_core`). Default
# "missing" means lifespan hasn't run yet — anything reading /health
# before startup completes (rare; lifespan runs before the first
Expand Down Expand Up @@ -2240,6 +2289,46 @@ def _runtime_payload() -> dict[str, Any]:
},
}

def _loop_callback_payload() -> LoopHealthState:
return {
"status": loop_health_state["status"],
"known_failures": loop_health_state["known_failures"],
"last_known_failure": loop_health_state["last_known_failure"],
}

def _install_loop_exception_handler() -> LoopExceptionHandler | None:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return None

previous_handler = loop.get_exception_handler()

def _loop_exception_handler(
_loop: asyncio.AbstractEventLoop, context: dict[str, Any]
) -> None:
if _is_known_websocket_callback_failure(context):
loop_health_state["status"] = "unhealthy"
loop_health_state["known_failures"] += 1
loop_health_state["last_known_failure"] = {
"message": context.get("message"),
"exception": str(context.get("exception"))
if context.get("exception")
else None,
}
return

delegate_handler = app.state.previous_loop_exception_handler
if delegate_handler is not None:
delegate_handler(_loop, context)
return
_loop.default_exception_handler(context)

loop.set_exception_handler(_loop_exception_handler)
app.state.loop_exception_handler = _loop_exception_handler
app.state.previous_loop_exception_handler = previous_handler
return previous_handler

def _health_payload(*, include_config: bool) -> dict[str, Any]:
checks = _health_checks()
ready = all(check["ready"] for check in checks.values())
Expand Down Expand Up @@ -2630,15 +2719,18 @@ async def _security_gate(request, call_next):
# Health & Metrics
@app.get("/livez")
async def livez():
callback_state = _loop_callback_payload()
healthy = callback_state["status"] == "healthy"
return JSONResponse(
status_code=200,
status_code=200 if healthy else 503,
content={
"service": "headroom-proxy",
"status": "healthy",
"alive": True,
"status": "healthy" if healthy else "unhealthy",
"alive": healthy,
"version": __version__,
"timestamp": _iso_utc_now(),
"uptime_seconds": _uptime_seconds(),
"loop_health": callback_state,
},
)

Expand Down
96 changes: 96 additions & 0 deletions tests/test_proxy_loop_exception_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Unit: event-loop callback handling for Codex WS disconnect regressions."""

from __future__ import annotations

import asyncio
from unittest.mock import MagicMock

import pytest
from fastapi.testclient import TestClient

from headroom.proxy.server import ProxyConfig, create_app

pytest.importorskip("fastapi")
pytest.importorskip("httpx")


def _known_loop_callback_context() -> dict[str, object]:
return {
"message": "Exception in callback Connection.connection_lost(ConnectionResetError())",
"exception": AttributeError("'ClientConnection' object has no attribute 'recv_messages'"),
}


def _make_client(app):
return TestClient(app, base_url="http://127.0.0.1", client=("127.0.0.1", 12345))


def test_livez_reports_known_websockets_callback_degradation():
config = ProxyConfig(
optimize=False,
cache_enabled=False,
rate_limit_enabled=False,
cost_tracking_enabled=False,
)
app = create_app(config)

with _make_client(app) as client:
before = client.get("/livez")
assert before.status_code == 200
assert before.json()["status"] == "healthy"
assert before.json()["alive"] is True

assert app.state.loop_exception_handler is not None
mock_loop = MagicMock(spec=asyncio.AbstractEventLoop)
app.state.loop_exception_handler(mock_loop, _known_loop_callback_context())

after = client.get("/livez")
assert after.status_code == 503
payload = after.json()
assert payload["status"] == "unhealthy"
assert payload["alive"] is False
loop_health = payload["loop_health"]
assert loop_health["status"] == "unhealthy"
assert loop_health["known_failures"] == 1
assert (
loop_health["last_known_failure"]["exception"]
== "'ClientConnection' object has no attribute 'recv_messages'"
)


def test_unrelated_loop_callback_is_delegated_to_previous_handler():
delegate_calls: list[dict[str, object]] = []
config = ProxyConfig(
optimize=False,
cache_enabled=False,
rate_limit_enabled=False,
cost_tracking_enabled=False,
)
app = create_app(config)

with _make_client(app) as client:
client.get("/livez")
assert app.state.loop_exception_handler is not None

def _previous(_loop: object, context: dict[str, object]) -> None:
delegate_calls.append(dict(context))

app.state.previous_loop_exception_handler = _previous

mock_loop = MagicMock(spec=asyncio.AbstractEventLoop)
app.state.loop_exception_handler(
mock_loop,
{
"message": "random callback failed",
"exception": RuntimeError("not known failure"),
},
)

assert len(delegate_calls) == 1
assert delegate_calls[0]["message"] == "random callback failed"
assert app.state.loop_callback_health["status"] == "healthy"
assert app.state.loop_callback_health["known_failures"] == 0

health = client.get("/livez").json()
assert health["status"] == "healthy"
assert health["alive"] is True
Loading