Skip to content
15 changes: 0 additions & 15 deletions coderd/ai_provider_canonical.go

This file was deleted.

2 changes: 1 addition & 1 deletion coderd/ai_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) {
if req.Settings != nil {
existing = mergeAIProviderSettings(existing, *req.Settings)
}
targetType := canonicalDatabaseAIProviderType(old.Type, existing)
targetType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(old.Type), existing))
targetBaseURL := ptr.NilToDefault(req.BaseURL, old.BaseUrl)
// Bedrock settings are only meaningful for Bedrock providers;
// rejecting the mismatch keeps a misconfiguration from sitting
Expand Down
12 changes: 7 additions & 5 deletions coderd/ai_providers_migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ func SeedAIProvidersFromEnv(
if err != nil {
return xerrors.Errorf("decode existing settings for %q: %w", existing.Name, err)
}
if canonicalDatabaseAIProviderType(existing.Type, existingSettings) == database.AiProviderTypeBedrock {
existingType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(existing.Type), existingSettings))
if existingType == database.AiProviderTypeBedrock {
logger.Warn(sysCtx, "skipping legacy Anthropic env seed because an existing Anthropic-named row contains Bedrock settings",
slog.F("name", dp.Name),
)
Expand All @@ -122,7 +123,8 @@ func SeedAIProvidersFromEnv(
if err != nil {
return xerrors.Errorf("decode existing settings for %q: %w", candidate.Name, err)
}
if canonicalDatabaseAIProviderType(candidate.Type, candidateSettings) == database.AiProviderTypeBedrock {
candidateType := database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(candidate.Type), candidateSettings))
if candidateType == database.AiProviderTypeBedrock {
existing = candidate
found = true
}
Expand All @@ -139,20 +141,20 @@ func SeedAIProvidersFromEnv(
case found:
existingSettings, err := db2sdk.AIProviderSettings(existing.Settings)
if err != nil {
return xerrors.Errorf("decode existing settings for %q: %w", dp.Name, err)
return xerrors.Errorf("decode existing settings for %q: %w", existing.Name, err)
}
// Load existing bearer keys so the canonical hash
// includes credentials for comparison.
existingKeyRows, err := tx.GetAIProviderKeysByProviderID(sysCtx, existing.ID)
if err != nil {
return xerrors.Errorf("load existing keys for %q: %w", dp.Name, err)
return xerrors.Errorf("load existing keys for %q: %w", existing.Name, err)
}
existingKeys := make([]string, 0, len(existingKeyRows))
for _, k := range existingKeyRows {
existingKeys = append(existingKeys, k.APIKey)
}
existingDP := desiredAIProvider{
Type: canonicalDatabaseAIProviderType(existing.Type, existingSettings),
Type: database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(existing.Type), existingSettings)),
BaseURL: existing.BaseUrl,
Bedrock: existingSettings.Bedrock,
Keys: existingKeys,
Expand Down
4 changes: 2 additions & 2 deletions coderd/database/db2sdk/db2sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ func APIAllowListTarget(entry rbac.AllowListElement) codersdk.APIAllowListTarget
}

// CanonicalAIProviderType returns the runtime provider type for a database row.
func CanonicalAIProviderType(row database.AIProvider) (database.AIProviderType, error) {
func CanonicalAIProviderType(row database.AIProvider) (codersdk.AIProviderType, error) {
settings, err := AIProviderSettings(row.Settings)
if err != nil {
return "", xerrors.Errorf("decode settings: %w", err)
}
return database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(row.Type), settings)), nil
return codersdk.CanonicalAIProviderType(codersdk.AIProviderType(row.Type), settings), nil
}

