package dao import ( "context" "fmt" "time" "model-gateway/consts/public" "model-gateway/model/entity" "gitea.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/os/gtime" ) // ClaimPendingGlobal 后台任务使用:全局抢占 pending 任务(不加 tenant 过滤) func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) { if batchSize <= 0 { batchSize = 1 } err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { sql := fmt.Sprintf( `SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, public.TableNameTask, batchSize, ) r, err := tx.GetAll(sql) if err != nil { return err } if r.IsEmpty() { tasks = nil return nil } if err := r.Structs(&tasks); err != nil { return err } now := time.Now() for _, t := range tasks { _, err = tx.Exec( fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id, ) if err != nil { return err } } return nil }) return } // ClaimPendingByTaskIDGlobal 按 task_id 定向抢占单个 pending 任务(不加 tenant 过滤) // 用于 createTask 创建成功后立即异步尝试执行当前任务,避免只依赖后续 runWork 扫描队列。 func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (task *entity.AsynchTask, err error) { if taskID == "" { return nil, nil } err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { sql := fmt.Sprintf( `SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file FROM %s WHERE deleted_at IS NULL AND state = 0 AND task_id = ? LIMIT 1 FOR UPDATE SKIP LOCKED`, public.TableNameTask, ) r, err := tx.GetOne(sql, taskID) if err != nil { return err } if r.IsEmpty() { task = nil return nil } if err := r.Struct(&task); err != nil { return err } now := time.Now() _, err = tx.Exec( fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id, ) return err }) return } func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, id int64, ossFile, fileType, textResult string, fileSize int64, expireAt *gtime.Time, expendTokens int) error { now := gtime.Now() _, err := gfdb.DB(ctx).Exec(ctx, fmt.Sprintf(`UPDATE %s SET state=2, oss_file=?, file_type=?, text_result=?, expend_tokens=?, file_size=?, error_msg='', finished_at=?, duration_seconds=EXTRACT(EPOCH FROM (? - created_at))::BIGINT, expire_at=NULL, phase=0, tmp_file='', updated_at=? WHERE id=?`, public.TableNameTask), ossFile, fileType, textResult, expendTokens, fileSize, now, now, now, id, ) return err } func (d *taskDao) UpdateFailedGlobal(ctx context.Context, id int64, errorMsg string) error { now := gtime.Now() _, err := gfdb.DB(ctx).Exec(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, duration_seconds=EXTRACT(EPOCH FROM (? - created_at))::BIGINT, phase=0, tmp_file='', updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, now, now, now, id, ) return err } // UpdateFailedKeepTmpGlobal OSS 上传失败:保留 phase/tmp_file,下一轮仅重试 OSS 上传 func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error { now := gtime.Now() _, err := gfdb.DB(ctx).Exec(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, now, now, id, ) return err } // UpdateTmpAfterModelGlobal 模型调用成功后,写入临时文件路径并标记 phase=1 func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error { _, err := gfdb.DB(ctx).Exec(ctx, fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask), tmpFile, id, ) return err } func (d *taskDao) SoftDeleteByTaskIDGlobal(ctx context.Context, taskID string) error { _, err := gfdb.DB(ctx).Exec(ctx, fmt.Sprintf(`UPDATE %s SET deleted_at=NOW(), updated_at=NOW() WHERE task_id=? AND deleted_at IS NULL`, public.TableNameTask), taskID, ) return err } func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error { _, err := gfdb.DB(ctx).Exec(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id, ) return err } // ListExpiredDownloadedGlobal 获取已下载(state=4)且过期的任务,用于清理 func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) { if limit <= 0 { limit = 200 } r, err := gfdb.DB(ctx).GetAll(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), limit, ) if err != nil { return nil, err } err = r.Structs(&list) return } // ListFailedRetryableGlobal 获取失败(state=3)且仍可重试的任务 // retry_count 不含首次执行;retry_times 表示失败后最多再重试 N 次 func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) { if limit <= 0 { limit = 200 } r, err := gfdb.DB(ctx).GetAll(ctx, fmt.Sprintf(` SELECT t.*, m.retry_queue_max_seconds AS retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id = m.tenant_id AND t.model_name = m.model_name WHERE t.deleted_at IS NULL AND t.state = 3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), limit, ) if err != nil { return nil, err } err = r.Structs(&list) return } // RequeueForRetryGlobal 将任务重新入队(state=0),并将 retry_count +1 // enqueueAt 用于控制重试任务在队列中的位置: // - enqueueAt 越早,越靠前(ClaimPendingGlobal 按 enqueue_at ASC 抢占) func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error { _, err := gfdb.DB(ctx).Exec(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id, ) return err } // ListFailedExhaustedGlobal 获取失败(state=3)且超过重试次数的任务,用于硬删除 func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) { if limit <= 0 { limit = 200 } r, err := gfdb.DB(ctx).GetAll(ctx, fmt.Sprintf(` SELECT t.* FROM %s t JOIN %s m ON t.tenant_id = m.tenant_id AND t.model_name = m.model_name WHERE t.deleted_at IS NULL AND t.state = 3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), limit, ) if err != nil { return nil, err } err = r.Structs(&list) return } // HardDeleteByIDGlobal 硬删除任务记录 func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error { _, err := gfdb.DB(ctx).Exec(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id, ) return err } // ListTimeoutTasksGlobal 根据模型配置 expected_seconds 判定超时任务: // - state in (0,1) // - 模型 expected_seconds > 0 // - now - created_at >= expected_seconds func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) { if limit <= 0 { limit = 200 } r, err := gfdb.DB(ctx).GetAll(ctx, fmt.Sprintf(` SELECT t.* FROM %s t JOIN %s m ON t.tenant_id = m.tenant_id AND t.model_name = m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel), limit, ) if err != nil { return nil, err } err = r.Structs(&list) return }