feat: 支持多租户多模型对话及文档去重优化

This commit is contained in:
2026-04-16 15:47:37 +08:00
parent 4ead3f82cf
commit 27b1dd3c27
34 changed files with 2188 additions and 315 deletions

View File

@@ -3,6 +3,8 @@ package eino
import (
"context"
"errors"
"fmt"
"rag/consts/model"
"rag/dao"
"sort"
"time"
@@ -29,21 +31,25 @@ type PGVectorRetriever struct {
topK int
index string
dslInfo map[string]any
reranker *DashScopeReranker // 通义精排
}
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
if config.Embedder == nil {
return nil, errors.New("embedder is required")
}
func NewPGVectorRetriever(ctx context.Context, config *PGVectorRetrieverConfig, configType model.ModelConfigType) (*PGVectorRetriever, error) {
if config.DefaultTopK <= 0 {
config.DefaultTopK = 5
}
e, err := GetTenantEmbedderByType(ctx, configType)
if err != nil {
return nil, err
}
return &PGVectorRetriever{
embedder: config.Embedder,
embedder: e,
topK: config.DefaultTopK,
index: config.DefaultIndex,
dslInfo: config.DSLInfo,
//reranker: NewDashScopeReranker(), // 👈 直接初始化你的精排
}, nil
}
@@ -138,48 +144,37 @@ func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...
}
// 合并 + 智能去重(保留最优分数)
docs := mergeAndDeduplicate(docsVector, docsFulltext)
mergedDocs := mergeAndDeduplicate(docsVector, docsFulltext)
// 排序:向量优先,同类型按距离升序
sort.Slice(docs, func(i, j int) bool {
//byI, okI := docs[i].MetaData["retrieve_by"].(string)
//byJ, okJ := docs[j].MetaData["retrieve_by"].(string)
//
//// 有类型标记的优先
//if okI && !okJ {
// return true
//}
//if !okI && okJ {
// return false
//}
//
//// 向量永远排前面
//if byI == "vector" && byJ == "fulltext" {
// return true
//}
//if byI == "fulltext" && byJ == "vector" {
// return false
//}
// 同类型按 distance 升序(越小越相似)
d1 := gconv.Float64(docs[i].MetaData["distance"])
d2 := gconv.Float64(docs[j].MetaData["distance"])
return d1 < d2
})
// 在Retrieve方法末尾增加相关性校验
validDocs := make([]*schema.Document, 0)
for i, d := range docs {
// 过滤distance过大的垃圾结果比如distance>0.8的直接丢弃)
if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 {
validDocs = append(validDocs, d)
// =========================
// 🔥 Cross-Encoder 精排
// =========================
var finalDocs []*schema.Document
if r.reranker != nil {
ranked, err := r.reranker.Rerank(ctx, query, mergedDocs)
if err != nil {
return nil, fmt.Errorf("rerank failed: %w", err)
}
finalDocs = ranked
} else {
sort.Slice(mergedDocs, func(i, j int) bool {
d1 := gconv.Float64(mergedDocs[i].MetaData["distance"])
d2 := gconv.Float64(mergedDocs[j].MetaData["distance"])
return d1 < d2
})
finalDocs = mergedDocs
}
// 如果没有有效结果返回空让LLM回答「暂无相关信息」
if len(validDocs) == 0 {
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
return validDocs, nil
// =========================
// 过滤无效文档
// =========================
const maxDistance = 0.8
validDocs := make([]*schema.Document, 0, len(finalDocs))
for _, doc := range finalDocs {
dist := gconv.Float64(doc.MetaData["distance"])
if dist <= maxDistance {
validDocs = append(validDocs, doc)
}
}
// 最多保留 topK
@@ -208,9 +203,15 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
if opts.TopK != nil {
topK = *opts.TopK
}
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
var datasetIds, documentIds []int64
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
}
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
}
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, queryVec, topK)
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, documentIds, queryVec, topK)
if err != nil {
return nil, err
}
@@ -236,10 +237,17 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
// ==========================================
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
topK := *opts.TopK
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
var datasetIds, documentIds []int64
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
}
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
}
// 调用你已有的 Meilisearch DAO
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, topK)
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, documentIds, topK)
if err != nil {
return nil, err
}