2026-04-03 11:14:44 +08:00
|
|
|
|
package eino
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"database/sql"
|
|
|
|
|
|
"errors"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"rag/dao"
|
|
|
|
|
|
"rag/model/dto"
|
|
|
|
|
|
"rag/model/entity"
|
|
|
|
|
|
|
|
|
|
|
|
"gitea.com/red-future/common/beans"
|
|
|
|
|
|
"github.com/cloudwego/eino/callbacks"
|
|
|
|
|
|
"github.com/cloudwego/eino/components/indexer"
|
|
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
|
|
|
|
"github.com/gogf/gf/v2/os/glog"
|
|
|
|
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|
|
|
|
|
"github.com/pgvector/pgvector-go"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
type PGVectorIndexerOptions struct {
|
|
|
|
|
|
BatchSize int // 每批处理多少条
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type PGVectorIndexer struct {
|
|
|
|
|
|
opts *PGVectorIndexerOptions
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer {
|
|
|
|
|
|
// 默认值
|
|
|
|
|
|
if opts.BatchSize <= 0 {
|
|
|
|
|
|
opts.BatchSize = 5
|
|
|
|
|
|
}
|
|
|
|
|
|
return &PGVectorIndexer{opts: opts}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (rows int64, err error) {
|
|
|
|
|
|
commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...)
|
|
|
|
|
|
|
|
|
|
|
|
if commonOpts.Embedding == nil {
|
|
|
|
|
|
return 0, errors.New("embedding model not set")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 回调
|
|
|
|
|
|
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
|
|
|
|
|
|
|
2026-04-03 17:59:05 +08:00
|
|
|
|
rows, err = i.bulkStore(ctx, docs, commonOpts)
|
2026-04-03 11:14:44 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
callbacks.OnError(ctx, err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-03 17:59:05 +08:00
|
|
|
|
callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(rows)})
|
2026-04-03 11:14:44 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (i *PGVectorIndexer) bulkStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
|
|
|
|
|
|
var batchDocs []*schema.Document
|
|
|
|
|
|
|
|
|
|
|
|
// 官方ES同款逻辑:满 BatchSize 就处理一批
|
|
|
|
|
|
for _, doc := range docs {
|
|
|
|
|
|
batchDocs = append(batchDocs, doc)
|
|
|
|
|
|
|
|
|
|
|
|
// 满了 → 处理
|
|
|
|
|
|
if len(batchDocs) >= i.opts.BatchSize {
|
|
|
|
|
|
var r int64
|
|
|
|
|
|
r, err = i.doStore(ctx, batchDocs, opts)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
rows = rows + r
|
|
|
|
|
|
batchDocs = nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 最后一批
|
|
|
|
|
|
if len(batchDocs) > 0 {
|
|
|
|
|
|
var r int64
|
|
|
|
|
|
r, err = i.doStore(ctx, batchDocs, opts)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
rows = rows + r
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
|
|
|
|
|
|
|
|
|
|
|
|
texts := make([]string, len(docs))
|
|
|
|
|
|
for i, d := range docs {
|
|
|
|
|
|
texts[i] = d.Content
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 向量化(官方ES也没有重试!)
|
|
|
|
|
|
vectors, err := opts.Embedding.EmbedStrings(ctx, texts)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 转成业务实体
|
2026-04-10 13:12:19 +08:00
|
|
|
|
var chunks []*dto.VectorDocumentVectorMsg
|
2026-04-03 11:14:44 +08:00
|
|
|
|
for idx, doc := range docs {
|
2026-04-10 13:12:19 +08:00
|
|
|
|
ck := new(dto.VectorDocumentVectorMsg)
|
2026-04-03 11:14:44 +08:00
|
|
|
|
err = gconv.Struct(doc.MetaData, ck)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
glog.Errorf(ctx, "doStore err: %v", err)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
ck.Content = doc.Content
|
|
|
|
|
|
ck.Vector = pgvector.NewVector(gconv.Float32s(vectors[idx]))
|
|
|
|
|
|
ck.VectorStatus = gconv.PtrInt8(1)
|
|
|
|
|
|
ck.Status = gconv.PtrInt8(1)
|
|
|
|
|
|
chunks = append(chunks, ck)
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(chunks) == 0 {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
ctx = context.WithValue(ctx, "user", &beans.User{
|
|
|
|
|
|
TenantId: chunks[0].TenantId,
|
|
|
|
|
|
UserName: chunks[0].Creator,
|
|
|
|
|
|
})
|
|
|
|
|
|
// 创建索引
|
|
|
|
|
|
if err = i.createOrUpdateDatasetIndex(ctx, chunks[0].DatasetId, len(vectors[0]), int64(len(chunks))); err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
// 入库
|
2026-04-10 13:12:19 +08:00
|
|
|
|
rows, err = dao.DocumentVector.BatchInsert(ctx, chunks)
|
2026-04-03 11:14:44 +08:00
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (i *PGVectorIndexer) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) error {
|
|
|
|
|
|
exist, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
|
|
|
|
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
if exist != nil {
|
|
|
|
|
|
_ = dao.DatasetIndex.IncVectorCount(ctx, exist.Id, vectorCount)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId)
|
|
|
|
|
|
idx := &entity.DatasetIndex{
|
|
|
|
|
|
DatasetId: datasetId,
|
|
|
|
|
|
Name: indexName,
|
|
|
|
|
|
Dimension: dimension,
|
|
|
|
|
|
FieldType: "float",
|
|
|
|
|
|
MetricType: "COSINE",
|
|
|
|
|
|
Status: gconv.PtrInt8(1),
|
|
|
|
|
|
VectorCount: vectorCount,
|
|
|
|
|
|
Description: fmt.Sprintf("数据集%d向量索引", datasetId),
|
|
|
|
|
|
}
|
|
|
|
|
|
_, err = dao.DatasetIndex.Insert(ctx, idx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
return i.createRealPGVectorIndex(ctx, indexName)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (i *PGVectorIndexer) createRealPGVectorIndex(ctx context.Context, indexName string) error {
|
|
|
|
|
|
if err := dao.DatasetIndex.InsertIndex(ctx, indexName); err != nil {
|
|
|
|
|
|
glog.Errorf(ctx, "create vector index failed: %v", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
glog.Infof(ctx, "created pgvector index: %s", indexName)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (i *PGVectorIndexer) GetType() string {
|
|
|
|
|
|
return "pgvector_indexer"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (i *PGVectorIndexer) IsCallbacksEnabled() bool {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|