feat: 支持多租户多模型对话及文档去重优化
This commit is contained in:
@@ -3,67 +3,211 @@ package eino
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/consts/model"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/jaeger"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ark"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/qianfan"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/tencentcloud"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/golang/glog"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// 全局只初始化一次
|
||||
var (
|
||||
EmbedderArk *ark.Embedder
|
||||
EmbedderDashscope *dashscope.Embedder
|
||||
EmbedderOpenAI *openai.Embedder
|
||||
)
|
||||
type EmbedderSet struct {
|
||||
Ark *ark.Embedder
|
||||
Ollama *ollama.Embedder
|
||||
OpenAI *openai.Embedder
|
||||
Qianfan *qianfan.Embedder
|
||||
TencentCloud *tencentcloud.Embedder
|
||||
DashScope *dashscope.Embedder
|
||||
}
|
||||
|
||||
// 全局租户容器:key=tenantId,value=该租户的向量模型
|
||||
var tenantEmbedders = make(map[uint64]*EmbedderSet)
|
||||
|
||||
func init() {
|
||||
ctx := context.Background()
|
||||
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
|
||||
var err error
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
cfg := &ark.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
|
||||
apiTypeVal := ark.APIType(apiType)
|
||||
cfg.APIType = &apiTypeVal
|
||||
}
|
||||
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
|
||||
case providerOpenai:
|
||||
chatModelConfig := &openai.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
|
||||
case providerDashscope:
|
||||
cfg := &dashscope.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
|
||||
}
|
||||
if err != nil {
|
||||
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, span := jaeger.NewSpan(ctx, "InitAllVector")
|
||||
defer span.End()
|
||||
InitAllVector(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
return EmbedderArk.EmbedStrings(ctx, texts)
|
||||
case providerOpenai:
|
||||
return EmbedderOpenAI.EmbedStrings(ctx, texts)
|
||||
case providerDashscope:
|
||||
return EmbedderDashscope.EmbedStrings(ctx, texts)
|
||||
// ===================== 1. 服务启动时调用:初始化所有租户 =====================
|
||||
func InitAllVector(ctx context.Context) {
|
||||
//list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
|
||||
// ModelType: model.ModelTypeVector.Code(),
|
||||
//})
|
||||
//if err != nil {
|
||||
// g.Log().Errorf(ctx, "获取所有租户ID失败: %v", err)
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//for _, l := range list {
|
||||
// err = InitVector(ctx, l)
|
||||
// if err != nil {
|
||||
// g.Log().Errorf(ctx, "初始化租户[%v]的向量模型失败: %v", l.TenantId, err)
|
||||
// continue
|
||||
// }
|
||||
//}
|
||||
modelDO := new(entity.Model)
|
||||
modelDO.TenantId = 1
|
||||
modelDO.ConfigType = model.ModelConfigTypeVectorDashScope.Code()
|
||||
var cfg entity.VectorModelConfigDashScope
|
||||
cfg.APIKey = "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
cfg.Model = "text-embedding-v3"
|
||||
modelDO.ConfigContent = gconv.Map(&cfg)
|
||||
err := InitVector(ctx, modelDO)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "初始化向量模型失败: %v", err)
|
||||
return
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported provider: %v", provider)
|
||||
}
|
||||
|
||||
func InitVector(ctx context.Context, modelDO *entity.Model) (err error) {
|
||||
set := &EmbedderSet{}
|
||||
switch *modelDO.ConfigType {
|
||||
case *model.ModelConfigTypeVectorArk.Code():
|
||||
// 解析 Ark 向量配置
|
||||
var cfg entity.VectorModelConfigArk
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析Ark向量配置失败: %v", err)
|
||||
}
|
||||
arkCfg := &ark.EmbeddingConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
}
|
||||
if !g.IsEmpty(cfg.APIType) {
|
||||
arkCfg.APIType = new(ark.APIType(cfg.APIType))
|
||||
}
|
||||
set.Ark, err = ark.NewEmbedder(ctx, arkCfg)
|
||||
|
||||
case *model.ModelConfigTypeVectorOllama.Code():
|
||||
// 解析 Ollama 向量配置
|
||||
var cfg entity.VectorModelConfigOllama
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析Ollama向量配置失败: %v", err)
|
||||
}
|
||||
set.Ollama, err = ollama.NewEmbedder(ctx, &ollama.EmbeddingConfig{
|
||||
BaseURL: cfg.BaseURL,
|
||||
Model: cfg.Model,
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeVectorOpenAI.Code():
|
||||
// 解析 OpenAI 向量配置
|
||||
var cfg entity.VectorModelConfigOpenAI
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析OpenAI向量配置失败: %v", err)
|
||||
}
|
||||
openaiCfg := &openai.EmbeddingConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
ByAzure: cfg.ByAzure,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIVersion: cfg.APIVersion,
|
||||
}
|
||||
set.OpenAI, err = openai.NewEmbedder(ctx, openaiCfg)
|
||||
|
||||
case *model.ModelConfigTypeVectorQianfan.Code():
|
||||
// 解析 千帆 向量配置
|
||||
var cfg entity.VectorModelConfigQianfan
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析千帆向量配置失败: %v", err)
|
||||
}
|
||||
qcfg := qianfan.GetQianfanSingletonConfig()
|
||||
qcfg.AccessKey = cfg.AccessKey
|
||||
qcfg.SecretKey = cfg.SecretKey
|
||||
set.Qianfan, err = qianfan.NewEmbedder(ctx, &qianfan.EmbeddingConfig{
|
||||
Model: cfg.Model,
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeVectorTencentCloud.Code():
|
||||
// 解析 腾讯云 向量配置
|
||||
var cfg entity.VectorModelConfigTencentCloud
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析腾讯云向量配置失败: %v", err)
|
||||
}
|
||||
set.TencentCloud, err = tencentcloud.NewEmbedder(ctx, &tencentcloud.EmbeddingConfig{
|
||||
SecretID: cfg.SecretID,
|
||||
SecretKey: cfg.SecretKey,
|
||||
Region: cfg.Region,
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeVectorDashScope.Code():
|
||||
// 解析 阿里 dashscope 向量配置
|
||||
var cfg entity.VectorModelConfigDashScope
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析阿里dashscope向量配置失败: %v", err)
|
||||
}
|
||||
set.DashScope, err = dashscope.NewEmbedder(ctx, &dashscope.EmbeddingConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
})
|
||||
|
||||
default:
|
||||
return fmt.Errorf("不支持的向量模型配置类型: %v", *modelDO.ConfigType)
|
||||
}
|
||||
|
||||
// 统一错误处理
|
||||
if err != nil {
|
||||
return fmt.Errorf("初始化向量模型失败: %v", err)
|
||||
}
|
||||
// 直接存入 map(无锁,重复初始化会直接覆盖)
|
||||
tenantEmbedders[modelDO.TenantId] = set
|
||||
g.Log().Infof(ctx, "向量模型[%v]初始化成功", modelDO.ConfigType)
|
||||
return
|
||||
}
|
||||
|
||||
func GetTenantEmbedder(tenantId uint64) (*EmbedderSet, error) {
|
||||
set := tenantEmbedders[tenantId]
|
||||
if set == nil {
|
||||
return nil, fmt.Errorf("租户[%v]的向量模型未初始化", tenantId)
|
||||
}
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func GetTenantEmbedderByType(ctx context.Context, configType model.ModelConfigType) (embedding.Embedder, error) {
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
set, err := GetTenantEmbedder(userInfo.TenantId)
|
||||
if set == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch *configType {
|
||||
case *model.ModelConfigTypeVectorArk.Code():
|
||||
return set.Ark, nil
|
||||
case *model.ModelConfigTypeVectorOllama.Code():
|
||||
return set.Ollama, nil
|
||||
case *model.ModelConfigTypeVectorOpenAI.Code():
|
||||
return set.OpenAI, nil
|
||||
case *model.ModelConfigTypeVectorQianfan.Code():
|
||||
return set.Qianfan, nil
|
||||
case *model.ModelConfigTypeVectorTencentCloud.Code():
|
||||
return set.TencentCloud, nil
|
||||
case *model.ModelConfigTypeVectorDashScope.Code():
|
||||
return set.DashScope, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的向量模型配置类型: %v", *configType)
|
||||
}
|
||||
}
|
||||
|
||||
func RefreshTenantEmbedder(ctx context.Context, modelDO *entity.Model) error {
|
||||
delete(tenantEmbedders, modelDO.TenantId)
|
||||
return InitVector(ctx, modelDO)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user