diff --git a/coderd/ai_provider_canonical.go b/coderd/ai_provider_canonical.go deleted file mode 100644 index e0a43992c365d..0000000000000 --- a/coderd/ai_provider_canonical.go +++ /dev/null @@ -1,15 +0,0 @@ -package coderd - -import ( - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/codersdk" -) - -func canonicalDatabaseAIProviderType(providerType database.AIProviderType, settings codersdk.AIProviderSettings) database.AIProviderType { - return database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(providerType), settings)) -} - -func canonicalAIProviderTypeForRow(provider database.AIProvider) (database.AIProviderType, error) { - return db2sdk.CanonicalAIProviderType(provider) -} diff --git a/coderd/ai_providers.go b/coderd/ai_providers.go index 5e1565f4e9155..f1efe24dbc465 100644 --- a/coderd/ai_providers.go +++ b/coderd/ai_providers.go @@ -311,7 +311,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) { if req.Settings != nil { existing = mergeAIProviderSettings(existing, *req.Settings) } - targetType := canonicalDatabaseAIProviderType(old.Type, existing) + targetType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(old.Type), existing)) targetBaseURL := ptr.NilToDefault(req.BaseURL, old.BaseUrl) // Bedrock settings are only meaningful for Bedrock providers; // rejecting the mismatch keeps a misconfiguration from sitting diff --git a/coderd/ai_providers_migrate.go b/coderd/ai_providers_migrate.go index 93a383e7b8ed5..bc4001b3dfc34 100644 --- a/coderd/ai_providers_migrate.go +++ b/coderd/ai_providers_migrate.go @@ -108,7 +108,8 @@ func SeedAIProvidersFromEnv( if err != nil { return xerrors.Errorf("decode existing settings for %q: %w", existing.Name, err) } - if canonicalDatabaseAIProviderType(existing.Type, existingSettings) == database.AiProviderTypeBedrock { + existingType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(existing.Type), existingSettings)) + if existingType == database.AiProviderTypeBedrock { logger.Warn(sysCtx, "skipping legacy Anthropic env seed because an existing Anthropic-named row contains Bedrock settings", slog.F("name", dp.Name), ) @@ -122,7 +123,8 @@ func SeedAIProvidersFromEnv( if err != nil { return xerrors.Errorf("decode existing settings for %q: %w", candidate.Name, err) } - if canonicalDatabaseAIProviderType(candidate.Type, candidateSettings) == database.AiProviderTypeBedrock { + candidateType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(candidate.Type), candidateSettings)) + if candidateType == database.AiProviderTypeBedrock { existing = candidate found = true } @@ -139,20 +141,20 @@ func SeedAIProvidersFromEnv( case found: existingSettings, err := db2sdk.AIProviderSettings(existing.Settings) if err != nil { - return xerrors.Errorf("decode existing settings for %q: %w", dp.Name, err) + return xerrors.Errorf("decode existing settings for %q: %w", existing.Name, err) } // Load existing bearer keys so the canonical hash // includes credentials for comparison. existingKeyRows, err := tx.GetAIProviderKeysByProviderID(sysCtx, existing.ID) if err != nil { - return xerrors.Errorf("load existing keys for %q: %w", dp.Name, err) + return xerrors.Errorf("load existing keys for %q: %w", existing.Name, err) } existingKeys := make([]string, 0, len(existingKeyRows)) for _, k := range existingKeyRows { existingKeys = append(existingKeys, k.APIKey) } existingDP := desiredAIProvider{ - Type: canonicalDatabaseAIProviderType(existing.Type, existingSettings), + Type: database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(existing.Type), existingSettings)), BaseURL: existing.BaseUrl, Bedrock: existingSettings.Bedrock, Keys: existingKeys, diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 0ac0de75be6d7..69259ab03ac92 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -44,12 +44,12 @@ func APIAllowListTarget(entry rbac.AllowListElement) codersdk.APIAllowListTarget } // CanonicalAIProviderType returns the runtime provider type for a database row. -func CanonicalAIProviderType(row database.AIProvider) (database.AIProviderType, error) { +func CanonicalAIProviderType(row database.AIProvider) (codersdk.AIProviderType, error) { settings, err := AIProviderSettings(row.Settings) if err != nil { return "", xerrors.Errorf("decode settings: %w", err) } - return database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(row.Type), settings)), nil + return codersdk.CanonicalAIProviderType(codersdk.AIProviderType(row.Type), settings), nil } // AIProvider converts a database row plus its API keys into the diff --git a/coderd/database/db2sdk/db2sdk_test.go b/coderd/database/db2sdk/db2sdk_test.go index fa10566d38dea..e034fbe4a0770 100644 --- a/coderd/database/db2sdk/db2sdk_test.go +++ b/coderd/database/db2sdk/db2sdk_test.go @@ -95,7 +95,7 @@ func TestAIProviderCanonicalTypeAndDisplayName(t *testing.T) { gotType, err := db2sdk.CanonicalAIProviderType(tt.row) require.NoError(t, err) - require.Equal(t, database.AIProviderType(tt.wantType), gotType) + require.Equal(t, tt.wantType, gotType) require.Equal(t, tt.wantDisplay, db2sdk.AIProviderDisplayName(tt.row, tt.wantType)) got, err := db2sdk.AIProvider(tt.row, nil) diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 4b87484fee6bc..44cabe47a801a 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -6532,16 +6532,15 @@ func parseUserAIProviderID(r *http.Request) (uuid.UUID, error) { } func convertAIProviderSummary(provider database.AIProvider) (codersdk.AIProviderSummary, error) { - providerType, err := canonicalAIProviderTypeForRow(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { return codersdk.AIProviderSummary{}, err } - sdkProviderType := codersdk.AIProviderType(providerType) return codersdk.AIProviderSummary{ ID: provider.ID, - Type: sdkProviderType, + Type: providerType, Name: provider.Name, - DisplayName: db2sdk.AIProviderDisplayName(provider, sdkProviderType), + DisplayName: db2sdk.AIProviderDisplayName(provider, providerType), Enabled: provider.Enabled, Deleted: provider.Deleted, }, nil @@ -6741,7 +6740,7 @@ func (api *API) configuredProvidersFromAIProviders(ctx context.Context, provider } func (api *API) configuredProviderFromAIProviderKeys(ctx context.Context, provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) { - providerType, err := canonicalAIProviderTypeForRow(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { api.Logger.Error(ctx, "failed to decode AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) return chatprovider.ConfiguredProvider{}, err @@ -6837,7 +6836,7 @@ func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) { return } for _, provider := range providers { - providerType, err := canonicalAIProviderTypeForRow(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { api.Logger.Error(ctx, "failed to decode AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ @@ -6911,7 +6910,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."}) return } - providerType, err := canonicalAIProviderTypeForRow(aiProvider) + providerType, err := db2sdk.CanonicalAIProviderType(aiProvider) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to decode AI provider settings.", @@ -7001,7 +7000,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { if !lockedAIProvider.Enabled { return errChatProviderNotConfigured } - lockedProviderType, err := canonicalAIProviderTypeForRow(lockedAIProvider) + lockedProviderType, err := db2sdk.CanonicalAIProviderType(lockedAIProvider) if err != nil { return xerrors.Errorf("canonicalize provider type for %q: %w", lockedAIProvider.Name, err) } @@ -7124,7 +7123,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { return } } else { - providerType, err := canonicalAIProviderTypeForRow(aiProvider) + providerType, err := db2sdk.CanonicalAIProviderType(aiProvider) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to decode AI provider settings.", @@ -7165,7 +7164,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."}) return } - providerType, err := canonicalAIProviderTypeForRow(aiProvider) + providerType, err := db2sdk.CanonicalAIProviderType(aiProvider) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to decode AI provider settings.", @@ -7262,7 +7261,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { if !aiProvider.Enabled { return errChatProviderNotConfigured } - providerType, err := canonicalAIProviderTypeForRow(aiProvider) + providerType, err := db2sdk.CanonicalAIProviderType(aiProvider) if err != nil { return xerrors.Errorf("canonicalize provider type for %q: %w", aiProvider.Name, err) } diff --git a/coderd/x/chatd/ai_provider_canonical.go b/coderd/x/chatd/ai_provider_canonical.go index fdf8b991649b1..7ef9b38bf74ff 100644 --- a/coderd/x/chatd/ai_provider_canonical.go +++ b/coderd/x/chatd/ai_provider_canonical.go @@ -1,53 +1,26 @@ package chatd import ( - "context" - - "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/codersdk" ) -func canonicalAIProviderType(provider database.AIProvider) (database.AIProviderType, error) { - return db2sdk.CanonicalAIProviderType(provider) -} - -func canonicalAIProviderTypeString(provider database.AIProvider) (string, error) { - providerType, err := canonicalAIProviderType(provider) - if err != nil { - return "", err - } - return string(providerType), nil -} - -func bestEffortCanonicalAIProviderType(ctx context.Context, logger slog.Logger, provider database.AIProvider) database.AIProviderType { - providerType, err := canonicalAIProviderType(provider) - if err != nil { - logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) - return provider.Type - } - return providerType -} - -func bestEffortCanonicalAIProviderTypeString(ctx context.Context, logger slog.Logger, provider database.AIProvider) string { - return string(bestEffortCanonicalAIProviderType(ctx, logger, provider)) -} - func aiProviderTypeCanSatisfyRequest(candidateProviderType string, requestedProviderType string) bool { if candidateProviderType == requestedProviderType { return true } - return requestedProviderType == string(database.AiProviderTypeAnthropic) && - candidateProviderType == string(database.AiProviderTypeBedrock) + return requestedProviderType == string(codersdk.AIProviderTypeAnthropic) && + candidateProviderType == string(codersdk.AIProviderTypeBedrock) } func aiProviderMatchesCanonicalType(provider database.AIProvider, normalizedProviderType string) (bool, error) { - providerType, err := canonicalAIProviderTypeString(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { return false, err } - return aiProviderTypeCanSatisfyRequest(chatprovider.NormalizeProvider(providerType), normalizedProviderType), nil + return aiProviderTypeCanSatisfyRequest(chatprovider.NormalizeProvider(string(providerType)), normalizedProviderType), nil } func aiProviderMatchesRawType(provider database.AIProvider, normalizedProviderType string) bool { diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index fd78e0338893e..df5b865732993 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -8658,9 +8658,14 @@ func (p *Server) aiProviderConfigFromKeys(ctx context.Context, provider database break } } + canonicalType, err := db2sdk.CanonicalAIProviderType(provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings for provider config", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = codersdk.AIProviderType(provider.Type) + } return chatprovider.ConfiguredProvider{ ProviderID: provider.ID, - Provider: bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider), + Provider: string(canonicalType), APIKey: apiKey, BaseURL: provider.BaseUrl, CentralAPIKeyEnabled: true, @@ -8774,12 +8779,12 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType( return chatprovider.ProviderAPIKeys{}, nil, nil } for _, provider := range providers { - canonicalProviderType, err := canonicalAIProviderTypeString(provider) + canonicalProviderType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + p.logger.Warn(ctx, "parse AI provider settings during key resolution", slog.F("provider_id", provider.ID), slog.Error(err)) continue } - providerKeysType := chatprovider.NormalizeProvider(canonicalProviderType) + providerKeysType := chatprovider.NormalizeProvider(string(canonicalProviderType)) if !aiProviderTypeCanSatisfyRequest(providerKeysType, normalizedProviderType) { continue } @@ -8792,7 +8797,7 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType( if !aiProviderMatchesRawType(provider, normalizedProviderType) { continue } - providerKeysType := chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider)) + providerKeysType := chatprovider.NormalizeProvider(string(provider.Type)) keys, matchedProvider, err := keysForProvider(provider, providerKeysType) if err != nil || matchedProvider != nil { return keys, matchedProvider, err diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 0286cf3394a9c..364668b62c047 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -307,6 +307,90 @@ func TestResolveDirectModelRouteForProviderTypeFallsBackToRawProviderType(t *tes require.Equal(t, "test-key", route.directProviderKeys().APIKey("openai")) } +func TestResolveAIGatewayModelRouteForProviderTypeMatchesCanonicalBedrockForAnthropicRequest(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + providerID := uuid.New() + + rawSettings, err := json.Marshal(codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }) + require.NoError(t, err) + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeAnthropic, + Enabled: true, + Settings: sql.NullString{String: string(rawSettings), Valid: true}, + } + + db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return([]database.AIProvider{provider}, nil) + + server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + route, err := server.resolveAIGatewayModelRouteForProviderType( + ctx, + ownerID, + chattool.ComputerUseProviderAnthropic, + ) + require.NoError(t, err) + + require.Equal(t, modelRouteKindAIGateway, route.kind) + require.Equal(t, providerID, route.aiGateway.Provider.Row.ID) + providerHint, err := route.providerHint() + require.NoError(t, err) + require.Equal(t, "bedrock", providerHint) +} + +func TestDirectProviderHintAndProviderForConfigErrorsOnMalformedSettings(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + providerID := uuid.New() + + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeAnthropic, + Enabled: true, + Settings: sql.NullString{String: "{", Valid: true}, + } + db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil) + + server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + _, _, err := server.directProviderHintAndProviderForConfig(ctx, database.ChatModelConfig{ + AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true}, + }) + require.ErrorContains(t, err, "canonicalize") +} + +func TestAIProviderConfigFromKeysFallsBackToRawTypeOnMalformedSettings(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + + // Corrupt valid Bedrock JSON to verify the fallback returns the raw type rather than "bedrock". + bedrockSettings, err := json.Marshal(codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }) + require.NoError(t, err) + + provider := database.AIProvider{ + ID: uuid.New(), + Type: database.AiProviderTypeAnthropic, + Enabled: true, + Settings: sql.NullString{String: string(bedrockSettings[:len(bedrockSettings)-1]), Valid: true}, + } + + server := &Server{logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + cfg, err := server.aiProviderConfigFromKeys(ctx, provider, nil) + require.NoError(t, err) + require.Equal(t, string(database.AiProviderTypeAnthropic), cfg.Provider) +} + func TestResolveModelRouteForProviderTypeAIGatewayRequiresProvider(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/model_routing.go b/coderd/x/chatd/model_routing.go index c5fa7129db7d4..0a7dac8c6e8c0 100644 --- a/coderd/x/chatd/model_routing.go +++ b/coderd/x/chatd/model_routing.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/codersdk" ) type modelClientRequest struct { @@ -57,7 +58,7 @@ func (r resolvedModelRoute) providerHint() (string, error) { case modelRouteKindDirect: return r.direct.ProviderHint, nil case modelRouteKindAIGateway: - return r.aiGateway.ModelProviderHint, nil + return string(r.aiGateway.Provider.CanonicalType), nil default: return "", xerrors.New("model route is not configured") } @@ -68,7 +69,7 @@ func (r resolvedModelRoute) withProviderHint(providerHint string) resolvedModelR case modelRouteKindDirect: r.direct.ProviderHint = providerHint case modelRouteKindAIGateway: - r.aiGateway.ModelProviderHint = providerHint + r.aiGateway.Provider.CanonicalType = codersdk.AIProviderType(providerHint) } return r } diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index 077f072083340..2baeff87d4cc8 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -16,6 +16,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" "github.com/coder/coder/v2/coderd/x/chatd/chaterror" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" @@ -30,23 +31,28 @@ const ( aibridgeDelegatedBYOKMarker = "delegated" ) +// canonicalizedProvider pairs a database row with its pre-computed canonical +// provider type, so the type is derived once at the routing boundary and +// carried through rather than re-derived from the row at each use site. +type canonicalizedProvider struct { + Row database.AIProvider + CanonicalType codersdk.AIProviderType +} + type aiGatewayModelRoute struct { - Provider database.AIProvider - ModelProviderHint string - ProviderAuth aiGatewayProviderAuth + Provider canonicalizedProvider + ProviderAuth aiGatewayProviderAuth } func newAIGatewayModelRoute( - provider database.AIProvider, - modelProviderHint string, + provider canonicalizedProvider, auth aiGatewayProviderAuth, ) resolvedModelRoute { return resolvedModelRoute{ kind: modelRouteKindAIGateway, aiGateway: aiGatewayModelRoute{ - Provider: provider, - ModelProviderHint: modelProviderHint, - ProviderAuth: auth, + Provider: provider, + ProviderAuth: auth, }, } } @@ -120,10 +126,10 @@ func (p *Server) newAIGatewayModel( route aiGatewayModelRoute, opts modelBuildOptions, ) (fantasy.LanguageModel, error) { - if route.Provider.ID == uuid.Nil { + if route.Provider.Row.ID == uuid.Nil { return nil, xerrors.New("AI Gateway routing requires a concrete AI provider") } - if route.Provider.Name == "" { + if route.Provider.Row.Name == "" { return nil, xerrors.New("AI Gateway routing requires an AI provider name") } if opts.ActiveAPIKeyID == "" { @@ -137,7 +143,7 @@ func (p *Server) newAIGatewayModel( ) } - if err := ValidateAIGatewayProviderModel(route.Provider, req.ModelName); err != nil { + if err := ValidateAIGatewayProviderModel(route.Provider.Row, req.ModelName); err != nil { return nil, chaterror.WithClassification( err, chaterror.ClassifiedError{ @@ -156,7 +162,7 @@ func (p *Server) newAIGatewayModel( if factory == nil || *factory == nil { return nil, xerrors.New("AI Gateway transport factory is not configured") } - rt, err := (*factory).TransportFor(route.Provider.Name, aibridge.SourceAgents) + rt, err := (*factory).TransportFor(route.Provider.Row.Name, aibridge.SourceAgents) if err != nil { return nil, xerrors.Errorf("create AI Gateway transport: %w", err) } @@ -169,11 +175,7 @@ func (p *Server) newAIGatewayModel( baseRT = &chatdebug.RecordingTransport{Base: baseRT} } - providerType, err := canonicalAIProviderType(route.Provider) - if err != nil { - return nil, xerrors.Errorf("canonicalize provider type for %q: %w", route.Provider.Name, err) - } - config := fantasyConfigForAIBridge(providerType) + config := fantasyConfigForAIBridge(route.Provider.CanonicalType) return newLanguageModel( config.ProviderHint, req.ModelName, @@ -189,14 +191,14 @@ type aibridgeFantasyConfig struct { Keys chatprovider.ProviderAPIKeys } -func fantasyConfigForAIBridge(providerType database.AIProviderType) aibridgeFantasyConfig { +func fantasyConfigForAIBridge(providerType codersdk.AIProviderType) aibridgeFantasyConfig { var fantasyProvider string baseURL := aibridgeLocalBaseURL + "/v1" switch providerType { - case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock: + case codersdk.AIProviderTypeAnthropic, codersdk.AIProviderTypeBedrock: fantasyProvider = fantasyanthropic.Name baseURL = aibridgeLocalBaseURL - case database.AiProviderTypeOpenai: + case codersdk.AIProviderTypeOpenAI: fantasyProvider = fantasyopenai.Name default: fantasyProvider = fantasyopenaicompat.Name @@ -214,15 +216,6 @@ func fantasyConfigForAIBridge(providerType database.AIProviderType) aibridgeFant } } -func aiGatewayRequestFormatForProviderType(providerType database.AIProviderType) aiGatewayRequestFormat { - switch providerType { - case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock: - return aiGatewayRequestFormatAnthropic - default: - return aiGatewayRequestFormatOpenAI - } -} - func (p *Server) aiGatewayProviderAuthForUser( ctx context.Context, ownerID uuid.UUID, @@ -263,20 +256,19 @@ func (p *Server) resolveAIGatewayRoute( provider database.AIProvider, modelProviderHint string, ) (resolvedModelRoute, error) { - providerType, err := canonicalAIProviderType(provider) - if err != nil { - return resolvedModelRoute{}, xerrors.Errorf("canonicalize provider type for %q: %w", provider.Name, err) + format := aiGatewayRequestFormatOpenAI + switch codersdk.AIProviderType(modelProviderHint) { + case codersdk.AIProviderTypeAnthropic, codersdk.AIProviderTypeBedrock: + format = aiGatewayRequestFormatAnthropic } - auth, err := p.aiGatewayProviderAuthForUser( - ctx, - ownerID, - provider, - aiGatewayRequestFormatForProviderType(providerType), - ) + auth, err := p.aiGatewayProviderAuthForUser(ctx, ownerID, provider, format) if err != nil { return resolvedModelRoute{}, xerrors.Errorf("resolve AI Gateway provider auth: %w", err) } - return newAIGatewayModelRoute(provider, modelProviderHint, auth), nil + return newAIGatewayModelRoute(canonicalizedProvider{ + Row: provider, + CanonicalType: codersdk.AIProviderType(modelProviderHint), + }, auth), nil } func (p *Server) resolveAIGatewayModelRouteForConfig( @@ -288,7 +280,7 @@ func (p *Server) resolveAIGatewayModelRouteForConfig( if err != nil { return resolvedModelRoute{}, err } - providerType, err := canonicalAIProviderType(provider) + providerType, err := db2sdk.CanonicalAIProviderType(provider) if err != nil { return resolvedModelRoute{}, xerrors.Errorf("canonicalize provider type for %q: %w", provider.Name, err) } @@ -304,12 +296,12 @@ func (p *Server) resolveAIGatewayModelRouteForProviderType( if err != nil { return resolvedModelRoute{}, err } - return p.resolveAIGatewayRoute( - ctx, - ownerID, - provider, - bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider), - ) + canonicalType, err := db2sdk.CanonicalAIProviderType(provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings for AI gateway route", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = codersdk.AIProviderType(provider.Type) + } + return p.resolveAIGatewayRoute(ctx, ownerID, provider, string(canonicalType)) } func (p *Server) gatewayProviderForConfig( @@ -341,7 +333,7 @@ func (p *Server) aiProviderForProviderType( } matches, err := aiProviderMatchesCanonicalType(provider, normalizedProviderType) if err != nil { - p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + p.logger.Warn(ctx, "parse AI provider settings during provider type lookup", slog.F("provider_id", provider.ID), slog.Error(err)) continue } if !matches { diff --git a/coderd/x/chatd/model_routing_direct.go b/coderd/x/chatd/model_routing_direct.go index 52792d3ef5bc8..5e020914fcb65 100644 --- a/coderd/x/chatd/model_routing_direct.go +++ b/coderd/x/chatd/model_routing_direct.go @@ -8,9 +8,12 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/codersdk" ) type directModelRoute struct { @@ -77,7 +80,12 @@ func (p *Server) resolveDirectModelRouteForProviderType( } providerHint := normalizedProviderType if provider != nil { - providerHint = chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, *provider)) + canonicalType, err := db2sdk.CanonicalAIProviderType(*provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings for direct route", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = codersdk.AIProviderType(provider.Type) + } + providerHint = chatprovider.NormalizeProvider(string(canonicalType)) } return newDirectModelRoute(providerHint, keys), nil } @@ -93,5 +101,9 @@ func (p *Server) directProviderHintAndProviderForConfig( if err != nil { return "", nil, err } - return bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider), &provider, nil + canonicalType, err := db2sdk.CanonicalAIProviderType(provider) + if err != nil { + return "", nil, xerrors.Errorf("canonicalize provider type for %q: %w", provider.Name, err) + } + return string(canonicalType), &provider, nil } diff --git a/coderd/x/chatd/model_routing_internal_test.go b/coderd/x/chatd/model_routing_internal_test.go index 76ede361deb5a..539be6aafd2c3 100644 --- a/coderd/x/chatd/model_routing_internal_test.go +++ b/coderd/x/chatd/model_routing_internal_test.go @@ -65,7 +65,10 @@ func aibridgeTestAIProvider(providerID uuid.UUID, providerName string, providerT } func aibridgeTestRoute(aiProvider database.AIProvider) resolvedModelRoute { - return newAIGatewayModelRoute(aiProvider, string(aiProvider.Type), aiGatewayProviderAuth{}) + return newAIGatewayModelRoute(canonicalizedProvider{ + Row: aiProvider, + CanonicalType: codersdk.AIProviderType(aiProvider.Type), + }, aiGatewayProviderAuth{}) } func aibridgeTestRequest(chat database.Chat, model string) modelClientRequest { @@ -81,14 +84,14 @@ func TestAIBridgeProviderFormatMapping(t *testing.T) { tests := []struct { name string - providerType database.AIProviderType + providerType codersdk.AIProviderType wantProvider string wantBaseURL string }{ - {name: "OpenAI", providerType: database.AiProviderTypeOpenai, wantProvider: "openai", wantBaseURL: "http://coder-aibridge/v1"}, - {name: "Anthropic", providerType: database.AiProviderTypeAnthropic, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, - {name: "Bedrock", providerType: database.AiProviderTypeBedrock, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, - {name: "Google", providerType: database.AiProviderTypeGoogle, wantProvider: "openai-compat", wantBaseURL: "http://coder-aibridge/v1"}, + {name: "OpenAI", providerType: codersdk.AIProviderTypeOpenAI, wantProvider: "openai", wantBaseURL: "http://coder-aibridge/v1"}, + {name: "Anthropic", providerType: codersdk.AIProviderTypeAnthropic, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, + {name: "Bedrock", providerType: codersdk.AIProviderTypeBedrock, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, + {name: "Google", providerType: codersdk.AIProviderTypeGoogle, wantProvider: "openai-compat", wantBaseURL: "http://coder-aibridge/v1"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -310,7 +313,10 @@ func TestAIGatewayModelForwardsProviderAuth(t *testing.T) { aiGatewayRoutingEnabled: true, aibridgeTransportFactory: aibridgeTestFactoryPointer(factory), } - route := newAIGatewayModelRoute(provider, string(provider.Type), auth) + route := newAIGatewayModelRoute(canonicalizedProvider{ + Row: provider, + CanonicalType: codersdk.AIProviderType(provider.Type), + }, auth) return server, route } @@ -668,7 +674,7 @@ func TestAIBridgeRoutingFailClosed(t *testing.T) { t.Run("StaticModel", func(t *testing.T) { t.Parallel() server := &Server{aiGatewayRoutingEnabled: true} - _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), newAIGatewayModelRoute(database.AIProvider{}, "", aiGatewayProviderAuth{}), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()}) + _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), newAIGatewayModelRoute(canonicalizedProvider{}, aiGatewayProviderAuth{}), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()}) require.ErrorContains(t, err, "concrete AI provider") }) } diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 5c9c54ab9ccc4..6aeaf0888ad95 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -19,6 +19,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" @@ -164,7 +165,7 @@ func enabledProviderContainsName( for _, provider := range providers { matches, err := aiProviderMatchesCanonicalType(provider, normalizedProviderName) if err != nil { - logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + logger.Warn(ctx, "parse AI provider settings during provider name check", slog.F("provider_id", provider.ID), slog.Error(err)) continue } if matches { @@ -518,7 +519,12 @@ func (p *Server) resolveModelConfigAndNormalizedProvider( if !provider.Enabled { return database.ChatModelConfig{}, "", sql.ErrNoRows } - providerName := chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider)) + canonicalType, err := db2sdk.CanonicalAIProviderType(provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings for model config", slog.F("provider_id", provider.ID), slog.Error(err)) + canonicalType = codersdk.AIProviderType(provider.Type) + } + providerName := chatprovider.NormalizeProvider(string(canonicalType)) if providerName == "" { return database.ChatModelConfig{}, "", errInvalidModelOverrideMetadata } diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index ce860f124929c..cbf222620caef 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -3600,3 +3600,109 @@ func TestAwaitSubagentCompletion(t *testing.T) { assert.Equal(t, "zero timeout ok", report) }) } + +func TestEnabledProviderContainsName(t *testing.T) { + t.Parallel() + + bedrockSettings, err := json.Marshal(codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }) + require.NoError(t, err) + + cases := []struct { + name string + providers []database.AIProvider + query string + want bool + }{ + { + name: "canonical match", + providers: []database.AIProvider{ + {Type: database.AiProviderTypeAnthropic}, + }, + query: "anthropic", + want: true, + }, + { + name: "legacy bedrock canonical match for anthropic request", + providers: []database.AIProvider{ + { + Type: database.AiProviderTypeAnthropic, + Settings: sql.NullString{String: string(bedrockSettings), Valid: true}, + }, + }, + query: "anthropic", + want: true, + }, + { + // First pass errors (malformed JSON), second pass matches via raw type. + name: "malformed settings falls back to raw type", + providers: []database.AIProvider{ + { + Type: database.AiProviderTypeOpenai, + Settings: sql.NullString{String: "{", Valid: true}, + }, + }, + query: "openai", + want: true, + }, + { + // First pass errors, second pass matches bedrock satisfying anthropic request. + name: "malformed bedrock satisfies anthropic via raw type", + providers: []database.AIProvider{ + { + Type: database.AiProviderTypeBedrock, + Settings: sql.NullString{String: "{", Valid: true}, + }, + }, + query: "anthropic", + want: true, + }, + { + name: "empty providers", + providers: []database.AIProvider{}, + query: "anthropic", + want: false, + }, + { + // Verifies the first-pass loop continues on error rather than aborting. + name: "first provider fails canonicalization, second matches canonically", + providers: []database.AIProvider{ + { + Type: database.AiProviderTypeAnthropic, + Settings: sql.NullString{String: "{", Valid: true}, + }, + {Type: database.AiProviderTypeAnthropic}, + }, + query: "anthropic", + want: true, + }, + { + // enabledProviderContainsName has no Enabled guard; callers pre-filter. + name: "disabled provider still matches", + providers: []database.AIProvider{ + {Type: database.AiProviderTypeAnthropic, Enabled: false}, + }, + query: "anthropic", + want: true, + }, + { + name: "no match", + providers: []database.AIProvider{ + {Type: database.AiProviderTypeOpenai}, + }, + query: "anthropic", + want: false, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + got := enabledProviderContainsName(ctx, logger, tt.providers, tt.query) + require.Equal(t, tt.want, got) + }) + } +}