Files
rag/dao/document_vector.go
2026-06-10 16:36:46 +08:00

153 lines
5.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dao
import (
"context"
"fmt"
"rag/consts/public"
"rag/model/dto"
"rag/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/full-text-search/meilisearch"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
var DocumentVector = new(documentVectorDao)
type documentVectorDao struct{}
// BatchInsert 批量插入文件块
func (d *documentVectorDao) BatchInsert(ctx context.Context, req []*dto.VectorDocumentVectorMsg) (rows int64, err error) {
var res []*entity.DocumentVector
if err = gconv.Structs(req, &res); err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).Data(&res).Insert()
if err != nil {
return
}
return r.RowsAffected()
}
// Update 更新文件块
func (d *documentVectorDao) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (rows int64, err error) {
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).OmitEmpty()
r, err := model.Data(&req).Where(entity.DocumentVectorCol.Id, req.Id).Update()
if err != nil {
return
}
return r.RowsAffected()
}
func (d *documentVectorDao) Delete(ctx context.Context, req *dto.DeleteDocumentVectorReq) (rows int64, err error) {
result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).OmitEmpty().
Where(entity.DocumentVectorCol.Id, req.Id).
Where(entity.DocumentVectorCol.DocumentId, req.DocumentId).
Delete()
if err != nil {
return
}
return result.RowsAffected()
}
// List 文件块列表
func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVectorReq, fields ...string) (res []*entity.DocumentVector, total int, err error) {
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).Fields(fields).OmitEmpty().
Where(entity.DocumentVectorCol.DatasetId, req.DatasetId).
Where(entity.DocumentVectorCol.DocumentId, req.DocumentId).
Where(entity.DocumentVectorCol.Status, req.Status).
Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus).
WhereIn(entity.DocumentVectorCol.DocumentId, req.DocumentIds)
if !g.IsEmpty(req.Keyword) {
model.WhereLike(entity.DocumentVectorCol.Content, "%"+req.Keyword+"%")
}
model.OrderAsc(entity.DocumentVectorCol.ChunkIndex)
if req.Page != nil {
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
}
r, total, err := model.AllAndCount(false)
if err != nil {
return
}
err = r.Structs(&res)
return
}
func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds, documentIds []int64, vector pgvector.Vector, topK int) (list gdb.List, err error) {
// 动态拼接 WHERE 条件
var whereCondition string
var queryParams []interface{}
// 优先使用 documentIds 查询
if len(documentIds) > 0 {
whereCondition = fmt.Sprintf(" AND %s IN (?) ", entity.DocumentVectorCol.DocumentId)
queryParams = append(queryParams, documentIds)
}
if len(datasetIds) > 0 {
whereCondition = fmt.Sprintf(" AND %s IN (?) ", entity.DocumentVectorCol.DatasetId)
queryParams = append(queryParams, datasetIds)
}
// 完整 SQL
sql := `
SELECT id, content, dataset_id, document_id,vector <=> ? AS distance
FROM rag_vector_document_vector
WHERE 1=1 ` + whereCondition + ` AND vector IS NOT NULL ORDER BY distance ASC LIMIT ?`
// 拼接参数vector + 条件参数 + topK
queryParams = append([]interface{}{vector}, queryParams...)
queryParams = append(queryParams, topK)
// 执行查询
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, queryParams...)
if err != nil {
return nil, err
}
return result.List(), nil
}
// SearchByKeywords 通过关键词全文检索文档块
func (d *documentVectorDao) SearchByKeywords(ctx context.Context, query string, datasetIds, documentIds []int64, topK int) (list gdb.List, err error) {
// 构建 meilisearch 查询参数
searchParams := &meilisearch.SearchParams{
Query: query,
Limit: int64(topK),
ShowRankingScore: true,
}
if len(datasetIds) > 0 {
datasetIdStrs := gconv.Strings(datasetIds)
quotedIds := make([]string, len(datasetIdStrs))
for i, id := range datasetIdStrs {
quotedIds[i] = fmt.Sprintf("%s", id)
}
searchParams.Filter = fmt.Sprintf("%s IN [%s]", entity.DocumentVectorCol.DatasetId, gstr.Implode(", ", quotedIds))
}
if len(documentIds) > 0 {
documentIdStrs := gconv.Strings(documentIds)
quotedIds := make([]string, len(documentIdStrs))
for i, id := range documentIdStrs {
quotedIds[i] = fmt.Sprintf("%s", id)
}
searchParams.Filter = fmt.Sprintf("%s IN [%s]", entity.DocumentVectorCol.DocumentId, gstr.Implode(", ", quotedIds))
}
// 执行搜索
var hits []map[string]interface{}
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
if err != nil {
return nil, err
}
// 转换查询结果为 gdb.List
resultList := make(gdb.List, 0, len(hits))
for _, hit := range hits {
resultList = append(resultList, hit)
}
return resultList, nil
}