Files
coze-studio/backend/bizpkg/config/modelmgr/model_get.go
T
zgene 6bed393c12
Backend Tests / backend-unit-test (push) Has been cancelled
Backend Tests / benchmark-test (push) Has been cancelled
CI@main / Node.js v22 (ubuntu-latest) (push) Has been cancelled
Thrift Syntax Validation / validate-thrift (push) Has been cancelled
License Check / License Check (push) Has been cancelled
first commit
2026-05-14 13:29:56 +08:00

251 lines
6.8 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelmgr
import (
"context"
"encoding/json"
"fmt"
"strings"
config "github.com/coze-dev/coze-studio/backend/api/model/admin/config"
"github.com/coze-dev/coze-studio/backend/api/model/app/developer_api"
"github.com/coze-dev/coze-studio/backend/bizpkg/config/modelmgr/internal/model"
"github.com/coze-dev/coze-studio/backend/bizpkg/config/modelmgr/internal/query"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func (c *ModelConfig) GetProviderModelList(ctx context.Context) ([]*config.ProviderModelList, error) {
modelProviderList := getModelProviderList()
res := make([]*config.ProviderModelList, 0, len(modelProviderList))
allModels, err := query.ModelInstance.WithContext(ctx).
Where(query.ModelInstance.DeletedAt.IsNull()).Find()
if err != nil {
return nil, err
}
modelClass2Models := make(map[developer_api.ModelClass][]*config.Model)
for _, model := range allModels {
m := c.toModel(ctx, model)
m.Capability = nil
m.Provider = nil
m.Parameters = nil
modelClass2Models[model.Provider.ModelClass] = append(modelClass2Models[model.Provider.ModelClass], m.Model)
if m.Connection != nil && m.Connection.BaseConnInfo != nil {
apiKey := m.Connection.BaseConnInfo.APIKey
if apiKey != "" {
n := len(apiKey)
if n <= 4 {
m.Connection.BaseConnInfo.APIKey = strings.Repeat("*", n)
} else if n <= 8 {
m.Connection.BaseConnInfo.APIKey = fmt.Sprintf("%s***%s", apiKey[:2], apiKey[n-2:])
} else {
m.Connection.BaseConnInfo.APIKey = fmt.Sprintf("%s***%s", apiKey[:4], apiKey[n-4:])
}
}
}
}
for _, provider := range modelProviderList {
if provider.IconURI != "" {
url, err := c.oss.GetObjectUrl(ctx, provider.IconURI)
if err != nil {
logs.CtxWarnf(ctx, "get model icon url failed, err: %v", err)
} else {
provider.IconURL = url
}
}
res = append(res, &config.ProviderModelList{
Provider: provider,
ModelList: modelClass2Models[provider.ModelClass],
})
}
return res, nil
}
func (c *ModelConfig) GetAllModelList(ctx context.Context) ([]*Model, error) {
return c.getModelList(ctx, true)
}
func (c *ModelConfig) GetOnlineModelList(ctx context.Context) ([]*Model, error) {
return c.getModelList(ctx, false)
}
func (c *ModelConfig) getModelList(ctx context.Context, includeDeleteModel bool) ([]*Model, error) {
useOldModel, err := c.UseOldModelConf(ctx)
if err != nil {
return nil, fmt.Errorf("get use old model conf failed, err: %w", err)
}
if useOldModel {
return oldModels, nil
}
var allModels []*model.ModelInstance
if includeDeleteModel {
allModels, err = query.ModelInstance.WithContext(ctx).Unscoped().Find()
} else {
allModels, err = query.ModelInstance.WithContext(ctx).
Where(query.ModelInstance.DeletedAt.IsNull()).Find()
}
if err != nil {
return nil, err
}
modelList := make([]*Model, 0, len(allModels))
for _, model := range allModels {
m := c.toModel(ctx, model)
modelList = append(modelList, m)
}
return modelList, nil
}
func (c *ModelConfig) GetOnlineModelListWithLimit(ctx context.Context, limit int) ([]*Model, error) {
useOldModel, err := c.UseOldModelConf(ctx)
if err != nil {
return nil, fmt.Errorf("get use old model conf failed, err: %w", err)
}
if useOldModel {
if limit > len(oldModels) {
limit = len(oldModels)
}
return oldModels[:limit], nil
}
allModels, err := query.ModelInstance.WithContext(ctx).Limit(limit).Find()
if err != nil {
return nil, err
}
modelList := make([]*Model, 0, len(allModels))
for _, model := range allModels {
m := c.toModel(ctx, model)
modelList = append(modelList, m)
}
return modelList, nil
}
func (c *ModelConfig) MGetModelByID(ctx context.Context, ids []int64) ([]*Model, error) {
useOldModel, err := c.UseOldModelConf(ctx)
if err != nil {
return nil, fmt.Errorf("get use old model conf failed, err: %w", err)
}
if useOldModel {
modelList := make([]*Model, 0, len(ids))
for _, id := range ids {
for _, old := range oldModels {
if old.ID == id {
modelList = append(modelList, old)
break
}
}
}
return modelList, nil
}
modelList := make([]*Model, 0, len(ids))
models, err := query.ModelInstance.WithContext(ctx).Unscoped().
Where(query.ModelInstance.ID.In(ids...)).Find()
if err != nil {
return nil, err
}
for _, model := range models {
m := c.toModel(ctx, model)
modelList = append(modelList, m)
}
return modelList, nil
}
func (c *ModelConfig) GetModelByID(ctx context.Context, modelID int64) (*Model, error) {
useOldModel, err := c.UseOldModelConf(ctx)
if err != nil {
return nil, fmt.Errorf("get use old model conf failed, err: %w", err)
}
if useOldModel {
for _, old := range oldModels {
if old.ID == modelID {
return old, nil
}
}
return nil, fmt.Errorf("model %d not found", modelID)
}
return c.getModelByID(ctx, modelID)
}
func (c *ModelConfig) getModelByID(ctx context.Context, modelID int64) (*Model, error) {
m, err := query.ModelInstance.WithContext(ctx).
Unscoped(). // allow get deleted data
Where(query.ModelInstance.ID.Eq(modelID)).First()
if err != nil {
return nil, err
}
return c.toModel(ctx, m), nil
}
func (c *ModelConfig) toModel(ctx context.Context, q *model.ModelInstance) *Model {
if q.Provider.IconURI != "" {
url, err := c.oss.GetObjectUrl(ctx, q.Provider.IconURI)
if err != nil {
logs.CtxWarnf(ctx, "get model icon url failed, err: %v", err)
} else {
q.Provider.IconURL = url
}
}
conn, err := decryptConn(ctx, q.Connection)
if err != nil {
logs.CtxWarnf(ctx, "decrypt model connection failed, err: %v", err)
}
extra := &ModelExtra{}
if err := json.Unmarshal([]byte(q.Extra), extra); err != nil {
logs.CtxWarnf(ctx, "unmarshal model extra (%s) failed, err: %v", q.Extra, err)
}
m := &Model{
Model: &config.Model{
ID: q.ID,
Provider: q.Provider,
DisplayInfo: q.DisplayInfo,
Connection: conn,
Type: config.ModelType(q.Type),
Capability: q.Capability,
Parameters: q.Parameters,
EnableBase64URL: extra.EnableBase64URL,
DeleteAtMs: q.DeletedAt.Time.UnixMilli(),
},
}
m.Status = ternary.IFElse(q.DeletedAt.Time.IsZero(), config.ModelStatus_StatusInUse,
config.ModelStatus_StatusDeleted)
return m
}