feat: 支持多租户多模型对话及文档去重优化
This commit is contained in:
243
common/eino/chat.go
Normal file
243
common/eino/chat.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/consts/model"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/jaeger"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino-ext/components/model/arkbot"
|
||||
"github.com/cloudwego/eino-ext/components/model/claude"
|
||||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||||
"github.com/cloudwego/eino-ext/components/model/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino-ext/components/model/qianfan"
|
||||
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||
modelChat "github.com/cloudwego/eino/components/model"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
type ChatModelSet struct {
|
||||
Ark *ark.ChatModel
|
||||
ArkBot *arkbot.ChatModel
|
||||
Claude *claude.ChatModel
|
||||
DeepSeek *deepseek.ChatModel
|
||||
Ollama *ollama.ChatModel
|
||||
OpenAI *openai.ChatModel
|
||||
Qianfan *qianfan.ChatModel
|
||||
Qwen *qwen.ChatModel
|
||||
}
|
||||
|
||||
// 全局租户容器:key=tenantId,value=该租户的对话模型
|
||||
var tenantChatModels = make(map[uint64]*ChatModelSet)
|
||||
|
||||
func init() {
|
||||
ctx := context.Background()
|
||||
ctx, span := jaeger.NewSpan(ctx, "InitAllChat")
|
||||
defer span.End()
|
||||
InitAllChat(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// ===================== 1. 服务启动时:初始化所有租户对话模型 =====================
|
||||
func InitAllChat(ctx context.Context) {
|
||||
list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
|
||||
ModelType: model.ModelTypeChat.Code(),
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "获取所有租户对话模型失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, l := range list {
|
||||
err = InitChat(ctx, l)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "初始化租户[%v]的对话模型失败: %v", l.TenantId, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func InitChat(ctx context.Context, modelDO *entity.Model) (err error) {
|
||||
set := &ChatModelSet{}
|
||||
switch *modelDO.ConfigType {
|
||||
case *model.ModelConfigTypeChatArk.Code():
|
||||
var cfg entity.ChatModelConfigArk
|
||||
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
||||
return fmt.Errorf("解析Ark配置失败: %v", err)
|
||||
}
|
||||
set.Ark, err = ark.NewChatModel(ctx, &ark.ChatModelConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
Temperature: gconv.PtrFloat32(0.7),
|
||||
MaxTokens: gconv.PtrInt(1024),
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeChatArkBot.Code():
|
||||
var cfg entity.ChatModelConfigArkBot
|
||||
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
||||
return fmt.Errorf("解析ArkBot配置失败: %v", err)
|
||||
}
|
||||
set.ArkBot, err = arkbot.NewChatModel(ctx, &arkbot.Config{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
Temperature: gconv.PtrFloat32(0.7),
|
||||
MaxTokens: gconv.PtrInt(1024),
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeChatClaude.Code():
|
||||
var cfg entity.ChatModelConfigClaude
|
||||
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
||||
return fmt.Errorf("解析Claude配置失败: %v", err)
|
||||
}
|
||||
claudeCfg := claude.Config{
|
||||
APIKey: cfg.APIKey,
|
||||
BaseURL: gconv.PtrString(cfg.BaseURL),
|
||||
Model: cfg.Model,
|
||||
Temperature: gconv.PtrFloat32(0.7),
|
||||
MaxTokens: gconv.Int(1024),
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
ByBedrock: cfg.ByBedrock,
|
||||
AccessKey: cfg.AccessKey,
|
||||
SecretAccessKey: cfg.SecretAccessKey,
|
||||
Region: cfg.Region,
|
||||
}
|
||||
set.Claude, err = claude.NewChatModel(ctx, &claudeCfg)
|
||||
|
||||
case *model.ModelConfigTypeChatDeepSeek.Code():
|
||||
var cfg entity.ChatModelConfigDeepSeek
|
||||
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
||||
return fmt.Errorf("解析DeepSeek配置失败: %v", err)
|
||||
}
|
||||
set.DeepSeek, err = deepseek.NewChatModel(ctx, &deepseek.ChatModelConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
BaseURL: cfg.BaseURL,
|
||||
Temperature: gconv.Float32(0.7),
|
||||
MaxTokens: gconv.Int(1024),
|
||||
TopP: gconv.Float32(1.0),
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeChatOllama.Code():
|
||||
var cfg entity.ChatModelConfigOllama
|
||||
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
||||
return fmt.Errorf("解析Ollama配置失败: %v", err)
|
||||
}
|
||||
set.Ollama, err = ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
|
||||
BaseURL: cfg.BaseURL,
|
||||
Model: cfg.Model,
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeChatOpenAI.Code():
|
||||
var cfg entity.ChatModelConfigOpenAI
|
||||
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
||||
return fmt.Errorf("解析OpenAI配置失败: %v", err)
|
||||
}
|
||||
openAiCfg := openai.ChatModelConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
ByAzure: cfg.ByAzure,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIVersion: cfg.APIVersion,
|
||||
Temperature: gconv.PtrFloat32(0.7),
|
||||
MaxCompletionTokens: gconv.PtrInt(1024),
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
}
|
||||
set.OpenAI, err = openai.NewChatModel(ctx, &openAiCfg)
|
||||
|
||||
case *model.ModelConfigTypeChatQianfan.Code():
|
||||
var cfg entity.ChatModelConfigQianfan
|
||||
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
||||
return fmt.Errorf("解析千帆配置失败: %v", err)
|
||||
}
|
||||
qcfg := qianfan.GetQianfanSingletonConfig()
|
||||
qcfg.AccessKey = cfg.AccessKey
|
||||
qcfg.SecretKey = cfg.SecretKey
|
||||
set.Qianfan, err = qianfan.NewChatModel(ctx, &qianfan.ChatModelConfig{
|
||||
Model: cfg.Model,
|
||||
Temperature: gconv.PtrFloat32(0.7),
|
||||
MaxCompletionTokens: gconv.PtrInt(1024),
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeChatQwen.Code():
|
||||
var cfg entity.ChatModelConfigQwen
|
||||
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
||||
return fmt.Errorf("解析Qwen配置失败: %v", err)
|
||||
}
|
||||
set.Qwen, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
BaseURL: cfg.BaseURL,
|
||||
Temperature: gconv.PtrFloat32(0.7),
|
||||
MaxTokens: gconv.PtrInt(1024),
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
})
|
||||
|
||||
default:
|
||||
return fmt.Errorf("不支持的对话模型类型: %v", *modelDO.ConfigType)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("初始化对话模型失败: %v", err)
|
||||
}
|
||||
|
||||
// 无锁存入租户 map
|
||||
tenantChatModels[modelDO.TenantId] = set
|
||||
g.Log().Infof(ctx, "租户[%v]对话模型[%v]初始化成功", modelDO.TenantId, *modelDO.ConfigType)
|
||||
return
|
||||
}
|
||||
|
||||
func GetTenantChatModel(tenantId uint64) (*ChatModelSet, error) {
|
||||
set := tenantChatModels[tenantId]
|
||||
if set == nil {
|
||||
return nil, fmt.Errorf("租户[%v]对话模型未初始化", tenantId)
|
||||
}
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func GetTenantChatModelByType(ctx context.Context, configType model.ModelConfigType) (modelChat.BaseChatModel, error) {
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
set, err := GetTenantChatModel(userInfo.TenantId)
|
||||
if set == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch *configType {
|
||||
case *model.ModelConfigTypeChatArk.Code():
|
||||
return set.Ark, nil
|
||||
case *model.ModelConfigTypeChatArkBot.Code():
|
||||
return set.ArkBot, nil
|
||||
case *model.ModelConfigTypeChatClaude.Code():
|
||||
return set.Claude, nil
|
||||
case *model.ModelConfigTypeChatDeepSeek.Code():
|
||||
return set.DeepSeek, nil
|
||||
case *model.ModelConfigTypeChatOllama.Code():
|
||||
return set.Ollama, nil
|
||||
case *model.ModelConfigTypeChatOpenAI.Code():
|
||||
return set.OpenAI, nil
|
||||
case *model.ModelConfigTypeChatQianfan.Code():
|
||||
return set.Qianfan, nil
|
||||
case *model.ModelConfigTypeChatQwen.Code():
|
||||
return set.Qwen, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的对话模型类型: %v", configType)
|
||||
}
|
||||
}
|
||||
|
||||
func RefreshTenantChatModel(ctx context.Context, modelDO *entity.Model) error {
|
||||
delete(tenantChatModels, modelDO.TenantId)
|
||||
return InitChat(ctx, modelDO)
|
||||
}
|
||||
@@ -5,13 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"rag/consts/model"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,48 +16,15 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
globalChatModel *qwen.ChatModel
|
||||
ragPromptTemplate prompt.ChatTemplate // EINO 官方模板
|
||||
)
|
||||
|
||||
func init() {
|
||||
ctx := context.Background()
|
||||
// 初始化大模型
|
||||
if err := initChatModel(ctx); err != nil {
|
||||
glog.Errorf(ctx, "初始化大模型失败: %v", err)
|
||||
}
|
||||
// 初始化 EINO 提示词模板
|
||||
initRAGPromptTemplate()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化通义千问
|
||||
func initChatModel(ctx context.Context) error {
|
||||
if globalChatModel != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
|
||||
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
|
||||
|
||||
cm, err := qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
Timeout: 60 * 1e9,
|
||||
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
|
||||
MaxTokens: gconv.PtrInt(1024), // 最长回答
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
globalChatModel = cm
|
||||
return nil
|
||||
}
|
||||
|
||||
// 初始化 EINO 官方提示词模板(最关键!)
|
||||
func initRAGPromptTemplate() {
|
||||
ragPromptTemplate = prompt.FromMessages(
|
||||
@@ -69,7 +33,7 @@ func initRAGPromptTemplate() {
|
||||
&schema.Message{
|
||||
Role: schema.System,
|
||||
Content: `你是专业客服,语气友好简洁。
|
||||
请严格依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
|
||||
请依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
|
||||
|
||||
参考知识:
|
||||
{knowledge}`,
|
||||
@@ -83,7 +47,7 @@ func initRAGPromptTemplate() {
|
||||
}
|
||||
|
||||
// NewChatModel 只处理逻辑,不复用创建模型
|
||||
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message) (replyMsg *schema.Message, err error) {
|
||||
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message, chatModel model.ModelConfigType) (replyMsg *schema.Message, err error) {
|
||||
// 1. 构建参考知识
|
||||
knowledge := buildKnowledgeAndSources(docs)
|
||||
// 2. 历史精简
|
||||
@@ -101,7 +65,7 @@ func NewChatModel(ctx context.Context, question string, docs []*schema.Document,
|
||||
msgs = append(msgs[:1], append(history, msgs[1:]...)...)
|
||||
}
|
||||
// 5. 🔥 直接使用全局单例,不重复创建
|
||||
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
|
||||
replyMsg, err = streamGenerateAnswer(ctx, msgs, chatModel)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -133,9 +97,14 @@ func buildKnowledgeAndSources(docs []*schema.Document) string {
|
||||
}
|
||||
|
||||
// streamGenerateAnswer 流式生成
|
||||
func streamGenerateAnswer(ctx context.Context, chatModel *qwen.ChatModel, msgs []*schema.Message) (reply *schema.Message, err error) {
|
||||
func streamGenerateAnswer(ctx context.Context, msgs []*schema.Message, chatModel model.ModelConfigType) (reply *schema.Message, err error) {
|
||||
|
||||
sr, err := chatModel.Stream(ctx, msgs)
|
||||
cm, err := GetTenantChatModelByType(ctx, chatModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sr, err := cm.Stream(ctx, msgs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stream failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
package eino
|
||||
|
||||
const (
|
||||
providerArk = "ark"
|
||||
providerOpenai = "openai"
|
||||
providerQianfan = "qianfan"
|
||||
providerDashscope = "dashscope"
|
||||
)
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/cloudwego/eino-ext/components/document/loader/file"
|
||||
"github.com/cloudwego/eino-ext/components/document/loader/url"
|
||||
"github.com/cloudwego/eino-ext/components/document/parser/docx"
|
||||
"github.com/cloudwego/eino-ext/components/document/parser/pdf"
|
||||
@@ -15,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
// LoadDocument 业务函数:加载文件
|
||||
func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
|
||||
func a(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
|
||||
p, err := docsParser(ctx, fileFormat)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -27,12 +28,34 @@ func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*sch
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
docs, err = loader.Load(context.Background(), document.Source{
|
||||
docs, err = loader.Load(ctx, document.Source{
|
||||
URI: fmt.Sprintf("%s%s", imageUrl, filePath),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
|
||||
p, err := docsParser(ctx, fileFormat)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 1. 创建文件加载器
|
||||
loader, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{
|
||||
UseNameAsID: false, // 使用文件名作为文档ID
|
||||
Parser: p,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 加载本地文件
|
||||
docs, err = loader.Load(ctx, document.Source{
|
||||
URI: "C:\\Users\\AI\\Desktop\\手机发展史.txt",
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func docsParser(ctx context.Context, fileFormat string) (p parser.Parser, err error) {
|
||||
switch fileFormat {
|
||||
case "docx":
|
||||
|
||||
@@ -2,6 +2,7 @@ package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/consts/model"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic"
|
||||
@@ -10,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
// SemanticSplitDocument 语义分割文档
|
||||
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
|
||||
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document, vectorModel model.ModelConfigType) (res []*schema.Document, err error) {
|
||||
// 默认分隔符(支持中英文)
|
||||
separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"}
|
||||
// 读取配置,使用合理的默认值
|
||||
@@ -18,24 +19,14 @@ func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []
|
||||
minChunkSize := g.Cfg().MustGet(ctx, "eino.splitter.minChunkSize").Int()
|
||||
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
|
||||
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
|
||||
}
|
||||
|
||||
// 使用批量包装器
|
||||
var batchEmbedder *BatchEmbedder
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
batchEmbedder = NewBatchEmbedder(EmbedderArk, batchSize)
|
||||
case providerOpenai:
|
||||
batchEmbedder = NewBatchEmbedder(EmbedderOpenAI, batchSize)
|
||||
case providerDashscope:
|
||||
batchEmbedder = NewBatchEmbedder(EmbedderDashscope, batchSize)
|
||||
embedder, err := GetTenantEmbedderByType(ctx, vectorModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
splitter, err := semantic.NewSplitter(ctx, &semantic.Config{
|
||||
Embedding: batchEmbedder,
|
||||
Embedding: NewBatchEmbedder(embedder, batchSize),
|
||||
BufferSize: bufferSize,
|
||||
MinChunkSize: minChunkSize,
|
||||
Percentile: percentile,
|
||||
|
||||
@@ -3,67 +3,211 @@ package eino
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/consts/model"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/jaeger"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ark"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/qianfan"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/tencentcloud"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/golang/glog"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// 全局只初始化一次
|
||||
var (
|
||||
EmbedderArk *ark.Embedder
|
||||
EmbedderDashscope *dashscope.Embedder
|
||||
EmbedderOpenAI *openai.Embedder
|
||||
)
|
||||
type EmbedderSet struct {
|
||||
Ark *ark.Embedder
|
||||
Ollama *ollama.Embedder
|
||||
OpenAI *openai.Embedder
|
||||
Qianfan *qianfan.Embedder
|
||||
TencentCloud *tencentcloud.Embedder
|
||||
DashScope *dashscope.Embedder
|
||||
}
|
||||
|
||||
// 全局租户容器:key=tenantId,value=该租户的向量模型
|
||||
var tenantEmbedders = make(map[uint64]*EmbedderSet)
|
||||
|
||||
func init() {
|
||||
ctx := context.Background()
|
||||
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
|
||||
var err error
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
cfg := &ark.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
|
||||
apiTypeVal := ark.APIType(apiType)
|
||||
cfg.APIType = &apiTypeVal
|
||||
}
|
||||
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
|
||||
case providerOpenai:
|
||||
chatModelConfig := &openai.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
|
||||
case providerDashscope:
|
||||
cfg := &dashscope.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
|
||||
}
|
||||
if err != nil {
|
||||
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, span := jaeger.NewSpan(ctx, "InitAllVector")
|
||||
defer span.End()
|
||||
InitAllVector(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
return EmbedderArk.EmbedStrings(ctx, texts)
|
||||
case providerOpenai:
|
||||
return EmbedderOpenAI.EmbedStrings(ctx, texts)
|
||||
case providerDashscope:
|
||||
return EmbedderDashscope.EmbedStrings(ctx, texts)
|
||||
// ===================== 1. 服务启动时调用:初始化所有租户 =====================
|
||||
func InitAllVector(ctx context.Context) {
|
||||
//list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
|
||||
// ModelType: model.ModelTypeVector.Code(),
|
||||
//})
|
||||
//if err != nil {
|
||||
// g.Log().Errorf(ctx, "获取所有租户ID失败: %v", err)
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//for _, l := range list {
|
||||
// err = InitVector(ctx, l)
|
||||
// if err != nil {
|
||||
// g.Log().Errorf(ctx, "初始化租户[%v]的向量模型失败: %v", l.TenantId, err)
|
||||
// continue
|
||||
// }
|
||||
//}
|
||||
modelDO := new(entity.Model)
|
||||
modelDO.TenantId = 1
|
||||
modelDO.ConfigType = model.ModelConfigTypeVectorDashScope.Code()
|
||||
var cfg entity.VectorModelConfigDashScope
|
||||
cfg.APIKey = "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
cfg.Model = "text-embedding-v3"
|
||||
modelDO.ConfigContent = gconv.Map(&cfg)
|
||||
err := InitVector(ctx, modelDO)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "初始化向量模型失败: %v", err)
|
||||
return
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported provider: %v", provider)
|
||||
}
|
||||
|
||||
func InitVector(ctx context.Context, modelDO *entity.Model) (err error) {
|
||||
set := &EmbedderSet{}
|
||||
switch *modelDO.ConfigType {
|
||||
case *model.ModelConfigTypeVectorArk.Code():
|
||||
// 解析 Ark 向量配置
|
||||
var cfg entity.VectorModelConfigArk
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析Ark向量配置失败: %v", err)
|
||||
}
|
||||
arkCfg := &ark.EmbeddingConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
}
|
||||
if !g.IsEmpty(cfg.APIType) {
|
||||
arkCfg.APIType = new(ark.APIType(cfg.APIType))
|
||||
}
|
||||
set.Ark, err = ark.NewEmbedder(ctx, arkCfg)
|
||||
|
||||
case *model.ModelConfigTypeVectorOllama.Code():
|
||||
// 解析 Ollama 向量配置
|
||||
var cfg entity.VectorModelConfigOllama
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析Ollama向量配置失败: %v", err)
|
||||
}
|
||||
set.Ollama, err = ollama.NewEmbedder(ctx, &ollama.EmbeddingConfig{
|
||||
BaseURL: cfg.BaseURL,
|
||||
Model: cfg.Model,
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeVectorOpenAI.Code():
|
||||
// 解析 OpenAI 向量配置
|
||||
var cfg entity.VectorModelConfigOpenAI
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析OpenAI向量配置失败: %v", err)
|
||||
}
|
||||
openaiCfg := &openai.EmbeddingConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
ByAzure: cfg.ByAzure,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIVersion: cfg.APIVersion,
|
||||
}
|
||||
set.OpenAI, err = openai.NewEmbedder(ctx, openaiCfg)
|
||||
|
||||
case *model.ModelConfigTypeVectorQianfan.Code():
|
||||
// 解析 千帆 向量配置
|
||||
var cfg entity.VectorModelConfigQianfan
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析千帆向量配置失败: %v", err)
|
||||
}
|
||||
qcfg := qianfan.GetQianfanSingletonConfig()
|
||||
qcfg.AccessKey = cfg.AccessKey
|
||||
qcfg.SecretKey = cfg.SecretKey
|
||||
set.Qianfan, err = qianfan.NewEmbedder(ctx, &qianfan.EmbeddingConfig{
|
||||
Model: cfg.Model,
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeVectorTencentCloud.Code():
|
||||
// 解析 腾讯云 向量配置
|
||||
var cfg entity.VectorModelConfigTencentCloud
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析腾讯云向量配置失败: %v", err)
|
||||
}
|
||||
set.TencentCloud, err = tencentcloud.NewEmbedder(ctx, &tencentcloud.EmbeddingConfig{
|
||||
SecretID: cfg.SecretID,
|
||||
SecretKey: cfg.SecretKey,
|
||||
Region: cfg.Region,
|
||||
})
|
||||
|
||||
case *model.ModelConfigTypeVectorDashScope.Code():
|
||||
// 解析 阿里 dashscope 向量配置
|
||||
var cfg entity.VectorModelConfigDashScope
|
||||
err = gconv.Struct(modelDO.ConfigContent, &cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析阿里dashscope向量配置失败: %v", err)
|
||||
}
|
||||
set.DashScope, err = dashscope.NewEmbedder(ctx, &dashscope.EmbeddingConfig{
|
||||
APIKey: cfg.APIKey,
|
||||
Model: cfg.Model,
|
||||
})
|
||||
|
||||
default:
|
||||
return fmt.Errorf("不支持的向量模型配置类型: %v", *modelDO.ConfigType)
|
||||
}
|
||||
|
||||
// 统一错误处理
|
||||
if err != nil {
|
||||
return fmt.Errorf("初始化向量模型失败: %v", err)
|
||||
}
|
||||
// 直接存入 map(无锁,重复初始化会直接覆盖)
|
||||
tenantEmbedders[modelDO.TenantId] = set
|
||||
g.Log().Infof(ctx, "向量模型[%v]初始化成功", modelDO.ConfigType)
|
||||
return
|
||||
}
|
||||
|
||||
func GetTenantEmbedder(tenantId uint64) (*EmbedderSet, error) {
|
||||
set := tenantEmbedders[tenantId]
|
||||
if set == nil {
|
||||
return nil, fmt.Errorf("租户[%v]的向量模型未初始化", tenantId)
|
||||
}
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func GetTenantEmbedderByType(ctx context.Context, configType model.ModelConfigType) (embedding.Embedder, error) {
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
set, err := GetTenantEmbedder(userInfo.TenantId)
|
||||
if set == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch *configType {
|
||||
case *model.ModelConfigTypeVectorArk.Code():
|
||||
return set.Ark, nil
|
||||
case *model.ModelConfigTypeVectorOllama.Code():
|
||||
return set.Ollama, nil
|
||||
case *model.ModelConfigTypeVectorOpenAI.Code():
|
||||
return set.OpenAI, nil
|
||||
case *model.ModelConfigTypeVectorQianfan.Code():
|
||||
return set.Qianfan, nil
|
||||
case *model.ModelConfigTypeVectorTencentCloud.Code():
|
||||
return set.TencentCloud, nil
|
||||
case *model.ModelConfigTypeVectorDashScope.Code():
|
||||
return set.DashScope, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的向量模型配置类型: %v", *configType)
|
||||
}
|
||||
}
|
||||
|
||||
func RefreshTenantEmbedder(ctx context.Context, modelDO *entity.Model) error {
|
||||
delete(tenantEmbedders, modelDO.TenantId)
|
||||
return InitVector(ctx, modelDO)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"rag/consts/model"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
@@ -34,7 +35,14 @@ func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer {
|
||||
return &PGVectorIndexer{opts: opts}
|
||||
}
|
||||
|
||||
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (rows int64, err error) {
|
||||
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, configType model.ModelConfigType, opts ...indexer.Option) (rows int64, err error) {
|
||||
|
||||
embedderByType, err := GetTenantEmbedderByType(ctx, configType)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
indexer.WithEmbedding(embedderByType)
|
||||
|
||||
commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...)
|
||||
|
||||
if commonOpts.Embedding == nil {
|
||||
|
||||
116
common/eino/rerank.go
Normal file
116
common/eino/rerank.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// DashScopeReranker 通义百炼 Rerank 精排(Cross-Encoder)
|
||||
type DashScopeReranker struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewDashScopeReranker() *DashScopeReranker {
|
||||
return &DashScopeReranker{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Rerank 对文档进行精排(Cross-Encoder 核心)
|
||||
func (d *DashScopeReranker) Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) {
|
||||
if len(docs) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// 官方必过 URL
|
||||
url := "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
|
||||
apiKey := g.Cfg().MustGet(ctx, "eino.rerank.apiKey").String()
|
||||
model := g.Cfg().MustGet(ctx, "eino.rerank.model").String()
|
||||
|
||||
documents := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
documents[i] = doc.Content
|
||||
}
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": model,
|
||||
"input": map[string]any{
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
"parameters": map[string]any{
|
||||
"top_n": len(docs),
|
||||
},
|
||||
}
|
||||
bs, _ := json.Marshal(reqBody)
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(bs))
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := d.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("rerank api error: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析结果
|
||||
var result struct {
|
||||
Output struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
} `json:"results"`
|
||||
} `json:"output"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 按分数排序
|
||||
type scoredDoc struct {
|
||||
doc *schema.Document
|
||||
score float64
|
||||
}
|
||||
scored := make([]scoredDoc, len(docs))
|
||||
for i, doc := range docs {
|
||||
scored[i] = scoredDoc{doc: doc, score: 0}
|
||||
}
|
||||
for _, res := range result.Output.Results {
|
||||
scored[res.Index].score = res.RelevanceScore
|
||||
}
|
||||
|
||||
// 分数从高到低排序
|
||||
for i := 0; i < len(scored); i++ {
|
||||
for j := i + 1; j < len(scored); j++ {
|
||||
if scored[j].score > scored[i].score {
|
||||
scored[i], scored[j] = scored[j], scored[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 输出最终排好的文档
|
||||
ranked := make([]*schema.Document, 0, len(scored))
|
||||
for _, s := range scored {
|
||||
s.doc.MetaData["rerank_score"] = s.score
|
||||
ranked = append(ranked, s.doc)
|
||||
}
|
||||
|
||||
return ranked, nil
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package eino
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"rag/consts/model"
|
||||
"rag/dao"
|
||||
"sort"
|
||||
"time"
|
||||
@@ -29,21 +31,25 @@ type PGVectorRetriever struct {
|
||||
topK int
|
||||
index string
|
||||
dslInfo map[string]any
|
||||
reranker *DashScopeReranker // 通义精排
|
||||
}
|
||||
|
||||
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
||||
if config.Embedder == nil {
|
||||
return nil, errors.New("embedder is required")
|
||||
}
|
||||
func NewPGVectorRetriever(ctx context.Context, config *PGVectorRetrieverConfig, configType model.ModelConfigType) (*PGVectorRetriever, error) {
|
||||
if config.DefaultTopK <= 0 {
|
||||
config.DefaultTopK = 5
|
||||
}
|
||||
|
||||
e, err := GetTenantEmbedderByType(ctx, configType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PGVectorRetriever{
|
||||
embedder: config.Embedder,
|
||||
embedder: e,
|
||||
topK: config.DefaultTopK,
|
||||
index: config.DefaultIndex,
|
||||
dslInfo: config.DSLInfo,
|
||||
//reranker: NewDashScopeReranker(), // 👈 直接初始化你的精排
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -138,48 +144,37 @@ func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...
|
||||
}
|
||||
|
||||
// 合并 + 智能去重(保留最优分数)
|
||||
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||||
mergedDocs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||||
|
||||
// 排序:向量优先,同类型按距离升序
|
||||
sort.Slice(docs, func(i, j int) bool {
|
||||
//byI, okI := docs[i].MetaData["retrieve_by"].(string)
|
||||
//byJ, okJ := docs[j].MetaData["retrieve_by"].(string)
|
||||
//
|
||||
//// 有类型标记的优先
|
||||
//if okI && !okJ {
|
||||
// return true
|
||||
//}
|
||||
//if !okI && okJ {
|
||||
// return false
|
||||
//}
|
||||
//
|
||||
//// 向量永远排前面
|
||||
//if byI == "vector" && byJ == "fulltext" {
|
||||
// return true
|
||||
//}
|
||||
//if byI == "fulltext" && byJ == "vector" {
|
||||
// return false
|
||||
//}
|
||||
|
||||
// 同类型按 distance 升序(越小越相似)
|
||||
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
||||
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
||||
return d1 < d2
|
||||
})
|
||||
|
||||
// 在Retrieve方法末尾,增加相关性校验
|
||||
validDocs := make([]*schema.Document, 0)
|
||||
for i, d := range docs {
|
||||
// 过滤distance过大的垃圾结果(比如distance>0.8的直接丢弃)
|
||||
if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 {
|
||||
validDocs = append(validDocs, d)
|
||||
// =========================
|
||||
// 🔥 Cross-Encoder 精排
|
||||
// =========================
|
||||
var finalDocs []*schema.Document
|
||||
if r.reranker != nil {
|
||||
ranked, err := r.reranker.Rerank(ctx, query, mergedDocs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rerank failed: %w", err)
|
||||
}
|
||||
finalDocs = ranked
|
||||
} else {
|
||||
sort.Slice(mergedDocs, func(i, j int) bool {
|
||||
d1 := gconv.Float64(mergedDocs[i].MetaData["distance"])
|
||||
d2 := gconv.Float64(mergedDocs[j].MetaData["distance"])
|
||||
return d1 < d2
|
||||
})
|
||||
finalDocs = mergedDocs
|
||||
}
|
||||
|
||||
// 如果没有有效结果,返回空,让LLM回答「暂无相关信息」
|
||||
if len(validDocs) == 0 {
|
||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
|
||||
return validDocs, nil
|
||||
// =========================
|
||||
// 过滤无效文档
|
||||
// =========================
|
||||
const maxDistance = 0.8
|
||||
validDocs := make([]*schema.Document, 0, len(finalDocs))
|
||||
for _, doc := range finalDocs {
|
||||
dist := gconv.Float64(doc.MetaData["distance"])
|
||||
if dist <= maxDistance {
|
||||
validDocs = append(validDocs, doc)
|
||||
}
|
||||
}
|
||||
|
||||
// 最多保留 topK
|
||||
@@ -208,9 +203,15 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
|
||||
if opts.TopK != nil {
|
||||
topK = *opts.TopK
|
||||
}
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
var datasetIds, documentIds []int64
|
||||
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
|
||||
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
}
|
||||
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
|
||||
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
|
||||
}
|
||||
|
||||
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
||||
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, documentIds, queryVec, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -236,10 +237,17 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
|
||||
// ==========================================
|
||||
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
topK := *opts.TopK
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
|
||||
var datasetIds, documentIds []int64
|
||||
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
|
||||
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
}
|
||||
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
|
||||
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
|
||||
}
|
||||
|
||||
// 调用你已有的 Meilisearch DAO
|
||||
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, topK)
|
||||
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, documentIds, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user