feat: 新增关键词类型及优化查询逻辑
支持关键词类型区分,优化文件向量查询SQL及DAO更新逻辑,移除冗余配置和注释代码。
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
@@ -34,7 +35,7 @@ func (d *documentVectorDao) BatchInsert(ctx context.Context, req []*dto.VectorDo
|
||||
|
||||
// Update 更新文件块
|
||||
func (d *documentVectorDao) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector)
|
||||
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).OmitEmpty()
|
||||
r, err := model.Data(&req).Where(entity.DocumentVectorCol.Id, req.Id).Update()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -48,8 +49,11 @@ func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVecto
|
||||
Where(entity.DocumentVectorCol.DatasetId, req.DatasetId).
|
||||
Where(entity.DocumentVectorCol.DocumentId, req.DocumentId).
|
||||
Where(entity.DocumentVectorCol.Status, req.Status).
|
||||
Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus).
|
||||
OrderDesc(entity.DocumentVectorCol.CreatedAt)
|
||||
Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus)
|
||||
if !g.IsEmpty(req.Keyword) {
|
||||
model.WhereLike(entity.DocumentVectorCol.Content, "%"+req.Keyword+"%")
|
||||
}
|
||||
model.OrderDesc(entity.DocumentVectorCol.CreatedAt)
|
||||
if req.Page != nil {
|
||||
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
|
||||
}
|
||||
@@ -62,13 +66,27 @@ func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVecto
|
||||
}
|
||||
|
||||
func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds []int64, vector pgvector.Vector, topK int) (list gdb.List, err error) {
|
||||
result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).
|
||||
Fields("id, content, dataset_id, document_id, vector <=> ? AS distance", vector).
|
||||
WhereIn(entity.DocumentVectorCol.DatasetId, datasetIds).
|
||||
WhereNotNull(entity.DocumentVectorCol.Vector).
|
||||
OrderAsc("distance").
|
||||
Limit(topK).
|
||||
All()
|
||||
//result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).
|
||||
// Fields("id, content, dataset_id, document_id, vector <=> ? AS distance").
|
||||
// WhereIn(entity.DocumentVectorCol.DatasetId, datasetIds).
|
||||
// WhereNotNull(entity.DocumentVectorCol.Vector).
|
||||
// OrderAsc("distance").
|
||||
// Limit(topK).
|
||||
// All()
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
sql := `
|
||||
SELECT id, content, dataset_id, document_id,
|
||||
vector <=> ? AS distance
|
||||
FROM rag_vector_document_vector
|
||||
WHERE dataset_id IN (?)
|
||||
AND vector IS NOT NULL
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
`
|
||||
// 顺序:vector, dataset_id, topK
|
||||
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, vector, datasetIds, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user