6bed393c12
Backend Tests / backend-unit-test (push) Has been cancelled
Backend Tests / benchmark-test (push) Has been cancelled
CI@main / Node.js v22 (ubuntu-latest) (push) Has been cancelled
Thrift Syntax Validation / validate-thrift (push) Has been cancelled
License Check / License Check (push) Has been cancelled
729 lines
22 KiB
Go
729 lines
22 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package repository
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
|
pluginCommon "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
|
"github.com/coze-dev/coze-studio/backend/crossdomain/plugin/consts"
|
|
"github.com/coze-dev/coze-studio/backend/crossdomain/plugin/convert/api"
|
|
"github.com/coze-dev/coze-studio/backend/crossdomain/plugin/model"
|
|
pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
|
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/dto"
|
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal"
|
|
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/dal/query"
|
|
"github.com/coze-dev/coze-studio/backend/infra/idgen"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/saasapi"
|
|
)
|
|
|
|
type toolRepoImpl struct {
|
|
query *query.Query
|
|
|
|
pluginDraftDAO *dal.PluginDraftDAO
|
|
|
|
toolDraftDAO *dal.ToolDraftDAO
|
|
toolDAO *dal.ToolDAO
|
|
toolVersionDAO *dal.ToolVersionDAO
|
|
agentToolDraftDAO *dal.AgentToolDraftDAO
|
|
agentToolVersionDAO *dal.AgentToolVersionDAO
|
|
}
|
|
|
|
type ToolRepoComponents struct {
|
|
IDGen idgen.IDGenerator
|
|
DB *gorm.DB
|
|
}
|
|
|
|
func NewToolRepo(components *ToolRepoComponents) ToolRepository {
|
|
return &toolRepoImpl{
|
|
query: query.Use(components.DB),
|
|
pluginDraftDAO: dal.NewPluginDraftDAO(components.DB, components.IDGen),
|
|
toolDraftDAO: dal.NewToolDraftDAO(components.DB, components.IDGen),
|
|
toolDAO: dal.NewToolDAO(components.DB, components.IDGen),
|
|
toolVersionDAO: dal.NewToolVersionDAO(components.DB, components.IDGen),
|
|
agentToolDraftDAO: dal.NewAgentToolDraftDAO(components.DB, components.IDGen),
|
|
agentToolVersionDAO: dal.NewAgentToolVersionDAO(components.DB, components.IDGen),
|
|
}
|
|
}
|
|
|
|
func (t *toolRepoImpl) CreateDraftTool(ctx context.Context, tool *entity.ToolInfo) (toolID int64, err error) {
|
|
return t.toolDraftDAO.Create(ctx, tool)
|
|
}
|
|
|
|
func (t *toolRepoImpl) UpsertDraftTools(ctx context.Context, pluginID int64, tools []*entity.ToolInfo) (err error) {
|
|
apis := slices.Transform(tools, func(tool *entity.ToolInfo) dto.UniqueToolAPI {
|
|
return dto.UniqueToolAPI{
|
|
SubURL: tool.GetSubURL(),
|
|
Method: tool.GetMethod(),
|
|
}
|
|
})
|
|
|
|
existTools, err := t.toolDraftDAO.MGetWithAPIs(ctx, pluginID, apis, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tx := t.query.Begin()
|
|
if tx.Error != nil {
|
|
return tx.Error
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
if e := tx.Rollback(); e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
|
|
return
|
|
}
|
|
if err != nil {
|
|
if e := tx.Rollback(); e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
}
|
|
}()
|
|
|
|
createdTools := make([]*entity.ToolInfo, 0, len(tools))
|
|
updatedTools := make([]*entity.ToolInfo, 0, len(existTools))
|
|
|
|
for _, tool := range tools {
|
|
existTool, exist := existTools[dto.UniqueToolAPI{
|
|
SubURL: tool.GetSubURL(),
|
|
Method: tool.GetMethod(),
|
|
}]
|
|
if !exist {
|
|
createdTools = append(createdTools, tool)
|
|
continue
|
|
}
|
|
|
|
tool.ID = existTool.ID
|
|
|
|
updatedTools = append(updatedTools, tool)
|
|
}
|
|
|
|
if len(createdTools) > 0 {
|
|
_, err = t.toolDraftDAO.BatchCreateWithTX(ctx, tx, createdTools)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if len(updatedTools) > 0 {
|
|
err = t.toolDraftDAO.BatchUpdateWithTX(ctx, tx, updatedTools)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (t *toolRepoImpl) UpdateDraftTool(ctx context.Context, tool *entity.ToolInfo) (err error) {
|
|
return t.toolDraftDAO.Update(ctx, tool)
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetDraftTool(ctx context.Context, toolID int64) (tool *entity.ToolInfo, exist bool, err error) {
|
|
return t.toolDraftDAO.Get(ctx, toolID)
|
|
}
|
|
|
|
func (t *toolRepoImpl) MGetDraftTools(ctx context.Context, toolIDs []int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) {
|
|
var opt *dal.ToolSelectedOption
|
|
if len(opts) > 0 {
|
|
opt = &dal.ToolSelectedOption{}
|
|
for _, o := range opts {
|
|
o(opt)
|
|
}
|
|
}
|
|
return t.toolDraftDAO.MGet(ctx, toolIDs, opt)
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetPluginAllDraftTools(ctx context.Context, pluginID int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) {
|
|
var opt *dal.ToolSelectedOption
|
|
if len(opts) > 0 {
|
|
opt = &dal.ToolSelectedOption{}
|
|
for _, o := range opts {
|
|
o(opt)
|
|
}
|
|
}
|
|
return t.toolDraftDAO.GetAll(ctx, pluginID, opt)
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetPluginAllOnlineTools(ctx context.Context, pluginID int64) (tools []*entity.ToolInfo, err error) {
|
|
pi, exist := pluginConf.GetPluginProduct(pluginID)
|
|
if exist {
|
|
tis := pi.GetPluginAllTools()
|
|
tools = slices.Transform(tis, func(ti *pluginConf.ToolInfo) *entity.ToolInfo {
|
|
return ti.Info
|
|
})
|
|
|
|
return tools, nil
|
|
}
|
|
|
|
tools, err = t.toolDAO.GetAll(ctx, pluginID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return tools, nil
|
|
}
|
|
|
|
func (t *toolRepoImpl) ListPluginDraftTools(ctx context.Context, pluginID int64, pageInfo dto.PageInfo) (tools []*entity.ToolInfo, total int64, err error) {
|
|
return t.toolDraftDAO.List(ctx, pluginID, pageInfo)
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetDraftToolWithAPI(ctx context.Context, pluginID int64, api dto.UniqueToolAPI) (tool *entity.ToolInfo, exist bool, err error) {
|
|
return t.toolDraftDAO.GetWithAPI(ctx, pluginID, api)
|
|
}
|
|
|
|
func (t *toolRepoImpl) MGetDraftToolWithAPI(ctx context.Context, pluginID int64, apis []dto.UniqueToolAPI, opts ...ToolSelectedOptions) (tools map[dto.UniqueToolAPI]*entity.ToolInfo, err error) {
|
|
var opt *dal.ToolSelectedOption
|
|
if len(opts) > 0 {
|
|
opt = &dal.ToolSelectedOption{}
|
|
for _, o := range opts {
|
|
o(opt)
|
|
}
|
|
}
|
|
return t.toolDraftDAO.MGetWithAPIs(ctx, pluginID, apis, opt)
|
|
}
|
|
|
|
func (t *toolRepoImpl) DeleteDraftTool(ctx context.Context, toolID int64) (err error) {
|
|
return t.toolDraftDAO.Delete(ctx, toolID)
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetOnlineTool(ctx context.Context, toolID int64) (tool *entity.ToolInfo, exist bool, err error) {
|
|
ti, exist := pluginConf.GetToolProduct(toolID)
|
|
if exist {
|
|
return ti.Info, true, nil
|
|
}
|
|
|
|
return t.toolDAO.Get(ctx, toolID)
|
|
}
|
|
|
|
func (t *toolRepoImpl) MGetOnlineTools(ctx context.Context, toolIDs []int64, opts ...ToolSelectedOptions) (tools []*entity.ToolInfo, err error) {
|
|
toolProducts := pluginConf.MGetToolProducts(toolIDs)
|
|
|
|
tools = slices.Transform(toolProducts, func(tool *pluginConf.ToolInfo) *entity.ToolInfo {
|
|
return tool.Info
|
|
})
|
|
productToolIDs := slices.ToMap(toolProducts, func(tool *pluginConf.ToolInfo) (int64, bool) {
|
|
return tool.Info.ID, true
|
|
})
|
|
|
|
customToolIDs := make([]int64, 0, len(toolIDs))
|
|
for _, id := range toolIDs {
|
|
_, ok := productToolIDs[id]
|
|
if ok {
|
|
continue
|
|
}
|
|
customToolIDs = append(customToolIDs, id)
|
|
}
|
|
|
|
var opt *dal.ToolSelectedOption
|
|
if len(opts) > 0 {
|
|
opt = &dal.ToolSelectedOption{}
|
|
for _, o := range opts {
|
|
o(opt)
|
|
}
|
|
}
|
|
|
|
customTools, err := t.toolDAO.MGet(ctx, customToolIDs, opt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tools = append(tools, customTools...)
|
|
|
|
return tools, nil
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetVersionTool(ctx context.Context, vTool model.VersionTool) (tool *entity.ToolInfo, exist bool, err error) {
|
|
ti, exist := pluginConf.GetToolProduct(vTool.ToolID)
|
|
if exist {
|
|
return ti.Info, true, nil
|
|
}
|
|
|
|
return t.toolVersionDAO.Get(ctx, vTool)
|
|
}
|
|
|
|
func (t *toolRepoImpl) MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) (tools []*entity.ToolInfo, err error) {
|
|
tools, err = t.toolVersionDAO.MGet(ctx, versionTools)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return tools, nil
|
|
}
|
|
|
|
func (t *toolRepoImpl) BindDraftAgentTools(ctx context.Context, agentID int64, bindTools []*model.BindToolInfo) (err error) {
|
|
opt := &dal.ToolSelectedOption{
|
|
ToolID: true,
|
|
}
|
|
draftAgentTools, err := t.agentToolDraftDAO.GetAll(ctx, agentID, opt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var allToolIDs []int64
|
|
var localToolIDs []int64
|
|
var saasToolIDs []int64
|
|
var saasToolPluginIDs []int64
|
|
allToolIDs = slices.Transform(bindTools, func(tool *model.BindToolInfo) int64 {
|
|
return tool.ToolID
|
|
})
|
|
|
|
for _, tool := range bindTools {
|
|
if ptr.From(tool.Source) == bot_common.PluginFrom_FromSaas {
|
|
saasToolIDs = append(saasToolIDs, tool.ToolID)
|
|
saasToolPluginIDs = append(saasToolPluginIDs, tool.PluginID)
|
|
} else {
|
|
localToolIDs = append(localToolIDs, tool.ToolID)
|
|
|
|
}
|
|
}
|
|
|
|
draftAgentToolIDMap := slices.ToMap(draftAgentTools, func(tool *entity.ToolInfo) (int64, bool) {
|
|
return tool.ID, true
|
|
})
|
|
|
|
bindToolIDMap := slices.ToMap(allToolIDs, func(toolID int64) (int64, bool) {
|
|
return toolID, true
|
|
})
|
|
|
|
newLocalBindToolIDs := make([]int64, 0, len(allToolIDs))
|
|
newSaasBindToolIDs := make([]int64, 0, len(allToolIDs))
|
|
for _, toolID := range allToolIDs {
|
|
_, ok := draftAgentToolIDMap[toolID]
|
|
if ok {
|
|
continue
|
|
}
|
|
if slices.Contains(saasToolIDs, toolID) {
|
|
newSaasBindToolIDs = append(newSaasBindToolIDs, toolID)
|
|
} else {
|
|
newLocalBindToolIDs = append(newLocalBindToolIDs, toolID)
|
|
}
|
|
}
|
|
|
|
removeToolIDs := make([]int64, 0, len(draftAgentTools))
|
|
for toolID := range draftAgentToolIDMap {
|
|
_, ok := bindToolIDMap[toolID]
|
|
if ok {
|
|
continue
|
|
}
|
|
removeToolIDs = append(removeToolIDs, toolID)
|
|
}
|
|
|
|
var onlineTools []*entity.ToolInfo
|
|
if len(newSaasBindToolIDs) > 0 {
|
|
saasPluginTools, _, err := t.BatchGetSaasPluginToolsInfo(ctx, saasToolPluginIDs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, toolIDs := range saasPluginTools {
|
|
for _, toolInfo := range toolIDs {
|
|
if slices.Contains(saasToolIDs, toolInfo.ID) {
|
|
onlineTools = append(onlineTools, toolInfo)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
tx := t.query.Begin()
|
|
if tx.Error != nil {
|
|
return tx.Error
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
if e := tx.Rollback(); e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
|
|
return
|
|
}
|
|
if err != nil {
|
|
if e := tx.Rollback(); e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
}
|
|
}()
|
|
|
|
if len(newLocalBindToolIDs) > 0 {
|
|
localTools, err := t.MGetOnlineTools(ctx, newLocalBindToolIDs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
onlineTools = append(onlineTools, localTools...)
|
|
}
|
|
|
|
if len(onlineTools) > 0 {
|
|
err = t.agentToolDraftDAO.BatchCreateIgnoreConflictWithTX(ctx, tx, agentID, onlineTools)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
err = t.agentToolDraftDAO.DeleteWithTX(ctx, tx, agentID, removeToolIDs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetAgentPluginIDs(ctx context.Context, agentID int64) (pluginIDs []int64, err error) {
|
|
return t.agentToolDraftDAO.GetAllPluginIDs(ctx, agentID)
|
|
}
|
|
|
|
func (t *toolRepoImpl) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) (err error) {
|
|
tools, err := t.agentToolDraftDAO.GetAll(ctx, fromAgentID, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(tools) == 0 {
|
|
return nil
|
|
}
|
|
|
|
tx := t.query.Begin()
|
|
if tx.Error != nil {
|
|
return tx.Error
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
if e := tx.Rollback(); e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
err = fmt.Errorf("catch panic: %v\nstack=%s", r, string(debug.Stack()))
|
|
return
|
|
}
|
|
if err != nil {
|
|
if e := tx.Rollback(); e != nil {
|
|
logs.CtxErrorf(ctx, "rollback failed, err=%v", e)
|
|
}
|
|
}
|
|
}()
|
|
|
|
err = t.agentToolDraftDAO.BatchCreateWithTX(ctx, tx, toAgentID, tools)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetDraftAgentTool(ctx context.Context, agentID, toolID int64) (tool *entity.ToolInfo, exist bool, err error) {
|
|
return t.agentToolDraftDAO.Get(ctx, agentID, toolID)
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetDraftAgentToolWithToolName(ctx context.Context, agentID int64, toolName string) (tool *entity.ToolInfo, exist bool, err error) {
|
|
return t.agentToolDraftDAO.GetWithToolName(ctx, agentID, toolName)
|
|
}
|
|
|
|
func (t *toolRepoImpl) MGetDraftAgentTools(ctx context.Context, agentID int64, toolIDs []int64) (tools []*entity.ToolInfo, err error) {
|
|
return t.agentToolDraftDAO.MGet(ctx, agentID, toolIDs)
|
|
}
|
|
|
|
func (t *toolRepoImpl) UpdateDraftAgentTool(ctx context.Context, req *UpdateDraftAgentToolRequest) (err error) {
|
|
return t.agentToolDraftDAO.UpdateWithToolName(ctx, req.AgentID, req.ToolName, req.Tool)
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetSpaceAllDraftAgentTools(ctx context.Context, agentID int64) (tools []*entity.ToolInfo, err error) {
|
|
return t.agentToolDraftDAO.GetAll(ctx, agentID, nil)
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetVersionAgentTool(ctx context.Context, agentID int64, vAgentTool model.VersionAgentTool) (tool *entity.ToolInfo, exist bool, err error) {
|
|
return t.agentToolVersionDAO.Get(ctx, agentID, vAgentTool)
|
|
}
|
|
|
|
func (t *toolRepoImpl) GetVersionAgentToolWithToolName(ctx context.Context, req *GetVersionAgentToolWithToolNameRequest) (tool *entity.ToolInfo, exist bool, err error) {
|
|
return t.agentToolVersionDAO.GetWithToolName(ctx, req.AgentID, req.ToolName, req.AgentVersion)
|
|
}
|
|
|
|
func (t *toolRepoImpl) MGetVersionAgentTool(ctx context.Context, agentID int64, vAgentTools []model.VersionAgentTool) (tools []*entity.ToolInfo, err error) {
|
|
return t.agentToolVersionDAO.MGet(ctx, agentID, vAgentTools)
|
|
}
|
|
|
|
func (t *toolRepoImpl) BatchCreateVersionAgentTools(ctx context.Context, agentID int64, agentVersion string, tools []*entity.ToolInfo) (err error) {
|
|
return t.agentToolVersionDAO.BatchCreate(ctx, agentID, agentVersion, tools)
|
|
}
|
|
|
|
// BatchGetSaasPluginToolsInfo retrieves tools information for SaaS plugins
|
|
func (t *toolRepoImpl) BatchGetSaasPluginToolsInfo(ctx context.Context, pluginIDs []int64) (tools map[int64][]*entity.ToolInfo, plugins map[int64]*entity.PluginInfo, err error) {
|
|
if len(pluginIDs) == 0 {
|
|
return make(map[int64][]*entity.ToolInfo), make(map[int64]*entity.PluginInfo), nil
|
|
}
|
|
|
|
client := saasapi.NewCozeAPIClient()
|
|
|
|
var idStrings []string
|
|
for _, id := range pluginIDs {
|
|
idStrings = append(idStrings, strconv.FormatInt(id, 10))
|
|
}
|
|
idsStr := strings.Join(idStrings, ",")
|
|
|
|
queryParams := map[string]interface{}{
|
|
"ids": idsStr,
|
|
}
|
|
|
|
resp, err := client.GetWithQuery(ctx, "/v1/plugins/mget", queryParams)
|
|
if err != nil {
|
|
return nil, nil, errorx.Wrapf(err, "failed to call coze.cn /v1/plugins/mget API")
|
|
}
|
|
|
|
var apiResp dto.SaasPluginToolsListResponse
|
|
|
|
if err := json.Unmarshal(resp.Data, &apiResp); err != nil {
|
|
return nil, nil, errorx.Wrapf(err, "failed to parse coze.cn API response")
|
|
}
|
|
|
|
result := make(map[int64][]*entity.ToolInfo)
|
|
plugins = make(map[int64]*entity.PluginInfo)
|
|
|
|
for _, plugin := range apiResp.Items {
|
|
|
|
pluginID, err := strconv.ParseInt(plugin.PluginID, 10, 64)
|
|
if err != nil {
|
|
return nil, nil, errorx.Wrapf(err, "failed to parse plugin ID %s", plugin.PluginID)
|
|
}
|
|
|
|
pluginInfo := convertCozePluginToEntity(&plugin)
|
|
plugins[pluginID] = pluginInfo
|
|
|
|
toolInfos := make([]*entity.ToolInfo, 0, len(plugin.Tools))
|
|
|
|
for _, tool := range plugin.Tools {
|
|
|
|
openapi3Operation, err := api.APIParamsToOpenapiOperation(convertFromJsonSchema(tool.InputSchema), convertFromJsonSchema(tool.OutputSchema))
|
|
if err != nil {
|
|
return nil, nil, errorx.Wrapf(err, "failed to convert input schema to openapi operation parameters")
|
|
}
|
|
openapi3Operation.OperationID = tool.Name
|
|
openapi3Operation.Summary = tool.Description
|
|
operation := &model.Openapi3Operation{
|
|
Operation: openapi3Operation,
|
|
}
|
|
id, err := strconv.ParseInt(tool.ToolID, 10, 64)
|
|
if err != nil {
|
|
return nil, nil, errorx.Wrapf(err, "failed to parse tool ID %s", tool.ToolID)
|
|
}
|
|
toolInfo := &entity.ToolInfo{
|
|
ID: id,
|
|
PluginID: pluginID,
|
|
Operation: operation,
|
|
Source: ptr.Of(bot_common.PluginFrom_FromSaas),
|
|
Version: ptr.Of("0"),
|
|
Method: ptr.Of("POST"),
|
|
SubURL: ptr.Of(convertSaasToolSubUrl(pluginID)),
|
|
}
|
|
|
|
toolInfos = append(toolInfos, toolInfo)
|
|
}
|
|
|
|
result[pluginID] = toolInfos
|
|
}
|
|
|
|
return result, plugins, nil
|
|
}
|
|
|
|
func convertCozePluginToEntity(cozePlugin *dto.SaasPluginToolsList) *entity.PluginInfo {
|
|
var pluginID int64
|
|
if id, err := strconv.ParseInt(cozePlugin.PluginID, 10, 64); err == nil {
|
|
pluginID = id
|
|
}
|
|
|
|
manifest := &model.PluginManifest{
|
|
SchemaVersion: "v1",
|
|
NameForModel: cozePlugin.Name,
|
|
NameForHuman: cozePlugin.Name,
|
|
DescriptionForModel: cozePlugin.Description,
|
|
DescriptionForHuman: cozePlugin.Description,
|
|
LogoURL: cozePlugin.IconURL,
|
|
Auth: &model.AuthV2{
|
|
Type: func() consts.AuthzType {
|
|
if !cozePlugin.IsCallAvailable {
|
|
return consts.AuthTypeOfSaasInstalled
|
|
}
|
|
return consts.AuthzTypeOfNone
|
|
}(),
|
|
},
|
|
API: model.APIDesc{
|
|
Type: "openapi",
|
|
},
|
|
}
|
|
|
|
pluginInfo := &model.PluginInfo{
|
|
ID: pluginID,
|
|
PluginType: pluginCommon.PluginType_PLUGIN,
|
|
SpaceID: 0,
|
|
DeveloperID: 0,
|
|
APPID: nil,
|
|
IconURL: &cozePlugin.IconURL,
|
|
ServerURL: ptr.Of("https://api.coze.cn"),
|
|
CreatedAt: cozePlugin.CreatedAt,
|
|
UpdatedAt: cozePlugin.UpdatedAt,
|
|
Manifest: manifest,
|
|
Source: ptr.Of(bot_common.PluginFrom_FromSaas),
|
|
}
|
|
|
|
return entity.NewPluginInfo(pluginInfo)
|
|
}
|
|
|
|
func convertSaasToolSubUrl(pluginID int64) string {
|
|
return fmt.Sprintf("/v1/plugins/%d/tools/call", pluginID)
|
|
}
|
|
|
|
// convertFromJsonSchema converts JSON schema to API parameters
|
|
func convertFromJsonSchema(schema *dto.JsonSchema) []*pluginCommon.APIParameter {
|
|
if schema == nil {
|
|
return []*pluginCommon.APIParameter{}
|
|
}
|
|
|
|
return convertJsonSchemaToParameters(schema, pluginCommon.ParameterLocation_Body)
|
|
}
|
|
|
|
// ConvertFromJsonSchemaForTest is an exported version for testing purposes
|
|
func ConvertFromJsonSchemaForTest(schema *dto.JsonSchema) []*pluginCommon.APIParameter {
|
|
return convertFromJsonSchema(schema)
|
|
}
|
|
|
|
// convertJsonSchemaToParameters recursively converts JSON schema to API parameters
|
|
func convertJsonSchemaToParameters(schema *dto.JsonSchema, location pluginCommon.ParameterLocation) []*pluginCommon.APIParameter {
|
|
if schema == nil {
|
|
return []*pluginCommon.APIParameter{}
|
|
}
|
|
|
|
var parameters []*pluginCommon.APIParameter
|
|
|
|
// Handle object type with properties
|
|
if schema.Type == dto.JsonSchemaType_OBJECT && len(schema.Properties) > 0 {
|
|
// Create a set of required fields for quick lookup
|
|
requiredFields := make(map[string]bool)
|
|
for _, field := range schema.Required {
|
|
requiredFields[field] = true
|
|
}
|
|
|
|
// Convert each property to a parameter
|
|
for name, propSchema := range schema.Properties {
|
|
if propSchema == nil {
|
|
continue
|
|
}
|
|
|
|
param := &pluginCommon.APIParameter{
|
|
Name: name,
|
|
Desc: propSchema.Description,
|
|
IsRequired: requiredFields[name],
|
|
Type: mapJsonSchemaTypeToParameterType(propSchema.Type),
|
|
Location: location,
|
|
}
|
|
|
|
// Handle nested object properties
|
|
if propSchema.Type == dto.JsonSchemaType_OBJECT && len(propSchema.Properties) > 0 {
|
|
param.SubParameters = convertJsonSchemaToParameters(propSchema, location)
|
|
}
|
|
|
|
// Handle array properties
|
|
if propSchema.Type == dto.JsonSchemaType_ARRAY && propSchema.Items != nil {
|
|
// Create a parameter for the array item
|
|
arrayItemParam := &pluginCommon.APIParameter{
|
|
Name: "[Array Item]",
|
|
Desc: propSchema.Items.Description,
|
|
IsRequired: true, // Array items are typically required
|
|
Type: mapJsonSchemaTypeToParameterType(propSchema.Items.Type),
|
|
Location: location,
|
|
}
|
|
|
|
// If array item is an object, recursively convert its properties
|
|
if propSchema.Items.Type == dto.JsonSchemaType_OBJECT && len(propSchema.Items.Properties) > 0 {
|
|
arrayItemParam.SubParameters = convertJsonSchemaToParameters(propSchema.Items, location)
|
|
}
|
|
|
|
// If array item is also an array, handle nested arrays
|
|
if propSchema.Items.Type == dto.JsonSchemaType_ARRAY && propSchema.Items.Items != nil {
|
|
// For nested arrays, create sub-parameters recursively
|
|
nestedArraySchema := &dto.JsonSchema{
|
|
Type: propSchema.Items.Type,
|
|
Items: propSchema.Items.Items,
|
|
Properties: propSchema.Items.Properties,
|
|
Required: propSchema.Items.Required,
|
|
}
|
|
arrayItemParam.SubParameters = convertJsonSchemaToParameters(nestedArraySchema, location)
|
|
}
|
|
|
|
param.SubParameters = []*pluginCommon.APIParameter{arrayItemParam}
|
|
}
|
|
|
|
parameters = append(parameters, param)
|
|
}
|
|
} else if schema.Type == dto.JsonSchemaType_ARRAY && schema.Items != nil {
|
|
// Handle top-level array (though this is less common for API parameters)
|
|
arrayItemParam := &pluginCommon.APIParameter{
|
|
Name: "[Array Item]",
|
|
Desc: schema.Items.Description,
|
|
IsRequired: true,
|
|
Type: mapJsonSchemaTypeToParameterType(schema.Items.Type),
|
|
Location: location,
|
|
}
|
|
|
|
if schema.Items.Type == dto.JsonSchemaType_OBJECT && len(schema.Items.Properties) > 0 {
|
|
arrayItemParam.SubParameters = convertJsonSchemaToParameters(schema.Items, location)
|
|
}
|
|
|
|
parameters = append(parameters, arrayItemParam)
|
|
}
|
|
|
|
return parameters
|
|
}
|
|
|
|
// mapJsonSchemaTypeToParameterType maps JSON schema types to parameter types
|
|
func mapJsonSchemaTypeToParameterType(schemaType dto.JsonSchemaType) pluginCommon.ParameterType {
|
|
switch schemaType {
|
|
case dto.JsonSchemaType_STRING:
|
|
return pluginCommon.ParameterType_String
|
|
case dto.JsonSchemaType_NUMBER:
|
|
return pluginCommon.ParameterType_Number
|
|
case dto.JsonSchemaType_INTEGER:
|
|
return pluginCommon.ParameterType_Integer
|
|
case dto.JsonSchemaType_BOOLEAN:
|
|
return pluginCommon.ParameterType_Bool
|
|
case dto.JsonSchemaType_OBJECT:
|
|
return pluginCommon.ParameterType_Object
|
|
case dto.JsonSchemaType_ARRAY:
|
|
return pluginCommon.ParameterType_Array
|
|
default:
|
|
return pluginCommon.ParameterType_String
|
|
}
|
|
}
|