feat: 支持多租户多模型对话及文档去重优化

This commit is contained in:
2026-04-16 15:47:37 +08:00
parent 4ead3f82cf
commit 27b1dd3c27
34 changed files with 2188 additions and 315 deletions

243
common/eino/chat.go Normal file
View File

@@ -0,0 +1,243 @@
package eino
import (
"context"
"fmt"
"rag/consts/model"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/jaeger"
"gitea.com/red-future/common/utils"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/arkbot"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino-ext/components/model/ollama"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino-ext/components/model/qianfan"
"github.com/cloudwego/eino-ext/components/model/qwen"
modelChat "github.com/cloudwego/eino/components/model"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
type ChatModelSet struct {
Ark *ark.ChatModel
ArkBot *arkbot.ChatModel
Claude *claude.ChatModel
DeepSeek *deepseek.ChatModel
Ollama *ollama.ChatModel
OpenAI *openai.ChatModel
Qianfan *qianfan.ChatModel
Qwen *qwen.ChatModel
}
// 全局租户容器key=tenantIdvalue=该租户的对话模型
var tenantChatModels = make(map[uint64]*ChatModelSet)
func init() {
ctx := context.Background()
ctx, span := jaeger.NewSpan(ctx, "InitAllChat")
defer span.End()
InitAllChat(ctx)
return
}
// ===================== 1. 服务启动时:初始化所有租户对话模型 =====================
func InitAllChat(ctx context.Context) {
list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
ModelType: model.ModelTypeChat.Code(),
})
if err != nil {
g.Log().Errorf(ctx, "获取所有租户对话模型失败: %v", err)
return
}
for _, l := range list {
err = InitChat(ctx, l)
if err != nil {
g.Log().Errorf(ctx, "初始化租户[%v]的对话模型失败: %v", l.TenantId, err)
continue
}
}
}
func InitChat(ctx context.Context, modelDO *entity.Model) (err error) {
set := &ChatModelSet{}
switch *modelDO.ConfigType {
case *model.ModelConfigTypeChatArk.Code():
var cfg entity.ChatModelConfigArk
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Ark配置失败: %v", err)
}
set.Ark, err = ark.NewChatModel(ctx, &ark.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
case *model.ModelConfigTypeChatArkBot.Code():
var cfg entity.ChatModelConfigArkBot
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析ArkBot配置失败: %v", err)
}
set.ArkBot, err = arkbot.NewChatModel(ctx, &arkbot.Config{
APIKey: cfg.APIKey,
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
case *model.ModelConfigTypeChatClaude.Code():
var cfg entity.ChatModelConfigClaude
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Claude配置失败: %v", err)
}
claudeCfg := claude.Config{
APIKey: cfg.APIKey,
BaseURL: gconv.PtrString(cfg.BaseURL),
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.Int(1024),
TopP: gconv.PtrFloat32(1.0),
ByBedrock: cfg.ByBedrock,
AccessKey: cfg.AccessKey,
SecretAccessKey: cfg.SecretAccessKey,
Region: cfg.Region,
}
set.Claude, err = claude.NewChatModel(ctx, &claudeCfg)
case *model.ModelConfigTypeChatDeepSeek.Code():
var cfg entity.ChatModelConfigDeepSeek
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析DeepSeek配置失败: %v", err)
}
set.DeepSeek, err = deepseek.NewChatModel(ctx, &deepseek.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
BaseURL: cfg.BaseURL,
Temperature: gconv.Float32(0.7),
MaxTokens: gconv.Int(1024),
TopP: gconv.Float32(1.0),
})
case *model.ModelConfigTypeChatOllama.Code():
var cfg entity.ChatModelConfigOllama
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Ollama配置失败: %v", err)
}
set.Ollama, err = ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
BaseURL: cfg.BaseURL,
Model: cfg.Model,
})
case *model.ModelConfigTypeChatOpenAI.Code():
var cfg entity.ChatModelConfigOpenAI
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析OpenAI配置失败: %v", err)
}
openAiCfg := openai.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
ByAzure: cfg.ByAzure,
BaseURL: cfg.BaseURL,
APIVersion: cfg.APIVersion,
Temperature: gconv.PtrFloat32(0.7),
MaxCompletionTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
}
set.OpenAI, err = openai.NewChatModel(ctx, &openAiCfg)
case *model.ModelConfigTypeChatQianfan.Code():
var cfg entity.ChatModelConfigQianfan
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析千帆配置失败: %v", err)
}
qcfg := qianfan.GetQianfanSingletonConfig()
qcfg.AccessKey = cfg.AccessKey
qcfg.SecretKey = cfg.SecretKey
set.Qianfan, err = qianfan.NewChatModel(ctx, &qianfan.ChatModelConfig{
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxCompletionTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
case *model.ModelConfigTypeChatQwen.Code():
var cfg entity.ChatModelConfigQwen
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Qwen配置失败: %v", err)
}
set.Qwen, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
BaseURL: cfg.BaseURL,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
default:
return fmt.Errorf("不支持的对话模型类型: %v", *modelDO.ConfigType)
}
if err != nil {
return fmt.Errorf("初始化对话模型失败: %v", err)
}
// 无锁存入租户 map
tenantChatModels[modelDO.TenantId] = set
g.Log().Infof(ctx, "租户[%v]对话模型[%v]初始化成功", modelDO.TenantId, *modelDO.ConfigType)
return
}
func GetTenantChatModel(tenantId uint64) (*ChatModelSet, error) {
set := tenantChatModels[tenantId]
if set == nil {
return nil, fmt.Errorf("租户[%v]对话模型未初始化", tenantId)
}
return set, nil
}
func GetTenantChatModelByType(ctx context.Context, configType model.ModelConfigType) (modelChat.BaseChatModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
set, err := GetTenantChatModel(userInfo.TenantId)
if set == nil {
return nil, err
}
switch *configType {
case *model.ModelConfigTypeChatArk.Code():
return set.Ark, nil
case *model.ModelConfigTypeChatArkBot.Code():
return set.ArkBot, nil
case *model.ModelConfigTypeChatClaude.Code():
return set.Claude, nil
case *model.ModelConfigTypeChatDeepSeek.Code():
return set.DeepSeek, nil
case *model.ModelConfigTypeChatOllama.Code():
return set.Ollama, nil
case *model.ModelConfigTypeChatOpenAI.Code():
return set.OpenAI, nil
case *model.ModelConfigTypeChatQianfan.Code():
return set.Qianfan, nil
case *model.ModelConfigTypeChatQwen.Code():
return set.Qwen, nil
default:
return nil, fmt.Errorf("不支持的对话模型类型: %v", configType)
}
}
func RefreshTenantChatModel(ctx context.Context, modelDO *entity.Model) error {
delete(tenantChatModels, modelDO.TenantId)
return InitChat(ctx, modelDO)
}