// AIProvider converts a database row plus its API keys into the
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/db2sdk/db2sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func TestAIProviderCanonicalTypeAndDisplayName(t *testing.T) {

gotType, err := db2sdk.CanonicalAIProviderType(tt.row)
require.NoError(t, err)
require.Equal(t, database.AIProviderType(tt.wantType), gotType)
require.Equal(t, tt.wantType, gotType)
require.Equal(t, tt.wantDisplay, db2sdk.AIProviderDisplayName(tt.row, tt.wantType))

got, err := db2sdk.AIProvider(tt.row, nil)
Expand Down
21 changes: 10 additions & 11 deletions coderd/exp_chats.go
Original file line number Diff line number Diff line change
Expand Up @@ -6532,16 +6532,15 @@ func parseUserAIProviderID(r *http.Request) (uuid.UUID, error) {
}

func convertAIProviderSummary(provider database.AIProvider) (codersdk.AIProviderSummary, error) {
providerType, err := canonicalAIProviderTypeForRow(provider)
providerType, err := db2sdk.CanonicalAIProviderType(provider)
if err != nil {
return codersdk.AIProviderSummary{}, err
}
sdkProviderType := codersdk.AIProviderType(providerType)
return codersdk.AIProviderSummary{
ID: provider.ID,
Type: sdkProviderType,
Type: providerType,
Name: provider.Name,
DisplayName: db2sdk.AIProviderDisplayName(provider, sdkProviderType),
DisplayName: db2sdk.AIProviderDisplayName(provider, providerType),
Enabled: provider.Enabled,
Deleted: provider.Deleted,
}, nil
Expand Down Expand Up @@ -6741,7 +6740,7 @@ func (api *API) configuredProvidersFromAIProviders(ctx context.Context, provider
}

func (api *API) configuredProviderFromAIProviderKeys(ctx context.Context, provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) {
providerType, err := canonicalAIProviderTypeForRow(provider)
providerType, err := db2sdk.CanonicalAIProviderType(provider)
if err != nil {
api.Logger.Error(ctx, "failed to decode AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err))
return chatprovider.ConfiguredProvider{}, err
Expand Down Expand Up @@ -6837,7 +6836,7 @@ func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) {
return
}
for _, provider := range providers {
providerType, err := canonicalAIProviderTypeForRow(provider)
providerType, err := db2sdk.CanonicalAIProviderType(provider)
if err != nil {
api.Logger.Error(ctx, "failed to decode AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Expand Down Expand Up @@ -6911,7 +6910,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
return
}
providerType, err := canonicalAIProviderTypeForRow(aiProvider)
providerType, err := db2sdk.CanonicalAIProviderType(aiProvider)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to decode AI provider settings.",
Expand Down Expand Up @@ -7001,7 +7000,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
if !lockedAIProvider.Enabled {
return errChatProviderNotConfigured
}
lockedProviderType, err := canonicalAIProviderTypeForRow(lockedAIProvider)
lockedProviderType, err := db2sdk.CanonicalAIProviderType(lockedAIProvider)
if err != nil {
return xerrors.Errorf("canonicalize provider type for %q: %w", lockedAIProvider.Name, err)
}
Expand Down Expand Up @@ -7124,7 +7123,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
return
}
} else {
providerType, err := canonicalAIProviderTypeForRow(aiProvider)
providerType, err := db2sdk.CanonicalAIProviderType(aiProvider)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to decode AI provider settings.",
Expand Down Expand Up @@ -7165,7 +7164,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
return
}
providerType, err := canonicalAIProviderTypeForRow(aiProvider)
providerType, err := db2sdk.CanonicalAIProviderType(aiProvider)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to decode AI provider settings.",
Expand Down Expand Up @@ -7262,7 +7261,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
if !aiProvider.Enabled {
return errChatProviderNotConfigured
}
providerType, err := canonicalAIProviderTypeForRow(aiProvider)
providerType, err := db2sdk.CanonicalAIProviderType(aiProvider)
if err != nil {
return xerrors.Errorf("canonicalize provider type for %q: %w", aiProvider.Name, err)
}
Expand Down
37 changes: 5 additions & 32 deletions coderd/x/chatd/ai_provider_canonical.go
Original file line number Diff line number Diff line change
@@ -1,53 +1,26 @@
package chatd

import (
"context"

"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/codersdk"
)

func canonicalAIProviderType(provider database.AIProvider) (database.AIProviderType, error) {
return db2sdk.CanonicalAIProviderType(provider)
Comment thread
johnstcn marked this conversation as resolved.
}

func canonicalAIProviderTypeString(provider database.AIProvider) (string, error) {
providerType, err := canonicalAIProviderType(provider)
if err != nil {
return "", err
}
return string(providerType), nil
}

func bestEffortCanonicalAIProviderType(ctx context.Context, logger slog.Logger, provider database.AIProvider) database.AIProviderType {
providerType, err := canonicalAIProviderType(provider)
if err != nil {
logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err))
return provider.Type
}
return providerType
}

func bestEffortCanonicalAIProviderTypeString(ctx context.Context, logger slog.Logger, provider database.AIProvider) string {
return string(bestEffortCanonicalAIProviderType(ctx, logger, provider))
}

func aiProviderTypeCanSatisfyRequest(candidateProviderType string, requestedProviderType string) bool {
if candidateProviderType == requestedProviderType {
return true
}
return requestedProviderType == string(database.AiProviderTypeAnthropic) &&
candidateProviderType == string(database.AiProviderTypeBedrock)
return requestedProviderType == string(codersdk.AIProviderTypeAnthropic) &&
candidateProviderType == string(codersdk.AIProviderTypeBedrock)
}

func aiProviderMatchesCanonicalType(provider database.AIProvider, normalizedProviderType string) (bool, error) {
providerType, err := canonicalAIProviderTypeString(provider)
providerType, err := db2sdk.CanonicalAIProviderType(provider)
if err != nil {
return false, err
}
return aiProviderTypeCanSatisfyRequest(chatprovider.NormalizeProvider(providerType), normalizedProviderType), nil
return aiProviderTypeCanSatisfyRequest(chatprovider.NormalizeProvider(string(providerType)), normalizedProviderType), nil
}

func aiProviderMatchesRawType(provider database.AIProvider, normalizedProviderType string) bool {
Expand Down
15 changes: 10 additions & 5 deletions coderd/x/chatd/chatd.go
Original file line number Diff line number Diff line change
Expand Up @@ -8658,9 +8658,14 @@ func (p *Server) aiProviderConfigFromKeys(ctx context.Context, provider database
break
}
}
canonicalType, err := db2sdk.CanonicalAIProviderType(provider)
Comment thread
johnstcn marked this conversation as resolved.
if err != nil {
p.logger.Warn(ctx, "parse AI provider settings for provider config", slog.F("provider_id", provider.ID), slog.Error(err))
canonicalType = codersdk.AIProviderType(provider.Type)
}
return chatprovider.ConfiguredProvider{
ProviderID: provider.ID,
Provider: bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider),
Provider: string(canonicalType),
APIKey: apiKey,
BaseURL: provider.BaseUrl,
CentralAPIKeyEnabled: true,
Expand Down Expand Up @@ -8774,12 +8779,12 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType(
return chatprovider.ProviderAPIKeys{}, nil, nil
}
for _, provider := range providers {
canonicalProviderType, err := canonicalAIProviderTypeString(provider)
canonicalProviderType, err := db2sdk.CanonicalAIProviderType(provider)
if err != nil {
p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err))
p.logger.Warn(ctx, "parse AI provider settings during key resolution", slog.F("provider_id", provider.ID), slog.Error(err))
continue
}
providerKeysType := chatprovider.NormalizeProvider(canonicalProviderType)
providerKeysType := chatprovider.NormalizeProvider(string(canonicalProviderType))
if !aiProviderTypeCanSatisfyRequest(providerKeysType, normalizedProviderType) {
continue
}
Expand All @@ -8792,7 +8797,7 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType(
if !aiProviderMatchesRawType(provider, normalizedProviderType) {
continue
}
providerKeysType := chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider))
providerKeysType := chatprovider.NormalizeProvider(string(provider.Type))
keys, matchedProvider, err := keysForProvider(provider, providerKeysType)
if err != nil || matchedProvider != nil {
return keys, matchedProvider, err
Expand Down
84 changes: 84 additions & 0 deletions coderd/x/chatd/chatd_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,90 @@ func TestResolveDirectModelRouteForProviderTypeFallsBackToRawProviderType(t *tes
require.Equal(t, "test-key", route.directProviderKeys().APIKey("openai"))
}

func TestResolveAIGatewayModelRouteForProviderTypeMatchesCanonicalBedrockForAnthropicRequest(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
providerID := uuid.New()

rawSettings, err := json.Marshal(codersdk.AIProviderSettings{
Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"},
})
require.NoError(t, err)
provider := database.AIProvider{
ID: providerID,
Type: database.AiProviderTypeAnthropic,
Enabled: true,
Settings: sql.NullString{String: string(rawSettings), Valid: true},
}

db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return([]database.AIProvider{provider}, nil)

server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})}
route, err := server.resolveAIGatewayModelRouteForProviderType(
ctx,
ownerID,
chattool.ComputerUseProviderAnthropic,
)
require.NoError(t, err)

