package dao import ( "context" "rag/consts/public" "rag/model/dto" "rag/model/entity" "gitea.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/util/gconv" ) var DocumentChunk = new(documentChunkDao) type documentChunkDao struct{} // BatchInsert 批量插入文件块 func (d *documentChunkDao) BatchInsert(ctx context.Context, req []*dto.VectorDocumentChunkMsg) (rows int64, err error) { var res []*entity.DocumentChunk if err = gconv.Structs(req, &res); err != nil { return } r, err := gfdb.DB(ctx).Model(ctx, public.TableNameDocumentChunk).Data(&res).Insert() if err != nil { return } return r.RowsAffected() } // Update 更新文件块 func (d *documentChunkDao) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (rows int64, err error) { model := gfdb.DB(ctx).Model(ctx, public.TableNameDocumentChunk) r, err := model.Data(&req).Where(entity.DocumentChunkCol.Id, req.Id).Update() if err != nil { return } return r.RowsAffected() } // List 文件块列表 func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkReq, fields ...string) (res []*entity.DocumentChunk, total int, err error) { model := gfdb.DB(ctx).Model(ctx, public.TableNameDocumentChunk).Fields(fields).OmitEmpty(). Where(entity.DocumentChunkCol.DatasetId, req.DatasetId). Where(entity.DocumentChunkCol.DocumentId, req.DocumentId). Where(entity.DocumentChunkCol.Status, req.Status). Where(entity.DocumentChunkCol.VectorStatus, req.VectorStatus). OrderDesc(entity.DocumentChunkCol.CreatedAt) if req.Page != nil { model.Page(int(req.Page.PageNum), int(req.Page.PageSize)) } r, total, err := model.AllAndCount(false) if err != nil { return } err = r.Structs(&res) return } //// Insert 插入向量文档 //func (d *vectorDocumentDao) Insert(ctx context.Context, docs []*entity.DocumentChunk) (ids []interface{}, err error) { // if len(docs) == 0 { // return // } // interfaces := make([]interface{}, len(docs)) // for i := range docs { // interfaces[i] = docs[i] // } // return mongoDB.Insert(ctx, interfaces, CollectionVectorDoc) //} // //// DeleteByIDs 根据ID删除向量文档 //func (d *vectorDocumentDao) DeleteByIDs(ctx context.Context, ids []string) (err error) { // if len(ids) == 0 { // return // } // objectIDs := make([]bson.ObjectID, len(ids)) // for i, id := range ids { // objectIDs[i], err = bson.ObjectIDFromHex(id) // if err != nil { // return err // } // } // filter := bson.M{"_id": bson.M{"$in": objectIDs}} // _, err = mongoDB.Delete(ctx, filter, CollectionVectorDoc) // return //} // //// GetByIndexID 根据索引ID获取向量文档 //func (d *vectorDocumentDao) GetByIndexID(ctx context.Context, indexID string, limit int) (result []*entity.DocumentChunk, err error) { // filter := bson.M{"indexId": indexID} // page := &beans.Page{PageNum: 1, PageSize: int64(limit)} // _, err = mongoDB.Find(ctx, filter, &result, CollectionVectorDoc, page, nil) // return //} // //// GetByVectorIDs 根据向量ID获取向量文档 //func (d *vectorDocumentDao) GetByVectorIDs(ctx context.Context, vectorIDs []string) (result []*entity.DocumentChunk, err error) { // if len(vectorIDs) == 0 { // return // } // filter := bson.M{"vectorId": bson.M{"$in": vectorIDs}} // _, err = mongoDB.Find(ctx, filter, &result, CollectionVectorDoc, &beans.Page{PageSize: -1}, nil) // return //}