Files
rag/common/eino/retriever.go

180 lines
4.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package eino
import (
"context"
"errors"
"rag/dao"
"sort"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/embedding"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
type PGVectorRetrieverConfig struct {
Embedder embedding.Embedder
DefaultTopK int
DefaultIndex string
DSLInfo map[string]any
}
type PGVectorRetriever struct {
embedder embedding.Embedder
topK int
index string
dslInfo map[string]any
}
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
if config.Embedder == nil {
return nil, errors.New("embedder is required")
}
if config.DefaultTopK <= 0 {
config.DefaultTopK = 5
}
return &PGVectorRetriever{
embedder: config.Embedder,
topK: config.DefaultTopK,
index: config.DefaultIndex,
dslInfo: config.DSLInfo,
}, nil
}
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
options := &retriever.Options{
Index: &r.index,
TopK: &r.topK,
DSLInfo: r.dslInfo,
Embedding: r.embedder,
}
options = retriever.GetCommonOptions(options, opts...)
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: query,
TopK: *options.TopK,
})
// ==========================================
// 🔥 双路检索:向量 + 全文
// ==========================================
docsVector, err := r.doRetrieveVector(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
// 合并 + 去重
docs := mergeAndDeduplicate(docsVector, docsFulltext)
// 排序distance 越小越靠前)
sort.Slice(docs, func(i, j int) bool {
d1 := gconv.Float64(docs[i].MetaData["distance"])
d2 := gconv.Float64(docs[j].MetaData["distance"])
return d1 < d2
})
// 最多保留 topK
if len(docs) > *options.TopK {
docs = docs[:*options.TopK]
}
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs})
return docs, nil
}
// ==========================================
// 1. 向量检索PG
// ==========================================
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, err
}
if len(vectors) == 0 {
return nil, errors.New("empty query vector")
}
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
topK := *opts.TopK
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
if err != nil {
return nil, err
}
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": gconv.Float64(row["distance"]),
"retrieve_by": "vector",
},
})
}
return docs, nil
}
// ==========================================
// 2. 全文检索Meilisearch🔥 新增
// ==========================================
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
topK := *opts.TopK
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
// 调用你已有的 Meilisearch DAO
rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK)
if err != nil {
return nil, err
}
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": 0.1, // 全文结果给高分
"retrieve_by": "fulltext",
},
})
}
return docs, nil
}
// ==========================================
// 合并去重
// ==========================================
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
idMap := make(map[string]*schema.Document)
for _, d := range vecDocs {
idMap[d.ID] = d
}
for _, d := range fullDocs {
if _, exists := idMap[d.ID]; !exists {
idMap[d.ID] = d
}
}
merged := make([]*schema.Document, 0, len(idMap))
for _, d := range idMap {
merged = append(merged, d)
}
return merged
}