Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions headroom/proxy/handlers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2379,6 +2379,10 @@ async def api_call_fn(
if _auth_header.startswith("Bearer ") and not _auth_header.startswith(
"Bearer sk-ant-api"
):
from headroom.proxy.savings_tracker import (
_estimate_cache_savings_usd,
_estimate_compression_savings_usd,
)
from headroom.subscription.tracker import (
get_subscription_tracker as _get_sub_tracker,
)
Expand All @@ -2389,6 +2393,10 @@ async def api_call_fn(
tokens_submitted=optimized_tokens,
tokens_saved_compression=tokens_saved,
tokens_saved_cache_reads=cr_tokens,
compression_savings_usd=_estimate_compression_savings_usd(
model, tokens_saved
),
cache_savings_usd=_estimate_cache_savings_usd(model, cr_tokens),
)

# The pre-refactor PERF emit (above) read raw usage
Expand Down
22 changes: 22 additions & 0 deletions headroom/proxy/savings_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,28 @@ def _estimate_compression_savings_usd(model: str, tokens_saved: int) -> float:
return 0.0


def _estimate_cache_savings_usd(model: str, cache_read_tokens: int) -> float:
"""Estimate cache-read savings in USD from discounted input tokens."""

litellm = _get_litellm_module()
if cache_read_tokens <= 0 or litellm is None:
return 0.0

try:
resolved = _resolve_litellm_model(model)
info = litellm.model_cost.get(resolved, {})
input_cost_per_token = info.get("input_cost_per_token")
if not input_cost_per_token:
return 0.0
cache_read_cost_per_token = info.get("cache_read_input_token_cost", input_cost_per_token)
savings_per_token = float(input_cost_per_token) - float(cache_read_cost_per_token)
if savings_per_token <= 0:
return 0.0
return float(cache_read_tokens) * savings_per_token
except Exception:
return 0.0


def _estimate_input_cost_usd(
model: str,
input_tokens: int,
Expand Down
35 changes: 35 additions & 0 deletions tests/test_proxy_cache_savings_usd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from types import SimpleNamespace

import pytest

from headroom.proxy import savings_tracker as savings_tracker_module


def test_estimate_cache_savings_usd_uses_discounted_cache_read_price(monkeypatch) -> None:
fake_litellm = SimpleNamespace(
model_cost={
"gpt-4o": {
"input_cost_per_token": 0.002,
"cache_read_input_token_cost": 0.001,
}
}
)
monkeypatch.setattr(savings_tracker_module, "LITELLM_AVAILABLE", True)
monkeypatch.setattr(savings_tracker_module, "litellm", fake_litellm)

assert savings_tracker_module._estimate_cache_savings_usd("gpt-4o", 100) == pytest.approx(
0.1
)


def test_estimate_cache_savings_usd_handles_missing_pricing(monkeypatch) -> None:
fake_litellm = SimpleNamespace(model_cost={})
monkeypatch.setattr(savings_tracker_module, "LITELLM_AVAILABLE", True)
monkeypatch.setattr(savings_tracker_module, "litellm", fake_litellm)

assert savings_tracker_module._estimate_cache_savings_usd("gpt-4o", 100) == 0.0

monkeypatch.setattr(savings_tracker_module, "LITELLM_AVAILABLE", False)
assert savings_tracker_module._estimate_cache_savings_usd("gpt-4o", 100) == 0.0
Loading