fix(task): 修复任务状态更新和超时处理问题
This commit is contained in:
@@ -36,6 +36,7 @@ func (d *taskDao) Update(ctx context.Context, req *entity.AsynchTask) (rows int6
|
|||||||
OmitEmpty().
|
OmitEmpty().
|
||||||
Data(&req).
|
Data(&req).
|
||||||
Where(entity.AsynchTaskCol.Id, req.Id).
|
Where(entity.AsynchTaskCol.Id, req.Id).
|
||||||
|
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
|
||||||
Update()
|
Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -15,16 +15,38 @@ import (
|
|||||||
|
|
||||||
// ======================== 查询辅助 ========================
|
// ======================== 查询辅助 ========================
|
||||||
|
|
||||||
// taskColumns 查询用的公共字段
|
|
||||||
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
|
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
|
||||||
|
|
||||||
|
// ======================== 通用 CRUD ========================
|
||||||
|
|
||||||
|
// UpdateFields 更新指定字段(map 版,用于必须更新零值的场景)
|
||||||
|
func (d *taskDao) UpdateFields(ctx context.Context, id int64, data map[string]any) (int64, error) {
|
||||||
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).
|
||||||
|
Model(ctx, public.TableNameTask).
|
||||||
|
Data(data).
|
||||||
|
Where(entity.AsynchTaskCol.Id, id).
|
||||||
|
Update()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return r.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
|
// execUpdate 内部辅助:执行原生 UPDATE,自动补 updated_at
|
||||||
|
func execUpdate(ctx context.Context, sql string, args ...any) error {
|
||||||
|
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, args...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// ======================== 事务抢占 ========================
|
// ======================== 事务抢占 ========================
|
||||||
|
|
||||||
// claimTasks 事务内抢占任务并更新 state=1
|
|
||||||
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
|
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
|
||||||
var tasks []*entity.AsynchTask
|
var tasks []*entity.AsynchTask
|
||||||
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||||
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, where)
|
sql := fmt.Sprintf(
|
||||||
|
`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`,
|
||||||
|
taskColumns, public.TableNameTask, where,
|
||||||
|
)
|
||||||
r, err := tx.GetOne(sql, args...)
|
r, err := tx.GetOne(sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -37,7 +59,10 @@ func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.Async
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id)
|
_, err = tx.Exec(
|
||||||
|
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||||
|
now, now, task.Id,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -53,8 +78,11 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*ent
|
|||||||
batchSize = 1
|
batchSize = 1
|
||||||
}
|
}
|
||||||
var tasks []*entity.AsynchTask
|
var tasks []*entity.AsynchTask
|
||||||
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||||
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, batchSize)
|
sql := fmt.Sprintf(
|
||||||
|
`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`,
|
||||||
|
taskColumns, public.TableNameTask, batchSize,
|
||||||
|
)
|
||||||
r, err := tx.GetAll(sql)
|
r, err := tx.GetAll(sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -67,7 +95,10 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*ent
|
|||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
for _, t := range tasks {
|
for _, t := range tasks {
|
||||||
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id)
|
_, err = tx.Exec(
|
||||||
|
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||||
|
now, now, t.Id,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -89,112 +120,96 @@ func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string)
|
|||||||
return tasks[0], nil
|
return tasks[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ======================== 更新辅助 ========================
|
// ======================== 业务更新方法 ========================
|
||||||
|
|
||||||
func execSQL(ctx context.Context, sql string, args ...any) error {
|
// RollbackToPendingGlobal 回滚到 pending 状态
|
||||||
_, err := gfdb.DB(ctx).Exec(ctx, sql, args...)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateTask 通用更新
|
|
||||||
func updateTask(ctx context.Context, id int64, data entity.AsynchTask) error {
|
|
||||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
|
|
||||||
Where(entity.AsynchTaskCol.Id, id).Data(data).Update()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSuccessGlobal 更新任务成功
|
|
||||||
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error {
|
|
||||||
return updateTask(ctx, t.Id, entity.AsynchTask{
|
|
||||||
State: 2,
|
|
||||||
OssFile: t.OssFile,
|
|
||||||
FileType: t.FileType,
|
|
||||||
TextResult: t.TextResult,
|
|
||||||
FileSize: t.FileSize,
|
|
||||||
ErrorMsg: "",
|
|
||||||
FinishedAt: gtime.Now(),
|
|
||||||
Phase: 0,
|
|
||||||
TmpFile: "",
|
|
||||||
ExpendTokens: t.ExpendTokens,
|
|
||||||
DurationSeconds: t.DurationSeconds,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateFailedGlobal 模型调用失败
|
|
||||||
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error {
|
|
||||||
return updateTask(ctx, t.Id, entity.AsynchTask{
|
|
||||||
State: 3,
|
|
||||||
ErrorMsg: t.ErrorMsg,
|
|
||||||
FinishedAt: gtime.Now(),
|
|
||||||
Phase: 0,
|
|
||||||
TmpFile: "",
|
|
||||||
TextResult: t.TextResult,
|
|
||||||
DurationSeconds: t.DurationSeconds,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateFailedKeepTmpGlobal OSS 上传失败
|
|
||||||
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
|
|
||||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, gtime.Now(), gtime.Now(), id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateTmpAfterModelGlobal 写临时文件
|
|
||||||
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
|
|
||||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask), tmpFile, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RollbackToPendingGlobal 回滚
|
|
||||||
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
||||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id)
|
// state=0 可能被 OmitEmpty 跳过,所以用原生 SQL + 条件 state=1 防并发
|
||||||
|
return execUpdate(ctx,
|
||||||
|
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
|
||||||
|
id,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IncRetryCountGlobal 重试计数+1
|
// IncRetryCountGlobal 重试计数 +1
|
||||||
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
|
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
|
||||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), id)
|
return execUpdate(ctx,
|
||||||
|
fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask),
|
||||||
|
id,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequeueForRetryGlobal 重新入队
|
// RequeueForRetryGlobal 重新入队
|
||||||
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
|
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
|
||||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id)
|
return execUpdate(ctx,
|
||||||
|
fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask),
|
||||||
|
enqueueAt, id,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ======================== 列表查询 ========================
|
// ======================== 列表查询 ========================
|
||||||
|
|
||||||
// ListExpiredDownloadedGlobal
|
// ListExpiredDownloadedGlobal 查询已过期下载的任务
|
||||||
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||||
return queryTasks(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), clampLimit(limit, 200))
|
return queryTasks(ctx,
|
||||||
|
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask),
|
||||||
|
gtime.Now(), clampLimit(limit, 200),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListFailedRetryableGlobal
|
// ListFailedRetryableGlobal 查询可重试的失败任务
|
||||||
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||||
return queryTasks(ctx, fmt.Sprintf(`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
|
return queryTasks(ctx,
|
||||||
|
fmt.Sprintf(
|
||||||
|
`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`,
|
||||||
|
public.TableNameTask, public.TableNameModel,
|
||||||
|
),
|
||||||
|
clampLimit(limit, 200),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListFailedExhaustedGlobal
|
// ListFailedExhaustedGlobal 查询重试次数耗尽的失败任务
|
||||||
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||||
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
|
return queryTasks(ctx,
|
||||||
|
fmt.Sprintf(
|
||||||
|
`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`,
|
||||||
|
public.TableNameTask, public.TableNameModel,
|
||||||
|
),
|
||||||
|
clampLimit(limit, 200),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListTimeoutTasksGlobal
|
// ListTimeoutTasksGlobal 查询超时任务
|
||||||
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||||
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
|
return queryTasks(ctx,
|
||||||
|
fmt.Sprintf(
|
||||||
|
`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`,
|
||||||
|
public.TableNameTask, public.TableNameModel,
|
||||||
|
),
|
||||||
|
clampLimit(limit, 200),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HardDeleteByIDGlobal
|
// HardDeleteByIDGlobal 物理删除任务
|
||||||
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
|
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
|
||||||
return execSQL(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id)
|
return execUpdate(ctx,
|
||||||
|
fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask),
|
||||||
|
id,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ======================== 内部辅助 ========================
|
// ======================== 内部辅助 ========================
|
||||||
|
|
||||||
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
|
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
|
||||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var list []*entity.AsynchTask
|
var list []*entity.AsynchTask
|
||||||
err = r.Structs(&list)
|
if err = r.Structs(&list); err != nil {
|
||||||
return list, err
|
return nil, err
|
||||||
|
}
|
||||||
|
return list, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func clampLimit(limit, defaultVal int) int {
|
func clampLimit(limit, defaultVal int) int {
|
||||||
@@ -203,12 +218,3 @@ func clampLimit(limit, defaultVal int) int {
|
|||||||
}
|
}
|
||||||
return limit
|
return limit
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateColumns 更新指定字段(结构体版)
|
|
||||||
func (d *taskDao) UpdateColumns(ctx context.Context, id int64, data entity.AsynchTask) error {
|
|
||||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
|
|
||||||
Where(entity.AsynchTaskCol.Id, id).
|
|
||||||
Data(data).
|
|
||||||
Update()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -37,7 +37,10 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
|
|||||||
} else {
|
} else {
|
||||||
for _, t := range list {
|
for _, t := range list {
|
||||||
t.ErrorMsg = "任务超时自动失败"
|
t.ErrorMsg = "任务超时自动失败"
|
||||||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
_, err = dao.Task.Update(ctx, t)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[清理] 标记任务失败: %v", err)
|
||||||
|
}
|
||||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||||
}
|
}
|
||||||
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
|
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
|
||||||
|
|||||||
@@ -21,8 +21,10 @@ import (
|
|||||||
"model-gateway/service/gateway"
|
"model-gateway/service/gateway"
|
||||||
"model-gateway/service/queue"
|
"model-gateway/service/queue"
|
||||||
|
|
||||||
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/os/gtime"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -92,7 +94,10 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
|||||||
if err == nil && tmpPath != "" {
|
if err == nil && tmpPath != "" {
|
||||||
task.TmpFile = tmpPath
|
task.TmpFile = tmpPath
|
||||||
task.Phase = 1
|
task.Phase = 1
|
||||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
_, err = dao.Task.Update(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4) 解析校验 + 响应映射(可重试,失败重新调模型)
|
// 4) 解析校验 + 响应映射(可重试,失败重新调模型)
|
||||||
@@ -116,7 +121,14 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
|||||||
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
|
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
|
||||||
task.TaskID, attempt, maxRetry, err)
|
task.TaskID, attempt, maxRetry, err)
|
||||||
if attempt == maxRetry {
|
if attempt == maxRetry {
|
||||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error())
|
task.State = 3
|
||||||
|
task.ErrorMsg = err.Error()
|
||||||
|
task.FinishedAt = gtime.Now()
|
||||||
|
task.Phase = 1
|
||||||
|
_, err = dao.Task.Update(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||||
|
}
|
||||||
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -130,7 +142,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
|||||||
task.TextResult = body
|
task.TextResult = body
|
||||||
task.FileSize = int64(oss.FileSize)
|
task.FileSize = int64(oss.FileSize)
|
||||||
|
|
||||||
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil {
|
if _, err = dao.Task.Update(ctx, task); err != nil {
|
||||||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -170,7 +182,10 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
|
|||||||
if tmpErr == nil && tmpPath != "" {
|
if tmpErr == nil && tmpPath != "" {
|
||||||
task.TmpFile = tmpPath
|
task.TmpFile = tmpPath
|
||||||
task.Phase = 1
|
task.Phase = 1
|
||||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
_, err = dao.Task.Update(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -258,7 +273,10 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
|
|||||||
if tmpErr == nil && tmpPath != "" {
|
if tmpErr == nil && tmpPath != "" {
|
||||||
task.TmpFile = tmpPath
|
task.TmpFile = tmpPath
|
||||||
task.Phase = 1
|
task.Phase = 1
|
||||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
_, err = dao.Task.Update(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,10 +314,11 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2) 先存 token 到数据库,防止后续失败丢失
|
// 2) 先存 token 到数据库,防止后续失败丢失
|
||||||
if tokens, ok := mapped[model.ResponseTokenField]; ok {
|
if _, ok := mapped[model.ResponseTokenField]; ok {
|
||||||
task.ExpendTokens = gconv.Int64(tokens)
|
task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField])
|
||||||
_ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{
|
_, err = dao.Task.Update(ctx, &entity.AsynchTask{
|
||||||
ExpendTokens: gconv.Int64(body[model.ResponseTokenField]),
|
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
|
||||||
|
ExpendTokens: task.ExpendTokens,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -483,7 +502,10 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startT
|
|||||||
t.State = 3
|
t.State = 3
|
||||||
t.ErrorMsg = errMsg
|
t.ErrorMsg = errMsg
|
||||||
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
_, err := dao.Task.Update(ctx, t)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
|
||||||
|
}
|
||||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user