6bed393c12
Backend Tests / backend-unit-test (push) Has been cancelled
Backend Tests / benchmark-test (push) Has been cancelled
CI@main / Node.js v22 (ubuntu-latest) (push) Has been cancelled
Thrift Syntax Validation / validate-thrift (push) Has been cancelled
License Check / License Check (push) Has been cancelled
291 lines
8.3 KiB
Go
291 lines
8.3 KiB
Go
// Code generated by hertz generator.
|
|
|
|
package coze
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/cloudwego/hertz/pkg/app"
|
|
"github.com/cloudwego/hertz/pkg/protocol/consts"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/admin/config"
|
|
bizConf "github.com/coze-dev/coze-studio/backend/bizpkg/config"
|
|
"github.com/coze-dev/coze-studio/backend/bizpkg/config/modelmgr"
|
|
"github.com/coze-dev/coze-studio/backend/bizpkg/llm/modelbuilder"
|
|
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
)
|
|
|
|
// GetBasicConfiguration .
|
|
// @router /api/admin/config/basic/get [GET]
|
|
func GetBasicConfiguration(ctx context.Context, c *app.RequestContext) {
|
|
baseConfig, err := bizConf.Base().GetBaseConfig(ctx)
|
|
if err != nil {
|
|
c.String(consts.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
|
|
resp := new(config.GetBasicConfigurationResp)
|
|
resp.Configuration = baseConfig
|
|
|
|
c.JSON(consts.StatusOK, resp)
|
|
}
|
|
|
|
// SaveBasicConfiguration .
|
|
// @router /api/admin/config/basic/save [POST]
|
|
func SaveBasicConfiguration(ctx context.Context, c *app.RequestContext) {
|
|
var err error
|
|
var req config.SaveBasicConfigurationReq
|
|
err = c.BindAndValidate(&req)
|
|
if err != nil {
|
|
c.String(consts.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
if req.Configuration == nil {
|
|
invalidParamRequestResponse(c, "Configuration is nil")
|
|
return
|
|
}
|
|
|
|
// TODO: check coze api token
|
|
|
|
// Validate ServerHost: allow http/https URLs, or hostname:port
|
|
if req.Configuration.ServerHost == "" {
|
|
invalidParamRequestResponse(c, "ServerHost is empty")
|
|
return
|
|
}
|
|
|
|
host := req.Configuration.ServerHost
|
|
if strings.Contains(host, "://") {
|
|
u, parseErr := url.Parse(host)
|
|
if parseErr != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
|
|
invalidParamRequestResponse(c, "ServerHost is invalid URL, require http/https")
|
|
return
|
|
}
|
|
} else {
|
|
// Expect hostname:port format
|
|
h, p, splitErr := net.SplitHostPort(host)
|
|
if splitErr != nil || h == "" {
|
|
invalidParamRequestResponse(c, "ServerHost must be hostname:port or http(s) URL")
|
|
return
|
|
}
|
|
port, portErr := strconv.Atoi(p)
|
|
if portErr != nil || port <= 0 || port > 65535 {
|
|
invalidParamRequestResponse(c, "ServerHost port is invalid")
|
|
return
|
|
}
|
|
}
|
|
|
|
logs.Infof("server host is valid %s", req.Configuration.ServerHost)
|
|
|
|
if req.Configuration.CodeRunnerType.String() == "<UNSET>" {
|
|
invalidParamRequestResponse(c, "CodeRunnerType is invalid")
|
|
return
|
|
}
|
|
|
|
err = bizConf.Base().SaveBaseConfig(ctx, req.Configuration)
|
|
if err != nil {
|
|
internalServerErrorResponse(ctx, c, fmt.Errorf("save basic config failed: %w", err))
|
|
return
|
|
}
|
|
|
|
resp := new(config.SaveBasicConfigurationResp)
|
|
c.JSON(consts.StatusOK, resp)
|
|
}
|
|
|
|
// GetKnowledgeConfig .
|
|
// @router /api/admin/config/knowledge/get [GET]
|
|
func GetKnowledgeConfig(ctx context.Context, c *app.RequestContext) {
|
|
knowledgeConfig, err := bizConf.Knowledge().GetKnowledgeConfig(ctx)
|
|
if err != nil {
|
|
internalServerErrorResponse(ctx, c, fmt.Errorf("get knowledge config failed: %w", err))
|
|
return
|
|
}
|
|
|
|
resp := new(config.GetKnowledgeConfigResp)
|
|
resp.KnowledgeConfig = knowledgeConfig
|
|
|
|
c.JSON(consts.StatusOK, resp)
|
|
}
|
|
|
|
// UpdateKnowledgeConfig .
|
|
// @router /api/admin/config/knowledge/save [POST]
|
|
func UpdateKnowledgeConfig(ctx context.Context, c *app.RequestContext) {
|
|
var err error
|
|
var req config.UpdateKnowledgeConfigReq
|
|
err = c.BindAndValidate(&req)
|
|
if err != nil {
|
|
c.String(consts.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
if req.KnowledgeConfig == nil {
|
|
invalidParamRequestResponse(c, "KnowledgeConfig is nil")
|
|
return
|
|
}
|
|
|
|
if req.KnowledgeConfig.EmbeddingConfig == nil {
|
|
invalidParamRequestResponse(c, "EmbeddingConfig is nil")
|
|
return
|
|
}
|
|
|
|
if req.KnowledgeConfig.EmbeddingConfig.Connection == nil {
|
|
invalidParamRequestResponse(c, "Connection is nil")
|
|
return
|
|
}
|
|
|
|
if req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo == nil {
|
|
invalidParamRequestResponse(c, "EmbeddingInfo is nil")
|
|
return
|
|
}
|
|
|
|
embedding, err := impl.GetEmbedding(ctx, req.KnowledgeConfig.EmbeddingConfig)
|
|
if err != nil {
|
|
invalidParamRequestResponse(c, fmt.Sprintf("get embedding failed: %v", err))
|
|
return
|
|
}
|
|
|
|
if req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims <= 0 {
|
|
req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims = int32(embedding.Dimensions())
|
|
|
|
embedding, err = impl.GetEmbedding(ctx, req.KnowledgeConfig.EmbeddingConfig)
|
|
if err != nil {
|
|
invalidParamRequestResponse(c, fmt.Sprintf("get embedding failed: %v", err))
|
|
return
|
|
}
|
|
}
|
|
|
|
denseEmbeddings, err := embedding.EmbedStrings(ctx, []string{"test"})
|
|
if err != nil {
|
|
invalidParamRequestResponse(c, fmt.Sprintf("embed test string failed: %v", err))
|
|
return
|
|
}
|
|
|
|
if len(denseEmbeddings) == 0 {
|
|
invalidParamRequestResponse(c, fmt.Sprintf("embed test string failed: %v", err))
|
|
return
|
|
}
|
|
|
|
logs.CtxDebugf(ctx, "embed test string result: %d, expect %d",
|
|
len(denseEmbeddings[0]), req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims)
|
|
if len(denseEmbeddings[0]) != int(req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims) {
|
|
invalidParamRequestResponse(c, fmt.Sprintf("embed test string failed: dims not match, expect %d, got %d",
|
|
req.KnowledgeConfig.EmbeddingConfig.Connection.EmbeddingInfo.Dims, len(denseEmbeddings[0])))
|
|
return
|
|
}
|
|
|
|
err = bizConf.Knowledge().SaveKnowledgeConfig(ctx, req.KnowledgeConfig)
|
|
if err != nil {
|
|
internalServerErrorResponse(ctx, c, fmt.Errorf("save knowledge config failed: %w", err))
|
|
return
|
|
}
|
|
|
|
resp := new(config.UpdateKnowledgeConfigResp)
|
|
|
|
c.JSON(consts.StatusOK, resp)
|
|
}
|
|
|
|
// GetModelList .
|
|
// @router /api/admin/config/model/list [GET]
|
|
func GetModelList(ctx context.Context, c *app.RequestContext) {
|
|
var err error
|
|
var req config.GetModelListReq
|
|
err = c.BindAndValidate(&req)
|
|
if err != nil {
|
|
c.String(consts.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
modelList, err := bizConf.ModelConf().GetProviderModelList(ctx)
|
|
if err != nil {
|
|
internalServerErrorResponse(ctx, c, fmt.Errorf("get model list failed: %w", err))
|
|
return
|
|
}
|
|
|
|
resp := new(config.GetModelListResp)
|
|
resp.ProviderModelList = modelList
|
|
|
|
c.JSON(consts.StatusOK, resp)
|
|
}
|
|
|
|
// CreateModel .
|
|
// @router /api/admin/config/model/create [POST]
|
|
func CreateModel(ctx context.Context, c *app.RequestContext) {
|
|
var err error
|
|
var req config.CreateModelReq
|
|
err = c.BindAndValidate(&req)
|
|
if err != nil {
|
|
c.String(consts.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
modelBuilder, err := modelbuilder.NewModelBuilder(req.ModelClass, &config.Model{
|
|
EnableBase64URL: req.EnableBase64URL,
|
|
Connection: req.Connection,
|
|
})
|
|
if err != nil {
|
|
invalidParamRequestResponse(c, fmt.Sprintf("create model builder failed: %v", err))
|
|
return
|
|
}
|
|
|
|
logs.CtxDebugf(ctx, "create model req: %s, conn: %s", conv.DebugJsonToStr(req), conv.DebugJsonToStr(req.Connection.BaseConnInfo))
|
|
|
|
chatModel, err := modelBuilder.Build(ctx, &modelbuilder.LLMParams{EnableThinking: ptr.Of(false)})
|
|
if err != nil {
|
|
invalidParamRequestResponse(c, fmt.Sprintf("build model failed: %v", err))
|
|
return
|
|
}
|
|
|
|
respMsgs, err := chatModel.Generate(ctx, []*schema.Message{
|
|
schema.SystemMessage("1+1=?,Just answer with a number, no explanation.")})
|
|
if err != nil {
|
|
invalidParamRequestResponse(c, fmt.Sprintf("generate model failed: %v", err))
|
|
return
|
|
}
|
|
|
|
logs.CtxDebugf(ctx, "chatModel.Generate resp : %s", conv.DebugJsonToStr(respMsgs))
|
|
|
|
id, err := bizConf.ModelConf().CreateModel(ctx, req.ModelClass, req.ModelName, req.Connection, &modelmgr.ModelExtra{
|
|
EnableBase64URL: req.EnableBase64URL,
|
|
})
|
|
if err != nil {
|
|
invalidParamRequestResponse(c, fmt.Sprintf("create model failed: %v", err))
|
|
return
|
|
}
|
|
|
|
resp := new(config.CreateModelResp)
|
|
resp.ID = id
|
|
|
|
c.JSON(consts.StatusOK, resp)
|
|
}
|
|
|
|
// DeleteModel .
|
|
// @router /api/admin/config/model/delete [POST]
|
|
func DeleteModel(ctx context.Context, c *app.RequestContext) {
|
|
var err error
|
|
var req config.DeleteModelReq
|
|
err = c.BindAndValidate(&req)
|
|
if err != nil {
|
|
c.String(consts.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
err = bizConf.ModelConf().DeleteModel(ctx, req.ID)
|
|
if err != nil {
|
|
internalServerErrorResponse(ctx, c, fmt.Errorf("delete model failed: %w", err))
|
|
return
|
|
}
|
|
|
|
resp := new(config.DeleteModelResp)
|
|
|
|
c.JSON(consts.StatusOK, resp)
|
|
}
|