require.Equal(t, modelRouteKindAIGateway, route.kind)
require.Equal(t, providerID, route.aiGateway.Provider.Row.ID)
providerHint, err := route.providerHint()
require.NoError(t, err)
require.Equal(t, "bedrock", providerHint)
}

func TestDirectProviderHintAndProviderForConfigErrorsOnMalformedSettings(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
providerID := uuid.New()

provider := database.AIProvider{
ID: providerID,
Type: database.AiProviderTypeAnthropic,
Enabled: true,
Settings: sql.NullString{String: "{", Valid: true},
}
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil)

server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})}
_, _, err := server.directProviderHintAndProviderForConfig(ctx, database.ChatModelConfig{
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
})
require.ErrorContains(t, err, "canonicalize")
}

func TestAIProviderConfigFromKeysFallsBackToRawTypeOnMalformedSettings(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitShort)

// Corrupt valid Bedrock JSON to verify the fallback returns the raw type rather than "bedrock".
bedrockSettings, err := json.Marshal(codersdk.AIProviderSettings{
Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"},
})
require.NoError(t, err)

provider := database.AIProvider{
ID: uuid.New(),
Type: database.AiProviderTypeAnthropic,
Enabled: true,
Settings: sql.NullString{String: string(bedrockSettings[:len(bedrockSettings)-1]), Valid: true},
}

