package service import ( "context" "fmt" "rag/common/eino" "rag/consts/model" "rag/consts/task" "rag/dao" "rag/model/dto" "rag/model/entity" "gitea.com/red-future/common/beans" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" "github.com/pgvector/pgvector-go" ) var DocumentVector = new(documentVectorService) type documentVectorService struct{} // Query 执行RAG查询 func (s *documentVectorService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) { modelInfo, err := dao.Model.Get(ctx, &dto.GetModelReq{ ModelType: model.ModelTypeChat.Code(), }) if err != nil { g.Log().Errorf(ctx, "获取模型失败: %v", err) return nil, fmt.Errorf("获取模型失败: %w", err) } if modelInfo == nil { g.Log().Errorf(ctx, "模型不存在: %v", model.ModelTypeChat.Code()) return nil, fmt.Errorf("模型不存在: %w", err) } // 4. 使用向量检索器进行查询 r, err := eino.NewPGVectorRetriever(ctx, &eino.PGVectorRetrieverConfig{ DefaultTopK: req.TopK, }, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型 if err != nil { g.Log().Errorf(ctx, "初始化向量检索器失败: %v", err) return nil, fmt.Errorf("初始化向量检索器失败: %w", err) } // 5. 执行向量检索 docs, err := r.Retrieve(ctx, req.Content, retriever.WithDSLInfo(map[string]any{ "dataset_ids": req.DatasetIds, "document_ids": req.DocumentIds, })) if err != nil { g.Log().Errorf(ctx, "向量检索失败: %v", err) return nil, fmt.Errorf("向量检索失败: %w", err) } messages := make([]*schema.Message, 0) err = gconv.Struct(req.History, &messages) if err != nil { g.Log().Errorf(ctx, "转换历史消息失败: %v", err) return nil, fmt.Errorf("转换历史消息失败: %w", err) } replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages, modelInfo.ConfigType) if err != nil { g.Log().Errorf(ctx, "向量检索失败: %v", err) return nil, fmt.Errorf("向量检索失败: %w", err) } return &dto.RAGQueryRes{ Answer: replyMsg.Content, }, nil } // Update 更新文件块 func (s *documentVectorService) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (err error) { _, err = dao.DocumentVector.Update(ctx, req) return } // List 获取文件块列表 func (s *documentVectorService) List(ctx context.Context, req *dto.ListDocumentVectorReq) (res *dto.ListDocumentVectorRes, err error) { list, total, err := dao.DocumentVector.List(ctx, req) if err != nil { return } res = &dto.ListDocumentVectorRes{ Total: total, } err = gconv.Struct(list, &res.List) return } func (s *documentVectorService) DocsChunkMsg(ctx context.Context, msg any) (err error) { var docs = make([]*schema.Document, 0) msgMap := gconv.Map(msg) if err = gconv.Structs(msgMap["data"], &docs); err != nil { g.Log().Error(ctx, "DocsChunkMsg err:", err) return } if len(docs) == 0 { g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty") return } ctx = context.WithValue(ctx, "user", &beans.User{ TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentVectorCol.TenantId]), UserName: gconv.String(docs[0].MetaData[entity.DocumentVectorCol.Creator]), }) documentId := gconv.Int64(docs[0].MetaData[entity.DocumentVectorCol.DocumentId]) var docsStore = make([]*schema.Document, 0) var docsInsert = make([]*dto.VectorDocumentVectorMsg, 0) for _, doc := range docs { if gconv.Bool(doc.MetaData["isNew"]) { docsStore = append(docsStore, doc) } else { ck := new(dto.VectorDocumentVectorMsg) err = gconv.Struct(doc.MetaData, ck) ck.Content = doc.Content ck.VectorStatus = gconv.PtrInt8(1) ck.Status = gconv.PtrInt8(1) docsInsert = append(docsInsert, ck) } } if !g.IsEmpty(docsStore) { idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{ BatchSize: 10, }) var rows int64 rows, err = idx.Store(ctx, docsStore, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型 if err != nil || rows == 0 { g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err) // 写入任务进度失败 任务类型为sql存储 remark := " 向量存储数量: " + gconv.String(rows) if err != nil { remark = "向量存储失败: " + err.Error() } err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ TaskId: documentId, TaskType: task.TaskTypeGenerateVector, Status: task.TaskStatusFailed, Remark: remark, }) return } } if !g.IsEmpty(docsInsert) { // 1. 提取所有 contentHash contentHashs := make([]string, 0, len(docsInsert)) for _, d := range docsInsert { contentHashs = append(contentHashs, d.ContentHash) } // 2. 分页查询已存在的向量(一页1000,避免大查询) var existVectors []*entity.DocumentVector for page := 1; ; page++ { res, total, err := dao.DocumentVector.List(ctx, &dto.ListDocumentVectorReq{ Page: &beans.Page{PageSize: 1000, PageNum: int64(page)}, ContentHashs: contentHashs, }) if err != nil { return err } if len(res) == 0 { break } existVectors = append(existVectors, res...) if len(existVectors) >= total { break } } // 3. 构建哈希 -> 向量 的映射表(O(1) 查找,性能提升巨大) vectorMap := make(map[string]pgvector.Vector, len(existVectors)) for _, v := range existVectors { vectorMap[v.ContentHash] = v.Vector } // 4. 回填向量 + 过滤掉数据库已存在的数据(避免重复插入) for _, d := range docsInsert { // 回填已有向量 if vec, ok := vectorMap[d.ContentHash]; ok { d.Vector = vec } } var rows int64 rows, err = dao.DocumentVector.BatchInsert(ctx, docsInsert) if err != nil || rows == 0 { g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err) // 写入任务进度失败 任务类型为sql存储 remark := " 向量存储数量: " + gconv.String(rows) if err != nil { remark = "向量存储失败: " + err.Error() } err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ TaskId: documentId, TaskType: task.TaskTypeGenerateVector, Status: task.TaskStatusFailed, Remark: remark, }) return } } // 写入任务进度成功 任务类型为sql存储 err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ TaskId: documentId, TaskType: task.TaskTypeGenerateVector, Status: task.TaskStatusCompleted, Remark: "向量生成完成", }) return }