feat: 支持多租户多模型对话及文档去重优化
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"rag/common/eino"
|
||||
"rag/consts/document"
|
||||
"rag/consts/keyword"
|
||||
"rag/consts/model"
|
||||
"rag/consts/public"
|
||||
"rag/consts/task"
|
||||
"rag/dao"
|
||||
@@ -22,10 +23,8 @@ import (
|
||||
"github.com/bjang03/gmq/mq"
|
||||
"github.com/bjang03/gmq/types"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/crypto/gmd5"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/database/gredis"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/grpool"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
@@ -37,7 +36,35 @@ type documentService struct{}
|
||||
|
||||
// Create 创建文件
|
||||
func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq) (res *dto.CreateDocumentRes, err error) {
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
|
||||
err = gfdb.DB(ctx, public.DbNameKnowledge).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
|
||||
doc, err := dao.Document.Get(ctx, &dto.GetDocumentReq{
|
||||
DatasetId: req.DatasetId,
|
||||
Title: req.Title,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !g.IsEmpty(doc) && doc.Id > 0 {
|
||||
_, err = dao.Keyword.Delete(ctx, &dto.DeleteKeywordReq{
|
||||
DocumentId: doc.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dao.DocumentVector.Delete(ctx, &dto.DeleteDocumentVectorReq{
|
||||
DocumentId: doc.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dao.Document.Delete(ctx, &dto.DeleteDocumentReq{
|
||||
Id: doc.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var id int64
|
||||
id, err = dao.Document.Insert(ctx, req)
|
||||
if err != nil {
|
||||
@@ -74,11 +101,11 @@ func (s *documentService) Update(ctx context.Context, req *dto.UpdateDocumentReq
|
||||
|
||||
// Delete 删除文件
|
||||
func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (err error) {
|
||||
docs, err := dao.Document.GetByID(ctx, &dto.GetDocumentReq{Id: req.Id})
|
||||
docs, err := dao.Document.Get(ctx, &dto.GetDocumentReq{Id: req.Id})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
|
||||
err = gfdb.DB(ctx, public.DbNameKnowledge).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
|
||||
datasetReq := &dto.UpdateDatasetReq{
|
||||
Id: docs.DatasetId,
|
||||
DocumentCount: -1,
|
||||
@@ -92,6 +119,18 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = dao.Keyword.Delete(ctx, &dto.DeleteKeywordReq{
|
||||
DocumentId: docs.Id,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = dao.DocumentVector.Delete(ctx, &dto.DeleteDocumentVectorReq{
|
||||
DocumentId: docs.Id,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{
|
||||
TaskId: docs.Id,
|
||||
}); err != nil {
|
||||
@@ -106,7 +145,7 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq
|
||||
|
||||
// Get 获取文件详情
|
||||
func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.GetDocumentRes, err error) {
|
||||
r, err := dao.Document.GetByID(ctx, req)
|
||||
r, err := dao.Document.Get(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -136,7 +175,7 @@ func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (r
|
||||
func (s *documentService) Vector(ctx context.Context, req *dto.DocumentVectorReq) (err error) {
|
||||
// 1. 查询文件信息
|
||||
documentReq := dto.GetDocumentReq{Id: req.Id}
|
||||
doc, err := dao.Document.GetByID(ctx, &documentReq)
|
||||
doc, err := dao.Document.Get(ctx, &documentReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -172,16 +211,13 @@ func (s *documentService) Vector(ctx context.Context, req *dto.DocumentVectorReq
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// ======================
|
||||
// 核心:grpool + g.Try 最佳实践
|
||||
// ======================
|
||||
// 使用带超时的background context,避免HTTP请求完成后context被取消
|
||||
taskCtx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
taskCtx = context.WithValue(taskCtx, "user", user)
|
||||
// 任务1: SQL 切分文档
|
||||
// 任务1: 语义 切分文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.sqlSplitDocument(ctx, doc); innerErr != nil {
|
||||
if innerErr := s.semanticSplitDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
@@ -189,10 +225,10 @@ func (s *documentService) Vector(ctx context.Context, req *dto.DocumentVectorReq
|
||||
})
|
||||
})
|
||||
|
||||
// 任务2: ES 切分文档
|
||||
// 任务2: 递归 切分文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.esSplitDocument(ctx, doc); innerErr != nil {
|
||||
if innerErr := s.recursiveSplitDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
@@ -327,8 +363,8 @@ func (s *documentService) extractDocument(ctx context.Context, doc *entity.Docum
|
||||
return
|
||||
}
|
||||
|
||||
// sqlSplitDocument SQL切分(支持取消)
|
||||
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// semanticSplitDocument 语义切分
|
||||
func (s *documentService) semanticSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
@@ -354,7 +390,7 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
||||
}
|
||||
|
||||
// 2. 语义切分文件
|
||||
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
|
||||
docsSplit, err := eino.SemanticSplitDocument(ctx, docs, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
@@ -394,8 +430,8 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
||||
}
|
||||
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
|
||||
var isNew, needCopy bool
|
||||
isNew, needCopy, err = s.checkRepeatWithDocId(ctx, public.KnowledgeContentHashSqlKey, contentHash, doc.Id)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
@@ -406,7 +442,7 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
||||
})
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
if !isNew && !needCopy {
|
||||
continue
|
||||
}
|
||||
var metaData = make(map[string]any)
|
||||
@@ -415,7 +451,13 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
||||
metaData[entity.DocumentCol.DatasetId] = doc.DatasetId
|
||||
metaData[entity.DocumentVectorCol.DocumentId] = doc.Id
|
||||
metaData[entity.DocumentVectorCol.ContentHash] = contentHash
|
||||
metaData[entity.DocumentVectorCol.ChunkIndex] = gconv.Int64(i)
|
||||
metaData[entity.DocumentVectorCol.ChunkIndex] = gconv.Int64(i + 1)
|
||||
if isNew {
|
||||
metaData["isNew"] = true
|
||||
}
|
||||
if needCopy {
|
||||
metaData["isNew"] = false
|
||||
}
|
||||
t.MetaData = metaData
|
||||
docsChunk = append(docsChunk, t)
|
||||
}
|
||||
@@ -468,8 +510,8 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
||||
return
|
||||
}
|
||||
|
||||
// esSplitDocument ES切分(支持取消)
|
||||
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// recursiveSplitDocument 递归切分
|
||||
func (s *documentService) recursiveSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
@@ -535,8 +577,8 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum
|
||||
}
|
||||
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
|
||||
var isNew, needCopy bool
|
||||
isNew, needCopy, err = s.checkRepeatWithDocId(ctx, public.KnowledgeContentHashEsKey, contentHash, doc.Id)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
@@ -547,7 +589,7 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum
|
||||
})
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
if !isNew && !needCopy {
|
||||
continue
|
||||
}
|
||||
meiliDocs = append(meiliDocs, map[string]interface{}{
|
||||
@@ -556,7 +598,7 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum
|
||||
entity.DocumentVectorCol.DocumentId: doc.Id,
|
||||
entity.DocumentVectorCol.Content: t.Content,
|
||||
entity.DocumentVectorCol.ContentHash: contentHash,
|
||||
entity.DocumentVectorCol.ChunkIndex: i,
|
||||
entity.DocumentVectorCol.ChunkIndex: i + 1,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -632,6 +674,7 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
|
||||
|
||||
// 3. Redis 无数据:根据 contentKey 类型选择查询方式
|
||||
var dictData = make([]*dto.DocumentVectorRPC, 0)
|
||||
|
||||
if public.KnowledgeContentHashSqlKey == contentKey {
|
||||
// SQL 方式:调用 HTTP 接口查询
|
||||
dictData, err = s.getHistoryDataFromHttp(ctx, doc)
|
||||
@@ -643,20 +686,16 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. 把查询到的数据写入 Redis(600s过期)
|
||||
for _, item := range dictData {
|
||||
// 去除可能的 JSON 引号
|
||||
contentHash := strings.Trim(item.ContentHash, `"`)
|
||||
key := fmt.Sprintf(contentKey, contentHash)
|
||||
_, err = g.Redis().Set(ctx, key, true, gredis.SetOption{
|
||||
TTLOption: gredis.TTLOption{
|
||||
EX: gconv.PtrInt64(600),
|
||||
},
|
||||
NX: true,
|
||||
})
|
||||
// SAdd:把文档ID加入集合(自动去重,可存多个)
|
||||
_, err = g.Redis().SAdd(ctx, key, item.DocumentId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置过期时间
|
||||
_, _ = g.Redis().Expire(ctx, key, 600)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -672,8 +711,10 @@ func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entit
|
||||
// 调用接口获取数据
|
||||
res, _, err := dao.DocumentVector.List(ctx, &dto.ListDocumentVectorReq{
|
||||
DatasetId: doc.DatasetId,
|
||||
Status: gconv.PtrInt8(1),
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gconv.Struct(res, &dictData)
|
||||
return
|
||||
}
|
||||
@@ -705,17 +746,39 @@ func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc
|
||||
return
|
||||
}
|
||||
|
||||
// checkRepeat 检查是否重复
|
||||
func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHash string) (success bool, err error) {
|
||||
var val *gvar.Var
|
||||
if val, err = g.Redis().Set(ctx, fmt.Sprintf(contentKey, contentHash), true, gredis.SetOption{
|
||||
TTLOption: gredis.TTLOption{
|
||||
EX: gconv.PtrInt64(600),
|
||||
},
|
||||
NX: true,
|
||||
}); err != nil {
|
||||
return
|
||||
// checkRepeatWithDocId 正确版:检查当前文档是否已存在该分片
|
||||
// 返回:isNew(是否需要生成向量)、isCrossDoc(是否跨文档需拷贝)、err
|
||||
func (s *documentService) checkRepeatWithDocId(ctx context.Context, contentKey string, contentHash string, currentDocId int64) (isNew bool, needCopy bool, err error) {
|
||||
key := fmt.Sprintf(contentKey, contentHash)
|
||||
|
||||
// 1. 检查当前文档ID是否在集合中
|
||||
exists, err := g.Redis().SIsMember(ctx, key, currentDocId)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
success = val.Bool()
|
||||
return
|
||||
|
||||
// 情况1:当前文档已存在 → 完全跳过,不生成、不拷贝
|
||||
if !g.IsEmpty(exists) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
// 2. 检查 key 是否存在(是否有任何文档拥有该分片)
|
||||
keyExists, err := g.Redis().Exists(ctx, key)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
// 情况2:key 不存在 = 全新数据 → 需要生成向量
|
||||
if g.IsEmpty(keyExists) {
|
||||
// 把当前文档ID加入集合
|
||||
_, err = g.Redis().SAdd(ctx, key, currentDocId)
|
||||
_, _ = g.Redis().Expire(ctx, key, 600)
|
||||
return true, false, err
|
||||
}
|
||||
|
||||
// 情况3:key 存在,但当前文档不在集合中 = 跨文档重复 → 不生成,需拷贝
|
||||
// 把当前文档ID加入集合(记录归属关系)
|
||||
_, err = g.Redis().SAdd(ctx, key, currentDocId)
|
||||
_, _ = g.Redis().Expire(ctx, key, 600)
|
||||
return false, true, err
|
||||
}
|
||||
|
||||
@@ -4,17 +4,18 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/common/eino"
|
||||
"rag/consts/model"
|
||||
"rag/consts/task"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
)
|
||||
|
||||
var DocumentVector = new(documentVectorService)
|
||||
@@ -23,23 +24,32 @@ type documentVectorService struct{}
|
||||
|
||||
// Query 执行RAG查询
|
||||
func (s *documentVectorService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
|
||||
if req.TopK <= 0 {
|
||||
req.TopK = 5
|
||||
|
||||
modelInfo, err := dao.Model.Get(ctx, &dto.GetModelReq{
|
||||
ModelType: model.ModelTypeChat.Code(),
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "获取模型失败: %v", err)
|
||||
return nil, fmt.Errorf("获取模型失败: %w", err)
|
||||
}
|
||||
if modelInfo == nil {
|
||||
g.Log().Errorf(ctx, "模型不存在: %v", model.ModelTypeChat.Code())
|
||||
return nil, fmt.Errorf("模型不存在: %w", err)
|
||||
}
|
||||
|
||||
// 4. 使用向量检索器进行查询
|
||||
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
|
||||
Embedder: eino.EmbedderDashscope,
|
||||
r, err := eino.NewPGVectorRetriever(ctx, &eino.PGVectorRetrieverConfig{
|
||||
DefaultTopK: req.TopK,
|
||||
})
|
||||
}, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "初始化向量检索器失败: %v", err)
|
||||
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 执行向量检索
|
||||
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
|
||||
"dataset_ids": req.DatasetIds,
|
||||
docs, err := r.Retrieve(ctx, req.Content, retriever.WithDSLInfo(map[string]any{
|
||||
"dataset_ids": req.DatasetIds,
|
||||
"document_ids": req.DocumentIds,
|
||||
}))
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "向量检索失败: %v", err)
|
||||
@@ -53,7 +63,7 @@ func (s *documentVectorService) Query(ctx context.Context, req *dto.RAGQueryReq)
|
||||
return nil, fmt.Errorf("转换历史消息失败: %w", err)
|
||||
}
|
||||
|
||||
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages)
|
||||
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages, modelInfo.ConfigType)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "向量检索失败: %v", err)
|
||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||
@@ -98,26 +108,108 @@ func (s *documentVectorService) DocsChunkMsg(ctx context.Context, msg any) (err
|
||||
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentVectorCol.TenantId]),
|
||||
UserName: gconv.String(docs[0].MetaData[entity.DocumentVectorCol.Creator]),
|
||||
})
|
||||
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||||
BatchSize: 10,
|
||||
})
|
||||
|
||||
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentVectorCol.DocumentId])
|
||||
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
|
||||
if err != nil || rows == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
remark := " 向量存储数量: " + gconv.String(rows)
|
||||
if err != nil {
|
||||
remark = "向量存储失败: " + err.Error()
|
||||
|
||||
var docsStore = make([]*schema.Document, 0)
|
||||
var docsInsert = make([]*dto.VectorDocumentVectorMsg, 0)
|
||||
for _, doc := range docs {
|
||||
if gconv.Bool(doc.MetaData["isNew"]) {
|
||||
docsStore = append(docsStore, doc)
|
||||
} else {
|
||||
ck := new(dto.VectorDocumentVectorMsg)
|
||||
err = gconv.Struct(doc.MetaData, ck)
|
||||
ck.Content = doc.Content
|
||||
ck.VectorStatus = gconv.PtrInt8(1)
|
||||
ck.Status = gconv.PtrInt8(1)
|
||||
docsInsert = append(docsInsert, ck)
|
||||
}
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: remark,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !g.IsEmpty(docsStore) {
|
||||
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||||
BatchSize: 10,
|
||||
})
|
||||
var rows int64
|
||||
rows, err = idx.Store(ctx, docsStore, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
|
||||
|
||||
if err != nil || rows == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
remark := " 向量存储数量: " + gconv.String(rows)
|
||||
if err != nil {
|
||||
remark = "向量存储失败: " + err.Error()
|
||||
}
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: remark,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !g.IsEmpty(docsInsert) {
|
||||
// 1. 提取所有 contentHash
|
||||
contentHashs := make([]string, 0, len(docsInsert))
|
||||
for _, d := range docsInsert {
|
||||
contentHashs = append(contentHashs, d.ContentHash)
|
||||
}
|
||||
|
||||
// 2. 分页查询已存在的向量(一页1000,避免大查询)
|
||||
var existVectors []*entity.DocumentVector
|
||||
for page := 1; ; page++ {
|
||||
res, total, err := dao.DocumentVector.List(ctx, &dto.ListDocumentVectorReq{
|
||||
Page: &beans.Page{PageSize: 1000, PageNum: int64(page)},
|
||||
ContentHashs: contentHashs,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
break
|
||||
}
|
||||
existVectors = append(existVectors, res...)
|
||||
if len(existVectors) >= total {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 构建哈希 -> 向量 的映射表(O(1) 查找,性能提升巨大)
|
||||
vectorMap := make(map[string]pgvector.Vector, len(existVectors))
|
||||
for _, v := range existVectors {
|
||||
vectorMap[v.ContentHash] = v.Vector
|
||||
}
|
||||
|
||||
// 4. 回填向量 + 过滤掉数据库已存在的数据(避免重复插入)
|
||||
for _, d := range docsInsert {
|
||||
// 回填已有向量
|
||||
if vec, ok := vectorMap[d.ContentHash]; ok {
|
||||
d.Vector = vec
|
||||
}
|
||||
}
|
||||
|
||||
var rows int64
|
||||
rows, err = dao.DocumentVector.BatchInsert(ctx, docsInsert)
|
||||
|
||||
if err != nil || rows == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
remark := " 向量存储数量: " + gconv.String(rows)
|
||||
if err != nil {
|
||||
remark = "向量存储失败: " + err.Error()
|
||||
}
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: remark,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 写入任务进度成功 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
|
||||
299
service/model.go
Normal file
299
service/model.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/common/eino"
|
||||
"rag/consts/model"
|
||||
"rag/consts/task"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var ModelService = new(modelService)
|
||||
|
||||
type modelService struct{}
|
||||
|
||||
// GetModelAllEnums 获取模型全量枚举(模型类型 + 配置类型 合并)
|
||||
func (s *modelService) GetModelAllEnums(ctx context.Context, req *dto.GetModelAllEnumsReq) (res *dto.GetModelEnumRes, err error) {
|
||||
_, _ = ctx, req
|
||||
res = new(dto.GetModelEnumRes)
|
||||
|
||||
// 获取所有模型类型
|
||||
modelTypeRes := model.GetAllModelTypeEnums()
|
||||
|
||||
var options []dto.ModelEnumOption
|
||||
for _, mt := range modelTypeRes.Options {
|
||||
// 构造 modelType
|
||||
modelTypeStr := gconv.String(mt.Key)
|
||||
modelType := model.ModelType(gconv.PtrString(modelTypeStr))
|
||||
|
||||
// 获取对应配置类型
|
||||
configRes := model.GetAllModelConfigTypeEnums(modelType)
|
||||
|
||||
// 把 configRes.Options 转成目标类型
|
||||
var configList []dto.ModelKeyValue
|
||||
err = gconv.Structs(configRes.Options, &configList)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
options = append(options, dto.ModelEnumOption{
|
||||
Key: mt.Key,
|
||||
Value: mt.Value,
|
||||
ConfigTypes: configList,
|
||||
})
|
||||
}
|
||||
|
||||
res.Options = options
|
||||
return
|
||||
}
|
||||
|
||||
func (s *modelService) GetModelConfigFormFields(ctx context.Context, req *dto.GetModelConfigFormFieldsReq) (*dto.GetModelConfigFormFieldsRes, error) {
|
||||
_ = ctx
|
||||
|
||||
fields := make([]map[string]interface{}, 0)
|
||||
|
||||
// ===================== 固定基础字段(CreateModelReq 前4个)=====================
|
||||
// 1. 模型类型:固定只读字段
|
||||
fields = append(fields, map[string]interface{}{
|
||||
"name": "modelType",
|
||||
"label": "模型类型",
|
||||
"type": "text",
|
||||
"disabled": true,
|
||||
"required": true,
|
||||
"value": model.GetModelTypeDescByCode(req.ModelType),
|
||||
})
|
||||
|
||||
var configTypeValue = "未知类型"
|
||||
if *req.ModelType == *model.ModelTypeVector.Code() {
|
||||
configTypeValue = model.GetVectorDescByCode(req.ConfigType)
|
||||
} else if *req.ModelType == *model.ModelTypeChat.Code() {
|
||||
configTypeValue = model.GetChatDescByCode(req.ConfigType)
|
||||
}
|
||||
|
||||
// 2. 配置类型:固定只读字段
|
||||
fields = append(fields, map[string]interface{}{
|
||||
"name": "configType",
|
||||
"label": "配置类型",
|
||||
"type": "text",
|
||||
"disabled": true,
|
||||
"required": true,
|
||||
"value": configTypeValue,
|
||||
})
|
||||
|
||||
// 3. 基础信息
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{
|
||||
"name": "modelName",
|
||||
"label": "模型名称",
|
||||
"type": "input",
|
||||
"required": true,
|
||||
"placeholder": "例如:DeepSeek 对话模型",
|
||||
},
|
||||
{
|
||||
"name": "modelDesc",
|
||||
"label": "模型描述",
|
||||
"type": "textarea",
|
||||
"required": false,
|
||||
},
|
||||
}...)
|
||||
|
||||
// 4. 通用模型名称字段
|
||||
fields = append(fields, map[string]interface{}{
|
||||
"name": "model",
|
||||
"label": "模型类型",
|
||||
"type": "input",
|
||||
"required": true,
|
||||
"placeholder": "例如:deepseek-chat / text-embedding-3-small",
|
||||
})
|
||||
|
||||
// ===================== 动态配置内容 ConfigContent =====================
|
||||
|
||||
// 根据模型类型 + 配置类型生成动态字段
|
||||
switch *req.ModelType {
|
||||
case *model.ModelTypeChat.Code():
|
||||
switch *req.ConfigType {
|
||||
case *model.ModelConfigTypeChatArk.Code():
|
||||
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
|
||||
|
||||
case *model.ModelConfigTypeChatArkBot.Code():
|
||||
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
|
||||
|
||||
case *model.ModelConfigTypeChatClaude.Code():
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{"name": "by_bedrock", "label": "使用 AWS Bedrock", "type": "switch", "default": true},
|
||||
{"name": "access_key", "label": "Access Key", "type": "input"},
|
||||
{"name": "secret_access_key", "label": "Secret Access Key", "type": "input"},
|
||||
{"name": "region", "label": "Region", "type": "input"},
|
||||
{"name": "api_key", "label": "API Key", "type": "input"},
|
||||
{"name": "base_url", "label": "Base URL", "type": "input"},
|
||||
}...)
|
||||
|
||||
case *model.ModelConfigTypeChatDeepSeek.Code():
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||||
{"name": "base_url", "label": "Base URL", "type": "input", "default": "https://api.deepseek.com"},
|
||||
}...)
|
||||
|
||||
case *model.ModelConfigTypeChatOllama.Code():
|
||||
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
|
||||
|
||||
case *model.ModelConfigTypeChatOpenAI.Code():
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||||
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
|
||||
{"name": "base_url", "label": "Base URL", "type": "input"},
|
||||
{"name": "api_version", "label": "API Version", "type": "input"},
|
||||
}...)
|
||||
|
||||
case *model.ModelConfigTypeChatQianfan.Code():
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
|
||||
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
|
||||
}...)
|
||||
|
||||
case *model.ModelConfigTypeChatQwen.Code():
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||||
{"name": "base_url", "label": "Base URL", "type": "input"},
|
||||
}...)
|
||||
}
|
||||
|
||||
case *model.ModelTypeVector.Code():
|
||||
switch *req.ConfigType {
|
||||
case *model.ModelConfigTypeVectorArk.Code():
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||||
{"name": "api_type", "label": "API Type", "type": "input"},
|
||||
}...)
|
||||
|
||||
case *model.ModelConfigTypeVectorOllama.Code():
|
||||
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
|
||||
|
||||
case *model.ModelConfigTypeVectorOpenAI.Code():
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||||
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
|
||||
{"name": "base_url", "label": "Base URL", "type": "input"},
|
||||
{"name": "api_version", "label": "API Version", "type": "input"},
|
||||
}...)
|
||||
|
||||
case *model.ModelConfigTypeVectorQianfan.Code():
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
|
||||
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
|
||||
}...)
|
||||
|
||||
case *model.ModelConfigTypeVectorTencentCloud.Code():
|
||||
fields = append(fields, []map[string]interface{}{
|
||||
{"name": "secret_id", "label": "Secret ID", "type": "input", "required": true},
|
||||
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
|
||||
{"name": "region", "label": "Region", "type": "input", "required": true, "default": "ap-beijing"},
|
||||
}...)
|
||||
case *model.ModelConfigTypeVectorDashScope.Code():
|
||||
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.GetModelConfigFormFieldsRes{
|
||||
ModelType: req.ModelType,
|
||||
ConfigType: req.ConfigType,
|
||||
Fields: fields,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
count, err := dao.Model.Count(ctx, &dto.GetModelReq{
|
||||
ModelType: req.ModelType,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if count > 0 {
|
||||
err = gerror.New("模型配置已存在")
|
||||
return
|
||||
}
|
||||
var id int64
|
||||
id, err = dao.Model.Insert(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res = &dto.CreateModelRes{Id: id}
|
||||
err = s.refresh(ctx, id)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) (err error) {
|
||||
count, err := dao.Task.Count(ctx, &dto.GetTaskReq{
|
||||
TaskStatus: task.TaskStatusRunning,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !g.IsEmpty(count) {
|
||||
err = gerror.New("任务正在执行中,模型配置暂时不可修改,请稍后再试")
|
||||
return
|
||||
}
|
||||
var updateCount int64
|
||||
updateCount, err = dao.Model.Update(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !g.IsEmpty(updateCount) {
|
||||
err = s.refresh(ctx, req.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *modelService) refresh(ctx context.Context, id int64) (err error) {
|
||||
var modelDO *entity.Model
|
||||
modelDO, err = dao.Model.Get(ctx, &dto.GetModelReq{
|
||||
Id: id,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *modelDO.ModelType == *model.ModelTypeChat.Code() {
|
||||
if err = eino.RefreshTenantChatModel(ctx, modelDO); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if *modelDO.ModelType == *model.ModelTypeVector.Code() {
|
||||
if err = eino.RefreshTenantEmbedder(ctx, modelDO); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) (err error) {
|
||||
_, err = dao.Model.Delete(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (res *dto.ModelVO, err error) {
|
||||
r, err := dao.Model.Get(ctx, req)
|
||||
err = gconv.Struct(r, &res)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||
list, total, err := dao.Model.List(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res = &dto.ListModelRes{
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
return
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/consts/document"
|
||||
"rag/consts/public"
|
||||
"rag/consts/task"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
@@ -37,7 +39,7 @@ func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskP
|
||||
TaskType: req.TaskType,
|
||||
Status: req.Status,
|
||||
})
|
||||
completed = IsAllSubTasksCompleted(taskVO)
|
||||
completed = IsAllSubTasks(taskVO, task.TaskStatusCompleted)
|
||||
}
|
||||
|
||||
// 1. 查询是否已存在该文档的该类型任务
|
||||
@@ -49,7 +51,7 @@ func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskP
|
||||
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
err = gfdb.DB(ctx, public.DbNameKnowledge).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
// 2. 如果不存在,则创建新任务
|
||||
if g.IsEmpty(existTask) {
|
||||
createReq := &dto.CreateTaskReq{
|
||||
@@ -80,17 +82,36 @@ func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskP
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "文档解析完成",
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "更新任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
|
||||
Id: req.TaskId,
|
||||
VectorStatus: document.VectorStatusCompleted.Code(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if task.TaskStatusFailed == req.Status {
|
||||
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
|
||||
Id: req.TaskId,
|
||||
VectorStatus: document.VectorStatusFailed.Code(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsAllSubTasksCompleted 判断三个子任务是否全部完成
|
||||
// 参数:传入当前文档的所有子任务列表
|
||||
func IsAllSubTasksCompleted(subTasks []*dto.TaskVO) bool {
|
||||
// IsAllSubTasks 判断三个子任务
|
||||
func IsAllSubTasks(subTasks []*dto.TaskVO, taskStatus task.TaskStatus) bool {
|
||||
// 必须包含 3 种任务类型
|
||||
hasKeywords := false
|
||||
hasVector := false
|
||||
@@ -98,7 +119,7 @@ func IsAllSubTasksCompleted(subTasks []*dto.TaskVO) bool {
|
||||
|
||||
for _, t := range subTasks {
|
||||
// 子任务必须是【已完成】状态才计数
|
||||
if t.Status == task.TaskStatusCompleted {
|
||||
if t.Status == taskStatus {
|
||||
switch t.TaskType {
|
||||
case task.TaskTypeExtractKeywords:
|
||||
hasKeywords = true
|
||||
@@ -113,3 +134,15 @@ func IsAllSubTasksCompleted(subTasks []*dto.TaskVO) bool {
|
||||
// 三个任务全部完成 → 返回true
|
||||
return hasKeywords && hasVector && hasFullText
|
||||
}
|
||||
|
||||
func (s *taskService) Get(ctx context.Context, req *dto.GetTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
list, total, err := dao.Task.Get(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res = &dto.ListTaskRes{
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user