Files
model-gateway/dao/task_dao_bg.go

221 lines
7.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
// ======================== 查询辅助 ========================
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
// ======================== 通用 CRUD ========================
// UpdateFields 更新指定字段map 版,用于必须更新零值的场景)
func (d *taskDao) UpdateFields(ctx context.Context, id int64, data map[string]any) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).
Model(ctx, public.TableNameTask).
Data(data).
Where(entity.AsynchTaskCol.Id, id).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// execUpdate 内部辅助:执行原生 UPDATE自动补 updated_at
func execUpdate(ctx context.Context, sql string, args ...any) error {
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, args...)
return err
}
// ======================== 事务抢占 ========================
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`,
taskColumns, public.TableNameTask, where,
)
r, err := tx.GetOne(sql, args...)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
var task entity.AsynchTask
if err := r.Struct(&task); err != nil {
return err
}
now := time.Now()
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, task.Id,
)
if err != nil {
return err
}
tasks = []*entity.AsynchTask{&task}
return nil
})
return tasks, err
}
// ClaimPendingGlobal 批量抢占 pending 任务
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) {
if batchSize <= 0 {
batchSize = 1
}
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`,
taskColumns, public.TableNameTask, batchSize,
)
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
if err := r.Structs(&tasks); err != nil {
return err
}
now := time.Now()
for _, t := range tasks {
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, t.Id,
)
if err != nil {
return err
}
}
return nil
})
return tasks, err
}
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) {
if taskID == "" {
return nil, nil
}
tasks, err := claimTasks(ctx, "AND task_id = ?", taskID)
if err != nil || len(tasks) == 0 {
return nil, err
}
return tasks[0], nil
}
// ======================== 业务更新方法 ========================
// RollbackToPendingGlobal 回滚到 pending 状态
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
// state=0 可能被 OmitEmpty 跳过,所以用原生 SQL + 条件 state=1 防并发
return execUpdate(ctx,
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
id,
)
}
// IncRetryCountGlobal 重试计数 +1
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
return execUpdate(ctx,
fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask),
id,
)
}
// RequeueForRetryGlobal 重新入队
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
return execUpdate(ctx,
fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask),
enqueueAt, id,
)
}
// ======================== 列表查询 ========================
// ListExpiredDownloadedGlobal 查询已过期下载的任务
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx,
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask),
gtime.Now(), clampLimit(limit, 200),
)
}
// ListFailedRetryableGlobal 查询可重试的失败任务
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx,
fmt.Sprintf(
`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`,
public.TableNameTask, public.TableNameModel,
),
clampLimit(limit, 200),
)
}
// ListFailedExhaustedGlobal 查询重试次数耗尽的失败任务
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx,
fmt.Sprintf(
`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`,
public.TableNameTask, public.TableNameModel,
),
clampLimit(limit, 200),
)
}
// ListTimeoutTasksGlobal 查询超时任务
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx,
fmt.Sprintf(
`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`,
public.TableNameTask, public.TableNameModel,
),
clampLimit(limit, 200),
)
}
// HardDeleteByIDGlobal 物理删除任务
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
return execUpdate(ctx,
fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask),
id,
)
}
// ======================== 内部辅助 ========================
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
if err != nil {
return nil, err
}
var list []*entity.AsynchTask
if err = r.Structs(&list); err != nil {
return nil, err
}
return list, nil
}
func clampLimit(limit, defaultVal int) int {
if limit <= 0 {
return defaultVal
}
return limit
}