Files
rag/service/document.go

484 lines
13 KiB
Go
Raw Normal View History

2026-04-03 09:16:53 +08:00
package service
import (
"context"
"fmt"
"rag/consts/document"
"rag/consts/public"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"strings"
"sync"
"time"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/full-text-search/meilisearch"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/rag/eino"
"gitea.com/red-future/common/rag/gse"
"gitea.com/red-future/common/utils"
gmq "github.com/bjang03/gmq/core/gmq"
"github.com/bjang03/gmq/mq"
"github.com/bjang03/gmq/types"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/crypto/gmd5"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/database/gredis"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Document = new(documentService)
type documentService struct{}
// Create 创建文件
func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq) (res *dto.CreateDocumentRes, err error) {
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
var id int64
id, err = dao.Document.Insert(ctx, req)
if err != nil {
return
}
datasetReq := &dto.UpdateDatasetReq{
Id: req.DatasetId,
DocumentCount: 1,
DocumentSize: req.FileSize,
}
_, err = dao.Dataset.Update(ctx, datasetReq)
if err != nil {
return
}
res = &dto.CreateDocumentRes{Id: id}
return
})
return
}
// Update 更新文件
func (s *documentService) Update(ctx context.Context, req *dto.UpdateDocumentReq) (err error) {
_, err = dao.Document.Update(ctx, req)
return
}
// Delete 删除文件
func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (err error) {
docs, err := dao.Document.GetByID(ctx, &dto.GetDocumentReq{Id: req.Id})
if err != nil {
return
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
datasetReq := &dto.UpdateDatasetReq{
Id: docs.DatasetId,
DocumentCount: -1,
DocumentSize: -docs.FileSize,
}
_, err = dao.Dataset.Update(ctx, datasetReq)
if err != nil {
return
}
_, err = dao.Document.Delete(ctx, req)
return
})
return
}
// Get 获取文件详情
func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) {
r, err := dao.Document.GetByID(ctx, req)
err = gconv.Struct(r, &res)
return
}
// List 文件列表
func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto.ListDocumentRes, err error) {
list, total, err := dao.Document.List(ctx, req)
if err != nil {
return nil, err
}
res = &dto.ListDocumentRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
//eino.TestIndexer()
//eino.TestRetriever()
return
}
// Process 处理文件(使用eino框架切分和向量化)
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
startTime := time.Now()
// 1. 查询文件信息
documentReq := dto.GetDocumentReq{Id: req.Id}
doc, err := dao.Document.GetByID(ctx, &documentReq)
if err != nil {
return nil, err
}
// 2. 使用eino框架进行文件切分并发执行
var vectorDocsCount, chunks int64
// 用 gopool 或者简单的错误等待,绝对不用裸 goroutine
var err1, err2, err3 error
var wg sync.WaitGroup
wg.Add(3)
// 任务1
go func() {
defer wg.Done()
vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc)
}()
// 任务2
go func() {
defer wg.Done()
err2 = s.esSplitDocument(ctx, doc)
}()
// 任务3
go func() {
defer wg.Done()
err3 = s.extractDocument(ctx, doc)
}()
// 直接等待,不使用通道,避免泄漏
wg.Wait()
updateDocumentReq := new(dto.UpdateDocumentReq)
updateDocumentReq.Id = req.Id
// 统一判断错误
if err1 != nil || err2 != nil || err3 != nil {
// 更新文档状态
updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code()
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
return nil, err
}
if err1 != nil {
return nil, err1
}
if err2 != nil {
return nil, err2
}
return nil, err3
}
// 4. 更新文件状态为处理中和切分数量
if vectorDocsCount > 0 {
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
} else {
updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code()
}
updateDocumentReq.ChunkCount = chunks
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
return
}
costTime := time.Since(startTime).Milliseconds()
return &dto.ProcessDocumentRes{
ChunkCount: chunks,
CostTime: costTime,
}, nil
}
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
return
}
var words []gse.Keyword
if len(docs[0].Content) < 500 {
words = gse.GseTool.Extract(docs[0].Content, 4)
} else if len(docs[0].Content) < 2000 {
words = gse.GseTool.Extract(docs[0].Content, 8)
} else if len(docs[0].Content) < 5000 {
words = gse.GseTool.Extract(docs[0].Content, 13)
} else {
var docsSplit []*schema.Document
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
if err != nil {
return
}
for _, t := range docsSplit {
words = append(words, gse.GseTool.Extract(t.Content, 6)...)
}
}
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
for _, word := range words {
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
DatasetId: doc.DatasetId,
DocumentId: doc.Id,
Word: word.Word,
Weight: gconv.Int16(word.Score),
})
}
if len(keywordReqs) > 0 {
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
if err != nil {
return
}
}
return
}
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) {
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
return
}
// 2. 语义切分文件
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
if err != nil {
return
}
docsSplitCount = gconv.Int64(len(docsSplit))
// 2. 获取历史数据
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
if err != nil {
return
}
// 3. 组装向量文档
var vectorDocs = make([]dto.VectorDocumentChunkMsg, 0)
for i, t := range docsSplit {
contentHash := gmd5.MustEncryptString(t.Content)
// 检查是否重复
var success bool
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
if err != nil {
return
}
if !success {
continue
}
vectorDocs = append(vectorDocs, dto.VectorDocumentChunkMsg{
TenantId: doc.TenantId,
Creator: doc.Creator,
DatasetId: doc.DatasetId,
DocumentId: doc.Id,
Content: t.Content,
ContentHash: contentHash,
ChunkIndex: gconv.Int64(i),
})
}
// 4. 发送消息到队列
if len(vectorDocs) > 0 {
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
PubMessage: types.PubMessage{
Topic: public.KnowledgeDocumentChunkTopic,
Data: vectorDocs,
},
})
}
vectorDocsCount = gconv.Int64(len(vectorDocs))
return
}
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
return
}
// 2. 递归切分文件
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
if err != nil {
return
}
// 2. 获取历史数据
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
if err != nil {
return
}
// 3. 组装向量文档并同时构建meilisearch文档
var meiliDocs = make([]interface{}, 0)
for i, t := range docsSplit {
contentHash := gmd5.MustEncryptString(t.Content)
// 检查是否重复
var success bool
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
if err != nil {
return
}
if !success {
continue
}
// 构建Meilisearch文档
meiliDocs = append(meiliDocs, map[string]interface{}{
"id": contentHash,
"datasetId": doc.DatasetId,
"documentId": doc.Id,
"content": t.Content,
"contentHash": contentHash,
"chunkIndex": i,
})
}
// 4. 写入到meilisearch数据库中
if len(meiliDocs) > 0 {
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
g.Log().Errorf(ctx, "写入meilisearch失败: %v", err)
return
}
}
return
}
// loadDocument 加载文件
func (s *documentService) loadDocument(ctx context.Context, doc *entity.Document) (docs []*schema.Document, err error) {
return eino.LoadDocument(ctx, doc.FilePath, doc.Format)
}
// getHistoryData 获取历史数据
func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Document, lockKey, contentKey string) (err error) {
docsLockKey := fmt.Sprintf(lockKey, doc.DatasetId)
success, err := utils.Lock(ctx, docsLockKey, int64(60), func(ctx context.Context) error {
// 1. 扫描 Redis 中所有 前缀为 rag:knowledge:xxx:contentHash 的 key
pattern := fmt.Sprintf(contentKey, "*")
keys, err := g.Redis().Keys(ctx, pattern)
if err != nil {
return err
}
// 2. Redis 有数据:只刷新过期时间,不查库
if len(keys) > 0 {
// 批量刷新过期时间为 60s
for _, key := range keys {
_, err = g.Redis().Expire(ctx, key, 600)
if err != nil {
return err
}
}
return nil
}
// 3. Redis 无数据:根据 contentKey 类型选择查询方式
var dictData = make([]*dto.DocumentChunkRPC, 0)
if public.KnowledgeContentHashSqlKey == contentKey {
// SQL 方式:调用 HTTP 接口查询
dictData, err = s.getHistoryDataFromHttp(ctx, doc)
} else {
// ES 方式:查询 meilisearch
dictData, err = s.getHistoryDataFromMeilisearch(ctx, doc)
}
if err != nil {
return err
}
// 4. 把查询到的数据写入 Redis600s过期
for _, item := range dictData {
// 去除可能的 JSON 引号
contentHash := strings.Trim(item.ContentHash, `"`)
key := fmt.Sprintf(contentKey, contentHash)
_, err = g.Redis().Set(ctx, key, true, gredis.SetOption{
TTLOption: gredis.TTLOption{
EX: gconv.PtrInt64(600),
},
NX: true,
})
if err != nil {
return err
}
}
return nil
})
if err != nil && !success {
return
}
return
}
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
// 调用接口获取数据
d := &dto.ListDocumentChunkRPC{}
if err = http.Get(ctx, "rag-vector/document/chunk/listDocumentChunk", headers, &d,
"datasetId", gconv.String(doc.DatasetId),
"status", 1); err != nil {
return
}
dictData = d.List
return
}
// getHistoryDataFromMeilisearch 通过 meilisearch 查询历史数据
func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
// 构建 meilisearch 查询参数
searchParams := &meilisearch.SearchParams{
Filter: fmt.Sprintf("datasetId = %d", doc.DatasetId),
Limit: 10000,
}
// 执行搜索
var hits []map[string]interface{}
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
if err != nil {
return
}
// 转换查询结果
dictData = make([]*dto.DocumentChunkRPC, 0)
for _, hit := range hits {
item := &dto.DocumentChunkRPC{}
if err = gconv.Struct(hit, item); err != nil {
return
}
dictData = append(dictData, item)
}
return
}
// checkRepeat 检查是否重复
func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHash string) (success bool, err error) {
var val *gvar.Var
if val, err = g.Redis().Set(ctx, fmt.Sprintf(contentKey, contentHash), true, gredis.SetOption{
TTLOption: gredis.TTLOption{
EX: gconv.PtrInt64(600),
},
NX: true,
}); err != nil {
return
}
success = val.Bool()
return
}
func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) {
var req = new(dto.KnowledgeDocumentMsg)
if err = gconv.Struct(msg, &req); err != nil {
g.Log().Error(ctx, "DocsVectorStatusMsg err:", err)
return
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: req.TenantId,
UserName: req.Creator,
})
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
Id: req.Id,
VectorStatus: req.VectorStatus,
})
return
}