feat: 支持多租户多模型对话及文档去重优化

This commit is contained in:
2026-04-16 15:47:37 +08:00
parent 4ead3f82cf
commit 27b1dd3c27
34 changed files with 2188 additions and 315 deletions

View File

@@ -3,67 +3,211 @@ package eino
import (
"context"
"fmt"
"rag/consts/model"
"rag/model/entity"
"gitea.com/red-future/common/jaeger"
"gitea.com/red-future/common/utils"
"github.com/cloudwego/eino-ext/components/embedding/ark"
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino-ext/components/embedding/qianfan"
"github.com/cloudwego/eino-ext/components/embedding/tencentcloud"
"github.com/cloudwego/eino/components/embedding"
"github.com/gogf/gf/v2/frame/g"
"github.com/golang/glog"
"github.com/gogf/gf/v2/util/gconv"
)
// 全局只初始化一次
var (
EmbedderArk *ark.Embedder
EmbedderDashscope *dashscope.Embedder
EmbedderOpenAI *openai.Embedder
)
type EmbedderSet struct {
Ark *ark.Embedder
Ollama *ollama.Embedder
OpenAI *openai.Embedder
Qianfan *qianfan.Embedder
TencentCloud *tencentcloud.Embedder
DashScope *dashscope.Embedder
}
// 全局租户容器key=tenantIdvalue=该租户的向量模型
var tenantEmbedders = make(map[uint64]*EmbedderSet)
func init() {
ctx := context.Background()
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
var err error
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
cfg := &ark.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
apiTypeVal := ark.APIType(apiType)
cfg.APIType = &apiTypeVal
}
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
case providerOpenai:
chatModelConfig := &openai.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
case providerDashscope:
cfg := &dashscope.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
}
if err != nil {
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
}
}
ctx, span := jaeger.NewSpan(ctx, "InitAllVector")
defer span.End()
InitAllVector(ctx)
return
}
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
return EmbedderArk.EmbedStrings(ctx, texts)
case providerOpenai:
return EmbedderOpenAI.EmbedStrings(ctx, texts)
case providerDashscope:
return EmbedderDashscope.EmbedStrings(ctx, texts)
// ===================== 1. 服务启动时调用:初始化所有租户 =====================
func InitAllVector(ctx context.Context) {
//list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
// ModelType: model.ModelTypeVector.Code(),
//})
//if err != nil {
// g.Log().Errorf(ctx, "获取所有租户ID失败: %v", err)
// return
//}
//
//for _, l := range list {
// err = InitVector(ctx, l)
// if err != nil {
// g.Log().Errorf(ctx, "初始化租户[%v]的向量模型失败: %v", l.TenantId, err)
// continue
// }
//}
modelDO := new(entity.Model)
modelDO.TenantId = 1
modelDO.ConfigType = model.ModelConfigTypeVectorDashScope.Code()
var cfg entity.VectorModelConfigDashScope
cfg.APIKey = "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
cfg.Model = "text-embedding-v3"
modelDO.ConfigContent = gconv.Map(&cfg)
err := InitVector(ctx, modelDO)
if err != nil {
g.Log().Errorf(ctx, "初始化向量模型失败: %v", err)
return
}
return nil, fmt.Errorf("unsupported provider: %v", provider)
}
func InitVector(ctx context.Context, modelDO *entity.Model) (err error) {
set := &EmbedderSet{}
switch *modelDO.ConfigType {
case *model.ModelConfigTypeVectorArk.Code():
// 解析 Ark 向量配置
var cfg entity.VectorModelConfigArk
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析Ark向量配置失败: %v", err)
}
arkCfg := &ark.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
}
if !g.IsEmpty(cfg.APIType) {
arkCfg.APIType = new(ark.APIType(cfg.APIType))
}
set.Ark, err = ark.NewEmbedder(ctx, arkCfg)
case *model.ModelConfigTypeVectorOllama.Code():
// 解析 Ollama 向量配置
var cfg entity.VectorModelConfigOllama
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析Ollama向量配置失败: %v", err)
}
set.Ollama, err = ollama.NewEmbedder(ctx, &ollama.EmbeddingConfig{
BaseURL: cfg.BaseURL,
Model: cfg.Model,
})
case *model.ModelConfigTypeVectorOpenAI.Code():
// 解析 OpenAI 向量配置
var cfg entity.VectorModelConfigOpenAI
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析OpenAI向量配置失败: %v", err)
}
openaiCfg := &openai.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
ByAzure: cfg.ByAzure,
BaseURL: cfg.BaseURL,
APIVersion: cfg.APIVersion,
}
set.OpenAI, err = openai.NewEmbedder(ctx, openaiCfg)
case *model.ModelConfigTypeVectorQianfan.Code():
// 解析 千帆 向量配置
var cfg entity.VectorModelConfigQianfan
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析千帆向量配置失败: %v", err)
}
qcfg := qianfan.GetQianfanSingletonConfig()
qcfg.AccessKey = cfg.AccessKey
qcfg.SecretKey = cfg.SecretKey
set.Qianfan, err = qianfan.NewEmbedder(ctx, &qianfan.EmbeddingConfig{
Model: cfg.Model,
})
case *model.ModelConfigTypeVectorTencentCloud.Code():
// 解析 腾讯云 向量配置
var cfg entity.VectorModelConfigTencentCloud
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析腾讯云向量配置失败: %v", err)
}
set.TencentCloud, err = tencentcloud.NewEmbedder(ctx, &tencentcloud.EmbeddingConfig{
SecretID: cfg.SecretID,
SecretKey: cfg.SecretKey,
Region: cfg.Region,
})
case *model.ModelConfigTypeVectorDashScope.Code():
// 解析 阿里 dashscope 向量配置
var cfg entity.VectorModelConfigDashScope
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析阿里dashscope向量配置失败: %v", err)
}
set.DashScope, err = dashscope.NewEmbedder(ctx, &dashscope.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
})
default:
return fmt.Errorf("不支持的向量模型配置类型: %v", *modelDO.ConfigType)
}
// 统一错误处理
if err != nil {
return fmt.Errorf("初始化向量模型失败: %v", err)
}
// 直接存入 map无锁重复初始化会直接覆盖
tenantEmbedders[modelDO.TenantId] = set
g.Log().Infof(ctx, "向量模型[%v]初始化成功", modelDO.ConfigType)
return
}
func GetTenantEmbedder(tenantId uint64) (*EmbedderSet, error) {
set := tenantEmbedders[tenantId]
if set == nil {
return nil, fmt.Errorf("租户[%v]的向量模型未初始化", tenantId)
}
return set, nil
}
func GetTenantEmbedderByType(ctx context.Context, configType model.ModelConfigType) (embedding.Embedder, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
set, err := GetTenantEmbedder(userInfo.TenantId)
if set == nil {
return nil, err
}
switch *configType {
case *model.ModelConfigTypeVectorArk.Code():
return set.Ark, nil
case *model.ModelConfigTypeVectorOllama.Code():
return set.Ollama, nil
case *model.ModelConfigTypeVectorOpenAI.Code():
return set.OpenAI, nil
case *model.ModelConfigTypeVectorQianfan.Code():
return set.Qianfan, nil
case *model.ModelConfigTypeVectorTencentCloud.Code():
return set.TencentCloud, nil
case *model.ModelConfigTypeVectorDashScope.Code():
return set.DashScope, nil
default:
return nil, fmt.Errorf("不支持的向量模型配置类型: %v", *configType)
}
}
func RefreshTenantEmbedder(ctx context.Context, modelDO *entity.Model) error {
delete(tenantEmbedders, modelDO.TenantId)
return InitVector(ctx, modelDO)
}