refactor: 重构文档处理流程和任务管理

This commit is contained in:
2026-04-09 09:11:43 +08:00
parent b6896f3fb4
commit 7f894745e9
34 changed files with 1216 additions and 1056 deletions

View File

@@ -5,17 +5,14 @@ import (
"errors"
"fmt"
"rag/common/eino"
"rag/common/gse"
"rag/common/task"
"rag/consts/document"
"rag/consts/public"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"strings"
"sync"
"time"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/full-text-search/meilisearch"
"gitea.com/red-future/common/http"
@@ -29,6 +26,7 @@ import (
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/database/gredis"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -54,7 +52,13 @@ func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq
return
}
res = &dto.CreateDocumentRes{Id: id}
// 写入任务进度待处理 任务类型为文档解析
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: id,
TaskType: task.TaskTypeDocParse,
Status: task.TaskStatusPending,
Remark: "文档上传成功待解析: " + req.Title,
})
return
})
@@ -79,11 +83,20 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq
DocumentCount: -1,
DocumentSize: -docs.FileSize,
}
_, err = dao.Dataset.Update(ctx, datasetReq)
if err != nil {
if _, err = dao.Dataset.Update(ctx, datasetReq); err != nil {
return
}
_, err = dao.Document.Delete(ctx, req)
if _, err = dao.Document.Delete(ctx, req); err != nil {
return
}
if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{
TaskId: docs.Id,
}); err != nil {
return
}
return
})
@@ -107,118 +120,159 @@ func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (r
Total: total,
}
err = gconv.Struct(list, &res.List)
//eino.TestIndexer()
//eino.TestRetriever()
return
}
// Process 处理文件(使用eino框架切分和向量化)
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
startTime := time.Now()
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (err error) {
// 1. 查询文件信息
documentReq := dto.GetDocumentReq{Id: req.Id}
doc, err := dao.Document.GetByID(ctx, &documentReq)
if err != nil {
return nil, err
return err
}
if g.IsEmpty(doc) {
return nil, errors.New("document not found")
return errors.New("document not found")
}
// 2. 使用eino框架进行文件切分并发执行
var vectorDocsCount, chunks int64
// 用 gopool 或者简单的错误等待,绝对不用裸 goroutine
var err1, err2, err3 error
var wg sync.WaitGroup
wg.Add(3)
// 任务1
go func() {
defer wg.Done()
vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc)
}()
// 任务2
go func() {
defer wg.Done()
err2 = s.esSplitDocument(ctx, doc)
}()
// 任务3
go func() {
defer wg.Done()
err3 = s.extractDocument(ctx, doc)
}()
// 直接等待,不使用通道,避免泄漏
wg.Wait()
// 2. 更新文档状态为处理中
updateDocumentReq := new(dto.UpdateDocumentReq)
updateDocumentReq.Id = req.Id
// 统一判断错误
if err1 != nil || err2 != nil || err3 != nil {
// 更新文档状态
updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code()
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
return nil, err
}
if err1 != nil {
return nil, err1
}
if err2 != nil {
return nil, err2
}
return nil, err3
}
// 4. 更新文件状态为处理中和切分数量
if vectorDocsCount > 0 {
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
} else {
updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code()
}
updateDocumentReq.ChunkCount = chunks
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
// 写入任务进度失败 任务类型为文档解析
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: req.Id,
TaskType: task.TaskTypeDocParse,
Status: task.TaskStatusFailed,
Remark: "更新文档状态失败: " + err.Error(),
})
return
}
costTime := time.Since(startTime).Milliseconds()
return &dto.ProcessDocumentRes{
ChunkCount: chunks,
CostTime: costTime,
}, nil
}
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
// 写入任务进度进行中 任务类型为文档解析
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: req.Id,
TaskType: task.TaskTypeDocParse,
Status: task.TaskStatusRunning,
Remark: "文档解析开始",
})
if err != nil {
return
}
var words []gse.Keyword
// ======================
// 核心grpool + g.Try 最佳实践
// ======================
taskCtx, cancel := context.WithCancel(ctx)
// 任务1: SQL 切分文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
if innerErr := s.sqlSplitDocument(ctx, doc); innerErr != nil {
cancel()
}
}, func(ctx context.Context, err error) {
cancel()
})
})
// 任务2: ES 切分文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
if innerErr := s.esSplitDocument(ctx, doc); innerErr != nil {
cancel()
}
}, func(ctx context.Context, err error) {
cancel()
})
})
// 任务3: 提取文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
if innerErr := s.extractDocument(ctx, doc); innerErr != nil {
cancel()
}
}, func(ctx context.Context, err error) {
cancel()
})
})
return nil
}
// extractDocument 关键词提取(支持取消)
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
// ========== 取消检查 1方法入口 ==========
if ctx.Err() != nil {
// 写入任务进度失败 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
// 写入任务进度失败 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "加载文件失败: " + err.Error(),
})
return
}
var words []utils.Keyword
if len(docs[0].Content) < 500 {
words = gse.GseTool.Extract(docs[0].Content, 4)
words = utils.GseTool.Extract(docs[0].Content, 4)
} else if len(docs[0].Content) < 2000 {
words = gse.GseTool.Extract(docs[0].Content, 8)
words = utils.GseTool.Extract(docs[0].Content, 8)
} else if len(docs[0].Content) < 5000 {
words = gse.GseTool.Extract(docs[0].Content, 13)
words = utils.GseTool.Extract(docs[0].Content, 13)
} else {
var docsSplit []*schema.Document
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
if err != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "递归分割文档失败: " + err.Error(),
})
return
}
// ========== 取消检查 2循环内部 ==========
for _, t := range docsSplit {
words = append(words, gse.GseTool.Extract(t.Content, 6)...)
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
words = append(words, utils.GseTool.Extract(t.Content, 6)...)
}
}
// ========== 取消检查 3批量操作前 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
for _, word := range words {
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
@@ -231,37 +285,111 @@ func (s *documentService) extractDocument(ctx context.Context, doc *entity.Docum
if len(keywordReqs) > 0 {
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
if err != nil {
// 写入任务进度失败 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "关键字存储失败: " + err.Error(),
})
return
}
// 写入任务进度已完成 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusCompleted,
Remark: "关键字提取完成",
})
} else {
// 写入任务进度已完成 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusCompleted,
Remark: "没有提取到关键词,关键字提取完成",
})
}
return
}
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) {
// sqlSplitDocument SQL切分支持取消
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
// ========== 取消检查 1方法入口 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "加载文件失败: " + err.Error(),
})
return
}
// 2. 语义切分文件
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "文档切分失败: " + err.Error(),
})
return
}
docsSplitCount = gconv.Int64(len(docsSplit))
// 2. 获取历史数据
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "获取历史数据失败: " + err.Error(),
})
return
}
// 3. 组装向量文档
var docsChunk = make([]*schema.Document, 0)
for i, t := range docsSplit {
// ========== 取消检查 2循环内部 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
contentHash := gmd5.MustEncryptString(t.Content)
// 检查是否重复
var success bool
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "检查重复数据失败: " + err.Error(),
})
return
}
if !success {
@@ -277,6 +405,18 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
t.MetaData = metaData
docsChunk = append(docsChunk, t)
}
// ========== 取消检查 3批量发送前 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 4. 发送消息到队列
if len(docsChunk) > 0 {
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
@@ -285,41 +425,117 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
Data: docsChunk,
},
})
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "发送消息到队列失败: " + err.Error(),
})
return
}
// 写入任务进度进行中 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusRunning,
Remark: "向量生成任务已提交到队列",
})
} else {
// 写入任务进度已完成 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusCompleted,
Remark: "无需生成向量,任务完成",
})
}
vectorDocsCount = gconv.Int64(len(docsChunk))
return
}
// esSplitDocument ES切分支持取消
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
// ========== 取消检查 1方法入口 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "加载文件失败: " + err.Error(),
})
return
}
// 2. 递归切分文件
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "文档切分失败: " + err.Error(),
})
return
}
// 2. 获取历史数据
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "获取历史数据失败: " + err.Error(),
})
return
}
// 3. 组装向量文档并同时构建meilisearch文档
var meiliDocs = make([]interface{}, 0)
for i, t := range docsSplit {
// ========== 取消检查 2循环内部 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
contentHash := gmd5.MustEncryptString(t.Content)
// 检查是否重复
var success bool
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "检查重复数据失败: " + err.Error(),
})
return
}
if !success {
continue
}
// 构建Meilisearch文档
meiliDocs = append(meiliDocs, map[string]interface{}{
entity.DocumentChunkCol.Id: contentHash,
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
@@ -329,12 +545,45 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum
entity.DocumentChunkCol.ChunkIndex: i,
})
}
// ========== 取消检查 3批量写入前 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 4. 写入到meilisearch数据库中
if len(meiliDocs) > 0 {
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
g.Log().Errorf(ctx, "写入meilisearch失败: %v", err)
// 写入任务进度失败 任务类型为meilisearch存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "写入meilisearch失败: " + err.Error(),
})
return
}
// 写入任务进度已完成 任务类型为meilisearch存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusCompleted,
Remark: "全文检索数据写入完成",
})
} else {
// 写入任务进度已完成 任务类型为meilisearch存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusCompleted,
Remark: "无需生成全文检索数据,任务完成",
})
}
return
}
@@ -467,20 +716,3 @@ func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHa
success = val.Bool()
return
}
func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) {
var req = new(dto.KnowledgeDocumentMsg)
if err = gconv.Struct(msg, &req); err != nil {
g.Log().Error(ctx, "DocsVectorStatusMsg err:", err)
return
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: req.TenantId,
UserName: req.Creator,
})
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
Id: req.Id,
VectorStatus: req.VectorStatus,
})
return
}

