6bed393c12
Backend Tests / backend-unit-test (push) Has been cancelled
Backend Tests / benchmark-test (push) Has been cancelled
CI@main / Node.js v22 (ubuntu-latest) (push) Has been cancelled
Thrift Syntax Validation / validate-thrift (push) Has been cancelled
License Check / License Check (push) Has been cancelled
195 lines
5.4 KiB
Go
195 lines
5.4 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package knowledge
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/admin/config"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/envkey"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/kvstore"
|
|
"github.com/coze-dev/coze-studio/backend/types/consts"
|
|
)
|
|
|
|
const (
|
|
knowledgeConfigKey = "knowledge_config"
|
|
)
|
|
|
|
type KnowledgeConfig struct {
|
|
knowledge *kvstore.KVStore[config.KnowledgeConfig]
|
|
}
|
|
|
|
func NewKnowledgeConfig(db *gorm.DB) *KnowledgeConfig {
|
|
return &KnowledgeConfig{
|
|
knowledge: kvstore.New[config.KnowledgeConfig](db),
|
|
}
|
|
}
|
|
|
|
func (c *KnowledgeConfig) GetKnowledgeConfig(ctx context.Context) (*config.KnowledgeConfig, error) {
|
|
conf, err := c.knowledge.Get(ctx, consts.KnowledgeConfigSpace, knowledgeConfigKey)
|
|
if err != nil {
|
|
if errors.Is(err, kvstore.ErrKeyNotFound) {
|
|
return getKnowledgeConfigurationFromOldConfig(), nil
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return conf, nil
|
|
}
|
|
|
|
func getKnowledgeConfigurationFromOldConfig() *config.KnowledgeConfig {
|
|
embeddingTypeStr := strings.ToUpper(os.Getenv("EMBEDDING_TYPE"))
|
|
baseURLKey := fmt.Sprintf("%s_EMBEDDING_BASE_URL", embeddingTypeStr)
|
|
apiKeyKey := fmt.Sprintf("%s_EMBEDDING_API_KEY", embeddingTypeStr)
|
|
modelKey := fmt.Sprintf("%s_EMBEDDING_MODEL", embeddingTypeStr)
|
|
dimsKey := fmt.Sprintf("%s_EMBEDDING_DIMS", embeddingTypeStr)
|
|
|
|
conf := &config.KnowledgeConfig{
|
|
EmbeddingConfig: &config.EmbeddingConfig{
|
|
Type: getEmbeddingType(),
|
|
MaxBatchSize: envkey.GetI32D("EMBEDDING_MAX_BATCH_SIZE", 100),
|
|
Connection: &config.EmbeddingConnection{
|
|
BaseConnInfo: &config.BaseConnectionInfo{
|
|
BaseURL: envkey.GetString(baseURLKey),
|
|
APIKey: envkey.GetString(apiKeyKey),
|
|
Model: envkey.GetString(modelKey),
|
|
},
|
|
EmbeddingInfo: &config.EmbeddingInfo{
|
|
Dims: envkey.GetI32D(dimsKey, 1024),
|
|
},
|
|
Ark: &config.ArkConnInfo{
|
|
APIType: envkey.GetStringD("ARK_EMBEDDING_API_TYPE", "text_api"),
|
|
},
|
|
Openai: &config.OpenAIConnInfo{
|
|
ByAzure: envkey.GetBoolD("OPENAI_EMBEDDING_BY_AZURE", false),
|
|
APIVersion: envkey.GetString("OPENAI_EMBEDDING_API_VERSION"),
|
|
},
|
|
|
|
Gemini: &config.GeminiConnInfo{
|
|
Backend: envkey.GetI32D("GEMINI_EMBEDDING_BACKEND", 1),
|
|
Project: envkey.GetString("GEMINI_EMBEDDING_PROJECT"),
|
|
Location: envkey.GetString("GEMINI_EMBEDDING_LOCATION"),
|
|
},
|
|
HTTP: &config.HttpConnection{
|
|
Address: envkey.GetString("HTTP_EMBEDDING_ADDR"),
|
|
},
|
|
},
|
|
},
|
|
RerankConfig: &config.RerankConfig{
|
|
Type: getRerankType(),
|
|
VikingdbConfig: &config.VikingDBConfig{
|
|
Ak: envkey.GetString("VIKINGDB_RERANK_AK"),
|
|
Sk: envkey.GetString("VIKINGDB_RERANK_SK"),
|
|
Host: envkey.GetString("VIKINGDB_RERANK_HOST"),
|
|
Region: envkey.GetString("VIKINGDB_RERANK_REGION"),
|
|
Model: envkey.GetString("VIKINGDB_RERANK_MODEL"),
|
|
},
|
|
},
|
|
OcrConfig: &config.OCRConfig{
|
|
Type: getOCRType(),
|
|
VolcengineAk: envkey.GetString("VE_OCR_AK"),
|
|
VolcengineSk: envkey.GetString("VE_OCR_SK"),
|
|
PaddleocrAPIURL: envkey.GetString("PADDLEOCR_OCR_API_URL"),
|
|
},
|
|
ParserConfig: &config.ParserConfig{
|
|
Type: getParserType(),
|
|
PaddleocrStructureAPIURL: envkey.GetString("PADDLEOCR_STRUCTURE_API_URL"),
|
|
},
|
|
BuiltinModelID: 0,
|
|
}
|
|
|
|
if conf.EmbeddingConfig.Type == config.EmbeddingType_Ark {
|
|
conf.EmbeddingConfig.Connection.BaseConnInfo.APIKey = getArkEmbeddingAPIKey()
|
|
}
|
|
|
|
return conf
|
|
}
|
|
|
|
func getArkEmbeddingAPIKey() string {
|
|
if len(envkey.GetString("ARK_EMBEDDING_API_KEY")) > 0 {
|
|
return envkey.GetString("ARK_EMBEDDING_API_KEY")
|
|
}
|
|
return envkey.GetString("ARK_EMBEDDING_AK")
|
|
}
|
|
|
|
func (c *KnowledgeConfig) SaveKnowledgeConfig(ctx context.Context, v *config.KnowledgeConfig) error {
|
|
return c.knowledge.Save(ctx, consts.KnowledgeConfigSpace, knowledgeConfigKey, v)
|
|
}
|
|
|
|
func getEmbeddingType() config.EmbeddingType {
|
|
embeddingTypeStr := os.Getenv("EMBEDDING_TYPE")
|
|
|
|
switch embeddingTypeStr {
|
|
case "openai":
|
|
return config.EmbeddingType_OpenAI
|
|
case "ark":
|
|
return config.EmbeddingType_Ark
|
|
case "ollama":
|
|
return config.EmbeddingType_Ollama
|
|
case "gemini":
|
|
return config.EmbeddingType_Gemini
|
|
case "http":
|
|
return config.EmbeddingType_HTTP
|
|
}
|
|
|
|
return config.EmbeddingType_Ark
|
|
}
|
|
|
|
func getRerankType() config.RerankType {
|
|
embeddingTypeStr := os.Getenv("RERANK_TYPE")
|
|
|
|
switch embeddingTypeStr {
|
|
case "rrf":
|
|
return config.RerankType_RRF
|
|
case "vikingdb":
|
|
return config.RerankType_VikingDB
|
|
}
|
|
|
|
return config.RerankType_RRF
|
|
}
|
|
|
|
func getOCRType() config.OCRType {
|
|
ocrTypeStr := os.Getenv("OCR_TYPE")
|
|
|
|
switch ocrTypeStr {
|
|
case "ve":
|
|
return config.OCRType_Volcengine
|
|
case "paddleocr":
|
|
return config.OCRType_Paddleocr
|
|
}
|
|
|
|
return config.OCRType_Volcengine
|
|
}
|
|
|
|
func getParserType() config.ParserType {
|
|
parserTypeStr := os.Getenv("PARSER_TYPE")
|
|
|
|
switch parserTypeStr {
|
|
case "paddleocr":
|
|
return config.ParserType_Paddleocr
|
|
}
|
|
|
|
return config.ParserType_builtin
|
|
}
|