From afd60caf562c4c3605b38e92e9a788b67fab38d7 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Thu, 11 Jun 2026 11:27:14 +0800 Subject: [PATCH] =?UTF-8?q?fix(task):=20=E4=BF=AE=E5=A4=8D=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E7=8A=B6=E6=80=81=E6=9B=B4=E6=96=B0=E5=92=8C=E8=B6=85?= =?UTF-8?q?=E6=97=B6=E5=A4=84=E7=90=86=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/task_dao.go | 1 + dao/task_dao_bg.go | 182 +++++++++++++++++++++-------------------- service/job/cleaner.go | 5 +- service/task/worker.go | 42 +++++++--- 4 files changed, 131 insertions(+), 99 deletions(-) diff --git a/dao/task_dao.go b/dao/task_dao.go index b8bc87c..dcb7766 100644 --- a/dao/task_dao.go +++ b/dao/task_dao.go @@ -36,6 +36,7 @@ func (d *taskDao) Update(ctx context.Context, req *entity.AsynchTask) (rows int6 OmitEmpty(). Data(&req). Where(entity.AsynchTaskCol.Id, req.Id). + Where(entity.AsynchTaskCol.TaskID, req.TaskID). Update() if err != nil { return diff --git a/dao/task_dao_bg.go b/dao/task_dao_bg.go index 605cd24..1141ac9 100644 --- a/dao/task_dao_bg.go +++ b/dao/task_dao_bg.go @@ -15,16 +15,38 @@ import ( // ======================== 查询辅助 ======================== -// taskColumns 查询用的公共字段 const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file` +// ======================== 通用 CRUD ======================== + +// UpdateFields 更新指定字段(map 版,用于必须更新零值的场景) +func (d *taskDao) UpdateFields(ctx context.Context, id int64, data map[string]any) (int64, error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway). + Model(ctx, public.TableNameTask). + Data(data). + Where(entity.AsynchTaskCol.Id, id). + Update() + if err != nil { + return 0, err + } + return r.RowsAffected() +} + +// execUpdate 内部辅助:执行原生 UPDATE,自动补 updated_at +func execUpdate(ctx context.Context, sql string, args ...any) error { + _, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, args...) + return err +} + // ======================== 事务抢占 ======================== -// claimTasks 事务内抢占任务并更新 state=1 func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) { var tasks []*entity.AsynchTask - err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { - sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, where) + err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { + sql := fmt.Sprintf( + `SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, + taskColumns, public.TableNameTask, where, + ) r, err := tx.GetOne(sql, args...) if err != nil { return err @@ -37,7 +59,10 @@ func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.Async return err } now := time.Now() - _, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id) + _, err = tx.Exec( + fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), + now, now, task.Id, + ) if err != nil { return err } @@ -53,8 +78,11 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*ent batchSize = 1 } var tasks []*entity.AsynchTask - err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { - sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, batchSize) + err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { + sql := fmt.Sprintf( + `SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, + taskColumns, public.TableNameTask, batchSize, + ) r, err := tx.GetAll(sql) if err != nil { return err @@ -67,7 +95,10 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*ent } now := time.Now() for _, t := range tasks { - _, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id) + _, err = tx.Exec( + fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), + now, now, t.Id, + ) if err != nil { return err } @@ -89,112 +120,96 @@ func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) return tasks[0], nil } -// ======================== 更新辅助 ======================== +// ======================== 业务更新方法 ======================== -func execSQL(ctx context.Context, sql string, args ...any) error { - _, err := gfdb.DB(ctx).Exec(ctx, sql, args...) - return err -} - -// updateTask 通用更新 -func updateTask(ctx context.Context, id int64, data entity.AsynchTask) error { - _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty(). - Where(entity.AsynchTaskCol.Id, id).Data(data).Update() - return err -} - -// UpdateSuccessGlobal 更新任务成功 -func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error { - return updateTask(ctx, t.Id, entity.AsynchTask{ - State: 2, - OssFile: t.OssFile, - FileType: t.FileType, - TextResult: t.TextResult, - FileSize: t.FileSize, - ErrorMsg: "", - FinishedAt: gtime.Now(), - Phase: 0, - TmpFile: "", - ExpendTokens: t.ExpendTokens, - DurationSeconds: t.DurationSeconds, - }) -} - -// UpdateFailedGlobal 模型调用失败 -func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error { - return updateTask(ctx, t.Id, entity.AsynchTask{ - State: 3, - ErrorMsg: t.ErrorMsg, - FinishedAt: gtime.Now(), - Phase: 0, - TmpFile: "", - TextResult: t.TextResult, - DurationSeconds: t.DurationSeconds, - }) -} - -// UpdateFailedKeepTmpGlobal OSS 上传失败 -func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error { - return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, gtime.Now(), gtime.Now(), id) -} - -// UpdateTmpAfterModelGlobal 写临时文件 -func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error { - return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask), tmpFile, id) -} - -// RollbackToPendingGlobal 回滚 +// RollbackToPendingGlobal 回滚到 pending 状态 func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error { - return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id) + // state=0 可能被 OmitEmpty 跳过,所以用原生 SQL + 条件 state=1 防并发 + return execUpdate(ctx, + fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), + id, + ) } -// IncRetryCountGlobal 重试计数+1 +// IncRetryCountGlobal 重试计数 +1 func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error { - return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), id) + return execUpdate(ctx, + fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), + id, + ) } // RequeueForRetryGlobal 重新入队 func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error { - return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id) + return execUpdate(ctx, + fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), + enqueueAt, id, + ) } // ======================== 列表查询 ======================== -// ListExpiredDownloadedGlobal +// ListExpiredDownloadedGlobal 查询已过期下载的任务 func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { - return queryTasks(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), clampLimit(limit, 200)) + return queryTasks(ctx, + fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), + gtime.Now(), clampLimit(limit, 200), + ) } -// ListFailedRetryableGlobal +// ListFailedRetryableGlobal 查询可重试的失败任务 func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { - return queryTasks(ctx, fmt.Sprintf(`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200)) + return queryTasks(ctx, + fmt.Sprintf( + `SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, + public.TableNameTask, public.TableNameModel, + ), + clampLimit(limit, 200), + ) } -// ListFailedExhaustedGlobal +// ListFailedExhaustedGlobal 查询重试次数耗尽的失败任务 func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { - return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200)) + return queryTasks(ctx, + fmt.Sprintf( + `SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, + public.TableNameTask, public.TableNameModel, + ), + clampLimit(limit, 200), + ) } -// ListTimeoutTasksGlobal +// ListTimeoutTasksGlobal 查询超时任务 func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { - return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200)) + return queryTasks(ctx, + fmt.Sprintf( + `SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, + public.TableNameTask, public.TableNameModel, + ), + clampLimit(limit, 200), + ) } -// HardDeleteByIDGlobal +// HardDeleteByIDGlobal 物理删除任务 func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error { - return execSQL(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id) + return execUpdate(ctx, + fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), + id, + ) } // ======================== 内部辅助 ======================== func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) { - r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...) + r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...) if err != nil { return nil, err } var list []*entity.AsynchTask - err = r.Structs(&list) - return list, err + if err = r.Structs(&list); err != nil { + return nil, err + } + return list, nil } func clampLimit(limit, defaultVal int) int { @@ -203,12 +218,3 @@ func clampLimit(limit, defaultVal int) int { } return limit } - -// UpdateColumns 更新指定字段(结构体版) -func (d *taskDao) UpdateColumns(ctx context.Context, id int64, data entity.AsynchTask) error { - _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty(). - Where(entity.AsynchTaskCol.Id, id). - Data(data). - Update() - return err -} diff --git a/service/job/cleaner.go b/service/job/cleaner.go index c133c56..d0117c4 100644 --- a/service/job/cleaner.go +++ b/service/job/cleaner.go @@ -37,7 +37,10 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error } else { for _, t := range list { t.ErrorMsg = "任务超时自动失败" - _ = dao.Task.UpdateFailedGlobal(ctx, t) + _, err = dao.Task.Update(ctx, t) + if err != nil { + g.Log().Errorf(ctx, "[清理] 标记任务失败: %v", err) + } queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) } g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list)) diff --git a/service/task/worker.go b/service/task/worker.go index a9c9553..e3acd91 100644 --- a/service/task/worker.go +++ b/service/task/worker.go @@ -21,8 +21,10 @@ import ( "model-gateway/service/gateway" "model-gateway/service/queue" + "gitea.redpowerfuture.com/red-future/common/beans" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/util/gconv" ) @@ -92,7 +94,10 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo if err == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 - _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) + _, err = dao.Task.Update(ctx, task) + if err != nil { + g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) + } } // 4) 解析校验 + 响应映射(可重试,失败重新调模型) @@ -116,7 +121,14 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { - _ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error()) + task.State = 3 + task.ErrorMsg = err.Error() + task.FinishedAt = gtime.Now() + task.Phase = 1 + _, err = dao.Task.Update(ctx, task) + if err != nil { + g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) + } w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err)) return } @@ -130,7 +142,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo task.TextResult = body task.FileSize = int64(oss.FileSize) - if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil { + if _, err = dao.Task.Update(ctx, task); err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) return } @@ -170,7 +182,10 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 - _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) + _, err = dao.Task.Update(ctx, task) + if err != nil { + g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr) + } } } @@ -258,7 +273,10 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 - _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) + _, err = dao.Task.Update(ctx, task) + if err != nil { + g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr) + } } } @@ -296,10 +314,11 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta } // 2) 先存 token 到数据库,防止后续失败丢失 - if tokens, ok := mapped[model.ResponseTokenField]; ok { - task.ExpendTokens = gconv.Int64(tokens) - _ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{ - ExpendTokens: gconv.Int64(body[model.ResponseTokenField]), + if _, ok := mapped[model.ResponseTokenField]; ok { + task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField]) + _, err = dao.Task.Update(ctx, &entity.AsynchTask{ + SQLBaseDO: beans.SQLBaseDO{Id: task.Id}, + ExpendTokens: task.ExpendTokens, }) } @@ -483,7 +502,10 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startT t.State = 3 t.ErrorMsg = errMsg t.DurationSeconds = int64(time.Since(startTime).Seconds()) - _ = dao.Task.UpdateFailedGlobal(ctx, t) + _, err := dao.Task.Update(ctx, t) + if err != nil { + g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err) + } queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) }