View File

@@ -3,15 +3,11 @@ package service
import (
"context"
"rag/common/eino"
"rag/consts/document"
"rag/consts/public"
"rag/common/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
gmq "github.com/bjang03/gmq/core/gmq"
"github.com/bjang03/gmq/mq"
"github.com/bjang03/gmq/types"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
@@ -22,10 +18,6 @@ var DocumentChunk = new(documentChunkService)
type documentChunkService struct{}
const (
DatasetIndexStatusReady = "ready"
)
// Update 更新文件块
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
_, err = dao.DocumentChunk.Update(ctx, req)
@@ -60,32 +52,29 @@ func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err e
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
BatchSize: 10,
})
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
if err != nil || rows == 0 {
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
// 写入任务进度失败 任务类型为sql存储
remark := " 向量存储数量: " + gconv.String(rows)
if err != nil {
remark = "向量存储失败: " + err.Error()
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: remark,
})
return
}
tenantId := gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId])
creator := gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator])
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
err = s.publishKnowledgeDocumentMsg(ctx, tenantId, creator, documentId, document.VectorStatusCompleted.Code())
return
}
// publishKnowledgeDocumentMsg 发布消息
func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {
knowledgeDocumentMsg := dto.KnowledgeDocumentMsg{
TenantId: tenantId,
Creator: creator,
Id: documentId,
VectorStatus: vectorStatus,
}
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
PubMessage: types.PubMessage{
Topic: public.KnowledgeDocumentVectorStatusTopic,
Data: knowledgeDocumentMsg,
},
// 写入任务进度成功 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusCompleted,
Remark: "向量生成完成",
})
return
}