server := &Server{logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})}
cfg, err := server.aiProviderConfigFromKeys(ctx, provider, nil)
require.NoError(t, err)
require.Equal(t, string(database.AiProviderTypeAnthropic), cfg.Provider)
}

func TestResolveModelRouteForProviderTypeAIGatewayRequiresProvider(t *testing.T) {
t.Parallel()

Expand Down
5 changes: 3 additions & 2 deletions coderd/x/chatd/model_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/codersdk"
)

type modelClientRequest struct {
Expand Down Expand Up @@ -57,7 +58,7 @@ func (r resolvedModelRoute) providerHint() (string, error) {
case modelRouteKindDirect:
return r.direct.ProviderHint, nil
case modelRouteKindAIGateway:
return r.aiGateway.ModelProviderHint, nil
return string(r.aiGateway.Provider.CanonicalType), nil
default:
return "", xerrors.New("model route is not configured")
}
Expand All @@ -68,7 +69,7 @@ func (r resolvedModelRoute) withProviderHint(providerHint string) resolvedModelR
case modelRouteKindDirect:
r.direct.ProviderHint = providerHint
case modelRouteKindAIGateway:
r.aiGateway.ModelProviderHint = providerHint
r.aiGateway.Provider.CanonicalType = codersdk.AIProviderType(providerHint)
}
return r
}
Expand Down
Loading
Loading