package service import ( "context" "fmt" "rag/consts/document" "rag/consts/public" "rag/dao" "rag/model/dto" "rag/model/entity" "strings" "sync" "time" "gitea.com/red-future/common/beans" "gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/full-text-search/meilisearch" "gitea.com/red-future/common/http" "gitea.com/red-future/common/rag/eino" "gitea.com/red-future/common/rag/gse" "gitea.com/red-future/common/utils" gmq "github.com/bjang03/gmq/core/gmq" "github.com/bjang03/gmq/mq" "github.com/bjang03/gmq/types" "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/crypto/gmd5" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gredis" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) var Document = new(documentService) type documentService struct{} // Create 创建文件 func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq) (res *dto.CreateDocumentRes, err error) { err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) { var id int64 id, err = dao.Document.Insert(ctx, req) if err != nil { return } datasetReq := &dto.UpdateDatasetReq{ Id: req.DatasetId, DocumentCount: 1, DocumentSize: req.FileSize, } _, err = dao.Dataset.Update(ctx, datasetReq) if err != nil { return } res = &dto.CreateDocumentRes{Id: id} return }) return } // Update 更新文件 func (s *documentService) Update(ctx context.Context, req *dto.UpdateDocumentReq) (err error) { _, err = dao.Document.Update(ctx, req) return } // Delete 删除文件 func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (err error) { docs, err := dao.Document.GetByID(ctx, &dto.GetDocumentReq{Id: req.Id}) if err != nil { return } err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) { datasetReq := &dto.UpdateDatasetReq{ Id: docs.DatasetId, DocumentCount: -1, DocumentSize: -docs.FileSize, } _, err = dao.Dataset.Update(ctx, datasetReq) if err != nil { return } _, err = dao.Document.Delete(ctx, req) return }) return } // Get 获取文件详情 func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) { r, err := dao.Document.GetByID(ctx, req) err = gconv.Struct(r, &res) return } // List 文件列表 func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto.ListDocumentRes, err error) { list, total, err := dao.Document.List(ctx, req) if err != nil { return nil, err } res = &dto.ListDocumentRes{ Total: total, } err = gconv.Struct(list, &res.List) //eino.TestIndexer() //eino.TestRetriever() return } // Process 处理文件(使用eino框架切分和向量化) func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) { startTime := time.Now() // 1. 查询文件信息 documentReq := dto.GetDocumentReq{Id: req.Id} doc, err := dao.Document.GetByID(ctx, &documentReq) if err != nil { return nil, err } // 2. 使用eino框架进行文件切分(并发执行) var vectorDocsCount, chunks int64 // 用 gopool 或者简单的错误等待,绝对不用裸 goroutine var err1, err2, err3 error var wg sync.WaitGroup wg.Add(3) // 任务1 go func() { defer wg.Done() vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc) }() // 任务2 go func() { defer wg.Done() err2 = s.esSplitDocument(ctx, doc) }() // 任务3 go func() { defer wg.Done() err3 = s.extractDocument(ctx, doc) }() // 直接等待,不使用通道,避免泄漏 wg.Wait() updateDocumentReq := new(dto.UpdateDocumentReq) updateDocumentReq.Id = req.Id // 统一判断错误 if err1 != nil || err2 != nil || err3 != nil { // 更新文档状态 updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code() if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil { return nil, err } if err1 != nil { return nil, err1 } if err2 != nil { return nil, err2 } return nil, err3 } // 4. 更新文件状态为处理中和切分数量 if vectorDocsCount > 0 { updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code() } else { updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code() } updateDocumentReq.ChunkCount = chunks if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil { return } costTime := time.Since(startTime).Milliseconds() return &dto.ProcessDocumentRes{ ChunkCount: chunks, CostTime: costTime, }, nil } func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) { // 1. 加载文件 docs, err := s.loadDocument(ctx, doc) if err != nil { return } var words []gse.Keyword if len(docs[0].Content) < 500 { words = gse.GseTool.Extract(docs[0].Content, 4) } else if len(docs[0].Content) < 2000 { words = gse.GseTool.Extract(docs[0].Content, 8) } else if len(docs[0].Content) < 5000 { words = gse.GseTool.Extract(docs[0].Content, 13) } else { var docsSplit []*schema.Document docsSplit, err = eino.RecursiveSplitDocument(ctx, docs) if err != nil { return } for _, t := range docsSplit { words = append(words, gse.GseTool.Extract(t.Content, 6)...) } } var keywordReqs = make([]*dto.CreateKeywordReq, 0) for _, word := range words { keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{ DatasetId: doc.DatasetId, DocumentId: doc.Id, Word: word.Word, Weight: gconv.Int16(word.Score), }) } if len(keywordReqs) > 0 { _, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs) if err != nil { return } } return } func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) { // 1. 加载文件 docs, err := s.loadDocument(ctx, doc) if err != nil { return } // 2. 语义切分文件 docsSplit, err := eino.SemanticSplitDocument(ctx, docs) if err != nil { return } docsSplitCount = gconv.Int64(len(docsSplit)) // 2. 获取历史数据 err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey) if err != nil { return } // 3. 组装向量文档 var vectorDocs = make([]dto.VectorDocumentChunkMsg, 0) for i, t := range docsSplit { contentHash := gmd5.MustEncryptString(t.Content) // 检查是否重复 var success bool success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash) if err != nil { return } if !success { continue } vectorDocs = append(vectorDocs, dto.VectorDocumentChunkMsg{ TenantId: doc.TenantId, Creator: doc.Creator, DatasetId: doc.DatasetId, DocumentId: doc.Id, Content: t.Content, ContentHash: contentHash, ChunkIndex: gconv.Int64(i), }) } // 4. 发送消息到队列 if len(vectorDocs) > 0 { err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{ PubMessage: types.PubMessage{ Topic: public.KnowledgeDocumentChunkTopic, Data: vectorDocs, }, }) } vectorDocsCount = gconv.Int64(len(vectorDocs)) return } func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) { // 1. 加载文件 docs, err := s.loadDocument(ctx, doc) if err != nil { return } // 2. 递归切分文件 docsSplit, err := eino.RecursiveSplitDocument(ctx, docs) if err != nil { return } // 2. 获取历史数据 err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey) if err != nil { return } // 3. 组装向量文档并同时构建meilisearch文档 var meiliDocs = make([]interface{}, 0) for i, t := range docsSplit { contentHash := gmd5.MustEncryptString(t.Content) // 检查是否重复 var success bool success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash) if err != nil { return } if !success { continue } // 构建Meilisearch文档 meiliDocs = append(meiliDocs, map[string]interface{}{ "id": contentHash, "datasetId": doc.DatasetId, "documentId": doc.Id, "content": t.Content, "contentHash": contentHash, "chunkIndex": i, }) } // 4. 写入到meilisearch数据库中 if len(meiliDocs) > 0 { if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil { g.Log().Errorf(ctx, "写入meilisearch失败: %v", err) return } } return } // loadDocument 加载文件 func (s *documentService) loadDocument(ctx context.Context, doc *entity.Document) (docs []*schema.Document, err error) { return eino.LoadDocument(ctx, doc.FilePath, doc.Format) } // getHistoryData 获取历史数据 func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Document, lockKey, contentKey string) (err error) { docsLockKey := fmt.Sprintf(lockKey, doc.DatasetId) success, err := utils.Lock(ctx, docsLockKey, int64(60), func(ctx context.Context) error { // 1. 扫描 Redis 中所有 前缀为 rag:knowledge:xxx:contentHash 的 key pattern := fmt.Sprintf(contentKey, "*") keys, err := g.Redis().Keys(ctx, pattern) if err != nil { return err } // 2. Redis 有数据:只刷新过期时间,不查库 if len(keys) > 0 { // 批量刷新过期时间为 60s for _, key := range keys { _, err = g.Redis().Expire(ctx, key, 600) if err != nil { return err } } return nil } // 3. Redis 无数据:根据 contentKey 类型选择查询方式 var dictData = make([]*dto.DocumentChunkRPC, 0) if public.KnowledgeContentHashSqlKey == contentKey { // SQL 方式:调用 HTTP 接口查询 dictData, err = s.getHistoryDataFromHttp(ctx, doc) } else { // ES 方式:查询 meilisearch dictData, err = s.getHistoryDataFromMeilisearch(ctx, doc) } if err != nil { return err } // 4. 把查询到的数据写入 Redis(600s过期) for _, item := range dictData { // 去除可能的 JSON 引号 contentHash := strings.Trim(item.ContentHash, `"`) key := fmt.Sprintf(contentKey, contentHash) _, err = g.Redis().Set(ctx, key, true, gredis.SetOption{ TTLOption: gredis.TTLOption{ EX: gconv.PtrInt64(600), }, NX: true, }) if err != nil { return err } } return nil }) if err != nil && !success { return } return } // getHistoryDataFromHttp 通过 HTTP 接口查询历史数据 func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) { headers := make(map[string]string) if r := g.RequestFromCtx(ctx); r != nil { for k, v := range r.Request.Header { if len(v) > 0 { headers[k] = v[0] } } } // 调用接口获取数据 d := &dto.ListDocumentChunkRPC{} if err = http.Get(ctx, "rag-vector/document/chunk/listDocumentChunk", headers, &d, "datasetId", gconv.String(doc.DatasetId), "status", 1); err != nil { return } dictData = d.List return } // getHistoryDataFromMeilisearch 通过 meilisearch 查询历史数据 func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) { // 构建 meilisearch 查询参数 searchParams := &meilisearch.SearchParams{ Filter: fmt.Sprintf("datasetId = %d", doc.DatasetId), Limit: 10000, } // 执行搜索 var hits []map[string]interface{} _, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits) if err != nil { return } // 转换查询结果 dictData = make([]*dto.DocumentChunkRPC, 0) for _, hit := range hits { item := &dto.DocumentChunkRPC{} if err = gconv.Struct(hit, item); err != nil { return } dictData = append(dictData, item) } return } // checkRepeat 检查是否重复 func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHash string) (success bool, err error) { var val *gvar.Var if val, err = g.Redis().Set(ctx, fmt.Sprintf(contentKey, contentHash), true, gredis.SetOption{ TTLOption: gredis.TTLOption{ EX: gconv.PtrInt64(600), }, NX: true, }); err != nil { return } success = val.Bool() return } func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) { var req = new(dto.KnowledgeDocumentMsg) if err = gconv.Struct(msg, &req); err != nil { g.Log().Error(ctx, "DocsVectorStatusMsg err:", err) return } ctx = context.WithValue(ctx, "user", &beans.User{ TenantId: req.TenantId, UserName: req.Creator, }) _, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{ Id: req.Id, VectorStatus: req.VectorStatus, }) return }