118 lines
2.6 KiB
Go
118 lines
2.6 KiB
Go
|
|
package eino
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"errors"
|
|||
|
|
|
|||
|
|
"github.com/cloudwego/eino/callbacks"
|
|||
|
|
"github.com/cloudwego/eino/components/embedding"
|
|||
|
|
"github.com/cloudwego/eino/components/retriever"
|
|||
|
|
"github.com/cloudwego/eino/schema"
|
|||
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|||
|
|
"github.com/pgvector/pgvector-go"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type PGVectorRetrieverConfig struct {
|
|||
|
|
Embedder embedding.Embedder
|
|||
|
|
DefaultTopK int
|
|||
|
|
DefaultIndex string
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type PGVectorRetriever struct {
|
|||
|
|
embedder embedding.Embedder
|
|||
|
|
topK int
|
|||
|
|
index string
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
|||
|
|
if config.Embedder == nil {
|
|||
|
|
return nil, errors.New("embedder is required")
|
|||
|
|
}
|
|||
|
|
if config.DefaultTopK <= 0 {
|
|||
|
|
config.DefaultTopK = 5
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &PGVectorRetriever{
|
|||
|
|
embedder: config.Embedder,
|
|||
|
|
topK: config.DefaultTopK,
|
|||
|
|
index: config.DefaultIndex,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
|||
|
|
|
|||
|
|
// 1. 处理公共 Option(官方标准写法)
|
|||
|
|
options := &retriever.Options{
|
|||
|
|
Index: &r.index,
|
|||
|
|
TopK: &r.topK,
|
|||
|
|
Embedding: r.embedder,
|
|||
|
|
}
|
|||
|
|
options = retriever.GetCommonOptions(options, opts...)
|
|||
|
|
|
|||
|
|
// 2. 回调(官方标准)
|
|||
|
|
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
|||
|
|
Query: query,
|
|||
|
|
TopK: *options.TopK,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
// 3. 执行检索
|
|||
|
|
docs, err := r.doRetrieve(ctx, query, options)
|
|||
|
|
if err != nil {
|
|||
|
|
callbacks.OnError(ctx, err)
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 4. 完成回调
|
|||
|
|
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
|
|||
|
|
Docs: docs,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return docs, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
|||
|
|
|
|||
|
|
// 1. 生成向量
|
|||
|
|
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
if len(vectors) == 0 {
|
|||
|
|
return nil, errors.New("empty query vector")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
queryVec := pgvector.NewVector(vectors[0])
|
|||
|
|
topK := *opts.TopK
|
|||
|
|
|
|||
|
|
// 2. PG 向量相似度检索 SQL
|
|||
|
|
sql := `
|
|||
|
|
SELECT id, content, dataset_id, document_id,
|
|||
|
|
vector <-> ? AS distance
|
|||
|
|
FROM document_chunk
|
|||
|
|
ORDER BY distance ASC
|
|||
|
|
LIMIT ?
|
|||
|
|
`
|
|||
|
|
|
|||
|
|
// 3. 查询
|
|||
|
|
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 4. 转为 Eino Document
|
|||
|
|
docs := make([]*schema.Document, 0, len(rows))
|
|||
|
|
for _, row := range rows {
|
|||
|
|
docs = append(docs, &schema.Document{
|
|||
|
|
ID: gconv.String(row["id"]),
|
|||
|
|
Content: gconv.String(row["content"]),
|
|||
|
|
Metadata: map[string]any{
|
|||
|
|
"dataset_id": row["dataset_id"],
|
|||
|
|
"document_id": row["document_id"],
|
|||
|
|
"distance": row["distance"],
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return docs, nil
|
|||
|
|
}
|