Files
rag/service/document_vector.go
2026-06-10 16:36:46 +08:00

222 lines
6.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"fmt"
"rag/common/eino"
"rag/consts/model"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
var DocumentVector = new(documentVectorService)
type documentVectorService struct{}
// Query 执行RAG查询
func (s *documentVectorService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
modelInfo, err := dao.Model.Get(ctx, &dto.GetModelReq{
ModelType: model.ModelTypeChat.Code(),
})
if err != nil {
g.Log().Errorf(ctx, "获取模型失败: %v", err)
return nil, fmt.Errorf("获取模型失败: %w", err)
}
if modelInfo == nil {
g.Log().Errorf(ctx, "模型不存在: %v", model.ModelTypeChat.Code())
return nil, fmt.Errorf("模型不存在: %w", err)
}
// 4. 使用向量检索器进行查询
r, err := eino.NewPGVectorRetriever(ctx, &eino.PGVectorRetrieverConfig{
DefaultTopK: req.TopK,
}, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
if err != nil {
g.Log().Errorf(ctx, "初始化向量检索器失败: %v", err)
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
}
// 5. 执行向量检索
docs, err := r.Retrieve(ctx, req.Content, retriever.WithDSLInfo(map[string]any{
"dataset_ids": req.DatasetIds,
"document_ids": req.DocumentIds,
}))
if err != nil {
g.Log().Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
messages := make([]*schema.Message, 0)
err = gconv.Struct(req.History, &messages)
if err != nil {
g.Log().Errorf(ctx, "转换历史消息失败: %v", err)
return nil, fmt.Errorf("转换历史消息失败: %w", err)
}
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages, modelInfo.ConfigType)
if err != nil {
g.Log().Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
return &dto.RAGQueryRes{
Answer: replyMsg.Content,
}, nil
}
// Update 更新文件块
func (s *documentVectorService) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (err error) {
_, err = dao.DocumentVector.Update(ctx, req)
return
}
// List 获取文件块列表
func (s *documentVectorService) List(ctx context.Context, req *dto.ListDocumentVectorReq) (res *dto.ListDocumentVectorRes, err error) {
list, total, err := dao.DocumentVector.List(ctx, req)
if err != nil {
return
}
res = &dto.ListDocumentVectorRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}
func (s *documentVectorService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
var docs = make([]*schema.Document, 0)
msgMap := gconv.Map(msg)
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
g.Log().Error(ctx, "DocsChunkMsg err:", err)
return
}
if len(docs) == 0 {
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
return
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentVectorCol.TenantId]),
UserName: gconv.String(docs[0].MetaData[entity.DocumentVectorCol.Creator]),
})
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentVectorCol.DocumentId])
var docsStore = make([]*schema.Document, 0)
var docsInsert = make([]*dto.VectorDocumentVectorMsg, 0)
for _, doc := range docs {
if gconv.Bool(doc.MetaData["isNew"]) {
docsStore = append(docsStore, doc)
} else {
ck := new(dto.VectorDocumentVectorMsg)
err = gconv.Struct(doc.MetaData, ck)
ck.Content = doc.Content
ck.VectorStatus = gconv.PtrInt8(1)
ck.Status = gconv.PtrInt8(1)
docsInsert = append(docsInsert, ck)
}
}
if !g.IsEmpty(docsStore) {
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
BatchSize: 10,
})
var rows int64
rows, err = idx.Store(ctx, docsStore, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
if err != nil || rows == 0 {
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
// 写入任务进度失败 任务类型为sql存储
remark := " 向量存储数量: " + gconv.String(rows)
if err != nil {
remark = "向量存储失败: " + err.Error()
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: remark,
})
return
}
}
if !g.IsEmpty(docsInsert) {
// 1. 提取所有 contentHash
contentHashs := make([]string, 0, len(docsInsert))
for _, d := range docsInsert {
contentHashs = append(contentHashs, d.ContentHash)
}
// 2. 分页查询已存在的向量一页1000避免大查询
var existVectors []*entity.DocumentVector
for page := 1; ; page++ {
res, total, err := dao.DocumentVector.List(ctx, &dto.ListDocumentVectorReq{
Page: &beans.Page{PageSize: 1000, PageNum: int64(page)},
ContentHashs: contentHashs,
})
if err != nil {
return err
}
if len(res) == 0 {
break
}
existVectors = append(existVectors, res...)
if len(existVectors) >= total {
break
}
}
// 3. 构建哈希 -> 向量 的映射表O(1) 查找,性能提升巨大)
vectorMap := make(map[string]pgvector.Vector, len(existVectors))
for _, v := range existVectors {
vectorMap[v.ContentHash] = v.Vector
}
// 4. 回填向量 + 过滤掉数据库已存在的数据(避免重复插入)
for _, d := range docsInsert {
// 回填已有向量
if vec, ok := vectorMap[d.ContentHash]; ok {
d.Vector = vec
}
}
var rows int64
rows, err = dao.DocumentVector.BatchInsert(ctx, docsInsert)
if err != nil || rows == 0 {
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
// 写入任务进度失败 任务类型为sql存储
remark := " 向量存储数量: " + gconv.String(rows)
if err != nil {
remark = "向量存储失败: " + err.Error()
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: remark,
})
return
}
}
// 写入任务进度成功 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusCompleted,
Remark: "向量生成完成",
})
return
}