52
service/rag_query.go Normal file
View File

@@ -0,0 +1,52 @@
package service
import (
"context"
"fmt"
"rag/common/eino"
"rag/model/dto"
"github.com/cloudwego/eino/components/retriever"
"github.com/gogf/gf/v2/os/glog"
)
var RAGQuery = new(ragQueryService)
type ragQueryService struct{}
// Query 执行RAG查询
func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
if req.TopK <= 0 {
req.TopK = 5
}
// 4. 使用向量检索器进行查询
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
Embedder: eino.EmbedderDashscope,
DefaultTopK: req.TopK,
})
if err != nil {
glog.Errorf(ctx, "初始化向量检索器失败: %v", err)
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
}
// 5. 执行向量检索
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
"dataset_ids": req.DatasetIds,
}))
if err != nil {
glog.Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
replyMsg, sources, err := eino.NewChatModel(ctx, req.Content, docs)
if err != nil {
glog.Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
return &dto.RAGQueryRes{
Answer: replyMsg.Content,
Sources: sources,
}, nil
}

107
service/task.go Normal file
View File

@@ -0,0 +1,107 @@
package service
import (
"context"
"rag/dao"
"rag/model/dto"
"rag/common/task"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Task = new(taskService)
type taskService struct{}
// WriteTaskProgress 写入任务进度(核心方法)
func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskProgressReq) (err error) {
t, total, err := dao.Task.Get(ctx, &dto.GetTaskReq{
TaskId: req.TaskId,
})
if err != nil {
g.Log().Errorf(ctx, "查询任务失败: %v", err)
return err
}
taskVO := make([]dto.TaskVO, 0, total)
err = gconv.Struct(t, taskVO)
if err != nil {
g.Log().Errorf(ctx, "转换任务失败: %v", err)
return err
}
taskVO = append(taskVO, dto.TaskVO{
TaskType: req.TaskType,
Status: req.Status,
})
completed := IsAllSubTasksCompleted(taskVO)
// 1. 查询是否已存在该文档的该类型任务
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
TaskId: req.TaskId,
TaskType: req.TaskType,
})
if err != nil {
g.Log().Errorf(ctx, "查询任务失败: %v", err)
return err
}
// 2. 如果不存在,则创建新任务
if g.IsEmpty(existTask) {
createReq := &dto.CreateTaskReq{
TaskId: req.TaskId,
TaskType: req.TaskType,
Status: req.Status,
Remark: req.Remark,
}
_, err = dao.Task.Insert(ctx, createReq)
} else {
// 3. 如果已存在,则更新任务
updateReq := &dto.UpdateTaskReq{
Id: existTask[0].Id,
Status: req.Status,
Remark: req.Remark,
}
_, err = dao.Task.Update(ctx, updateReq)
if err != nil {
g.Log().Errorf(ctx, "更新任务失败: %v", err)
return err
}
}
if completed {
// 3. 如果已存在,则更新任务
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
TaskId: req.TaskId,
Status: task.TaskStatusCompleted,
})
}
return
}
// IsAllSubTasksCompleted 判断三个子任务是否全部完成
// 参数:传入当前文档的所有子任务列表
func IsAllSubTasksCompleted(subTasks []dto.TaskVO) bool {
// 必须包含 3 种任务类型
hasKeywords := false
hasVector := false
hasFullText := false
for _, t := range subTasks {
// 子任务必须是【已完成】状态才计数
if t.Status == task.TaskStatusCompleted {
switch t.TaskType {
case task.TaskTypeExtractKeywords:
hasKeywords = true
case task.TaskTypeGenerateVector:
hasVector = true
case task.TaskTypeFullTextSearch:
hasFullText = true
}
}
}
// 三个任务全部完成 → 返回true
return hasKeywords && hasVector && hasFullText
}