Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions pkg/cli/model_costs.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ func initModelPrices() {
func findModelPricing(provider, model string) (map[string]float64, bool) {
initModelPrices()

normalizedProvider := normalizeCatalogProvider(provider)
normalizedProvider := modelsdev.NormalizeProvider(provider)
normalizedModel := strings.ToLower(strings.TrimSpace(model))
comparableModel := normalizeComparableModelID(normalizedModel)
comparableModel := modelsdev.NormalizeComparableModelID(normalizedModel)
if normalizedModel == "" { //nolint:tolowerequalfold
return nil, false
}
Expand All @@ -92,10 +92,10 @@ func findModelPricing(provider, model string) (map[string]float64, bool) {
if !strings.Contains(fullID, "/") && normalizedProvider != "" {
fullID = normalizedProvider + "/" + normalizedModel
}
comparableFullID := normalizeComparableModelID(fullID)
comparableFullID := modelsdev.NormalizeComparableModelID(fullID)

for _, record := range modelPriceRecords {
if (fullID != "" && record.id == fullID) || (comparableFullID != "" && normalizeComparableModelID(record.id) == comparableFullID) {
if (fullID != "" && record.id == fullID) || (comparableFullID != "" && modelsdev.NormalizeComparableModelID(record.id) == comparableFullID) {
modelCostsLog.Printf("Exact pricing match: provider=%s, model=%s -> %s", provider, model, record.id)
return record.pricing, true
}
Expand All @@ -107,7 +107,7 @@ func findModelPricing(provider, model string) (map[string]float64, bool) {
bestGenericLen := -1

for _, record := range modelPriceRecords {
comparableRecordModel := normalizeComparableModelID(record.model)
comparableRecordModel := modelsdev.NormalizeComparableModelID(record.model)
if record.model == normalizedModel || comparableRecordModel == comparableModel {
if normalizedProvider != "" && record.provider == normalizedProvider {
return record.pricing, true
Expand Down Expand Up @@ -142,19 +142,6 @@ func findModelPricing(provider, model string) (map[string]float64, bool) {
return nil, false
}

func normalizeCatalogProvider(provider string) string {
switch strings.ToLower(strings.TrimSpace(provider)) {
case "github", "copilot", "github_models":
return "github-copilot"
default:
return strings.ToLower(strings.TrimSpace(provider))
}
}

func normalizeComparableModelID(value string) string {
return strings.NewReplacer(".", "-", "_", "-").Replace(strings.ToLower(strings.TrimSpace(value)))
}

func usdToAIC(usd float64) float64 {
return usd / 0.01
}
Expand All @@ -167,7 +154,7 @@ func computeModelInferenceCostUSD(provider, model string, inputTokens, outputTok

input := inputTokens
cacheRead := cacheReadTokens
if cacheRead > 0 && providerIncludesCacheReadsInInput(normalizeCatalogProvider(provider)) {
if cacheRead > 0 && providerIncludesCacheReadsInInput(modelsdev.NormalizeProvider(provider)) {
input = max(inputTokens-cacheReadTokens, 0)
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/cli/model_costs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/github/gh-aw/pkg/modelsdev"
)

func TestFindModelPricing(t *testing.T) {
Expand All @@ -21,7 +23,7 @@ func TestComputeModelInferenceAIC(t *testing.T) {
assert.InDelta(t, 0.54825, aic, 1e-9)
}

func TestNormalizeCatalogProvider(t *testing.T) {
func TestNormalizeProvider(t *testing.T) {
tests := []struct {
input string
want string
Expand All @@ -40,7 +42,7 @@ func TestNormalizeCatalogProvider(t *testing.T) {
name = "<empty>"
}
t.Run(name, func(t *testing.T) {
got := normalizeCatalogProvider(tt.input)
got := modelsdev.NormalizeProvider(tt.input)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[/codebase-design] TestNormalizeCatalogProvider now tests modelsdev.NormalizeProvider, making its name stale — and its cases are a subset of the new TestNormalizeProvider in catalog_test.go, so the coverage is duplicated.

💡 Suggested fix

Either remove this test entirely (canonical coverage already lives in pkg/modelsdev/catalog_test.go), or rename it and replace the cases with an integration-level test that goes through findModelPricing with a provider alias to verify the wiring end-to-end. Keeping it as-is means a failure reports TestNormalizeCatalogProvider when the function named in the message is NormalizeProvider, which is confusing.

@copilot please address this.

assert.Equal(t, tt.want, got)
Comment on lines 44 to 46
})
}
Expand Down
24 changes: 16 additions & 8 deletions pkg/modelsdev/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ const (
// catalogURL is a variable so tests can override it with a local HTTP server.
var catalogURL = "https://models.dev/catalog.json"

// modelIDReplacer normalizes separator characters in model IDs so that IDs
// differing only in ".", "_", or "-" compare equal.
var modelIDReplacer = strings.NewReplacer(".", "-", "_", "-")

var log = logger.New("modelsdev:catalog")

// rawCatalog mirrors the top-level models.dev catalog JSON structure.
Expand Down Expand Up @@ -59,13 +63,13 @@ func FindPricing(ctx context.Context, provider, model string) (map[string]float6
return nil, false
}

normalizedProvider := normalizeProvider(provider)
normalizedProvider := NormalizeProvider(provider)
trimmedModel := strings.TrimSpace(model)
if trimmedModel == "" {
return nil, false
}
normalizedModel := strings.ToLower(trimmedModel)
comparableModel := normalizeComparableModelID(normalizedModel)
comparableModel := NormalizeComparableModelID(normalizedModel)

log.Printf("FindPricing: looking up provider=%q model=%q", normalizedProvider, normalizedModel)

Expand All @@ -78,7 +82,7 @@ func FindPricing(ctx context.Context, provider, model string) (map[string]float6
}
// Comparable (dot/underscore-normalized) model ID match.
for mn, pricing := range providerModels {
if normalizeComparableModelID(mn) == comparableModel {
if NormalizeComparableModelID(mn) == comparableModel {
log.Printf("FindPricing: provider-scoped comparable match %q for %q", mn, normalizedModel)
return pricing, true
}
Expand All @@ -93,7 +97,7 @@ func FindPricing(ctx context.Context, provider, model string) (map[string]float6
return pricing, true
}
for mn, pricing := range providerModels {
if normalizeComparableModelID(mn) == comparableModel {
if NormalizeComparableModelID(mn) == comparableModel {
log.Printf("FindPricing: cross-provider comparable match %q for %q", mn, normalizedModel)
return pricing, true
}
Expand Down Expand Up @@ -163,7 +167,7 @@ func parseCatalog(data []byte) (pricingCache, error) {

parsed := make(pricingCache)
for providerName, provider := range raw.Providers {
normalizedProvider := normalizeProvider(providerName)
normalizedProvider := NormalizeProvider(providerName)
if normalizedProvider == "" {
continue
}
Expand Down Expand Up @@ -213,7 +217,9 @@ func parseCostMap(raw map[string]json.RawMessage) map[string]float64 {
return result
}

func normalizeProvider(provider string) string {
// NormalizeProvider maps provider aliases (e.g. "github", "copilot", "github_models")
// to their canonical form ("github-copilot") and lower-cases all other values.
func NormalizeProvider(provider string) string {
switch strings.ToLower(strings.TrimSpace(provider)) {
case "github", "copilot", "github_models":
return "github-copilot"
Expand All @@ -222,6 +228,8 @@ func normalizeProvider(provider string) string {
}
}

func normalizeComparableModelID(value string) string {
return strings.NewReplacer(".", "-", "_", "-").Replace(strings.ToLower(strings.TrimSpace(value)))
// NormalizeComparableModelID lower-cases the value and replaces "." and "_" with "-"
// so that model IDs differing only in those separators compare equal.
func NormalizeComparableModelID(value string) string {
return modelIDReplacer.Replace(strings.ToLower(strings.TrimSpace(value)))
}
33 changes: 33 additions & 0 deletions pkg/modelsdev/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,36 @@ func TestFindPricing(t *testing.T) {
assert.Nil(t, pricing)
})
}

func TestNormalizeProvider(t *testing.T) {
cases := []struct{ input, want string }{
{"github", "github-copilot"},
{"copilot", "github-copilot"},
{"github_models", "github-copilot"},
{"GITHUB_MODELS", "github-copilot"},
{"anthropic", "anthropic"},

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[/tdd] TestNormalizeProvider has no case for "github_models" with mixed capitalisation (e.g. "GitHub_Models"). The switch only lower-cases before comparing, so capitalised variants work, but a test case would lock that in as a specification.

💡 Suggested additional case
{"GitHub_Models", "github-copilot"},

Adding it turns the test suite into a living specification for alias matching, and prevents a future refactor from accidentally dropping case-insensitive handling.

@copilot please address this.

{"OpenAI", "openai"},
{" Anthropic ", "anthropic"},
{"", ""},
}
for _, tc := range cases {
t.Run(tc.input+"->"+tc.want, func(t *testing.T) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unaddressable sub-test name: when tc.input is "" the sub-test name becomes "->", which cannot be targeted with -run and silently conflates with any other "->" case in test output.

💡 Suggested fix

Apply the same guard used in the CLI test (model_costs_test.go:41):

for _, tc := range cases {
    name := tc.input + "->" + tc.want
    if tc.input == "" {
        name = "<empty>->" + tc.want
    }
    t.Run(name, func(t *testing.T) {
        assert.Equal(t, tc.want, NormalizeProvider(tc.input))
    })
}

Without this, go test -run TestNormalizeProvider/-> silently matches nothing, and CI logs show "->" with no way to identify which empty case failed. The same issue exists in TestNormalizeComparableModelID (line 177).

assert.Equal(t, tc.want, NormalizeProvider(tc.input))
})
}
}

func TestNormalizeComparableModelID(t *testing.T) {
cases := []struct{ input, want string }{
{"claude-sonnet-4.6", "claude-sonnet-4-6"},
{"gpt_4o", "gpt-4o"},
{"GPT-4O", "gpt-4o"},
{" claude.3 ", "claude-3"},
{"", ""},
}
for _, tc := range cases {
t.Run(tc.input+"->"+tc.want, func(t *testing.T) {
assert.Equal(t, tc.want, NormalizeComparableModelID(tc.input))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[/tdd] TestNormalizeComparableModelID is missing a case for a string that is already in canonical form (e.g. "gpt-4o""gpt-4o"). Without an idempotency case, a future change that over-normalises (e.g. double-stripping) would not be caught by this suite.

💡 Suggested additional case
{"gpt-4o", "gpt-4o"}, // already canonical — idempotency check

@copilot please address this.

})
}
}
Loading