View File

@@ -5,13 +5,10 @@ import (
"errors"
"fmt"
"io"
"rag/consts/model"
"github.com/cloudwego/eino-ext/components/model/qwen"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/gconv"
)
const (
@@ -19,48 +16,15 @@ const (
)
var (
globalChatModel *qwen.ChatModel
ragPromptTemplate prompt.ChatTemplate // EINO 官方模板
)
func init() {
ctx := context.Background()
// 初始化大模型
if err := initChatModel(ctx); err != nil {
glog.Errorf(ctx, "初始化大模型失败: %v", err)
}
// 初始化 EINO 提示词模板
initRAGPromptTemplate()
return
}
// 初始化通义千问
func initChatModel(ctx context.Context) error {
if globalChatModel != nil {
return nil
}
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
cm, err := qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
APIKey: apiKey,
Model: model,
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
Timeout: 60 * 1e9,
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
MaxTokens: gconv.PtrInt(1024), // 最长回答
TopP: gconv.PtrFloat32(1.0),
})
if err != nil {
return err
}
globalChatModel = cm
return nil
}
// 初始化 EINO 官方提示词模板(最关键!)
func initRAGPromptTemplate() {
ragPromptTemplate = prompt.FromMessages(
@@ -69,7 +33,7 @@ func initRAGPromptTemplate() {
&schema.Message{
Role: schema.System,
Content: `你是专业客服,语气友好简洁。
严格依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
请依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
参考知识:
{knowledge}`,
@@ -83,7 +47,7 @@ func initRAGPromptTemplate() {
}
// NewChatModel 只处理逻辑,不复用创建模型
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message) (replyMsg *schema.Message, err error) {
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message, chatModel model.ModelConfigType) (replyMsg *schema.Message, err error) {
// 1. 构建参考知识
knowledge := buildKnowledgeAndSources(docs)
// 2. 历史精简
@@ -101,7 +65,7 @@ func NewChatModel(ctx context.Context, question string, docs []*schema.Document,
msgs = append(msgs[:1], append(history, msgs[1:]...)...)
}
// 5. 🔥 直接使用全局单例,不重复创建
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
replyMsg, err = streamGenerateAnswer(ctx, msgs, chatModel)
return
}
@@ -133,9 +97,14 @@ func buildKnowledgeAndSources(docs []*schema.Document) string {
}
// streamGenerateAnswer 流式生成
func streamGenerateAnswer(ctx context.Context, chatModel *qwen.ChatModel, msgs []*schema.Message) (reply *schema.Message, err error) {
func streamGenerateAnswer(ctx context.Context, msgs []*schema.Message, chatModel model.ModelConfigType) (reply *schema.Message, err error) {
sr, err := chatModel.Stream(ctx, msgs)
cm, err := GetTenantChatModelByType(ctx, chatModel)
if err != nil {
return nil, err
}
sr, err := cm.Stream(ctx, msgs)
if err != nil {
return nil, fmt.Errorf("stream failed: %w", err)
}

View File

@@ -1,8 +0,0 @@
package eino
const (
providerArk = "ark"
providerOpenai = "openai"
providerQianfan = "qianfan"
providerDashscope = "dashscope"
)

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"gitea.com/red-future/common/utils"
"github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino-ext/components/document/loader/url"
"github.com/cloudwego/eino-ext/components/document/parser/docx"
"github.com/cloudwego/eino-ext/components/document/parser/pdf"
@@ -15,7 +16,7 @@ import (
)
// LoadDocument 业务函数:加载文件
func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
func a(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
p, err := docsParser(ctx, fileFormat)
if err != nil {
return
@@ -27,12 +28,34 @@ func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*sch
if err != nil {
return
}
docs, err = loader.Load(context.Background(), document.Source{
docs, err = loader.Load(ctx, document.Source{
URI: fmt.Sprintf("%s%s", imageUrl, filePath),
})
return
}
func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
p, err := docsParser(ctx, fileFormat)
if err != nil {
return
}
// 1. 创建文件加载器
loader, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{
UseNameAsID: false, // 使用文件名作为文档ID
Parser: p,
})
if err != nil {
return
}
// 2. 加载本地文件
docs, err = loader.Load(ctx, document.Source{
URI: "C:\\Users\\AI\\Desktop\\手机发展史.txt",
})
return
}
func docsParser(ctx context.Context, fileFormat string) (p parser.Parser, err error) {
switch fileFormat {
case "docx":

View File

@@ -2,6 +2,7 @@ package eino
import (
"context"
"rag/consts/model"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic"
@@ -10,7 +11,7 @@ import (
)
// SemanticSplitDocument 语义分割文档
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document, vectorModel model.ModelConfigType) (res []*schema.Document, err error) {
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
// 读取配置,使用合理的默认值
@@ -18,24 +19,14 @@ func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []
minChunkSize := g.Cfg().MustGet(ctx, "eino.splitter.minChunkSize").Int()
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
if batchSize <= 0 {
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
}
// 使用批量包装器
var batchEmbedder *BatchEmbedder
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
batchEmbedder = NewBatchEmbedder(EmbedderArk, batchSize)
case providerOpenai:
batchEmbedder = NewBatchEmbedder(EmbedderOpenAI, batchSize)
case providerDashscope:
batchEmbedder = NewBatchEmbedder(EmbedderDashscope, batchSize)
embedder, err := GetTenantEmbedderByType(ctx, vectorModel)
if err != nil {
return nil, err
}
splitter, err := semantic.NewSplitter(ctx, &semantic.Config{
Embedding: batchEmbedder,
Embedding: NewBatchEmbedder(embedder, batchSize),
BufferSize: bufferSize,
MinChunkSize: minChunkSize,
Percentile: percentile,

View File

@@ -3,67 +3,211 @@ package eino
import (
"context"
"fmt"
"rag/consts/model"
"rag/model/entity"
"gitea.com/red-future/common/jaeger"
"gitea.com/red-future/common/utils"
"github.com/cloudwego/eino-ext/components/embedding/ark"
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino-ext/components/embedding/qianfan"
"github.com/cloudwego/eino-ext/components/embedding/tencentcloud"
"github.com/cloudwego/eino/components/embedding"
"github.com/gogf/gf/v2/frame/g"
"github.com/golang/glog"
"github.com/gogf/gf/v2/util/gconv"
)
// 全局只初始化一次
var (
EmbedderArk *ark.Embedder
EmbedderDashscope *dashscope.Embedder
EmbedderOpenAI *openai.Embedder
)
type EmbedderSet struct {
Ark *ark.Embedder
Ollama *ollama.Embedder
OpenAI *openai.Embedder
Qianfan *qianfan.Embedder
TencentCloud *tencentcloud.Embedder
DashScope *dashscope.Embedder
}
// 全局租户容器key=tenantIdvalue=该租户的向量模型
var tenantEmbedders = make(map[uint64]*EmbedderSet)
func init() {
ctx := context.Background()
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
var err error
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
cfg := &ark.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
apiTypeVal := ark.APIType(apiType)
cfg.APIType = &apiTypeVal
}
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
case providerOpenai:
chatModelConfig := &openai.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
case providerDashscope:
cfg := &dashscope.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
}
if err != nil {
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
}
}
ctx, span := jaeger.NewSpan(ctx, "InitAllVector")
defer span.End()
InitAllVector(ctx)
return
}
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
return EmbedderArk.EmbedStrings(ctx, texts)
case providerOpenai:
return EmbedderOpenAI.EmbedStrings(ctx, texts)
case providerDashscope:
return EmbedderDashscope.EmbedStrings(ctx, texts)
// ===================== 1. 服务启动时调用:初始化所有租户 =====================
func InitAllVector(ctx context.Context) {
//list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
// ModelType: model.ModelTypeVector.Code(),
//})
//if err != nil {
// g.Log().Errorf(ctx, "获取所有租户ID失败: %v", err)
// return
//}
//
//for _, l := range list {
// err = InitVector(ctx, l)
// if err != nil {
// g.Log().Errorf(ctx, "初始化租户[%v]的向量模型失败: %v", l.TenantId, err)
// continue
// }
//}
modelDO := new(entity.Model)
modelDO.TenantId = 1
modelDO.ConfigType = model.ModelConfigTypeVectorDashScope.Code()
var cfg entity.VectorModelConfigDashScope
cfg.APIKey = "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
cfg.Model = "text-embedding-v3"
modelDO.ConfigContent = gconv.Map(&cfg)
err := InitVector(ctx, modelDO)
if err != nil {
g.Log().Errorf(ctx, "初始化向量模型失败: %v", err)
return
}
return nil, fmt.Errorf("unsupported provider: %v", provider)
}
func InitVector(ctx context.Context, modelDO *entity.Model) (err error) {
set := &EmbedderSet{}
switch *modelDO.ConfigType {
case *model.ModelConfigTypeVectorArk.Code():
// 解析 Ark 向量配置
var cfg entity.VectorModelConfigArk
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析Ark向量配置失败: %v", err)
}
arkCfg := &ark.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
}
if !g.IsEmpty(cfg.APIType) {
arkCfg.APIType = new(ark.APIType(cfg.APIType))
}
set.Ark, err = ark.NewEmbedder(ctx, arkCfg)
case *model.ModelConfigTypeVectorOllama.Code():
// 解析 Ollama 向量配置
var cfg entity.VectorModelConfigOllama
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析Ollama向量配置失败: %v", err)
}
set.Ollama, err = ollama.NewEmbedder(ctx, &ollama.EmbeddingConfig{
BaseURL: cfg.BaseURL,
Model: cfg.Model,
})
case *model.ModelConfigTypeVectorOpenAI.Code():
// 解析 OpenAI 向量配置
var cfg entity.VectorModelConfigOpenAI
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析OpenAI向量配置失败: %v", err)
}
openaiCfg := &openai.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
ByAzure: cfg.ByAzure,
BaseURL: cfg.BaseURL,
APIVersion: cfg.APIVersion,
}
set.OpenAI, err = openai.NewEmbedder(ctx, openaiCfg)
case *model.ModelConfigTypeVectorQianfan.Code():
// 解析 千帆 向量配置
var cfg entity.VectorModelConfigQianfan
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析千帆向量配置失败: %v", err)
}
qcfg := qianfan.GetQianfanSingletonConfig()
qcfg.AccessKey = cfg.AccessKey
qcfg.SecretKey = cfg.SecretKey
set.Qianfan, err = qianfan.NewEmbedder(ctx, &qianfan.EmbeddingConfig{
Model: cfg.Model,
})
case *model.ModelConfigTypeVectorTencentCloud.Code():
// 解析 腾讯云 向量配置
var cfg entity.VectorModelConfigTencentCloud
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析腾讯云向量配置失败: %v", err)
}
set.TencentCloud, err = tencentcloud.NewEmbedder(ctx, &tencentcloud.EmbeddingConfig{
SecretID: cfg.SecretID,
SecretKey: cfg.SecretKey,
Region: cfg.Region,
})
case *model.ModelConfigTypeVectorDashScope.Code():
// 解析 阿里 dashscope 向量配置
var cfg entity.VectorModelConfigDashScope
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析阿里dashscope向量配置失败: %v", err)
}
set.DashScope, err = dashscope.NewEmbedder(ctx, &dashscope.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
})
default:
return fmt.Errorf("不支持的向量模型配置类型: %v", *modelDO.ConfigType)
}
// 统一错误处理
if err != nil {
return fmt.Errorf("初始化向量模型失败: %v", err)
}
// 直接存入 map无锁重复初始化会直接覆盖
tenantEmbedders[modelDO.TenantId] = set
g.Log().Infof(ctx, "向量模型[%v]初始化成功", modelDO.ConfigType)
return
}
func GetTenantEmbedder(tenantId uint64) (*EmbedderSet, error) {
set := tenantEmbedders[tenantId]
if set == nil {
return nil, fmt.Errorf("租户[%v]的向量模型未初始化", tenantId)
}
return set, nil
}
func GetTenantEmbedderByType(ctx context.Context, configType model.ModelConfigType) (embedding.Embedder, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
set, err := GetTenantEmbedder(userInfo.TenantId)
if set == nil {
return nil, err
}
switch *configType {
case *model.ModelConfigTypeVectorArk.Code():
return set.Ark, nil
case *model.ModelConfigTypeVectorOllama.Code():
return set.Ollama, nil
case *model.ModelConfigTypeVectorOpenAI.Code():
return set.OpenAI, nil
case *model.ModelConfigTypeVectorQianfan.Code():
return set.Qianfan, nil
case *model.ModelConfigTypeVectorTencentCloud.Code():
return set.TencentCloud, nil
case *model.ModelConfigTypeVectorDashScope.Code():
return set.DashScope, nil
default:
return nil, fmt.Errorf("不支持的向量模型配置类型: %v", *configType)
}
}
func RefreshTenantEmbedder(ctx context.Context, modelDO *entity.Model) error {
delete(tenantEmbedders, modelDO.TenantId)
return InitVector(ctx, modelDO)
}

