Files
model-gateway/dao/task_dao_bg.go

221 lines
7.0 KiB
Go
Raw Normal View History

2026-04-29 15:54:14 +08:00
package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
2026-04-29 15:54:14 +08:00
2026-06-10 16:16:05 +08:00
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
2026-04-29 15:54:14 +08:00
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
// ======================== 查询辅助 ========================
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
// ======================== 通用 CRUD ========================
// UpdateFields 更新指定字段map 版,用于必须更新零值的场景)
func (d *taskDao) UpdateFields(ctx context.Context, id int64, data map[string]any) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).
Model(ctx, public.TableNameTask).
Data(data).
Where(entity.AsynchTaskCol.Id, id).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// execUpdate 内部辅助:执行原生 UPDATE自动补 updated_at
func execUpdate(ctx context.Context, sql string, args ...any) error {
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, args...)
return err
}
// ======================== 事务抢占 ========================
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`,
taskColumns, public.TableNameTask, where,
)
r, err := tx.GetOne(sql, args...)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
var task entity.AsynchTask
if err := r.Struct(&task); err != nil {
return err
}
now := time.Now()
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, task.Id,
)
if err != nil {
return err
}
tasks = []*entity.AsynchTask{&task}
return nil
})
return tasks, err
}
// ClaimPendingGlobal 批量抢占 pending 任务
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) {
2026-04-29 15:54:14 +08:00
if batchSize <= 0 {
batchSize = 1
}
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`,
taskColumns, public.TableNameTask, batchSize,
)
2026-04-29 15:54:14 +08:00
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
if err := r.Structs(&tasks); err != nil {
return err
}
now := time.Now()
for _, t := range tasks {
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, t.Id,
)
2026-04-29 15:54:14 +08:00
if err != nil {
return err
}
}
return nil
})
return tasks, err
2026-04-29 15:54:14 +08:00
}
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) {
2026-05-12 13:45:08 +08:00
if taskID == "" {
return nil, nil
}
tasks, err := claimTasks(ctx, "AND task_id = ?", taskID)
if err != nil || len(tasks) == 0 {
return nil, err
}
return tasks[0], nil
2026-05-12 13:45:08 +08:00
}
// ======================== 业务更新方法 ========================
// RollbackToPendingGlobal 回滚到 pending 状态
2026-04-29 15:54:14 +08:00
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
// state=0 可能被 OmitEmpty 跳过,所以用原生 SQL + 条件 state=1 防并发
return execUpdate(ctx,
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
id,
)
2026-04-29 15:54:14 +08:00
}
// IncRetryCountGlobal 重试计数 +1
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
return execUpdate(ctx,
fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask),
id,
)
2026-04-29 15:54:14 +08:00
}
// RequeueForRetryGlobal 重新入队
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
return execUpdate(ctx,
fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask),
enqueueAt, id,
)
2026-04-29 15:54:14 +08:00
}
// ======================== 列表查询 ========================
// ListExpiredDownloadedGlobal 查询已过期下载的任务
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx,
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask),
gtime.Now(), clampLimit(limit, 200),
)
2026-04-29 15:54:14 +08:00
}
// ListFailedRetryableGlobal 查询可重试的失败任务
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx,
fmt.Sprintf(
`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`,
public.TableNameTask, public.TableNameModel,
),
clampLimit(limit, 200),
)
2026-04-29 15:54:14 +08:00
}
// ListFailedExhaustedGlobal 查询重试次数耗尽的失败任务
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx,
fmt.Sprintf(
`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`,
public.TableNameTask, public.TableNameModel,
),
clampLimit(limit, 200),
)
}
// ListTimeoutTasksGlobal 查询超时任务
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx,
fmt.Sprintf(
`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`,
public.TableNameTask, public.TableNameModel,
),
clampLimit(limit, 200),
)
}
// HardDeleteByIDGlobal 物理删除任务
2026-04-29 15:54:14 +08:00
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
return execUpdate(ctx,
fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask),
id,
)
2026-04-29 15:54:14 +08:00
}
// ======================== 内部辅助 ========================
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
2026-04-29 15:54:14 +08:00
if err != nil {
return nil, err
}
var list []*entity.AsynchTask
if err = r.Structs(&list); err != nil {
return nil, err
}
return list, nil
}
func clampLimit(limit, defaultVal int) int {
if limit <= 0 {
return defaultVal
}
return limit
}