6bed393c12
Backend Tests / backend-unit-test (push) Has been cancelled
Backend Tests / benchmark-test (push) Has been cancelled
CI@main / Node.js v22 (ubuntu-latest) (push) Has been cancelled
Thrift Syntax Validation / validate-thrift (push) Has been cancelled
License Check / License Check (push) Has been cancelled
210 lines
6.7 KiB
Go
210 lines
6.7 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
package impl
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/admin/config"
|
|
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore"
|
|
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/elasticsearch"
|
|
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/milvus"
|
|
searchstoreOceanbase "github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/oceanbase"
|
|
"github.com/coze-dev/coze-studio/backend/infra/document/searchstore/impl/vikingdb"
|
|
"github.com/coze-dev/coze-studio/backend/infra/embedding"
|
|
"github.com/coze-dev/coze-studio/backend/infra/embedding/impl"
|
|
"github.com/coze-dev/coze-studio/backend/infra/es/impl/es"
|
|
"github.com/coze-dev/coze-studio/backend/infra/oceanbase"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/envkey"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
)
|
|
|
|
type Manager = searchstore.Manager
|
|
|
|
func New(ctx context.Context, conf *config.KnowledgeConfig, es es.Client) ([]Manager, error) {
|
|
// es full text search
|
|
esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es})
|
|
|
|
// vector search
|
|
mgr, err := getVectorStore(ctx, conf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init vector store failed, err=%w", err)
|
|
}
|
|
|
|
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
|
|
}
|
|
|
|
func getVectorStore(ctx context.Context, conf *config.KnowledgeConfig) (searchstore.Manager, error) {
|
|
vsType := os.Getenv("VECTOR_STORE_TYPE")
|
|
|
|
switch vsType {
|
|
case "milvus":
|
|
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
|
defer cancel()
|
|
|
|
var (
|
|
milvusAddr = os.Getenv("MILVUS_ADDR")
|
|
user = os.Getenv("MILVUS_USER")
|
|
password = os.Getenv("MILVUS_PASSWORD")
|
|
milvusToken = os.Getenv("MILVUS_TOKEN")
|
|
)
|
|
mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{
|
|
Address: milvusAddr,
|
|
Username: user,
|
|
Password: password,
|
|
APIKey: milvusToken,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
|
|
}
|
|
|
|
emb, err := impl.GetEmbedding(ctx, conf.EmbeddingConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
|
|
}
|
|
|
|
mgr, err := milvus.NewManager(&milvus.ManagerConfig{
|
|
Client: mc,
|
|
Embedding: emb,
|
|
EnableHybrid: ptr.Of(true),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
|
|
}
|
|
|
|
return mgr, nil
|
|
case "vikingdb":
|
|
var (
|
|
host = os.Getenv("VIKING_DB_HOST")
|
|
region = os.Getenv("VIKING_DB_REGION")
|
|
ak = os.Getenv("VIKING_DB_AK")
|
|
sk = os.Getenv("VIKING_DB_SK")
|
|
scheme = os.Getenv("VIKING_DB_SCHEME")
|
|
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
|
|
)
|
|
if ak == "" || sk == "" {
|
|
return nil, fmt.Errorf("invalid vikingdb ak / sk")
|
|
}
|
|
if host == "" {
|
|
host = "api-vikingdb.volces.com"
|
|
}
|
|
if region == "" {
|
|
region = "cn-beijing"
|
|
}
|
|
if scheme == "" {
|
|
scheme = "https"
|
|
}
|
|
|
|
var embConfig *vikingdb.VikingEmbeddingConfig
|
|
if modelName != "" {
|
|
embName := vikingdb.VikingEmbeddingModelName(modelName)
|
|
if embName.Dimensions() == 0 {
|
|
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
|
|
}
|
|
embConfig = &vikingdb.VikingEmbeddingConfig{
|
|
UseVikingEmbedding: true,
|
|
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
|
|
ModelName: embName,
|
|
ModelVersion: embName.ModelVersion(),
|
|
DenseWeight: ptr.Of(0.2),
|
|
BuiltinEmbedding: nil,
|
|
}
|
|
} else {
|
|
builtinEmbedding, err := impl.GetEmbedding(ctx, conf.EmbeddingConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
|
|
}
|
|
|
|
embConfig = &vikingdb.VikingEmbeddingConfig{
|
|
UseVikingEmbedding: false,
|
|
EnableHybrid: false,
|
|
BuiltinEmbedding: builtinEmbedding,
|
|
}
|
|
}
|
|
|
|
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
|
|
mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{
|
|
Service: svc,
|
|
IndexingConfig: nil, // use default config
|
|
EmbeddingConfig: embConfig,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
|
|
}
|
|
|
|
return mgr, nil
|
|
|
|
case "oceanbase":
|
|
emb, err := impl.GetEmbedding(ctx, conf.EmbeddingConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init oceanbase embedding failed, err=%w", err)
|
|
}
|
|
|
|
var (
|
|
host = os.Getenv("OCEANBASE_HOST")
|
|
port = os.Getenv("OCEANBASE_PORT")
|
|
user = os.Getenv("OCEANBASE_USER")
|
|
password = os.Getenv("OCEANBASE_PASSWORD")
|
|
database = os.Getenv("OCEANBASE_DATABASE")
|
|
)
|
|
if host == "" || port == "" || user == "" || password == "" || database == "" {
|
|
return nil, fmt.Errorf("invalid oceanbase configuration: host, port, user, password, database are required")
|
|
}
|
|
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
|
user, password, host, port, database)
|
|
|
|
client, err := oceanbase.NewOceanBaseClient(dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init oceanbase client failed, err=%w", err)
|
|
}
|
|
|
|
if err = client.InitDatabase(ctx); err != nil {
|
|
return nil, fmt.Errorf("init oceanbase database failed, err=%w", err)
|
|
}
|
|
|
|
// Get configuration from environment variables with defaults
|
|
batchSize := envkey.GetIntD("OCEANBASE_BATCH_SIZE", 100)
|
|
enableCache := envkey.GetBoolD("OCEANBASE_ENABLE_CACHE", true)
|
|
cacheTTL := time.Duration(envkey.GetI32D("OCEANBASE_CACHE_TTL", 300)) * time.Second
|
|
maxConnections := envkey.GetIntD("OCEANBASE_MAX_CONNECTIONS", 100)
|
|
connTimeout := time.Duration(envkey.GetI32D("OCEANBASE_CONN_TIMEOUT", 30)) * time.Second
|
|
|
|
managerConfig := &searchstoreOceanbase.ManagerConfig{
|
|
Client: client,
|
|
Embedding: emb,
|
|
BatchSize: batchSize,
|
|
EnableCache: enableCache,
|
|
CacheTTL: cacheTTL,
|
|
MaxConnections: maxConnections,
|
|
ConnTimeout: connTimeout,
|
|
}
|
|
mgr, err := searchstoreOceanbase.NewManager(managerConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init oceanbase vector store failed, err=%w", err)
|
|
}
|
|
return mgr, nil
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
|
|
}
|
|
}
|