2026-04-03 17:59:05 +08:00
|
|
|
|
package eino
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"errors"
|
2026-04-09 09:11:43 +08:00
|
|
|
|
"rag/dao"
|
|
|
|
|
|
"sort"
|
2026-04-09 13:57:46 +08:00
|
|
|
|
"time"
|
2026-04-03 17:59:05 +08:00
|
|
|
|
|
|
|
|
|
|
"github.com/cloudwego/eino/callbacks"
|
|
|
|
|
|
"github.com/cloudwego/eino/components/embedding"
|
|
|
|
|
|
"github.com/cloudwego/eino/components/retriever"
|
|
|
|
|
|
"github.com/cloudwego/eino/schema"
|
2026-04-09 13:57:46 +08:00
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
|
|
|
|
"github.com/gogf/gf/v2/os/grpool"
|
2026-04-03 17:59:05 +08:00
|
|
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|
|
|
|
|
"github.com/pgvector/pgvector-go"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
type PGVectorRetrieverConfig struct {
|
|
|
|
|
|
Embedder embedding.Embedder
|
|
|
|
|
|
DefaultTopK int
|
|
|
|
|
|
DefaultIndex string
|
2026-04-09 09:11:43 +08:00
|
|
|
|
DSLInfo map[string]any
|
2026-04-03 17:59:05 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type PGVectorRetriever struct {
|
|
|
|
|
|
embedder embedding.Embedder
|
|
|
|
|
|
topK int
|
|
|
|
|
|
index string
|
2026-04-09 09:11:43 +08:00
|
|
|
|
dslInfo map[string]any
|
2026-04-03 17:59:05 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
|
|
|
|
|
if config.Embedder == nil {
|
|
|
|
|
|
return nil, errors.New("embedder is required")
|
|
|
|
|
|
}
|
|
|
|
|
|
if config.DefaultTopK <= 0 {
|
|
|
|
|
|
config.DefaultTopK = 5
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return &PGVectorRetriever{
|
|
|
|
|
|
embedder: config.Embedder,
|
|
|
|
|
|
topK: config.DefaultTopK,
|
|
|
|
|
|
index: config.DefaultIndex,
|
2026-04-09 09:11:43 +08:00
|
|
|
|
dslInfo: config.DSLInfo,
|
2026-04-03 17:59:05 +08:00
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
|
|
|
|
|
options := &retriever.Options{
|
|
|
|
|
|
Index: &r.index,
|
|
|
|
|
|
TopK: &r.topK,
|
2026-04-09 09:11:43 +08:00
|
|
|
|
DSLInfo: r.dslInfo,
|
2026-04-03 17:59:05 +08:00
|
|
|
|
Embedding: r.embedder,
|
|
|
|
|
|
}
|
|
|
|
|
|
options = retriever.GetCommonOptions(options, opts...)
|
|
|
|
|
|
|
2026-04-09 13:57:46 +08:00
|
|
|
|
// 安全保护:防止 nil 指针 panic
|
|
|
|
|
|
topK := 10
|
|
|
|
|
|
if options.TopK != nil {
|
|
|
|
|
|
topK = *options.TopK
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 17:59:05 +08:00
|
|
|
|
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
|
|
|
|
|
Query: query,
|
|
|
|
|
|
TopK: *options.TopK,
|
|
|
|
|
|
})
|
|
|
|
|
|
|
2026-04-09 09:11:43 +08:00
|
|
|
|
// ==========================================
|
2026-04-09 13:57:46 +08:00
|
|
|
|
// 🔥 优化版:grpool 并行双路检索(安全、健壮、无泄漏)
|
2026-04-09 09:11:43 +08:00
|
|
|
|
// ==========================================
|
2026-04-09 13:57:46 +08:00
|
|
|
|
var (
|
|
|
|
|
|
docsVector []*schema.Document
|
|
|
|
|
|
docsFulltext []*schema.Document
|
|
|
|
|
|
errVector error
|
|
|
|
|
|
errFulltext error
|
|
|
|
|
|
|
|
|
|
|
|
// 缓冲通道=2,确保无死锁等待
|
|
|
|
|
|
done = make(chan struct{}, 2)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// 上下文:超时 + 可取消双保障(建议5s超时,根据业务调整)
|
|
|
|
|
|
taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
|
|
|
|
// 封装并行任务函数,消除重复代码
|
|
|
|
|
|
runTask := func(task func() error, errTarget *error) {
|
|
|
|
|
|
defer func() {
|
|
|
|
|
|
// 任务结束必发信号,确保通道不阻塞
|
|
|
|
|
|
done <- struct{}{}
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
// 捕获 panic + 执行业务逻辑
|
|
|
|
|
|
g.TryCatch(taskCtx, func(ctx context.Context) {
|
|
|
|
|
|
*errTarget = task()
|
|
|
|
|
|
}, func(ctx context.Context, panicErr error) {
|
|
|
|
|
|
*errTarget = panicErr
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
// 任务失败:立即取消另一个任务(快速失败)
|
|
|
|
|
|
if *errTarget != nil {
|
|
|
|
|
|
cancel()
|
|
|
|
|
|
}
|
2026-04-09 09:11:43 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-09 13:57:46 +08:00
|
|
|
|
// ----------------------
|
|
|
|
|
|
// 并行提交两个检索任务
|
|
|
|
|
|
// ----------------------
|
|
|
|
|
|
// 任务1:向量检索
|
|
|
|
|
|
grpool.Add(taskCtx, func(ctx context.Context) {
|
|
|
|
|
|
runTask(func() error {
|
|
|
|
|
|
docsVector, errVector = r.doRetrieveVector(ctx, query, options)
|
|
|
|
|
|
return errVector
|
|
|
|
|
|
}, &errVector)
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
// 任务2:全文检索
|
|
|
|
|
|
grpool.Add(taskCtx, func(ctx context.Context) {
|
|
|
|
|
|
runTask(func() error {
|
|
|
|
|
|
docsFulltext, errFulltext = r.doRetrieveMeilisearch(ctx, query, options)
|
|
|
|
|
|
return errFulltext
|
|
|
|
|
|
}, &errFulltext)
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
// ----------------------
|
|
|
|
|
|
// 安全等待所有任务完成
|
|
|
|
|
|
// ----------------------
|
|
|
|
|
|
<-done
|
|
|
|
|
|
<-done
|
|
|
|
|
|
|
|
|
|
|
|
// ----------------------
|
|
|
|
|
|
// 统一错误处理
|
|
|
|
|
|
// ----------------------
|
|
|
|
|
|
// 用 errors.Join 合并所有错误,不丢失信息
|
|
|
|
|
|
if err := errors.Join(errVector, errFulltext); err != nil {
|
2026-04-03 17:59:05 +08:00
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-09 13:57:46 +08:00
|
|
|
|
// 合并 + 智能去重(保留最优分数)
|
2026-04-09 09:11:43 +08:00
|
|
|
|
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
|
|
|
|
|
|
2026-04-09 13:57:46 +08:00
|
|
|
|
// 排序:向量优先,同类型按距离升序
|
2026-04-09 09:11:43 +08:00
|
|
|
|
sort.Slice(docs, func(i, j int) bool {
|
2026-04-09 13:57:46 +08:00
|
|
|
|
//byI, okI := docs[i].MetaData["retrieve_by"].(string)
|
|
|
|
|
|
//byJ, okJ := docs[j].MetaData["retrieve_by"].(string)
|
|
|
|
|
|
//
|
|
|
|
|
|
//// 有类型标记的优先
|
|
|
|
|
|
//if okI && !okJ {
|
|
|
|
|
|
// return true
|
|
|
|
|
|
//}
|
|
|
|
|
|
//if !okI && okJ {
|
|
|
|
|
|
// return false
|
|
|
|
|
|
//}
|
|
|
|
|
|
//
|
|
|
|
|
|
//// 向量永远排前面
|
|
|
|
|
|
//if byI == "vector" && byJ == "fulltext" {
|
|
|
|
|
|
// return true
|
|
|
|
|
|
//}
|
|
|
|
|
|
//if byI == "fulltext" && byJ == "vector" {
|
|
|
|
|
|
// return false
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
// 同类型按 distance 升序(越小越相似)
|
2026-04-09 09:11:43 +08:00
|
|
|
|
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
|
|
|
|
|
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
|
|
|
|
|
return d1 < d2
|
2026-04-03 17:59:05 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
2026-04-09 13:57:46 +08:00
|
|
|
|
// 在Retrieve方法末尾,增加相关性校验
|
|
|
|
|
|
validDocs := make([]*schema.Document, 0)
|
|
|
|
|
|
for i, d := range docs {
|
|
|
|
|
|
// 过滤distance过大的垃圾结果(比如distance>0.8的直接丢弃)
|
|
|
|
|
|
if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 {
|
|
|
|
|
|
validDocs = append(validDocs, d)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 如果没有有效结果,返回空,让LLM回答「暂无相关信息」
|
|
|
|
|
|
if len(validDocs) == 0 {
|
|
|
|
|
|
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
|
|
|
|
|
|
return validDocs, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-09 09:11:43 +08:00
|
|
|
|
// 最多保留 topK
|
2026-04-09 13:57:46 +08:00
|
|
|
|
if len(validDocs) > topK {
|
|
|
|
|
|
validDocs = validDocs[:topK]
|
2026-04-09 09:11:43 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-09 13:57:46 +08:00
|
|
|
|
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
|
|
|
|
|
|
return validDocs, nil
|
2026-04-03 17:59:05 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-09 09:11:43 +08:00
|
|
|
|
// ==========================================
|
|
|
|
|
|
// 1. 向量检索(PG)
|
|
|
|
|
|
// ==========================================
|
|
|
|
|
|
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
2026-04-03 17:59:05 +08:00
|
|
|
|
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(vectors) == 0 {
|
|
|
|
|
|
return nil, errors.New("empty query vector")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-09 09:11:43 +08:00
|
|
|
|
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
|
2026-04-09 13:57:46 +08:00
|
|
|
|
topK := 10
|
|
|
|
|
|
if opts.TopK != nil {
|
|
|
|
|
|
topK = *opts.TopK
|
|
|
|
|
|
}
|
2026-04-09 09:11:43 +08:00
|
|
|
|
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
2026-04-03 17:59:05 +08:00
|
|
|
|
|
2026-04-09 09:11:43 +08:00
|
|
|
|
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
2026-04-03 17:59:05 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
docs := make([]*schema.Document, 0, len(rows))
|
|
|
|
|
|
for _, row := range rows {
|
|
|
|
|
|
docs = append(docs, &schema.Document{
|
|
|
|
|
|
ID: gconv.String(row["id"]),
|
|
|
|
|
|
Content: gconv.String(row["content"]),
|
2026-04-09 09:11:43 +08:00
|
|
|
|
MetaData: map[string]any{
|
|
|
|
|
|
"dataset_id": gconv.Int64(row["dataset_id"]),
|
|
|
|
|
|
"document_id": gconv.Int64(row["document_id"]),
|
|
|
|
|
|
"distance": gconv.Float64(row["distance"]),
|
|
|
|
|
|
"retrieve_by": "vector",
|
2026-04-03 17:59:05 +08:00
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
2026-04-09 09:11:43 +08:00
|
|
|
|
return docs, nil
|
|
|
|
|
|
}
|
2026-04-03 17:59:05 +08:00
|
|
|
|
|
2026-04-09 09:11:43 +08:00
|
|
|
|
// ==========================================
|
|
|
|
|
|
// 2. 全文检索(Meilisearch)🔥 新增
|
|
|
|
|
|
// ==========================================
|
|
|
|
|
|
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
|
|
|
|
|
topK := *opts.TopK
|
|
|
|
|
|
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
|
|
|
|
|
|
|
|
|
|
|
// 调用你已有的 Meilisearch DAO
|
|
|
|
|
|
rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
docs := make([]*schema.Document, 0, len(rows))
|
|
|
|
|
|
for _, row := range rows {
|
2026-04-09 13:57:46 +08:00
|
|
|
|
score := gconv.Float64(row["_rankingScore"])
|
|
|
|
|
|
distance := score
|
|
|
|
|
|
|
2026-04-09 09:11:43 +08:00
|
|
|
|
docs = append(docs, &schema.Document{
|
|
|
|
|
|
ID: gconv.String(row["id"]),
|
|
|
|
|
|
Content: gconv.String(row["content"]),
|
|
|
|
|
|
MetaData: map[string]any{
|
|
|
|
|
|
"dataset_id": gconv.Int64(row["dataset_id"]),
|
|
|
|
|
|
"document_id": gconv.Int64(row["document_id"]),
|
2026-04-09 13:57:46 +08:00
|
|
|
|
"distance": distance,
|
2026-04-09 09:11:43 +08:00
|
|
|
|
"retrieve_by": "fulltext",
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
2026-04-03 17:59:05 +08:00
|
|
|
|
return docs, nil
|
|
|
|
|
|
}
|
2026-04-09 09:11:43 +08:00
|
|
|
|
|
|
|
|
|
|
// ==========================================
|
2026-04-09 13:57:46 +08:00
|
|
|
|
// 合并去重(智能版:两路都命中时,保留向量结果 + 全文标记)
|
2026-04-09 09:11:43 +08:00
|
|
|
|
// ==========================================
|
|
|
|
|
|
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
|
|
|
|
|
|
idMap := make(map[string]*schema.Document)
|
2026-04-09 13:57:46 +08:00
|
|
|
|
|
|
|
|
|
|
// 先存入向量结果
|
2026-04-09 09:11:43 +08:00
|
|
|
|
for _, d := range vecDocs {
|
|
|
|
|
|
idMap[d.ID] = d
|
|
|
|
|
|
}
|
2026-04-09 13:57:46 +08:00
|
|
|
|
|
|
|
|
|
|
// 再处理全文:不存在则添加;存在则标记“双路命中”,不覆盖向量分数
|
2026-04-09 09:11:43 +08:00
|
|
|
|
for _, d := range fullDocs {
|
2026-04-09 13:57:46 +08:00
|
|
|
|
if existDoc, ok := idMap[d.ID]; ok {
|
|
|
|
|
|
// 标记同时被向量和全文检索到
|
|
|
|
|
|
existDoc.MetaData["retrieve_by"] = "both"
|
|
|
|
|
|
} else {
|
2026-04-09 09:11:43 +08:00
|
|
|
|
idMap[d.ID] = d
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-04-09 13:57:46 +08:00
|
|
|
|
|
2026-04-09 09:11:43 +08:00
|
|
|
|
merged := make([]*schema.Document, 0, len(idMap))
|
|
|
|
|
|
for _, d := range idMap {
|
|
|
|
|
|
merged = append(merged, d)
|
|
|
|
|
|
}
|
|
|
|
|
|
return merged
|
|
|
|
|
|
}
|