refactor: 重构文档处理流程和任务管理
This commit is contained in:
@@ -5,17 +5,14 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"rag/common/eino"
|
||||
"rag/common/gse"
|
||||
"rag/common/task"
|
||||
"rag/consts/document"
|
||||
"rag/consts/public"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||
"gitea.com/red-future/common/http"
|
||||
@@ -29,6 +26,7 @@ import (
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/database/gredis"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/grpool"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
@@ -54,7 +52,13 @@ func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq
|
||||
return
|
||||
}
|
||||
res = &dto.CreateDocumentRes{Id: id}
|
||||
|
||||
// 写入任务进度待处理 任务类型为文档解析
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: id,
|
||||
TaskType: task.TaskTypeDocParse,
|
||||
Status: task.TaskStatusPending,
|
||||
Remark: "文档上传成功待解析: " + req.Title,
|
||||
})
|
||||
return
|
||||
})
|
||||
|
||||
@@ -79,11 +83,20 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq
|
||||
DocumentCount: -1,
|
||||
DocumentSize: -docs.FileSize,
|
||||
}
|
||||
_, err = dao.Dataset.Update(ctx, datasetReq)
|
||||
if err != nil {
|
||||
if _, err = dao.Dataset.Update(ctx, datasetReq); err != nil {
|
||||
return
|
||||
}
|
||||
_, err = dao.Document.Delete(ctx, req)
|
||||
|
||||
if _, err = dao.Document.Delete(ctx, req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{
|
||||
TaskId: docs.Id,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
})
|
||||
|
||||
@@ -107,118 +120,159 @@ func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (r
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
|
||||
//eino.TestIndexer()
|
||||
//eino.TestRetriever()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Process 处理文件(使用eino框架切分和向量化)
|
||||
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (err error) {
|
||||
// 1. 查询文件信息
|
||||
documentReq := dto.GetDocumentReq{Id: req.Id}
|
||||
doc, err := dao.Document.GetByID(ctx, &documentReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
if g.IsEmpty(doc) {
|
||||
return nil, errors.New("document not found")
|
||||
return errors.New("document not found")
|
||||
}
|
||||
|
||||
// 2. 使用eino框架进行文件切分(并发执行)
|
||||
var vectorDocsCount, chunks int64
|
||||
// 用 gopool 或者简单的错误等待,绝对不用裸 goroutine
|
||||
var err1, err2, err3 error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(3)
|
||||
|
||||
// 任务1
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 任务2
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err2 = s.esSplitDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 任务3
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err3 = s.extractDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 直接等待,不使用通道,避免泄漏
|
||||
wg.Wait()
|
||||
|
||||
// 2. 更新文档状态为处理中
|
||||
updateDocumentReq := new(dto.UpdateDocumentReq)
|
||||
updateDocumentReq.Id = req.Id
|
||||
|
||||
// 统一判断错误
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
// 更新文档状态
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code()
|
||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
return nil, err3
|
||||
}
|
||||
|
||||
// 4. 更新文件状态为处理中和切分数量
|
||||
if vectorDocsCount > 0 {
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
||||
} else {
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code()
|
||||
}
|
||||
updateDocumentReq.ChunkCount = chunks
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||||
// 写入任务进度失败 任务类型为文档解析
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: req.Id,
|
||||
TaskType: task.TaskTypeDocParse,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "更新文档状态失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
costTime := time.Since(startTime).Milliseconds()
|
||||
|
||||
return &dto.ProcessDocumentRes{
|
||||
ChunkCount: chunks,
|
||||
CostTime: costTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
// 写入任务进度进行中 任务类型为文档解析
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: req.Id,
|
||||
TaskType: task.TaskTypeDocParse,
|
||||
Status: task.TaskStatusRunning,
|
||||
Remark: "文档解析开始",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var words []gse.Keyword
|
||||
// ======================
|
||||
// 核心:grpool + g.Try 最佳实践
|
||||
// ======================
|
||||
taskCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
// 任务1: SQL 切分文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.sqlSplitDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
cancel()
|
||||
})
|
||||
})
|
||||
|
||||
// 任务2: ES 切分文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.esSplitDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
cancel()
|
||||
})
|
||||
})
|
||||
|
||||
// 任务3: 提取文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.extractDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
cancel()
|
||||
})
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractDocument 关键词提取(支持取消)
|
||||
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
// 写入任务进度失败 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "加载文件失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var words []utils.Keyword
|
||||
if len(docs[0].Content) < 500 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 4)
|
||||
words = utils.GseTool.Extract(docs[0].Content, 4)
|
||||
} else if len(docs[0].Content) < 2000 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 8)
|
||||
words = utils.GseTool.Extract(docs[0].Content, 8)
|
||||
} else if len(docs[0].Content) < 5000 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 13)
|
||||
words = utils.GseTool.Extract(docs[0].Content, 13)
|
||||
} else {
|
||||
var docsSplit []*schema.Document
|
||||
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "递归分割文档失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// ========== 取消检查 2:循环内部 ==========
|
||||
for _, t := range docsSplit {
|
||||
words = append(words, gse.GseTool.Extract(t.Content, 6)...)
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
words = append(words, utils.GseTool.Extract(t.Content, 6)...)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 取消检查 3:批量操作前 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
|
||||
for _, word := range words {
|
||||
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
|
||||
@@ -231,37 +285,111 @@ func (s *documentService) extractDocument(ctx context.Context, doc *entity.Docum
|
||||
if len(keywordReqs) > 0 {
|
||||
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "关键字存储失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度已完成 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "关键字提取完成",
|
||||
})
|
||||
} else {
|
||||
// 写入任务进度已完成 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "没有提取到关键词,关键字提取完成",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) {
|
||||
// sqlSplitDocument SQL切分(支持取消)
|
||||
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "加载文件失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 语义切分文件
|
||||
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "文档切分失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
docsSplitCount = gconv.Int64(len(docsSplit))
|
||||
|
||||
// 2. 获取历史数据
|
||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "获取历史数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 组装向量文档
|
||||
var docsChunk = make([]*schema.Document, 0)
|
||||
for i, t := range docsSplit {
|
||||
// ========== 取消检查 2:循环内部 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
// 检查是否重复
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "检查重复数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
@@ -277,6 +405,18 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
||||
t.MetaData = metaData
|
||||
docsChunk = append(docsChunk, t)
|
||||
}
|
||||
|
||||
// ========== 取消检查 3:批量发送前 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 4. 发送消息到队列
|
||||
if len(docsChunk) > 0 {
|
||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||
@@ -285,41 +425,117 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
||||
Data: docsChunk,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "发送消息到队列失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度进行中 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusRunning,
|
||||
Remark: "向量生成任务已提交到队列",
|
||||
})
|
||||
} else {
|
||||
// 写入任务进度已完成 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "无需生成向量,任务完成",
|
||||
})
|
||||
}
|
||||
vectorDocsCount = gconv.Int64(len(docsChunk))
|
||||
return
|
||||
}
|
||||
|
||||
// esSplitDocument ES切分(支持取消)
|
||||
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "加载文件失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 递归切分文件
|
||||
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "文档切分失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 获取历史数据
|
||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "获取历史数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 组装向量文档并同时构建meilisearch文档
|
||||
var meiliDocs = make([]interface{}, 0)
|
||||
for i, t := range docsSplit {
|
||||
// ========== 取消检查 2:循环内部 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
// 检查是否重复
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "检查重复数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
continue
|
||||
}
|
||||
// 构建Meilisearch文档
|
||||
meiliDocs = append(meiliDocs, map[string]interface{}{
|
||||
entity.DocumentChunkCol.Id: contentHash,
|
||||
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
|
||||
@@ -329,12 +545,45 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum
|
||||
entity.DocumentChunkCol.ChunkIndex: i,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== 取消检查 3:批量写入前 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 4. 写入到meilisearch数据库中
|
||||
if len(meiliDocs) > 0 {
|
||||
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
|
||||
g.Log().Errorf(ctx, "写入meilisearch失败: %v", err)
|
||||
// 写入任务进度失败 任务类型为meilisearch存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "写入meilisearch失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度已完成 任务类型为meilisearch存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "全文检索数据写入完成",
|
||||
})
|
||||
} else {
|
||||
// 写入任务进度已完成 任务类型为meilisearch存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "无需生成全文检索数据,任务完成",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -467,20 +716,3 @@ func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHa
|
||||
success = val.Bool()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) {
|
||||
var req = new(dto.KnowledgeDocumentMsg)
|
||||
if err = gconv.Struct(msg, &req); err != nil {
|
||||
g.Log().Error(ctx, "DocsVectorStatusMsg err:", err)
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||
TenantId: req.TenantId,
|
||||
UserName: req.Creator,
|
||||
})
|
||||
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
|
||||
Id: req.Id,
|
||||
VectorStatus: req.VectorStatus,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,15 +3,11 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"rag/common/eino"
|
||||
"rag/consts/document"
|
||||
"rag/consts/public"
|
||||
"rag/common/task"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
gmq "github.com/bjang03/gmq/core/gmq"
|
||||
"github.com/bjang03/gmq/mq"
|
||||
"github.com/bjang03/gmq/types"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
@@ -22,10 +18,6 @@ var DocumentChunk = new(documentChunkService)
|
||||
|
||||
type documentChunkService struct{}
|
||||
|
||||
const (
|
||||
DatasetIndexStatusReady = "ready"
|
||||
)
|
||||
|
||||
// Update 更新文件块
|
||||
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
|
||||
_, err = dao.DocumentChunk.Update(ctx, req)
|
||||
@@ -60,32 +52,29 @@ func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err e
|
||||
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||||
BatchSize: 10,
|
||||
})
|
||||
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
|
||||
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
|
||||
if err != nil || rows == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
remark := " 向量存储数量: " + gconv.String(rows)
|
||||
if err != nil {
|
||||
remark = "向量存储失败: " + err.Error()
|
||||
}
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: remark,
|
||||
})
|
||||
return
|
||||
}
|
||||
tenantId := gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId])
|
||||
creator := gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator])
|
||||
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, tenantId, creator, documentId, document.VectorStatusCompleted.Code())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// publishKnowledgeDocumentMsg 发布消息
|
||||
func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {
|
||||
knowledgeDocumentMsg := dto.KnowledgeDocumentMsg{
|
||||
TenantId: tenantId,
|
||||
Creator: creator,
|
||||
Id: documentId,
|
||||
VectorStatus: vectorStatus,
|
||||
}
|
||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||
PubMessage: types.PubMessage{
|
||||
Topic: public.KnowledgeDocumentVectorStatusTopic,
|
||||
Data: knowledgeDocumentMsg,
|
||||
},
|
||||
// 写入任务进度成功 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "向量生成完成",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
52
service/rag_query.go
Normal file
52
service/rag_query.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/common/eino"
|
||||
"rag/model/dto"
|
||||
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
)
|
||||
|
||||
var RAGQuery = new(ragQueryService)
|
||||
|
||||
type ragQueryService struct{}
|
||||
|
||||
// Query 执行RAG查询
|
||||
func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
|
||||
if req.TopK <= 0 {
|
||||
req.TopK = 5
|
||||
}
|
||||
|
||||
// 4. 使用向量检索器进行查询
|
||||
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
|
||||
Embedder: eino.EmbedderDashscope,
|
||||
DefaultTopK: req.TopK,
|
||||
})
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "初始化向量检索器失败: %v", err)
|
||||
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 执行向量检索
|
||||
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
|
||||
"dataset_ids": req.DatasetIds,
|
||||
}))
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "向量检索失败: %v", err)
|
||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||
}
|
||||
|
||||
replyMsg, sources, err := eino.NewChatModel(ctx, req.Content, docs)
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "向量检索失败: %v", err)
|
||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||
}
|
||||
|
||||
return &dto.RAGQueryRes{
|
||||
Answer: replyMsg.Content,
|
||||
Sources: sources,
|
||||
}, nil
|
||||
}
|
||||
107
service/task.go
Normal file
107
service/task.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
|
||||
"rag/common/task"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Task = new(taskService)
|
||||
|
||||
type taskService struct{}
|
||||
|
||||
// WriteTaskProgress 写入任务进度(核心方法)
|
||||
func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskProgressReq) (err error) {
|
||||
t, total, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
taskVO := make([]dto.TaskVO, 0, total)
|
||||
err = gconv.Struct(t, taskVO)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "转换任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
taskVO = append(taskVO, dto.TaskVO{
|
||||
TaskType: req.TaskType,
|
||||
Status: req.Status,
|
||||
})
|
||||
completed := IsAllSubTasksCompleted(taskVO)
|
||||
|
||||
// 1. 查询是否已存在该文档的该类型任务
|
||||
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
TaskType: req.TaskType,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. 如果不存在,则创建新任务
|
||||
if g.IsEmpty(existTask) {
|
||||
createReq := &dto.CreateTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
TaskType: req.TaskType,
|
||||
Status: req.Status,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, createReq)
|
||||
} else {
|
||||
// 3. 如果已存在,则更新任务
|
||||
updateReq := &dto.UpdateTaskReq{
|
||||
Id: existTask[0].Id,
|
||||
Status: req.Status,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
_, err = dao.Task.Update(ctx, updateReq)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "更新任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if completed {
|
||||
// 3. 如果已存在,则更新任务
|
||||
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
Status: task.TaskStatusCompleted,
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsAllSubTasksCompleted 判断三个子任务是否全部完成
|
||||
// 参数:传入当前文档的所有子任务列表
|
||||
func IsAllSubTasksCompleted(subTasks []dto.TaskVO) bool {
|
||||
// 必须包含 3 种任务类型
|
||||
hasKeywords := false
|
||||
hasVector := false
|
||||
hasFullText := false
|
||||
|
||||
for _, t := range subTasks {
|
||||
// 子任务必须是【已完成】状态才计数
|
||||
if t.Status == task.TaskStatusCompleted {
|
||||
switch t.TaskType {
|
||||
case task.TaskTypeExtractKeywords:
|
||||
hasKeywords = true
|
||||
case task.TaskTypeGenerateVector:
|
||||
hasVector = true
|
||||
case task.TaskTypeFullTextSearch:
|
||||
hasFullText = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 三个任务全部完成 → 返回true
|
||||
return hasKeywords && hasVector && hasFullText
|
||||
}
|
||||
Reference in New Issue
Block a user