package dao import ( "context" "fmt" "rag/consts/public" "rag/model/dto" "rag/model/entity" "gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/full-text-search/meilisearch" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" "github.com/pgvector/pgvector-go" ) var DocumentVector = new(documentVectorDao) type documentVectorDao struct{} // BatchInsert 批量插入文件块 func (d *documentVectorDao) BatchInsert(ctx context.Context, req []*dto.VectorDocumentVectorMsg) (rows int64, err error) { var res []*entity.DocumentVector if err = gconv.Structs(req, &res); err != nil { return } r, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).Data(&res).Insert() if err != nil { return } return r.RowsAffected() } // Update 更新文件块 func (d *documentVectorDao) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (rows int64, err error) { model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).OmitEmpty() r, err := model.Data(&req).Where(entity.DocumentVectorCol.Id, req.Id).Update() if err != nil { return } return r.RowsAffected() } // List 文件块列表 func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVectorReq, fields ...string) (res []*entity.DocumentVector, total int, err error) { model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).Fields(fields).OmitEmpty(). Where(entity.DocumentVectorCol.DatasetId, req.DatasetId). Where(entity.DocumentVectorCol.DocumentId, req.DocumentId). Where(entity.DocumentVectorCol.Status, req.Status). Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus) if !g.IsEmpty(req.Keyword) { model.WhereLike(entity.DocumentVectorCol.Content, "%"+req.Keyword+"%") } model.OrderDesc(entity.DocumentVectorCol.CreatedAt) if req.Page != nil { model.Page(int(req.Page.PageNum), int(req.Page.PageSize)) } r, total, err := model.AllAndCount(false) if err != nil { return } err = r.Structs(&res) return } func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds []int64, vector pgvector.Vector, topK int) (list gdb.List, err error) { //result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector). // Fields("id, content, dataset_id, document_id, vector <=> ? AS distance"). // WhereIn(entity.DocumentVectorCol.DatasetId, datasetIds). // WhereNotNull(entity.DocumentVectorCol.Vector). // OrderAsc("distance"). // Limit(topK). // All() //if err != nil { // return nil, err //} sql := ` SELECT id, content, dataset_id, document_id, vector <=> ? AS distance FROM rag_vector_document_vector WHERE dataset_id IN (?) AND vector IS NOT NULL ORDER BY distance ASC LIMIT ? ` // 顺序:vector, dataset_id, topK result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, vector, datasetIds, topK) if err != nil { return nil, err } return result.List(), nil } // SearchByKeywords 通过关键词全文检索文档块 func (d *documentVectorDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) { // 构建 meilisearch 查询参数 searchParams := &meilisearch.SearchParams{ Query: query, Limit: int64(topK), ShowRankingScore: true, } // 构建 datasetIds 过滤条件 if len(datasetIds) > 0 { datasetIdStrs := gconv.Strings(datasetIds) quotedIds := make([]string, len(datasetIdStrs)) for i, id := range datasetIdStrs { quotedIds[i] = fmt.Sprintf("%s", id) } searchParams.Filter = fmt.Sprintf("dataset_id IN [%s]", gstr.Implode(", ", quotedIds)) } // 执行搜索 var hits []map[string]interface{} _, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits) if err != nil { return nil, err } // 转换查询结果为 gdb.List resultList := make(gdb.List, 0, len(hits)) for _, hit := range hits { resultList = append(resultList, hit) } return resultList, nil }