feat: 优化RAG检索与聊天模型支持历史对话

实现双路检索并行优化,使用EINO官方模板重构聊天逻辑,增加多轮对话历史记录管理及相关性过滤,并修复数据库唯一索引。
This commit is contained in:
2026-04-09 13:57:46 +08:00
parent 14a429f4ae
commit 2ced0a43e5
9 changed files with 310 additions and 147 deletions

View File

@@ -12,10 +12,10 @@ import (
"rag/model/dto"
"rag/model/entity"
"strings"
"time"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/full-text-search/meilisearch"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/utils"
gmq "github.com/bjang03/gmq/core/gmq"
"github.com/bjang03/gmq/mq"
@@ -159,12 +159,16 @@ func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentR
if err != nil {
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
// ======================
// 核心grpool + g.Try 最佳实践
// ======================
taskCtx, cancel := context.WithCancel(ctx)
// 使用带超时的background context避免HTTP请求完成后context被取消
taskCtx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
taskCtx = context.WithValue(taskCtx, "user", user)
// 任务1: SQL 切分文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
@@ -655,23 +659,12 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
// 调用接口获取数据
d := &dto.ListDocumentChunkRPC{}
if err = http.Get(ctx, "rag-vector/document/chunk/listDocumentChunk", headers, &d,
"datasetId", gconv.String(doc.DatasetId),
"status", 1); err != nil {
return
}
dictData = d.List
res, _, err := dao.DocumentChunk.List(ctx, &dto.ListDocumentChunkReq{
DatasetId: doc.DatasetId,
Status: gconv.PtrInt8(1),
})
err = gconv.Struct(res, &dictData)
return
}

View File

@@ -8,6 +8,7 @@ import (
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/beans"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
@@ -48,7 +49,10 @@ func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err e
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
return
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId]),
UserName: gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator]),
})
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
BatchSize: 10,
})

View File

@@ -7,7 +7,9 @@ import (
"rag/model/dto"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/gconv"
)
var RAGQuery = new(ragQueryService)
@@ -39,14 +41,20 @@ func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto
return nil, fmt.Errorf("向量检索失败: %w", err)
}
replyMsg, sources, err := eino.NewChatModel(ctx, req.Content, docs)
messages := make([]*schema.Message, 0)
err = gconv.Struct(req.History, &messages)
if err != nil {
glog.Errorf(ctx, "转换历史消息失败: %v", err)
return nil, fmt.Errorf("转换历史消息失败: %w", err)
}
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages)
if err != nil {
glog.Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
return &dto.RAGQueryRes{
Answer: replyMsg.Content,
Sources: sources,
Answer: replyMsg.Content,
}, nil
}

View File

@@ -7,6 +7,8 @@ import (
"rag/common/task"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -24,17 +26,20 @@ func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskP
g.Log().Errorf(ctx, "查询任务失败: %v", err)
return err
}
taskVO := make([]dto.TaskVO, 0, total)
err = gconv.Struct(t, taskVO)
if err != nil {
g.Log().Errorf(ctx, "转换任务失败: %v", err)
return err
completed := false
if total != 0 {
taskVO := make([]*dto.TaskVO, 0, total)
err = gconv.Struct(t, &taskVO)
if err != nil {
g.Log().Errorf(ctx, "转换任务失败: %v", err)
return err
}
taskVO = append(taskVO, &dto.TaskVO{
TaskType: req.TaskType,
Status: req.Status,
})
completed = IsAllSubTasksCompleted(taskVO)
}
taskVO = append(taskVO, dto.TaskVO{
TaskType: req.TaskType,
Status: req.Status,
})
completed := IsAllSubTasksCompleted(taskVO)
// 1. 查询是否已存在该文档的该类型任务
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
@@ -45,44 +50,48 @@ func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskP
g.Log().Errorf(ctx, "查询任务失败: %v", err)
return err
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
// 2. 如果不存在,则创建新任务
if g.IsEmpty(existTask) {
createReq := &dto.CreateTaskReq{
TaskId: req.TaskId,
TaskType: req.TaskType,
Status: req.Status,
Remark: req.Remark,
}
_, err = dao.Task.Insert(ctx, createReq)
} else {
// 3. 如果已存在,则更新任务
updateReq := &dto.UpdateTaskReq{
Id: existTask[0].Id,
Status: req.Status,
Remark: req.Remark,
}
_, err = dao.Task.Update(ctx, updateReq)
if err != nil {
g.Log().Errorf(ctx, "更新任务失败: %v", err)
return err
}
}
// 2. 如果不存在,则创建新任务
if g.IsEmpty(existTask) {
createReq := &dto.CreateTaskReq{
TaskId: req.TaskId,
TaskType: req.TaskType,
Status: req.Status,
Remark: req.Remark,
if completed {
// 3. 如果已存在,则更新任务
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
TaskId: req.TaskId,
Status: task.TaskStatusCompleted,
Remark: "文档解析完成",
})
}
_, err = dao.Task.Insert(ctx, createReq)
} else {
// 3. 如果已存在,则更新任务
updateReq := &dto.UpdateTaskReq{
Id: existTask[0].Id,
Status: req.Status,
Remark: req.Remark,
}
_, err = dao.Task.Update(ctx, updateReq)
if err != nil {
g.Log().Errorf(ctx, "更新任务失败: %v", err)
return err
}
}
if completed {
// 3. 如果已存在,则更新任务
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
TaskId: req.TaskId,
Status: task.TaskStatusCompleted,
})
}
return nil
})
return
}
// IsAllSubTasksCompleted 判断三个子任务是否全部完成
// 参数:传入当前文档的所有子任务列表
func IsAllSubTasksCompleted(subTasks []dto.TaskVO) bool {
func IsAllSubTasksCompleted(subTasks []*dto.TaskVO) bool {
// 必须包含 3 种任务类型
hasKeywords := false
hasVector := false