feat: rag初始版
This commit is contained in:
87
service/dataset.go
Normal file
87
service/dataset.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Dataset = new(datasetService)
|
||||
|
||||
type datasetService struct{}
|
||||
|
||||
// Create 创建数据集
|
||||
func (s *datasetService) Create(ctx context.Context, req *dto.CreateDatasetReq) (res *dto.CreateDatasetRes, err error) {
|
||||
id, err := dao.Dataset.Insert(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return &dto.CreateDatasetRes{Id: id}, nil
|
||||
}
|
||||
|
||||
// Update 更新数据集
|
||||
func (s *datasetService) Update(ctx context.Context, req *dto.UpdateDatasetReq) (err error) {
|
||||
_, err = dao.Dataset.Update(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete 删除数据集
|
||||
func (s *datasetService) Delete(ctx context.Context, req *dto.DeleteDatasetReq) (err error) {
|
||||
_, err = dao.Dataset.Delete(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// List 数据集列表
|
||||
func (s *datasetService) List(ctx context.Context, req *dto.ListDatasetReq) (res *dto.ListDatasetRes, err error) {
|
||||
list, total, err := dao.Dataset.List(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res = &dto.ListDatasetRes{
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
return
|
||||
}
|
||||
|
||||
//// Search 搜索(示例,实际需要调用向量库)
|
||||
//func (s *datasetService) Search(ctx context.Context, req *dto.SearchReq) (res *dto.SearchRes, err error) {
|
||||
// // 1. 获取数据集信息
|
||||
// kb, err := dao.Dataset.GetByID(ctx, req)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// // 2. 获取文件块
|
||||
// chunks, err := dao.Chunk.FindChunksByKBIDWithLimit(ctx, req.KBID, 0, req.TopK)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// // 3. TODO: 使用向量检索(需要集成向量库)
|
||||
// // 暂时使用简单的关键词匹配
|
||||
// results := make([]dto.SearchResult, 0)
|
||||
// for _, chunk := range chunks {
|
||||
// results = append(results, dto.SearchResult{
|
||||
// Content: chunk.Content,
|
||||
// Score: 0.8, // TODO: 计算实际向量相似度
|
||||
// DocumentID: chunk.DocumentID,
|
||||
// ChunkIndex: chunk.Index,
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// g.Log().Infof(ctx, "数据集[%s]搜索完成,查询:%s,结果数:%d", kb.Name, req.Query, len(results))
|
||||
//
|
||||
// return &dto.SearchRes{Results: results}, nil
|
||||
//}
|
||||
//
|
||||
//// formatChunks 格式化文件块为上下文
|
||||
//func (s *datasetService) formatChunks(chunks []*entity.DocumentChunk) string {
|
||||
// var sb strings.Builder
|
||||
// for i, chunk := range chunks {
|
||||
// sb.WriteString(fmt.Sprintf("[%d] %s\n\n", i+1, chunk.Content))
|
||||
// }
|
||||
// return sb.String()
|
||||
//}
|
||||
5
service/dataset_index.go
Normal file
5
service/dataset_index.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package service
|
||||
|
||||
var DatasetIndex = new(datasetIndexService)
|
||||
|
||||
type datasetIndexService struct{}
|
||||
483
service/document.go
Normal file
483
service/document.go
Normal file
@@ -0,0 +1,483 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/consts/document"
|
||||
"rag/consts/public"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/rag/eino"
|
||||
"gitea.com/red-future/common/rag/gse"
|
||||
"gitea.com/red-future/common/utils"
|
||||
gmq "github.com/bjang03/gmq/core/gmq"
|
||||
"github.com/bjang03/gmq/mq"
|
||||
"github.com/bjang03/gmq/types"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/crypto/gmd5"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/database/gredis"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Document = new(documentService)
|
||||
|
||||
type documentService struct{}
|
||||
|
||||
// Create 创建文件
|
||||
func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq) (res *dto.CreateDocumentRes, err error) {
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
|
||||
var id int64
|
||||
id, err = dao.Document.Insert(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
datasetReq := &dto.UpdateDatasetReq{
|
||||
Id: req.DatasetId,
|
||||
DocumentCount: 1,
|
||||
DocumentSize: req.FileSize,
|
||||
}
|
||||
_, err = dao.Dataset.Update(ctx, datasetReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res = &dto.CreateDocumentRes{Id: id}
|
||||
|
||||
return
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Update 更新文件
|
||||
func (s *documentService) Update(ctx context.Context, req *dto.UpdateDocumentReq) (err error) {
|
||||
_, err = dao.Document.Update(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete 删除文件
|
||||
func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (err error) {
|
||||
docs, err := dao.Document.GetByID(ctx, &dto.GetDocumentReq{Id: req.Id})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
|
||||
datasetReq := &dto.UpdateDatasetReq{
|
||||
Id: docs.DatasetId,
|
||||
DocumentCount: -1,
|
||||
DocumentSize: -docs.FileSize,
|
||||
}
|
||||
_, err = dao.Dataset.Update(ctx, datasetReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = dao.Document.Delete(ctx, req)
|
||||
return
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Get 获取文件详情
|
||||
func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) {
|
||||
r, err := dao.Document.GetByID(ctx, req)
|
||||
err = gconv.Struct(r, &res)
|
||||
return
|
||||
}
|
||||
|
||||
// List 文件列表
|
||||
func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto.ListDocumentRes, err error) {
|
||||
list, total, err := dao.Document.List(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res = &dto.ListDocumentRes{
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
|
||||
//eino.TestIndexer()
|
||||
//eino.TestRetriever()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Process 处理文件(使用eino框架切分和向量化)
|
||||
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. 查询文件信息
|
||||
documentReq := dto.GetDocumentReq{Id: req.Id}
|
||||
doc, err := dao.Document.GetByID(ctx, &documentReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 使用eino框架进行文件切分(并发执行)
|
||||
var vectorDocsCount, chunks int64
|
||||
// 用 gopool 或者简单的错误等待,绝对不用裸 goroutine
|
||||
var err1, err2, err3 error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(3)
|
||||
|
||||
// 任务1
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 任务2
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err2 = s.esSplitDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 任务3
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err3 = s.extractDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 直接等待,不使用通道,避免泄漏
|
||||
wg.Wait()
|
||||
|
||||
updateDocumentReq := new(dto.UpdateDocumentReq)
|
||||
updateDocumentReq.Id = req.Id
|
||||
|
||||
// 统一判断错误
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
// 更新文档状态
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code()
|
||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
return nil, err3
|
||||
}
|
||||
|
||||
// 4. 更新文件状态为处理中和切分数量
|
||||
if vectorDocsCount > 0 {
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
||||
} else {
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code()
|
||||
}
|
||||
updateDocumentReq.ChunkCount = chunks
|
||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
costTime := time.Since(startTime).Milliseconds()
|
||||
|
||||
return &dto.ProcessDocumentRes{
|
||||
ChunkCount: chunks,
|
||||
CostTime: costTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var words []gse.Keyword
|
||||
if len(docs[0].Content) < 500 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 4)
|
||||
} else if len(docs[0].Content) < 2000 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 8)
|
||||
} else if len(docs[0].Content) < 5000 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 13)
|
||||
} else {
|
||||
var docsSplit []*schema.Document
|
||||
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, t := range docsSplit {
|
||||
words = append(words, gse.GseTool.Extract(t.Content, 6)...)
|
||||
}
|
||||
}
|
||||
|
||||
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
|
||||
for _, word := range words {
|
||||
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
|
||||
DatasetId: doc.DatasetId,
|
||||
DocumentId: doc.Id,
|
||||
Word: word.Word,
|
||||
Weight: gconv.Int16(word.Score),
|
||||
})
|
||||
}
|
||||
if len(keywordReqs) > 0 {
|
||||
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) {
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 2. 语义切分文件
|
||||
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
docsSplitCount = gconv.Int64(len(docsSplit))
|
||||
// 2. 获取历史数据
|
||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 3. 组装向量文档
|
||||
var vectorDocs = make([]dto.VectorDocumentChunkMsg, 0)
|
||||
for i, t := range docsSplit {
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
// 检查是否重复
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
continue
|
||||
}
|
||||
vectorDocs = append(vectorDocs, dto.VectorDocumentChunkMsg{
|
||||
TenantId: doc.TenantId,
|
||||
Creator: doc.Creator,
|
||||
DatasetId: doc.DatasetId,
|
||||
DocumentId: doc.Id,
|
||||
Content: t.Content,
|
||||
ContentHash: contentHash,
|
||||
ChunkIndex: gconv.Int64(i),
|
||||
})
|
||||
|
||||
}
|
||||
// 4. 发送消息到队列
|
||||
if len(vectorDocs) > 0 {
|
||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||
PubMessage: types.PubMessage{
|
||||
Topic: public.KnowledgeDocumentChunkTopic,
|
||||
Data: vectorDocs,
|
||||
},
|
||||
})
|
||||
}
|
||||
vectorDocsCount = gconv.Int64(len(vectorDocs))
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 2. 递归切分文件
|
||||
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 2. 获取历史数据
|
||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 3. 组装向量文档并同时构建meilisearch文档
|
||||
var meiliDocs = make([]interface{}, 0)
|
||||
for i, t := range docsSplit {
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
// 检查是否重复
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
continue
|
||||
}
|
||||
// 构建Meilisearch文档
|
||||
meiliDocs = append(meiliDocs, map[string]interface{}{
|
||||
"id": contentHash,
|
||||
"datasetId": doc.DatasetId,
|
||||
"documentId": doc.Id,
|
||||
"content": t.Content,
|
||||
"contentHash": contentHash,
|
||||
"chunkIndex": i,
|
||||
})
|
||||
}
|
||||
// 4. 写入到meilisearch数据库中
|
||||
if len(meiliDocs) > 0 {
|
||||
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
|
||||
g.Log().Errorf(ctx, "写入meilisearch失败: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// loadDocument 加载文件
|
||||
func (s *documentService) loadDocument(ctx context.Context, doc *entity.Document) (docs []*schema.Document, err error) {
|
||||
return eino.LoadDocument(ctx, doc.FilePath, doc.Format)
|
||||
}
|
||||
|
||||
// getHistoryData 获取历史数据
|
||||
func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Document, lockKey, contentKey string) (err error) {
|
||||
docsLockKey := fmt.Sprintf(lockKey, doc.DatasetId)
|
||||
success, err := utils.Lock(ctx, docsLockKey, int64(60), func(ctx context.Context) error {
|
||||
// 1. 扫描 Redis 中所有 前缀为 rag:knowledge:xxx:contentHash 的 key
|
||||
pattern := fmt.Sprintf(contentKey, "*")
|
||||
keys, err := g.Redis().Keys(ctx, pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Redis 有数据:只刷新过期时间,不查库
|
||||
if len(keys) > 0 {
|
||||
// 批量刷新过期时间为 60s
|
||||
for _, key := range keys {
|
||||
_, err = g.Redis().Expire(ctx, key, 600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 3. Redis 无数据:根据 contentKey 类型选择查询方式
|
||||
var dictData = make([]*dto.DocumentChunkRPC, 0)
|
||||
if public.KnowledgeContentHashSqlKey == contentKey {
|
||||
// SQL 方式:调用 HTTP 接口查询
|
||||
dictData, err = s.getHistoryDataFromHttp(ctx, doc)
|
||||
} else {
|
||||
// ES 方式:查询 meilisearch
|
||||
dictData, err = s.getHistoryDataFromMeilisearch(ctx, doc)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. 把查询到的数据写入 Redis(600s过期)
|
||||
for _, item := range dictData {
|
||||
// 去除可能的 JSON 引号
|
||||
contentHash := strings.Trim(item.ContentHash, `"`)
|
||||
key := fmt.Sprintf(contentKey, contentHash)
|
||||
_, err = g.Redis().Set(ctx, key, true, gredis.SetOption{
|
||||
TTLOption: gredis.TTLOption{
|
||||
EX: gconv.PtrInt64(600),
|
||||
},
|
||||
NX: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil && !success {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
|
||||
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
|
||||
headers := make(map[string]string)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
for k, v := range r.Request.Header {
|
||||
if len(v) > 0 {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 调用接口获取数据
|
||||
d := &dto.ListDocumentChunkRPC{}
|
||||
if err = http.Get(ctx, "rag-vector/document/chunk/listDocumentChunk", headers, &d,
|
||||
"datasetId", gconv.String(doc.DatasetId),
|
||||
"status", 1); err != nil {
|
||||
return
|
||||
}
|
||||
dictData = d.List
|
||||
return
|
||||
}
|
||||
|
||||
// getHistoryDataFromMeilisearch 通过 meilisearch 查询历史数据
|
||||
func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
|
||||
// 构建 meilisearch 查询参数
|
||||
searchParams := &meilisearch.SearchParams{
|
||||
Filter: fmt.Sprintf("datasetId = %d", doc.DatasetId),
|
||||
Limit: 10000,
|
||||
}
|
||||
|
||||
// 执行搜索
|
||||
var hits []map[string]interface{}
|
||||
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 转换查询结果
|
||||
dictData = make([]*dto.DocumentChunkRPC, 0)
|
||||
for _, hit := range hits {
|
||||
item := &dto.DocumentChunkRPC{}
|
||||
if err = gconv.Struct(hit, item); err != nil {
|
||||
return
|
||||
}
|
||||
dictData = append(dictData, item)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// checkRepeat 检查是否重复
|
||||
func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHash string) (success bool, err error) {
|
||||
var val *gvar.Var
|
||||
if val, err = g.Redis().Set(ctx, fmt.Sprintf(contentKey, contentHash), true, gredis.SetOption{
|
||||
TTLOption: gredis.TTLOption{
|
||||
EX: gconv.PtrInt64(600),
|
||||
},
|
||||
NX: true,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
success = val.Bool()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) {
|
||||
var req = new(dto.KnowledgeDocumentMsg)
|
||||
if err = gconv.Struct(msg, &req); err != nil {
|
||||
g.Log().Error(ctx, "DocsVectorStatusMsg err:", err)
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||
TenantId: req.TenantId,
|
||||
UserName: req.Creator,
|
||||
})
|
||||
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
|
||||
Id: req.Id,
|
||||
VectorStatus: req.VectorStatus,
|
||||
})
|
||||
return
|
||||
}
|
||||
176
service/document_chunk.go
Normal file
176
service/document_chunk.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"rag/consts/document"
|
||||
"rag/consts/public"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/rag/eino"
|
||||
gmq "github.com/bjang03/gmq/core/gmq"
|
||||
"github.com/bjang03/gmq/mq"
|
||||
"github.com/bjang03/gmq/types"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
)
|
||||
|
||||
var DocumentChunk = new(documentChunkService)
|
||||
|
||||
type documentChunkService struct{}
|
||||
|
||||
const (
|
||||
DatasetIndexStatusReady = "ready"
|
||||
)
|
||||
|
||||
// Update 更新文件块
|
||||
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
|
||||
_, err = dao.DocumentChunk.Update(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// List 获取文件块列表
|
||||
func (s *documentChunkService) List(ctx context.Context, req *dto.ListDocumentChunkReq) (res *dto.ListDocumentChunkRes, err error) {
|
||||
list, total, err := dao.DocumentChunk.List(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res = &dto.ListDocumentChunkRes{
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
|
||||
var req = make([]*dto.VectorDocumentChunkMsg, 0)
|
||||
msgMap := gconv.Map(msg)
|
||||
if err = gconv.Structs(msgMap["data"], &req); err != nil {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||
return
|
||||
}
|
||||
if len(req) == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||
TenantId: req[0].TenantId,
|
||||
UserName: req[0].Creator,
|
||||
})
|
||||
|
||||
// 调用eino接口获取向量
|
||||
var vectorDocsStr = make([]string, 0, len(req))
|
||||
for _, t := range req {
|
||||
vectorDocsStr = append(vectorDocsStr, t.Content)
|
||||
}
|
||||
embeddings, err := eino.EmbedStrings(ctx, vectorDocsStr)
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取向量维度
|
||||
dimension := 0
|
||||
if len(embeddings) > 0 {
|
||||
dimension = len(embeddings[0])
|
||||
}
|
||||
|
||||
// 创建或更新DatasetIndex
|
||||
err = s.createOrUpdateDatasetIndex(ctx, req[0].DatasetId, dimension, int64(len(req)))
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "CreateOrUpdateDatasetIndex err:", err)
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新向量文档
|
||||
for i, embedding := range embeddings {
|
||||
req[i].Vector = pgvector.NewVector(gconv.Float32s(embedding))
|
||||
req[i].VectorStatus = document.VectorStatusCompleted.Code()
|
||||
req[i].Status = document.StatusEnable.Code()
|
||||
}
|
||||
_, err = dao.DocumentChunk.BatchInsert(ctx, req)
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
||||
return
|
||||
}
|
||||
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusCompleted.Code())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// createOrUpdateDatasetIndex 创建或更新数据集索引
|
||||
func (s *documentChunkService) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) (err error) {
|
||||
// 查询数据集是否已有索引
|
||||
existIndex, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
// 已有索引 → 只更新数量
|
||||
if existIndex != nil {
|
||||
_ = dao.DatasetIndex.IncVectorCount(ctx, existIndex.Id, vectorCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ====================== 创建新索引 ======================
|
||||
indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId) // 真实PG索引名
|
||||
// 1. 插入索引配置
|
||||
index := &entity.DatasetIndex{
|
||||
DatasetId: datasetId,
|
||||
Name: indexName,
|
||||
Dimension: dimension,
|
||||
FieldType: "float",
|
||||
MetricType: "COSINE",
|
||||
Status: gconv.PtrInt8(1),
|
||||
VectorCount: vectorCount,
|
||||
Description: fmt.Sprintf("数据集%d向量索引", datasetId),
|
||||
}
|
||||
_, err = dao.DatasetIndex.Insert(ctx, index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. 真正创建 PGVector 索引(唯一真实索引!)
|
||||
err = s.createRealPGVectorIndex(ctx, indexName)
|
||||
return err
|
||||
}
|
||||
|
||||
// createRealPGVectorIndex 真正在PostgreSQL创建向量索引(真实可用)
|
||||
func (s *documentChunkService) createRealPGVectorIndex(ctx context.Context, indexName string) error {
|
||||
// 执行真实建索引语句
|
||||
err := dao.DatasetIndex.InsertIndex(ctx, indexName)
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "创建向量索引失败:", err)
|
||||
return err
|
||||
}
|
||||
g.Log().Info(ctx, "PGVector真实索引创建成功:"+indexName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishKnowledgeDocumentMsg 发布消息
|
||||
func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {
|
||||
knowledgeDocumentMsg := dto.KnowledgeDocumentMsg{
|
||||
TenantId: tenantId,
|
||||
Creator: creator,
|
||||
Id: documentId,
|
||||
VectorStatus: vectorStatus,
|
||||
}
|
||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||
PubMessage: types.PubMessage{
|
||||
Topic: public.KnowledgeDocumentVectorStatusTopic,
|
||||
Data: knowledgeDocumentMsg,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
65
service/keyword.go
Normal file
65
service/keyword.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Keyword = new(keywordService)
|
||||
|
||||
type keywordService struct{}
|
||||
|
||||
func (s *keywordService) Create(ctx context.Context, req *dto.CreateKeywordReq) (res *dto.CreateKeywordRes, err error) {
|
||||
count, err := dao.Keyword.Count(ctx, &dto.ListKeywordReq{
|
||||
DatasetId: req.DatasetId,
|
||||
DocumentId: req.DocumentId,
|
||||
Word: req.Word,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if count > 0 {
|
||||
err = gerror.New("关键词已存在")
|
||||
return
|
||||
}
|
||||
var id int64
|
||||
id, err = dao.Keyword.Insert(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res = &dto.CreateKeywordRes{Id: id}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keywordService) Update(ctx context.Context, req *dto.UpdateKeywordReq) (err error) {
|
||||
_, err = dao.Keyword.Update(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keywordService) Delete(ctx context.Context, req *dto.DeleteKeywordReq) (err error) {
|
||||
_, err = dao.Keyword.Delete(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keywordService) Get(ctx context.Context, req *dto.GetKeywordReq) (res *dto.KeywordVO, err error) {
|
||||
r, err := dao.Keyword.GetByID(ctx, req)
|
||||
err = gconv.Struct(r, &res)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keywordService) List(ctx context.Context, req *dto.ListKeywordReq) (res *dto.ListKeywordRes, err error) {
|
||||
list, total, err := dao.Keyword.List(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res = &dto.ListKeywordRes{
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user