View File

@@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"rag/consts/model"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
@@ -34,7 +35,14 @@ func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer {
return &PGVectorIndexer{opts: opts}
}
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (rows int64, err error) {
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, configType model.ModelConfigType, opts ...indexer.Option) (rows int64, err error) {
embedderByType, err := GetTenantEmbedderByType(ctx, configType)
if err != nil {
return
}
indexer.WithEmbedding(embedderByType)
commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...)
if commonOpts.Embedding == nil {

116
common/eino/rerank.go Normal file
View File

@@ -0,0 +1,116 @@
package eino
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
)
// DashScopeReranker 通义百炼 Rerank 精排Cross-Encoder
type DashScopeReranker struct {
httpClient *http.Client
}
func NewDashScopeReranker() *DashScopeReranker {
return &DashScopeReranker{
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// Rerank 对文档进行精排Cross-Encoder 核心)
func (d *DashScopeReranker) Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) {
if len(docs) == 0 {
return docs, nil
}
// 官方必过 URL
url := "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
apiKey := g.Cfg().MustGet(ctx, "eino.rerank.apiKey").String()
model := g.Cfg().MustGet(ctx, "eino.rerank.model").String()
documents := make([]string, len(docs))
for i, doc := range docs {
documents[i] = doc.Content
}
reqBody := map[string]any{
"model": model,
"input": map[string]any{
"query": query,
"documents": documents,
},
"parameters": map[string]any{
"top_n": len(docs),
},
}
bs, _ := json.Marshal(reqBody)
req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(bs))
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := d.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("rerank api error: status=%d, body=%s", resp.StatusCode, string(body))
}
// 解析结果
var result struct {
Output struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
} `json:"results"`
} `json:"output"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
// 按分数排序
type scoredDoc struct {
doc *schema.Document
score float64
}
scored := make([]scoredDoc, len(docs))
for i, doc := range docs {
scored[i] = scoredDoc{doc: doc, score: 0}
}
for _, res := range result.Output.Results {
scored[res.Index].score = res.RelevanceScore
}
// 分数从高到低排序
for i := 0; i < len(scored); i++ {
for j := i + 1; j < len(scored); j++ {
if scored[j].score > scored[i].score {
scored[i], scored[j] = scored[j], scored[i]
}
}
}
// 输出最终排好的文档
ranked := make([]*schema.Document, 0, len(scored))
for _, s := range scored {
s.doc.MetaData["rerank_score"] = s.score
ranked = append(ranked, s.doc)
}
return ranked, nil
}

View File

@@ -3,6 +3,8 @@ package eino
import (
"context"
"errors"
"fmt"
"rag/consts/model"
"rag/dao"
"sort"
"time"
@@ -29,21 +31,25 @@ type PGVectorRetriever struct {
topK int
index string
dslInfo map[string]any
reranker *DashScopeReranker // 通义精排
}
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
if config.Embedder == nil {
return nil, errors.New("embedder is required")
}
func NewPGVectorRetriever(ctx context.Context, config *PGVectorRetrieverConfig, configType model.ModelConfigType) (*PGVectorRetriever, error) {
if config.DefaultTopK <= 0 {
config.DefaultTopK = 5
}
e, err := GetTenantEmbedderByType(ctx, configType)
if err != nil {
return nil, err
}
return &PGVectorRetriever{
embedder: config.Embedder,
embedder: e,
topK: config.DefaultTopK,
index: config.DefaultIndex,
dslInfo: config.DSLInfo,
//reranker: NewDashScopeReranker(), // 👈 直接初始化你的精排
}, nil
}
@@ -138,48 +144,37 @@ func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...
}
// 合并 + 智能去重(保留最优分数)
docs := mergeAndDeduplicate(docsVector, docsFulltext)
mergedDocs := mergeAndDeduplicate(docsVector, docsFulltext)
// 排序:向量优先,同类型按距离升序
sort.Slice(docs, func(i, j int) bool {
//byI, okI := docs[i].MetaData["retrieve_by"].(string)
//byJ, okJ := docs[j].MetaData["retrieve_by"].(string)
//
//// 有类型标记的优先
//if okI && !okJ {
// return true
//}
//if !okI && okJ {
// return false
//}
//
//// 向量永远排前面
//if byI == "vector" && byJ == "fulltext" {
// return true
//}
//if byI == "fulltext" && byJ == "vector" {
// return false
//}
// 同类型按 distance 升序(越小越相似)
d1 := gconv.Float64(docs[i].MetaData["distance"])
d2 := gconv.Float64(docs[j].MetaData["distance"])
return d1 < d2
})
// 在Retrieve方法末尾增加相关性校验
validDocs := make([]*schema.Document, 0)
for i, d := range docs {
// 过滤distance过大的垃圾结果比如distance>0.8的直接丢弃)
if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 {
validDocs = append(validDocs, d)
// =========================
// 🔥 Cross-Encoder 精排
// =========================
var finalDocs []*schema.Document
if r.reranker != nil {
ranked, err := r.reranker.Rerank(ctx, query, mergedDocs)
if err != nil {
return nil, fmt.Errorf("rerank failed: %w", err)
}
finalDocs = ranked
} else {
sort.Slice(mergedDocs, func(i, j int) bool {
d1 := gconv.Float64(mergedDocs[i].MetaData["distance"])
d2 := gconv.Float64(mergedDocs[j].MetaData["distance"])
return d1 < d2
})
finalDocs = mergedDocs
}
// 如果没有有效结果返回空让LLM回答「暂无相关信息」
if len(validDocs) == 0 {
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
return validDocs, nil
// =========================
// 过滤无效文档
// =========================
const maxDistance = 0.8
validDocs := make([]*schema.Document, 0, len(finalDocs))
for _, doc := range finalDocs {
dist := gconv.Float64(doc.MetaData["distance"])
if dist <= maxDistance {
validDocs = append(validDocs, doc)
}
}
// 最多保留 topK
@@ -208,9 +203,15 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
if opts.TopK != nil {
topK = *opts.TopK
}
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
var datasetIds, documentIds []int64
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
}
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
}
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, queryVec, topK)
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, documentIds, queryVec, topK)
if err != nil {
return nil, err
}
@@ -236,10 +237,17 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
// ==========================================
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
topK := *opts.TopK
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
var datasetIds, documentIds []int64
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
}
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
}
// 调用你已有的 Meilisearch DAO
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, topK)
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, documentIds, topK)
if err != nil {
return nil, err
}