Files
rag/common/eino/indexer.go

178 lines
4.3 KiB
Go
Raw Normal View History

2026-04-03 11:14:44 +08:00
package eino
import (
"context"
"database/sql"
"errors"
"fmt"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/beans"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
type PGVectorIndexerOptions struct {
BatchSize int // 每批处理多少条
}
type PGVectorIndexer struct {
opts *PGVectorIndexerOptions
}
func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer {
// 默认值
if opts.BatchSize <= 0 {
opts.BatchSize = 5
}
return &PGVectorIndexer{opts: opts}
}
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (rows int64, err error) {
commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...)
if commonOpts.Embedding == nil {
return 0, errors.New("embedding model not set")
}
// 回调
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
rows, err = i.bulkStore(ctx, docs, commonOpts)
2026-04-03 11:14:44 +08:00
if err != nil {
callbacks.OnError(ctx, err)
return
}
callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(rows)})
2026-04-03 11:14:44 +08:00
return
}
func (i *PGVectorIndexer) bulkStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
var batchDocs []*schema.Document
// 官方ES同款逻辑满 BatchSize 就处理一批
for _, doc := range docs {
batchDocs = append(batchDocs, doc)
// 满了 → 处理
if len(batchDocs) >= i.opts.BatchSize {
var r int64
r, err = i.doStore(ctx, batchDocs, opts)
if err != nil {
return
}
rows = rows + r
batchDocs = nil
}
}
// 最后一批
if len(batchDocs) > 0 {
var r int64
r, err = i.doStore(ctx, batchDocs, opts)
if err != nil {
return
}
rows = rows + r
}
return
}
func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
texts := make([]string, len(docs))
for i, d := range docs {
texts[i] = d.Content
}
// 向量化官方ES也没有重试
vectors, err := opts.Embedding.EmbedStrings(ctx, texts)
if err != nil {
return
}
// 转成业务实体
var chunks []*dto.VectorDocumentVectorMsg
2026-04-03 11:14:44 +08:00
for idx, doc := range docs {
ck := new(dto.VectorDocumentVectorMsg)
2026-04-03 11:14:44 +08:00
err = gconv.Struct(doc.MetaData, ck)
if err != nil {
glog.Errorf(ctx, "doStore err: %v", err)
continue
}
ck.Content = doc.Content
ck.Vector = pgvector.NewVector(gconv.Float32s(vectors[idx]))
ck.VectorStatus = gconv.PtrInt8(1)
ck.Status = gconv.PtrInt8(1)
chunks = append(chunks, ck)
}
if len(chunks) == 0 {
return
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: chunks[0].TenantId,
UserName: chunks[0].Creator,
})
// 创建索引
if err = i.createOrUpdateDatasetIndex(ctx, chunks[0].DatasetId, len(vectors[0]), int64(len(chunks))); err != nil {
return
}
// 入库
rows, err = dao.DocumentVector.BatchInsert(ctx, chunks)
2026-04-03 11:14:44 +08:00
return
}
func (i *PGVectorIndexer) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) error {
exist, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if exist != nil {
_ = dao.DatasetIndex.IncVectorCount(ctx, exist.Id, vectorCount)
return nil
}
indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId)
idx := &entity.DatasetIndex{
DatasetId: datasetId,
Name: indexName,
Dimension: dimension,
FieldType: "float",
MetricType: "COSINE",
Status: gconv.PtrInt8(1),
VectorCount: vectorCount,
Description: fmt.Sprintf("数据集%d向量索引", datasetId),
}
_, err = dao.DatasetIndex.Insert(ctx, idx)
if err != nil {
return err
}
return i.createRealPGVectorIndex(ctx, indexName)
}
func (i *PGVectorIndexer) createRealPGVectorIndex(ctx context.Context, indexName string) error {
if err := dao.DatasetIndex.InsertIndex(ctx, indexName); err != nil {
glog.Errorf(ctx, "create vector index failed: %v", err)
return err
}
glog.Infof(ctx, "created pgvector index: %s", indexName)
return nil
}
func (i *PGVectorIndexer) GetType() string {
return "pgvector_indexer"
}
func (i *PGVectorIndexer) IsCallbacksEnabled() bool {
return true
}