package dao import ( "context" "model-gateway/consts/public" "model-gateway/model/entity" "gitea.redpowerfuture.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/util/gconv" ) var Task = &taskDao{} type taskDao struct{} // Insert 插入 func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) { m := new(entity.AsynchTask) err = gconv.Struct(req, &m) if err != nil { return } r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). Insert(m) if err != nil { return } return r.LastInsertId() } // Get 获取 func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). OmitEmpty(). Where(entity.AsynchTaskCol.TaskID, req.TaskID). Fields(fields).One() if err != nil { return } err = r.Struct(&m) return } // ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据) func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) { if len(taskIDs) == 0 { return nil, nil } r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). OmitEmpty(). WhereIn(entity.AsynchTaskCol.TaskID, taskIDs). All() if err != nil { return nil, err } err = r.Structs(&m) return } // MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间 func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error { data := gdb.Map{ entity.AsynchTaskCol.State: 4, entity.AsynchTaskCol.ExpireAt: expireAt, entity.AsynchTaskCol.Updater: "", } _, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.Id, id). Where(entity.AsynchTaskCol.State, 2). Data(data). Update() return err } // List 任务分页查询(受 gfdb 租户 Hook 影响) func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) { m := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL") if modelNameLike != "" { m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%") } if taskIDLike != "" { m = m.WhereLike(entity.AsynchTaskCol.TaskID, "%"+taskIDLike+"%") } if state != nil { m = m.Where(entity.AsynchTaskCol.State, *state) } m = m.OrderDesc(entity.AsynchTaskCol.CreatedAt) if pageNum > 0 && pageSize > 0 { m = m.Page(pageNum, pageSize) } r, totalInt, err := m.AllAndCount(false) if err != nil { return nil, 0, err } total = gconv.Int64(totalInt) err = r.Structs(&list) return } // GetPendingAsyncTasks 获取进行中的异步任务 func (d *taskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { var tasks []*entity.AsynchTask err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). Where("state", 1). Where("deleted_at IS NULL"). Limit(limit). Scan(&tasks) return tasks, err }