From 26a5e7bdb236b617ab5366286bfc1cb61de7e959 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 5 Jun 2026 17:11:01 +0100 Subject: [PATCH 01/11] chore: simplify canonicalization --- coderd/ai_provider_canonical.go | 15 ----------- coderd/ai_providers.go | 2 +- coderd/ai_providers_migrate.go | 8 +++--- coderd/exp_chats.go | 16 ++++++------ coderd/x/chatd/ai_provider_canonical.go | 32 ++---------------------- coderd/x/chatd/chatd.go | 18 ++++++++++--- coderd/x/chatd/model_routing_aibridge.go | 19 +++++++------- coderd/x/chatd/model_routing_direct.go | 16 ++++++++++-- coderd/x/chatd/subagent.go | 8 +++++- 9 files changed, 61 insertions(+), 73 deletions(-) delete mode 100644 coderd/ai_provider_canonical.go diff --git a/coderd/ai_provider_canonical.go b/coderd/ai_provider_canonical.go deleted file mode 100644 index e0a43992c365d..0000000000000 --- a/coderd/ai_provider_canonical.go +++ /dev/null @@ -1,15 +0,0 @@ -package coderd - -import ( - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/codersdk" -) - -func canonicalDatabaseAIProviderType(providerType database.AIProviderType, settings codersdk.AIProviderSettings) database.AIProviderType { - return database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(providerType), settings)) -} - -func canonicalAIProviderTypeForRow(provider database.AIProvider) (database.AIProviderType, error) { - return db2sdk.CanonicalAIProviderType(provider) -} diff --git a/coderd/ai_providers.go b/coderd/ai_providers.go index 5e1565f4e9155..f1efe24dbc465 100644 --- a/coderd/ai_providers.go +++ b/coderd/ai_providers.go @@ -311,7 +311,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) { if req.Settings != nil { existing = mergeAIProviderSettings(existing, *req.Settings) } - targetType := canonicalDatabaseAIProviderType(old.Type, existing) + targetType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(old.Type), existing)) targetBaseURL := ptr.NilToDefault(req.BaseURL, old.BaseUrl) // Bedrock settings are only meaningful for Bedrock providers; // rejecting the mismatch keeps a misconfiguration from sitting diff --git a/coderd/ai_providers_migrate.go b/coderd/ai_providers_migrate.go index 93a383e7b8ed5..f711cd0fe3a2b 100644 --- a/coderd/ai_providers_migrate.go +++ b/coderd/ai_providers_migrate.go @@ -108,7 +108,8 @@ func SeedAIProvidersFromEnv( if err != nil { return xerrors.Errorf("decode existing settings for %q: %w", existing.Name, err) } - if canonicalDatabaseAIProviderType(existing.Type, existingSettings) == database.AiProviderTypeBedrock { + existingDType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(existing.Type), existingSettings)) + if existingDType == database.AiProviderTypeBedrock { logger.Warn(sysCtx, "skipping legacy Anthropic env seed because an existing Anthropic-named row contains Bedrock settings", slog.F("name", dp.Name), ) @@ -122,7 +123,8 @@ func SeedAIProvidersFromEnv( if err != nil { return xerrors.Errorf("decode existing settings for %q: %w", candidate.Name, err) } - if canonicalDatabaseAIProviderType(candidate.Type, candidateSettings) == database.AiProviderTypeBedrock { + candidateDType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(candidate.Type), candidateSettings)) + if candidateDType == database.AiProviderTypeBedrock { existing = candidate found = true } @@ -152,7 +154,7 @@ func SeedAIProvidersFromEnv( existingKeys = append(existingKeys, k.APIKey) } existingDP := desiredAIProvider{ - Type: canonicalDatabaseAIProviderType(existing.Type, existingSettings), + Type: database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(existing.Type), existingSettings)), BaseURL: existing.BaseUrl, Bedrock: existingSettings.Bedrock, Keys: existingKeys, diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 4b87484fee6bc..9b05930c44d81 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -6532,7 +6532,7 @@ func parseUserAIProviderID(r *http.Request) (uuid.UUID, error) { } func convertAIProviderSummary(provider database.AIProvider) (codersdk.AIProviderSummary, error) { - providerType, err := canonicalAIProviderTypeForRow(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { return codersdk.AIProviderSummary{}, err } @@ -6741,7 +6741,7 @@ func (api *API) configuredProvidersFromAIProviders(ctx context.Context, provider } func (api *API) configuredProviderFromAIProviderKeys(ctx context.Context, provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) { - providerType, err := canonicalAIProviderTypeForRow(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { api.Logger.Error(ctx, "failed to decode AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) return chatprovider.ConfiguredProvider{}, err @@ -6837,7 +6837,7 @@ func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) { return } for _, provider := range providers { - providerType, err := canonicalAIProviderTypeForRow(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { api.Logger.Error(ctx, "failed to decode AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ @@ -6911,7 +6911,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."}) return } - providerType, err := canonicalAIProviderTypeForRow(aiProvider) + providerType, err := db2sdk.CanonicalAIProviderType(aiProvider) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to decode AI provider settings.", @@ -7001,7 +7001,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { if !lockedAIProvider.Enabled { return errChatProviderNotConfigured } - lockedProviderType, err := canonicalAIProviderTypeForRow(lockedAIProvider) + lockedProviderType, err := db2sdk.CanonicalAIProviderType(lockedAIProvider) if err != nil { return xerrors.Errorf("canonicalize provider type for %q: %w", lockedAIProvider.Name, err) } @@ -7124,7 +7124,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { return } } else { - providerType, err := canonicalAIProviderTypeForRow(aiProvider) + providerType, err := db2sdk.CanonicalAIProviderType(aiProvider) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to decode AI provider settings.", @@ -7165,7 +7165,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."}) return } - providerType, err := canonicalAIProviderTypeForRow(aiProvider) + providerType, err := db2sdk.CanonicalAIProviderType(aiProvider) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to decode AI provider settings.", @@ -7262,7 +7262,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { if !aiProvider.Enabled { return errChatProviderNotConfigured } - providerType, err := canonicalAIProviderTypeForRow(aiProvider) + providerType, err := db2sdk.CanonicalAIProviderType(aiProvider) if err != nil { return xerrors.Errorf("canonicalize provider type for %q: %w", aiProvider.Name, err) } diff --git a/coderd/x/chatd/ai_provider_canonical.go b/coderd/x/chatd/ai_provider_canonical.go index fdf8b991649b1..61c50416912ec 100644 --- a/coderd/x/chatd/ai_provider_canonical.go +++ b/coderd/x/chatd/ai_provider_canonical.go @@ -1,39 +1,11 @@ package chatd import ( - "context" - - "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" ) -func canonicalAIProviderType(provider database.AIProvider) (database.AIProviderType, error) { - return db2sdk.CanonicalAIProviderType(provider) -} - -func canonicalAIProviderTypeString(provider database.AIProvider) (string, error) { - providerType, err := canonicalAIProviderType(provider) - if err != nil { - return "", err - } - return string(providerType), nil -} - -func bestEffortCanonicalAIProviderType(ctx context.Context, logger slog.Logger, provider database.AIProvider) database.AIProviderType { - providerType, err := canonicalAIProviderType(provider) - if err != nil { - logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - return provider.Type - } - return providerType -} - -func bestEffortCanonicalAIProviderTypeString(ctx context.Context, logger slog.Logger, provider database.AIProvider) string { - return string(bestEffortCanonicalAIProviderType(ctx, logger, provider)) -} - func aiProviderTypeCanSatisfyRequest(candidateProviderType string, requestedProviderType string) bool { if candidateProviderType == requestedProviderType { return true @@ -43,11 +15,11 @@ func aiProviderTypeCanSatisfyRequest(candidateProviderType string, requestedProv } func aiProviderMatchesCanonicalType(provider database.AIProvider, normalizedProviderType string) (bool, error) { - providerType, err := canonicalAIProviderTypeString(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { return false, err } - return aiProviderTypeCanSatisfyRequest(chatprovider.NormalizeProvider(providerType), normalizedProviderType), nil + return aiProviderTypeCanSatisfyRequest(chatprovider.NormalizeProvider(string(providerType)), normalizedProviderType), nil } func aiProviderMatchesRawType(provider database.AIProvider, normalizedProviderType string) bool { diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index fd78e0338893e..9529d8bb55ce4 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -8658,9 +8658,14 @@ func (p *Server) aiProviderConfigFromKeys(ctx context.Context, provider database break } } + canonicalType, err := db2sdk.CanonicalAIProviderType(provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = provider.Type + } return chatprovider.ConfiguredProvider{ ProviderID: provider.ID, - Provider: bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider), + Provider: string(canonicalType), APIKey: apiKey, BaseURL: provider.BaseUrl, CentralAPIKeyEnabled: true, @@ -8774,12 +8779,12 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType( return chatprovider.ProviderAPIKeys{}, nil, nil } for _, provider := range providers { - canonicalProviderType, err := canonicalAIProviderTypeString(provider) + canonicalProviderType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) continue } - providerKeysType := chatprovider.NormalizeProvider(canonicalProviderType) + providerKeysType := chatprovider.NormalizeProvider(string(canonicalProviderType)) if !aiProviderTypeCanSatisfyRequest(providerKeysType, normalizedProviderType) { continue } @@ -8792,7 +8797,12 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType( if !aiProviderMatchesRawType(provider, normalizedProviderType) { continue } - providerKeysType := chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider)) + canonicalType, err := db2sdk.CanonicalAIProviderType(provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = provider.Type + } + providerKeysType := chatprovider.NormalizeProvider(string(canonicalType)) keys, matchedProvider, err := keysForProvider(provider, providerKeysType) if err != nil || matchedProvider != nil { return keys, matchedProvider, err diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index 077f072083340..a2fd9c28e3133 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -16,6 +16,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" "github.com/coder/coder/v2/coderd/x/chatd/chaterror" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" @@ -169,7 +170,7 @@ func (p *Server) newAIGatewayModel( baseRT = &chatdebug.RecordingTransport{Base: baseRT} } - providerType, err := canonicalAIProviderType(route.Provider) + providerType, err := db2sdk.CanonicalAIProviderType(route.Provider) if err != nil { return nil, xerrors.Errorf("canonicalize provider type for %q: %w", route.Provider.Name, err) } @@ -263,7 +264,7 @@ func (p *Server) resolveAIGatewayRoute( provider database.AIProvider, modelProviderHint string, ) (resolvedModelRoute, error) { - providerType, err := canonicalAIProviderType(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { return resolvedModelRoute{}, xerrors.Errorf("canonicalize provider type for %q: %w", provider.Name, err) } @@ -288,7 +289,7 @@ func (p *Server) resolveAIGatewayModelRouteForConfig( if err != nil { return resolvedModelRoute{}, err } - providerType, err := canonicalAIProviderType(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { return resolvedModelRoute{}, xerrors.Errorf("canonicalize provider type for %q: %w", provider.Name, err) } @@ -304,12 +305,12 @@ func (p *Server) resolveAIGatewayModelRouteForProviderType( if err != nil { return resolvedModelRoute{}, err } - return p.resolveAIGatewayRoute( - ctx, - ownerID, - provider, - bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider), - ) + canonicalType, err := db2sdk.CanonicalAIProviderType(provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = provider.Type + } + return p.resolveAIGatewayRoute(ctx, ownerID, provider, string(canonicalType)) } func (p *Server) gatewayProviderForConfig( diff --git a/coderd/x/chatd/model_routing_direct.go b/coderd/x/chatd/model_routing_direct.go index 52792d3ef5bc8..e0d6505709715 100644 --- a/coderd/x/chatd/model_routing_direct.go +++ b/coderd/x/chatd/model_routing_direct.go @@ -8,7 +8,9 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" ) @@ -77,7 +79,12 @@ func (p *Server) resolveDirectModelRouteForProviderType( } providerHint := normalizedProviderType if provider != nil { - providerHint = chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, *provider)) + canonicalType, err := db2sdk.CanonicalAIProviderType(*provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = provider.Type + } + providerHint = chatprovider.NormalizeProvider(string(canonicalType)) } return newDirectModelRoute(providerHint, keys), nil } @@ -93,5 +100,10 @@ func (p *Server) directProviderHintAndProviderForConfig( if err != nil { return "", nil, err } - return bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider), &provider, nil + canonicalType, err := db2sdk.CanonicalAIProviderType(provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = provider.Type + } + return string(canonicalType), &provider, nil } diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 5c9c54ab9ccc4..e7a613c276431 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -19,6 +19,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" @@ -518,7 +519,12 @@ func (p *Server) resolveModelConfigAndNormalizedProvider( if !provider.Enabled { return database.ChatModelConfig{}, "", sql.ErrNoRows } - providerName := chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider)) + canonicalType, err := db2sdk.CanonicalAIProviderType(provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = provider.Type + } + providerName := chatprovider.NormalizeProvider(string(canonicalType)) if providerName == "" { return database.ChatModelConfig{}, "", errInvalidModelOverrideMetadata } From 17fbf11485f1212ded2a6ae4b08b047f17ab6228 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 5 Jun 2026 18:14:45 +0100 Subject: [PATCH 02/11] refactor: change db2sdk.CanonicalAIProviderType to sdk type and remove unnecessary casts --- coderd/database/db2sdk/db2sdk.go | 4 ++-- coderd/database/db2sdk/db2sdk_test.go | 2 +- coderd/exp_chats.go | 5 ++--- coderd/x/chatd/chatd.go | 4 ++-- coderd/x/chatd/model_routing_aibridge.go | 18 +++++++----------- coderd/x/chatd/model_routing_direct.go | 5 +++-- coderd/x/chatd/model_routing_internal_test.go | 10 +++++----- coderd/x/chatd/subagent.go | 2 +- 8 files changed, 23 insertions(+), 27 deletions(-) diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 0ac0de75be6d7..69259ab03ac92 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -44,12 +44,12 @@ func APIAllowListTarget(entry rbac.AllowListElement) codersdk.APIAllowListTarget } // CanonicalAIProviderType returns the runtime provider type for a database row. -func CanonicalAIProviderType(row database.AIProvider) (database.AIProviderType, error) { +func CanonicalAIProviderType(row database.AIProvider) (codersdk.AIProviderType, error) { settings, err := AIProviderSettings(row.Settings) if err != nil { return "", xerrors.Errorf("decode settings: %w", err) } - return database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(row.Type), settings)), nil + return codersdk.CanonicalAIProviderType(codersdk.AIProviderType(row.Type), settings), nil } // AIProvider converts a database row plus its API keys into the diff --git a/coderd/database/db2sdk/db2sdk_test.go b/coderd/database/db2sdk/db2sdk_test.go index fa10566d38dea..e034fbe4a0770 100644 --- a/coderd/database/db2sdk/db2sdk_test.go +++ b/coderd/database/db2sdk/db2sdk_test.go @@ -95,7 +95,7 @@ func TestAIProviderCanonicalTypeAndDisplayName(t *testing.T) { gotType, err := db2sdk.CanonicalAIProviderType(tt.row) require.NoError(t, err) - require.Equal(t, database.AIProviderType(tt.wantType), gotType) + require.Equal(t, tt.wantType, gotType) require.Equal(t, tt.wantDisplay, db2sdk.AIProviderDisplayName(tt.row, tt.wantType)) got, err := db2sdk.AIProvider(tt.row, nil) diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 9b05930c44d81..44cabe47a801a 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -6536,12 +6536,11 @@ func convertAIProviderSummary(provider database.AIProvider) (codersdk.AIProvider if err != nil { return codersdk.AIProviderSummary{}, err } - sdkProviderType := codersdk.AIProviderType(providerType) return codersdk.AIProviderSummary{ ID: provider.ID, - Type: sdkProviderType, + Type: providerType, Name: provider.Name, - DisplayName: db2sdk.AIProviderDisplayName(provider, sdkProviderType), + DisplayName: db2sdk.AIProviderDisplayName(provider, providerType), Enabled: provider.Enabled, Deleted: provider.Deleted, }, nil diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 9529d8bb55ce4..b4e93f57d712b 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -8661,7 +8661,7 @@ func (p *Server) aiProviderConfigFromKeys(ctx context.Context, provider database canonicalType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - canonicalType = provider.Type + canonicalType = codersdk.AIProviderType(provider.Type) } return chatprovider.ConfiguredProvider{ ProviderID: provider.ID, @@ -8800,7 +8800,7 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType( canonicalType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - canonicalType = provider.Type + canonicalType = codersdk.AIProviderType(provider.Type) } providerKeysType := chatprovider.NormalizeProvider(string(canonicalType)) keys, matchedProvider, err := keysForProvider(provider, providerKeysType) diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index a2fd9c28e3133..c056eeefb0989 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -190,14 +190,14 @@ type aibridgeFantasyConfig struct { Keys chatprovider.ProviderAPIKeys } -func fantasyConfigForAIBridge(providerType database.AIProviderType) aibridgeFantasyConfig { +func fantasyConfigForAIBridge(providerType codersdk.AIProviderType) aibridgeFantasyConfig { var fantasyProvider string baseURL := aibridgeLocalBaseURL + "/v1" switch providerType { - case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock: + case codersdk.AIProviderTypeAnthropic, codersdk.AIProviderTypeBedrock: fantasyProvider = fantasyanthropic.Name baseURL = aibridgeLocalBaseURL - case database.AiProviderTypeOpenai: + case codersdk.AIProviderTypeOpenAI: fantasyProvider = fantasyopenai.Name default: fantasyProvider = fantasyopenaicompat.Name @@ -215,9 +215,9 @@ func fantasyConfigForAIBridge(providerType database.AIProviderType) aibridgeFant } } -func aiGatewayRequestFormatForProviderType(providerType database.AIProviderType) aiGatewayRequestFormat { +func aiGatewayRequestFormatForProviderType(providerType codersdk.AIProviderType) aiGatewayRequestFormat { switch providerType { - case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock: + case codersdk.AIProviderTypeAnthropic, codersdk.AIProviderTypeBedrock: return aiGatewayRequestFormatAnthropic default: return aiGatewayRequestFormatOpenAI @@ -264,15 +264,11 @@ func (p *Server) resolveAIGatewayRoute( provider database.AIProvider, modelProviderHint string, ) (resolvedModelRoute, error) { - providerType, err := db2sdk.CanonicalAIProviderType(provider) - if err != nil { - return resolvedModelRoute{}, xerrors.Errorf("canonicalize provider type for %q: %w", provider.Name, err) - } auth, err := p.aiGatewayProviderAuthForUser( ctx, ownerID, provider, - aiGatewayRequestFormatForProviderType(providerType), + aiGatewayRequestFormatForProviderType(codersdk.AIProviderType(modelProviderHint)), ) if err != nil { return resolvedModelRoute{}, xerrors.Errorf("resolve AI Gateway provider auth: %w", err) @@ -308,7 +304,7 @@ func (p *Server) resolveAIGatewayModelRouteForProviderType( canonicalType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - canonicalType = provider.Type + canonicalType = codersdk.AIProviderType(provider.Type) } return p.resolveAIGatewayRoute(ctx, ownerID, provider, string(canonicalType)) } diff --git a/coderd/x/chatd/model_routing_direct.go b/coderd/x/chatd/model_routing_direct.go index e0d6505709715..a11dd05923dbc 100644 --- a/coderd/x/chatd/model_routing_direct.go +++ b/coderd/x/chatd/model_routing_direct.go @@ -13,6 +13,7 @@ import ( "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/codersdk" ) type directModelRoute struct { @@ -82,7 +83,7 @@ func (p *Server) resolveDirectModelRouteForProviderType( canonicalType, err := db2sdk.CanonicalAIProviderType(*provider) if err != nil { p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - canonicalType = provider.Type + canonicalType = codersdk.AIProviderType(provider.Type) } providerHint = chatprovider.NormalizeProvider(string(canonicalType)) } @@ -103,7 +104,7 @@ func (p *Server) directProviderHintAndProviderForConfig( canonicalType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - canonicalType = provider.Type + canonicalType = codersdk.AIProviderType(provider.Type) } return string(canonicalType), &provider, nil } diff --git a/coderd/x/chatd/model_routing_internal_test.go b/coderd/x/chatd/model_routing_internal_test.go index 76ede361deb5a..dce5e7597d3b0 100644 --- a/coderd/x/chatd/model_routing_internal_test.go +++ b/coderd/x/chatd/model_routing_internal_test.go @@ -81,14 +81,14 @@ func TestAIBridgeProviderFormatMapping(t *testing.T) { tests := []struct { name string - providerType database.AIProviderType + providerType codersdk.AIProviderType wantProvider string wantBaseURL string }{ - {name: "OpenAI", providerType: database.AiProviderTypeOpenai, wantProvider: "openai", wantBaseURL: "http://coder-aibridge/v1"}, - {name: "Anthropic", providerType: database.AiProviderTypeAnthropic, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, - {name: "Bedrock", providerType: database.AiProviderTypeBedrock, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, - {name: "Google", providerType: database.AiProviderTypeGoogle, wantProvider: "openai-compat", wantBaseURL: "http://coder-aibridge/v1"}, + {name: "OpenAI", providerType: codersdk.AIProviderTypeOpenAI, wantProvider: "openai", wantBaseURL: "http://coder-aibridge/v1"}, + {name: "Anthropic", providerType: codersdk.AIProviderTypeAnthropic, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, + {name: "Bedrock", providerType: codersdk.AIProviderTypeBedrock, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, + {name: "Google", providerType: codersdk.AIProviderTypeGoogle, wantProvider: "openai-compat", wantBaseURL: "http://coder-aibridge/v1"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index e7a613c276431..923aa2f8ea927 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -522,7 +522,7 @@ func (p *Server) resolveModelConfigAndNormalizedProvider( canonicalType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - canonicalType = provider.Type + canonicalType = codersdk.AIProviderType(provider.Type) } providerName := chatprovider.NormalizeProvider(string(canonicalType)) if providerName == "" { From 66c55444084c1fd1cde7225155bb7b1af96c2722 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 5 Jun 2026 22:00:59 +0100 Subject: [PATCH 03/11] add test coverage for direct provider hints in case of invalid config --- coderd/x/chatd/chatd_internal_test.go | 77 ++++++++++++++++++++++++++ coderd/x/chatd/model_routing_direct.go | 3 +- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 0286cf3394a9c..90197b0e270e5 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -307,6 +307,83 @@ func TestResolveDirectModelRouteForProviderTypeFallsBackToRawProviderType(t *tes require.Equal(t, "test-key", route.directProviderKeys().APIKey("openai")) } +func TestResolveAIGatewayModelRouteForProviderTypeMatchesCanonicalBedrockForAnthropicRequest(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + providerID := uuid.New() + + rawSettings, err := json.Marshal(codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }) + require.NoError(t, err) + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeAnthropic, + Enabled: true, + Settings: sql.NullString{String: string(rawSettings), Valid: true}, + } + + db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return([]database.AIProvider{provider}, nil) + + server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + route, err := server.resolveAIGatewayModelRouteForProviderType( + ctx, + ownerID, + chattool.ComputerUseProviderAnthropic, + ) + require.NoError(t, err) + + require.Equal(t, modelRouteKindAIGateway, route.kind) + providerHint, err := route.providerHint() + require.NoError(t, err) + require.Equal(t, "bedrock", providerHint) +} + +func TestDirectProviderHintAndProviderForConfigErrorsOnMalformedSettings(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + providerID := uuid.New() + + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeAnthropic, + Enabled: true, + Settings: sql.NullString{String: "{", Valid: true}, + } + db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil) + + server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + _, _, err := server.directProviderHintAndProviderForConfig(ctx, database.ChatModelConfig{ + AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true}, + }) + require.Error(t, err) +} + +func TestAIProviderConfigFromKeysFallsBackToRawTypeOnMalformedSettings(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + + provider := database.AIProvider{ + ID: uuid.New(), + Type: database.AiProviderTypeAnthropic, + Enabled: true, + Settings: sql.NullString{String: "{", Valid: true}, + } + + server := &Server{logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + cfg, err := server.aiProviderConfigFromKeys(ctx, provider, nil) + require.NoError(t, err) + require.Equal(t, string(database.AiProviderTypeAnthropic), cfg.Provider) +} + func TestResolveModelRouteForProviderTypeAIGatewayRequiresProvider(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/model_routing_direct.go b/coderd/x/chatd/model_routing_direct.go index a11dd05923dbc..6bb2b3e886b15 100644 --- a/coderd/x/chatd/model_routing_direct.go +++ b/coderd/x/chatd/model_routing_direct.go @@ -103,8 +103,7 @@ func (p *Server) directProviderHintAndProviderForConfig( } canonicalType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - canonicalType = codersdk.AIProviderType(provider.Type) + return "", nil, xerrors.Errorf("canonicalize provider type for %q: %w", provider.Name, err) } return string(canonicalType), &provider, nil } From 5237912959614301480338581ee85d922cc9adcb Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 5 Jun 2026 22:38:47 +0100 Subject: [PATCH 04/11] add coverage for enabledProviderContainsName --- coderd/x/chatd/subagent_internal_test.go | 78 ++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index ce860f124929c..ecc103477d3d0 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -3600,3 +3600,81 @@ func TestAwaitSubagentCompletion(t *testing.T) { assert.Equal(t, "zero timeout ok", report) }) } + +func TestEnabledProviderContainsName(t *testing.T) { + t.Parallel() + + bedrockSettings, err := json.Marshal(codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }) + require.NoError(t, err) + + cases := []struct { + name string + providers []database.AIProvider + query string + want bool + }{ + { + name: "canonical match", + providers: []database.AIProvider{ + {Type: database.AiProviderTypeAnthropic}, + }, + query: "anthropic", + want: true, + }, + { + name: "legacy bedrock canonical match for anthropic request", + providers: []database.AIProvider{ + { + Type: database.AiProviderTypeAnthropic, + Settings: sql.NullString{String: string(bedrockSettings), Valid: true}, + }, + }, + query: "anthropic", + want: true, + }, + { + // First pass errors (malformed JSON), second pass matches via raw type. + name: "malformed settings falls back to raw type", + providers: []database.AIProvider{ + { + Type: database.AiProviderTypeOpenai, + Settings: sql.NullString{String: "{", Valid: true}, + }, + }, + query: "openai", + want: true, + }, + { + // First pass errors, second pass matches bedrock satisfying anthropic request. + name: "malformed bedrock satisfies anthropic via raw type", + providers: []database.AIProvider{ + { + Type: database.AiProviderTypeBedrock, + Settings: sql.NullString{String: "{", Valid: true}, + }, + }, + query: "anthropic", + want: true, + }, + { + name: "no match", + providers: []database.AIProvider{ + {Type: database.AiProviderTypeOpenai}, + }, + query: "anthropic", + want: false, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + got := enabledProviderContainsName(ctx, logger, tt.providers, tt.query) + require.Equal(t, tt.want, got) + }) + } +} From 8ca2ed14d647c841503555d7a1ee5ad19f221e4f Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 5 Jun 2026 22:51:15 +0100 Subject: [PATCH 05/11] fix: skip unnecessary provider type recanonicalization --- coderd/x/chatd/chatd.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index b4e93f57d712b..5f34d890a131d 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -8797,12 +8797,7 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType( if !aiProviderMatchesRawType(provider, normalizedProviderType) { continue } - canonicalType, err := db2sdk.CanonicalAIProviderType(provider) - if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - canonicalType = codersdk.AIProviderType(provider.Type) - } - providerKeysType := chatprovider.NormalizeProvider(string(canonicalType)) + providerKeysType := chatprovider.NormalizeProvider(string(provider.Type)) keys, matchedProvider, err := keysForProvider(provider, providerKeysType) if err != nil || matchedProvider != nil { return keys, matchedProvider, err From 4d76f03cb71c37d9d8c4ef5fcb0c32ef2ebe6c30 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 5 Jun 2026 22:54:52 +0100 Subject: [PATCH 06/11] fix incorrect log --- coderd/ai_providers_migrate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/ai_providers_migrate.go b/coderd/ai_providers_migrate.go index f711cd0fe3a2b..45112d78b5f42 100644 --- a/coderd/ai_providers_migrate.go +++ b/coderd/ai_providers_migrate.go @@ -141,7 +141,7 @@ func SeedAIProvidersFromEnv( case found: existingSettings, err := db2sdk.AIProviderSettings(existing.Settings) if err != nil { - return xerrors.Errorf("decode existing settings for %q: %w", dp.Name, err) + return xerrors.Errorf("decode existing settings for %q: %w", existing.Name, err) } // Load existing bearer keys so the canonical hash // includes credentials for comparison. From a3bb9f52eb1427bf95139c786cb8957eef2470b4 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 5 Jun 2026 23:08:36 +0100 Subject: [PATCH 07/11] second round of deep-review fixes --- coderd/x/chatd/chatd_internal_test.go | 11 ++++++++- coderd/x/chatd/model_routing_aibridge.go | 21 +++++------------ coderd/x/chatd/subagent_internal_test.go | 29 ++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 90197b0e270e5..c0aa764fc9936 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -338,6 +338,7 @@ func TestResolveAIGatewayModelRouteForProviderTypeMatchesCanonicalBedrockForAnth require.NoError(t, err) require.Equal(t, modelRouteKindAIGateway, route.kind) + require.Equal(t, providerID, route.aiGateway.Provider.ID) providerHint, err := route.providerHint() require.NoError(t, err) require.Equal(t, "bedrock", providerHint) @@ -371,11 +372,19 @@ func TestAIProviderConfigFromKeysFallsBackToRawTypeOnMalformedSettings(t *testin ctx := testutil.Context(t, testutil.WaitShort) + // Build valid Bedrock settings then corrupt the JSON. If parsing succeeded, + // CanonicalAIProviderType would return "bedrock". The fallback must return + // the raw provider.Type ("anthropic"), confirming the fallback path fired. + bedrockSettings, err := json.Marshal(codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }) + require.NoError(t, err) + provider := database.AIProvider{ ID: uuid.New(), Type: database.AiProviderTypeAnthropic, Enabled: true, - Settings: sql.NullString{String: "{", Valid: true}, + Settings: sql.NullString{String: string(bedrockSettings[:len(bedrockSettings)-1]), Valid: true}, } server := &Server{logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index c056eeefb0989..262094419f03b 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -215,15 +215,6 @@ func fantasyConfigForAIBridge(providerType codersdk.AIProviderType) aibridgeFant } } -func aiGatewayRequestFormatForProviderType(providerType codersdk.AIProviderType) aiGatewayRequestFormat { - switch providerType { - case codersdk.AIProviderTypeAnthropic, codersdk.AIProviderTypeBedrock: - return aiGatewayRequestFormatAnthropic - default: - return aiGatewayRequestFormatOpenAI - } -} - func (p *Server) aiGatewayProviderAuthForUser( ctx context.Context, ownerID uuid.UUID, @@ -264,12 +255,12 @@ func (p *Server) resolveAIGatewayRoute( provider database.AIProvider, modelProviderHint string, ) (resolvedModelRoute, error) { - auth, err := p.aiGatewayProviderAuthForUser( - ctx, - ownerID, - provider, - aiGatewayRequestFormatForProviderType(codersdk.AIProviderType(modelProviderHint)), - ) + format := aiGatewayRequestFormatOpenAI + switch codersdk.AIProviderType(modelProviderHint) { + case codersdk.AIProviderTypeAnthropic, codersdk.AIProviderTypeBedrock: + format = aiGatewayRequestFormatAnthropic + } + auth, err := p.aiGatewayProviderAuthForUser(ctx, ownerID, provider, format) if err != nil { return resolvedModelRoute{}, xerrors.Errorf("resolve AI Gateway provider auth: %w", err) } diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index ecc103477d3d0..d0b334581ed98 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -3658,6 +3658,35 @@ func TestEnabledProviderContainsName(t *testing.T) { query: "anthropic", want: true, }, + { + name: "empty providers", + providers: []database.AIProvider{}, + query: "anthropic", + want: false, + }, + { + // First provider fails canonicalization; second matches canonically. + // Verifies the first-pass loop continues on error rather than aborting. + name: "first provider fails canonicalization, second matches canonically", + providers: []database.AIProvider{ + { + Type: database.AiProviderTypeAnthropic, + Settings: sql.NullString{String: "{", Valid: true}, + }, + {Type: database.AiProviderTypeAnthropic}, + }, + query: "anthropic", + want: true, + }, + { + // enabledProviderContainsName has no Enabled guard; callers pre-filter. + name: "disabled provider still matches", + providers: []database.AIProvider{ + {Type: database.AiProviderTypeAnthropic, Enabled: false}, + }, + query: "anthropic", + want: true, + }, { name: "no match", providers: []database.AIProvider{ From 97c75d5825cbf204573e47cd08af679860442bef Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Mon, 8 Jun 2026 10:25:47 +0100 Subject: [PATCH 08/11] chore: address R1 findings --- coderd/ai_providers_migrate.go | 10 +++++----- coderd/x/chatd/ai_provider_canonical.go | 5 +++-- coderd/x/chatd/chatd.go | 4 ++-- coderd/x/chatd/chatd_internal_test.go | 6 ++---- coderd/x/chatd/model_routing_aibridge.go | 4 ++-- coderd/x/chatd/model_routing_direct.go | 2 +- coderd/x/chatd/subagent.go | 4 ++-- coderd/x/chatd/subagent_internal_test.go | 1 - 8 files changed, 17 insertions(+), 19 deletions(-) diff --git a/coderd/ai_providers_migrate.go b/coderd/ai_providers_migrate.go index 45112d78b5f42..bc4001b3dfc34 100644 --- a/coderd/ai_providers_migrate.go +++ b/coderd/ai_providers_migrate.go @@ -108,8 +108,8 @@ func SeedAIProvidersFromEnv( if err != nil { return xerrors.Errorf("decode existing settings for %q: %w", existing.Name, err) } - existingDType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(existing.Type), existingSettings)) - if existingDType == database.AiProviderTypeBedrock { + existingType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(existing.Type), existingSettings)) + if existingType == database.AiProviderTypeBedrock { logger.Warn(sysCtx, "skipping legacy Anthropic env seed because an existing Anthropic-named row contains Bedrock settings", slog.F("name", dp.Name), ) @@ -123,8 +123,8 @@ func SeedAIProvidersFromEnv( if err != nil { return xerrors.Errorf("decode existing settings for %q: %w", candidate.Name, err) } - candidateDType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(candidate.Type), candidateSettings)) - if candidateDType == database.AiProviderTypeBedrock { + candidateType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(candidate.Type), candidateSettings)) + if candidateType == database.AiProviderTypeBedrock { existing = candidate found = true } @@ -147,7 +147,7 @@ func SeedAIProvidersFromEnv( // includes credentials for comparison. existingKeyRows, err := tx.GetAIProviderKeysByProviderID(sysCtx, existing.ID) if err != nil { - return xerrors.Errorf("load existing keys for %q: %w", dp.Name, err) + return xerrors.Errorf("load existing keys for %q: %w", existing.Name, err) } existingKeys := make([]string, 0, len(existingKeyRows)) for _, k := range existingKeyRows { diff --git a/coderd/x/chatd/ai_provider_canonical.go b/coderd/x/chatd/ai_provider_canonical.go index 61c50416912ec..7ef9b38bf74ff 100644 --- a/coderd/x/chatd/ai_provider_canonical.go +++ b/coderd/x/chatd/ai_provider_canonical.go @@ -4,14 +4,15 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/codersdk" ) func aiProviderTypeCanSatisfyRequest(candidateProviderType string, requestedProviderType string) bool { if candidateProviderType == requestedProviderType { return true } - return requestedProviderType == string(database.AiProviderTypeAnthropic) && - candidateProviderType == string(database.AiProviderTypeBedrock) + return requestedProviderType == string(codersdk.AIProviderTypeAnthropic) && + candidateProviderType == string(codersdk.AIProviderTypeBedrock) } func aiProviderMatchesCanonicalType(provider database.AIProvider, normalizedProviderType string) (bool, error) { diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 5f34d890a131d..df5b865732993 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -8660,7 +8660,7 @@ func (p *Server) aiProviderConfigFromKeys(ctx context.Context, provider database } canonicalType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + p.logger.Warn(ctx, "parse AI provider settings for provider config", slog.F("provider_id", provider.ID), slog.Error(err)) canonicalType = codersdk.AIProviderType(provider.Type) } return chatprovider.ConfiguredProvider{ @@ -8781,7 +8781,7 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType( for _, provider := range providers { canonicalProviderType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + p.logger.Warn(ctx, "parse AI provider settings during key resolution", slog.F("provider_id", provider.ID), slog.Error(err)) continue } providerKeysType := chatprovider.NormalizeProvider(string(canonicalProviderType)) diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index c0aa764fc9936..cd2eb530bb9cf 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -364,7 +364,7 @@ func TestDirectProviderHintAndProviderForConfigErrorsOnMalformedSettings(t *test _, _, err := server.directProviderHintAndProviderForConfig(ctx, database.ChatModelConfig{ AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true}, }) - require.Error(t, err) + require.ErrorContains(t, err, "canonicalize") } func TestAIProviderConfigFromKeysFallsBackToRawTypeOnMalformedSettings(t *testing.T) { @@ -372,9 +372,7 @@ func TestAIProviderConfigFromKeysFallsBackToRawTypeOnMalformedSettings(t *testin ctx := testutil.Context(t, testutil.WaitShort) - // Build valid Bedrock settings then corrupt the JSON. If parsing succeeded, - // CanonicalAIProviderType would return "bedrock". The fallback must return - // the raw provider.Type ("anthropic"), confirming the fallback path fired. + // Corrupt valid Bedrock JSON to verify the fallback returns the raw type rather than "bedrock". bedrockSettings, err := json.Marshal(codersdk.AIProviderSettings{ Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, }) diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index 262094419f03b..6c39a21453582 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -294,7 +294,7 @@ func (p *Server) resolveAIGatewayModelRouteForProviderType( } canonicalType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + p.logger.Warn(ctx, "parse AI provider settings for AI gateway route", slog.F("provider_id", provider.ID), slog.Error(err)) canonicalType = codersdk.AIProviderType(provider.Type) } return p.resolveAIGatewayRoute(ctx, ownerID, provider, string(canonicalType)) @@ -329,7 +329,7 @@ func (p *Server) aiProviderForProviderType( } matches, err := aiProviderMatchesCanonicalType(provider, normalizedProviderType) if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + p.logger.Warn(ctx, "parse AI provider settings during provider type lookup", slog.F("provider_id", provider.ID), slog.Error(err)) continue } if !matches { diff --git a/coderd/x/chatd/model_routing_direct.go b/coderd/x/chatd/model_routing_direct.go index 6bb2b3e886b15..5e020914fcb65 100644 --- a/coderd/x/chatd/model_routing_direct.go +++ b/coderd/x/chatd/model_routing_direct.go @@ -82,7 +82,7 @@ func (p *Server) resolveDirectModelRouteForProviderType( if provider != nil { canonicalType, err := db2sdk.CanonicalAIProviderType(*provider) if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + p.logger.Warn(ctx, "parse AI provider settings for direct route", slog.F("provider_id", provider.ID), slog.Error(err)) canonicalType = codersdk.AIProviderType(provider.Type) } providerHint = chatprovider.NormalizeProvider(string(canonicalType)) diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 923aa2f8ea927..6aeaf0888ad95 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -165,7 +165,7 @@ func enabledProviderContainsName( for _, provider := range providers { matches, err := aiProviderMatchesCanonicalType(provider, normalizedProviderName) if err != nil { - logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + logger.Warn(ctx, "parse AI provider settings during provider name check", slog.F("provider_id", provider.ID), slog.Error(err)) continue } if matches { @@ -521,7 +521,7 @@ func (p *Server) resolveModelConfigAndNormalizedProvider( } canonicalType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + p.logger.Warn(ctx, "parse AI provider settings for model config", slog.F("provider_id", provider.ID), slog.Error(err)) canonicalType = codersdk.AIProviderType(provider.Type) } providerName := chatprovider.NormalizeProvider(string(canonicalType)) diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index d0b334581ed98..cbf222620caef 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -3665,7 +3665,6 @@ func TestEnabledProviderContainsName(t *testing.T) { want: false, }, { - // First provider fails canonicalization; second matches canonically. // Verifies the first-pass loop continues on error rather than aborting. name: "first provider fails canonicalization, second matches canonically", providers: []database.AIProvider{ From 40c0d0be2c436a503a11af69aaef09d68d43772b Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Mon, 8 Jun 2026 11:33:27 +0100 Subject: [PATCH 09/11] address R2 P3 --- coderd/x/chatd/model_routing_aibridge.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index 6c39a21453582..553b8f7d81c47 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -170,11 +170,7 @@ func (p *Server) newAIGatewayModel( baseRT = &chatdebug.RecordingTransport{Base: baseRT} } - providerType, err := db2sdk.CanonicalAIProviderType(route.Provider) - if err != nil { - return nil, xerrors.Errorf("canonicalize provider type for %q: %w", route.Provider.Name, err) - } - config := fantasyConfigForAIBridge(providerType) + config := fantasyConfigForAIBridge(codersdk.AIProviderType(route.ModelProviderHint)) return newLanguageModel( config.ProviderHint, req.ModelName, From 06fdfc0fbc3bdf0f6d1327acac76a09eca079561 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Mon, 8 Jun 2026 11:49:35 +0100 Subject: [PATCH 10/11] make canonicalized provider its own type --- coderd/x/chatd/chatd_internal_test.go | 2 +- coderd/x/chatd/model_routing.go | 5 +-- coderd/x/chatd/model_routing_aibridge.go | 36 +++++++++++-------- coderd/x/chatd/model_routing_internal_test.go | 12 +++++-- 4 files changed, 35 insertions(+), 20 deletions(-) diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index cd2eb530bb9cf..364668b62c047 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -338,7 +338,7 @@ func TestResolveAIGatewayModelRouteForProviderTypeMatchesCanonicalBedrockForAnth require.NoError(t, err) require.Equal(t, modelRouteKindAIGateway, route.kind) - require.Equal(t, providerID, route.aiGateway.Provider.ID) + require.Equal(t, providerID, route.aiGateway.Provider.Row.ID) providerHint, err := route.providerHint() require.NoError(t, err) require.Equal(t, "bedrock", providerHint) diff --git a/coderd/x/chatd/model_routing.go b/coderd/x/chatd/model_routing.go index c5fa7129db7d4..0a7dac8c6e8c0 100644 --- a/coderd/x/chatd/model_routing.go +++ b/coderd/x/chatd/model_routing.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/codersdk" ) type modelClientRequest struct { @@ -57,7 +58,7 @@ func (r resolvedModelRoute) providerHint() (string, error) { case modelRouteKindDirect: return r.direct.ProviderHint, nil case modelRouteKindAIGateway: - return r.aiGateway.ModelProviderHint, nil + return string(r.aiGateway.Provider.CanonicalType), nil default: return "", xerrors.New("model route is not configured") } @@ -68,7 +69,7 @@ func (r resolvedModelRoute) withProviderHint(providerHint string) resolvedModelR case modelRouteKindDirect: r.direct.ProviderHint = providerHint case modelRouteKindAIGateway: - r.aiGateway.ModelProviderHint = providerHint + r.aiGateway.Provider.CanonicalType = codersdk.AIProviderType(providerHint) } return r } diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index 553b8f7d81c47..a6a98d9d11485 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -31,23 +31,28 @@ const ( aibridgeDelegatedBYOKMarker = "delegated" ) +// canonicalizedProvider pairs a database row with its pre-computed canonical +// provider type. Keeping them together ensures the type is derived once and +// used consistently; callers cannot re-canonicalize and risk a different result. +type canonicalizedProvider struct { + Row database.AIProvider + CanonicalType codersdk.AIProviderType +} + type aiGatewayModelRoute struct { - Provider database.AIProvider - ModelProviderHint string - ProviderAuth aiGatewayProviderAuth + Provider canonicalizedProvider + ProviderAuth aiGatewayProviderAuth } func newAIGatewayModelRoute( - provider database.AIProvider, - modelProviderHint string, + provider canonicalizedProvider, auth aiGatewayProviderAuth, ) resolvedModelRoute { return resolvedModelRoute{ kind: modelRouteKindAIGateway, aiGateway: aiGatewayModelRoute{ - Provider: provider, - ModelProviderHint: modelProviderHint, - ProviderAuth: auth, + Provider: provider, + ProviderAuth: auth, }, } } @@ -121,10 +126,10 @@ func (p *Server) newAIGatewayModel( route aiGatewayModelRoute, opts modelBuildOptions, ) (fantasy.LanguageModel, error) { - if route.Provider.ID == uuid.Nil { + if route.Provider.Row.ID == uuid.Nil { return nil, xerrors.New("AI Gateway routing requires a concrete AI provider") } - if route.Provider.Name == "" { + if route.Provider.Row.Name == "" { return nil, xerrors.New("AI Gateway routing requires an AI provider name") } if opts.ActiveAPIKeyID == "" { @@ -138,7 +143,7 @@ func (p *Server) newAIGatewayModel( ) } - if err := ValidateAIGatewayProviderModel(route.Provider, req.ModelName); err != nil { + if err := ValidateAIGatewayProviderModel(route.Provider.Row, req.ModelName); err != nil { return nil, chaterror.WithClassification( err, chaterror.ClassifiedError{ @@ -157,7 +162,7 @@ func (p *Server) newAIGatewayModel( if factory == nil || *factory == nil { return nil, xerrors.New("AI Gateway transport factory is not configured") } - rt, err := (*factory).TransportFor(route.Provider.Name, aibridge.SourceAgents) + rt, err := (*factory).TransportFor(route.Provider.Row.Name, aibridge.SourceAgents) if err != nil { return nil, xerrors.Errorf("create AI Gateway transport: %w", err) } @@ -170,7 +175,7 @@ func (p *Server) newAIGatewayModel( baseRT = &chatdebug.RecordingTransport{Base: baseRT} } - config := fantasyConfigForAIBridge(codersdk.AIProviderType(route.ModelProviderHint)) + config := fantasyConfigForAIBridge(route.Provider.CanonicalType) return newLanguageModel( config.ProviderHint, req.ModelName, @@ -260,7 +265,10 @@ func (p *Server) resolveAIGatewayRoute( if err != nil { return resolvedModelRoute{}, xerrors.Errorf("resolve AI Gateway provider auth: %w", err) } - return newAIGatewayModelRoute(provider, modelProviderHint, auth), nil + return newAIGatewayModelRoute(canonicalizedProvider{ + Row: provider, + CanonicalType: codersdk.AIProviderType(modelProviderHint), + }, auth), nil } func (p *Server) resolveAIGatewayModelRouteForConfig( diff --git a/coderd/x/chatd/model_routing_internal_test.go b/coderd/x/chatd/model_routing_internal_test.go index dce5e7597d3b0..539be6aafd2c3 100644 --- a/coderd/x/chatd/model_routing_internal_test.go +++ b/coderd/x/chatd/model_routing_internal_test.go @@ -65,7 +65,10 @@ func aibridgeTestAIProvider(providerID uuid.UUID, providerName string, providerT } func aibridgeTestRoute(aiProvider database.AIProvider) resolvedModelRoute { - return newAIGatewayModelRoute(aiProvider, string(aiProvider.Type), aiGatewayProviderAuth{}) + return newAIGatewayModelRoute(canonicalizedProvider{ + Row: aiProvider, + CanonicalType: codersdk.AIProviderType(aiProvider.Type), + }, aiGatewayProviderAuth{}) } func aibridgeTestRequest(chat database.Chat, model string) modelClientRequest { @@ -310,7 +313,10 @@ func TestAIGatewayModelForwardsProviderAuth(t *testing.T) { aiGatewayRoutingEnabled: true, aibridgeTransportFactory: aibridgeTestFactoryPointer(factory), } - route := newAIGatewayModelRoute(provider, string(provider.Type), auth) + route := newAIGatewayModelRoute(canonicalizedProvider{ + Row: provider, + CanonicalType: codersdk.AIProviderType(provider.Type), + }, auth) return server, route } @@ -668,7 +674,7 @@ func TestAIBridgeRoutingFailClosed(t *testing.T) { t.Run("StaticModel", func(t *testing.T) { t.Parallel() server := &Server{aiGatewayRoutingEnabled: true} - _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), newAIGatewayModelRoute(database.AIProvider{}, "", aiGatewayProviderAuth{}), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()}) + _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), newAIGatewayModelRoute(canonicalizedProvider{}, aiGatewayProviderAuth{}), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()}) require.ErrorContains(t, err, "concrete AI provider") }) } From ffedf25b79cc87ba223ee0bf8812c4a29f39b7a7 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Mon, 8 Jun 2026 15:05:48 +0100 Subject: [PATCH 11/11] address CRF-10 --- coderd/x/chatd/model_routing_aibridge.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index a6a98d9d11485..2baeff87d4cc8 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -32,8 +32,8 @@ const ( ) // canonicalizedProvider pairs a database row with its pre-computed canonical -// provider type. Keeping them together ensures the type is derived once and -// used consistently; callers cannot re-canonicalize and risk a different result. +// provider type, so the type is derived once at the routing boundary and +// carried through rather than re-derived from the row at each use site. type canonicalizedProvider struct { Row database.AIProvider CanonicalType codersdk.AIProviderType