package dao import ( "context" "fmt" "time" "model-gateway/consts/public" "model-gateway/model/entity" "gitea.redpowerfuture.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/os/gtime" ) // ======================== 查询辅助 ======================== const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file` // ======================== 通用 CRUD ======================== // UpdateFields 更新指定字段(map 版,用于必须更新零值的场景) func (d *taskDao) UpdateFields(ctx context.Context, id int64, data map[string]any) (int64, error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway). Model(ctx, public.TableNameTask). Data(data). Where(entity.AsynchTaskCol.Id, id). Update() if err != nil { return 0, err } return r.RowsAffected() } // execUpdate 内部辅助:执行原生 UPDATE,自动补 updated_at func execUpdate(ctx context.Context, sql string, args ...any) error { _, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, args...) return err } // ======================== 事务抢占 ======================== func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) { var tasks []*entity.AsynchTask err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { sql := fmt.Sprintf( `SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, where, ) r, err := tx.GetOne(sql, args...) if err != nil { return err } if r.IsEmpty() { return nil } var task entity.AsynchTask if err := r.Struct(&task); err != nil { return err } now := time.Now() _, err = tx.Exec( fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id, ) if err != nil { return err } tasks = []*entity.AsynchTask{&task} return nil }) return tasks, err } // ClaimPendingGlobal 批量抢占 pending 任务 func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) { if batchSize <= 0 { batchSize = 1 } var tasks []*entity.AsynchTask err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { sql := fmt.Sprintf( `SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, batchSize, ) r, err := tx.GetAll(sql) if err != nil { return err } if r.IsEmpty() { return nil } if err := r.Structs(&tasks); err != nil { return err } now := time.Now() for _, t := range tasks { _, err = tx.Exec( fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id, ) if err != nil { return err } } return nil }) return tasks, err } // ClaimPendingByTaskIDGlobal 按 task_id 抢占 func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) { if taskID == "" { return nil, nil } tasks, err := claimTasks(ctx, "AND task_id = ?", taskID) if err != nil || len(tasks) == 0 { return nil, err } return tasks[0], nil } // ======================== 业务更新方法 ======================== // RollbackToPendingGlobal 回滚到 pending 状态 func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error { // state=0 可能被 OmitEmpty 跳过,所以用原生 SQL + 条件 state=1 防并发 return execUpdate(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id, ) } // IncRetryCountGlobal 重试计数 +1 func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error { return execUpdate(ctx, fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), id, ) } // RequeueForRetryGlobal 重新入队 func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error { return execUpdate(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id, ) } // ======================== 列表查询 ======================== // ListExpiredDownloadedGlobal 查询已过期下载的任务 func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { return queryTasks(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), clampLimit(limit, 200), ) } // ListFailedRetryableGlobal 查询可重试的失败任务 func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { return queryTasks(ctx, fmt.Sprintf( `SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel, ), clampLimit(limit, 200), ) } // ListFailedExhaustedGlobal 查询重试次数耗尽的失败任务 func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { return queryTasks(ctx, fmt.Sprintf( `SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel, ), clampLimit(limit, 200), ) } // ListTimeoutTasksGlobal 查询超时任务 func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { return queryTasks(ctx, fmt.Sprintf( `SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel, ), clampLimit(limit, 200), ) } // HardDeleteByIDGlobal 物理删除任务 func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error { return execUpdate(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id, ) } // ======================== 内部辅助 ======================== func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...) if err != nil { return nil, err } var list []*entity.AsynchTask if err = r.Structs(&list); err != nil { return nil, err } return list, nil } func clampLimit(limit, defaultVal int) int { if limit <= 0 { return defaultVal } return limit }