feat: 支持多租户多模型对话及文档去重优化
This commit is contained in:
@@ -43,17 +43,29 @@ func (d *documentVectorDao) Update(ctx context.Context, req *dto.UpdateDocumentV
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *documentVectorDao) Delete(ctx context.Context, req *dto.DeleteDocumentVectorReq) (rows int64, err error) {
|
||||
result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).OmitEmpty().
|
||||
Where(entity.DocumentVectorCol.Id, req.Id).
|
||||
Where(entity.DocumentVectorCol.DocumentId, req.DocumentId).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// List 文件块列表
|
||||
func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVectorReq, fields ...string) (res []*entity.DocumentVector, total int, err error) {
|
||||
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).Fields(fields).OmitEmpty().
|
||||
Where(entity.DocumentVectorCol.DatasetId, req.DatasetId).
|
||||
Where(entity.DocumentVectorCol.DocumentId, req.DocumentId).
|
||||
Where(entity.DocumentVectorCol.Status, req.Status).
|
||||
Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus)
|
||||
Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus).
|
||||
WhereIn(entity.DocumentVectorCol.DocumentId, req.DocumentIds)
|
||||
if !g.IsEmpty(req.Keyword) {
|
||||
model.WhereLike(entity.DocumentVectorCol.Content, "%"+req.Keyword+"%")
|
||||
}
|
||||
model.OrderDesc(entity.DocumentVectorCol.CreatedAt)
|
||||
model.OrderAsc(entity.DocumentVectorCol.ChunkIndex)
|
||||
if req.Page != nil {
|
||||
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
|
||||
}
|
||||
@@ -65,28 +77,32 @@ func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVecto
|
||||
return
|
||||
}
|
||||
|
||||
func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds []int64, vector pgvector.Vector, topK int) (list gdb.List, err error) {
|
||||
//result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).
|
||||
// Fields("id, content, dataset_id, document_id, vector <=> ? AS distance").
|
||||
// WhereIn(entity.DocumentVectorCol.DatasetId, datasetIds).
|
||||
// WhereNotNull(entity.DocumentVectorCol.Vector).
|
||||
// OrderAsc("distance").
|
||||
// Limit(topK).
|
||||
// All()
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds, documentIds []int64, vector pgvector.Vector, topK int) (list gdb.List, err error) {
|
||||
// 动态拼接 WHERE 条件
|
||||
var whereCondition string
|
||||
var queryParams []interface{}
|
||||
|
||||
// 优先使用 documentIds 查询
|
||||
if len(documentIds) > 0 {
|
||||
whereCondition = fmt.Sprintf(" AND %s IN (?) ", entity.DocumentVectorCol.DocumentId)
|
||||
queryParams = append(queryParams, documentIds)
|
||||
}
|
||||
if len(datasetIds) > 0 {
|
||||
whereCondition = fmt.Sprintf(" AND %s IN (?) ", entity.DocumentVectorCol.DatasetId)
|
||||
queryParams = append(queryParams, datasetIds)
|
||||
}
|
||||
|
||||
// 完整 SQL
|
||||
sql := `
|
||||
SELECT id, content, dataset_id, document_id,
|
||||
vector <=> ? AS distance
|
||||
SELECT id, content, dataset_id, document_id,vector <=> ? AS distance
|
||||
FROM rag_vector_document_vector
|
||||
WHERE dataset_id IN (?)
|
||||
AND vector IS NOT NULL
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
`
|
||||
// 顺序:vector, dataset_id, topK
|
||||
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, vector, datasetIds, topK)
|
||||
WHERE 1=1 ` + whereCondition + ` AND vector IS NOT NULL ORDER BY distance ASC LIMIT ?`
|
||||
// 拼接参数:vector + 条件参数 + topK
|
||||
queryParams = append([]interface{}{vector}, queryParams...)
|
||||
queryParams = append(queryParams, topK)
|
||||
|
||||
// 执行查询
|
||||
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, queryParams...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -94,7 +110,7 @@ func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds []int
|
||||
}
|
||||
|
||||
// SearchByKeywords 通过关键词全文检索文档块
|
||||
func (d *documentVectorDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) {
|
||||
func (d *documentVectorDao) SearchByKeywords(ctx context.Context, query string, datasetIds, documentIds []int64, topK int) (list gdb.List, err error) {
|
||||
// 构建 meilisearch 查询参数
|
||||
searchParams := &meilisearch.SearchParams{
|
||||
Query: query,
|
||||
@@ -102,14 +118,22 @@ func (d *documentVectorDao) SearchByKeywords(ctx context.Context, query string,
|
||||
ShowRankingScore: true,
|
||||
}
|
||||
|
||||
// 构建 datasetIds 过滤条件
|
||||
if len(datasetIds) > 0 {
|
||||
datasetIdStrs := gconv.Strings(datasetIds)
|
||||
quotedIds := make([]string, len(datasetIdStrs))
|
||||
for i, id := range datasetIdStrs {
|
||||
quotedIds[i] = fmt.Sprintf("%s", id)
|
||||
}
|
||||
searchParams.Filter = fmt.Sprintf("dataset_id IN [%s]", gstr.Implode(", ", quotedIds))
|
||||
searchParams.Filter = fmt.Sprintf("%s IN [%s]", entity.DocumentVectorCol.DatasetId, gstr.Implode(", ", quotedIds))
|
||||
}
|
||||
|
||||
if len(documentIds) > 0 {
|
||||
documentIdStrs := gconv.Strings(documentIds)
|
||||
quotedIds := make([]string, len(documentIdStrs))
|
||||
for i, id := range documentIdStrs {
|
||||
quotedIds[i] = fmt.Sprintf("%s", id)
|
||||
}
|
||||
searchParams.Filter = fmt.Sprintf("%s IN [%s]", entity.DocumentVectorCol.DocumentId, gstr.Implode(", ", quotedIds))
|
||||
}
|
||||
|
||||
// 执行搜索
|
||||
@@ -124,6 +148,5 @@ func (d *documentVectorDao) SearchByKeywords(ctx context.Context, query string,
|
||||
for _, hit := range hits {
|
||||
resultList = append(resultList, hit)
|
||||
}
|
||||
|
||||
return resultList, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user