package eino import ( "context" "errors" "rag/dao" "sort" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/util/gconv" "github.com/pgvector/pgvector-go" ) type PGVectorRetrieverConfig struct { Embedder embedding.Embedder DefaultTopK int DefaultIndex string DSLInfo map[string]any } type PGVectorRetriever struct { embedder embedding.Embedder topK int index string dslInfo map[string]any } func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) { if config.Embedder == nil { return nil, errors.New("embedder is required") } if config.DefaultTopK <= 0 { config.DefaultTopK = 5 } return &PGVectorRetriever{ embedder: config.Embedder, topK: config.DefaultTopK, index: config.DefaultIndex, dslInfo: config.DSLInfo, }, nil } func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { options := &retriever.Options{ Index: &r.index, TopK: &r.topK, DSLInfo: r.dslInfo, Embedding: r.embedder, } options = retriever.GetCommonOptions(options, opts...) ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ Query: query, TopK: *options.TopK, }) // ========================================== // 🔥 双路检索:向量 + 全文 // ========================================== docsVector, err := r.doRetrieveVector(ctx, query, options) if err != nil { callbacks.OnError(ctx, err) return nil, err } docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options) if err != nil { callbacks.OnError(ctx, err) return nil, err } // 合并 + 去重 docs := mergeAndDeduplicate(docsVector, docsFulltext) // 排序(distance 越小越靠前) sort.Slice(docs, func(i, j int) bool { d1 := gconv.Float64(docs[i].MetaData["distance"]) d2 := gconv.Float64(docs[j].MetaData["distance"]) return d1 < d2 }) // 最多保留 topK if len(docs) > *options.TopK { docs = docs[:*options.TopK] } callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs}) return docs, nil } // ========================================== // 1. 向量检索(PG) // ========================================== func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) { vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query}) if err != nil { return nil, err } if len(vectors) == 0 { return nil, errors.New("empty query vector") } queryVec := pgvector.NewVector(gconv.Float32s(vectors[0])) topK := *opts.TopK datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"]) rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK) if err != nil { return nil, err } docs := make([]*schema.Document, 0, len(rows)) for _, row := range rows { docs = append(docs, &schema.Document{ ID: gconv.String(row["id"]), Content: gconv.String(row["content"]), MetaData: map[string]any{ "dataset_id": gconv.Int64(row["dataset_id"]), "document_id": gconv.Int64(row["document_id"]), "distance": gconv.Float64(row["distance"]), "retrieve_by": "vector", }, }) } return docs, nil } // ========================================== // 2. 全文检索(Meilisearch)🔥 新增 // ========================================== func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) { topK := *opts.TopK datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"]) // 调用你已有的 Meilisearch DAO rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK) if err != nil { return nil, err } docs := make([]*schema.Document, 0, len(rows)) for _, row := range rows { docs = append(docs, &schema.Document{ ID: gconv.String(row["id"]), Content: gconv.String(row["content"]), MetaData: map[string]any{ "dataset_id": gconv.Int64(row["dataset_id"]), "document_id": gconv.Int64(row["document_id"]), "distance": 0.1, // 全文结果给高分 "retrieve_by": "fulltext", }, }) } return docs, nil } // ========================================== // 合并去重 // ========================================== func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document { idMap := make(map[string]*schema.Document) for _, d := range vecDocs { idMap[d.ID] = d } for _, d := range fullDocs { if _, exists := idMap[d.ID]; !exists { idMap[d.ID] = d } } merged := make([]*schema.Document, 0, len(idMap)) for _, d := range idMap { merged = append(merged, d) } return merged }