Files
model-gateway/dao/task_dao_bg.go

276 lines
7.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
// ClaimPendingGlobal 后台任务使用:全局抢占 pending 任务(不加 tenant 过滤)
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
if batchSize <= 0 {
batchSize = 1
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file
FROM %s
WHERE deleted_at IS NULL AND state = 0
ORDER BY enqueue_at ASC
LIMIT %d
FOR UPDATE SKIP LOCKED`,
public.TableNameTask,
batchSize,
)
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
tasks = nil
return nil
}
if err := r.Structs(&tasks); err != nil {
return err
}
now := time.Now()
for _, t := range tasks {
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, t.Id,
)
if err != nil {
return err
}
}
return nil
})
return
}
// ClaimPendingByTaskIDGlobal 按 task_id 定向抢占单个 pending 任务(不加 tenant 过滤)
// 用于 createTask 创建成功后立即异步尝试执行当前任务,避免只依赖后续 runWork 扫描队列。
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (task *entity.AsynchTask, err error) {
if taskID == "" {
return nil, nil
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file
FROM %s
WHERE deleted_at IS NULL AND state = 0 AND task_id = ?
LIMIT 1
FOR UPDATE SKIP LOCKED`,
public.TableNameTask,
)
r, err := tx.GetOne(sql, taskID)
if err != nil {
return err
}
if r.IsEmpty() {
task = nil
return nil
}
if err := r.Struct(&task); err != nil {
return err
}
now := time.Now()
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, task.Id,
)
return err
})
return
}
// UpdateSuccessGlobal 更新任务成功
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error {
now := gtime.Now()
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, t.Id).
Data(entity.AsynchTask{
State: 2,
OssFile: t.OssFile,
FileType: t.FileType,
TextResult: t.TextResult,
FileSize: t.FileSize,
ErrorMsg: "",
FinishedAt: now,
Phase: 0,
TmpFile: "",
ExpendTokens: t.ExpendTokens,
}).
Update()
return err
}
// UpdateFailedGlobal 模型调用失败
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error {
now := gtime.Now()
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, t.Id).
Data(entity.AsynchTask{
State: 3,
ErrorMsg: t.ErrorMsg,
FinishedAt: now,
Phase: 0,
TmpFile: "",
}).
Update()
return err
}
// UpdateFailedKeepTmpGlobal OSS 上传失败:保留 phase/tmp_file下一轮仅重试 OSS 上传
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
now := gtime.Now()
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask),
errorMsg, now, now, id,
)
return err
}
// UpdateTmpAfterModelGlobal 模型调用成功后,写入临时文件路径并标记 phase=1
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask),
tmpFile, id,
)
return err
}
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
id,
)
return err
}
// ListExpiredDownloadedGlobal 获取已下载(state=4)且过期的任务,用于清理
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 200
}
r, err := gfdb.DB(ctx).GetAll(ctx,
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask),
gtime.Now(), limit,
)
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// ListFailedRetryableGlobal 获取失败(state=3)且仍可重试的任务
// retry_count 不含首次执行retry_times 表示失败后最多再重试 N 次
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 200
}
r, err := gfdb.DB(ctx).GetAll(ctx,
fmt.Sprintf(`
SELECT t.*,
m.retry_queue_max_seconds AS retry_queue_max_seconds
FROM %s t
JOIN %s m
ON t.tenant_id = m.tenant_id
AND t.model_name = m.model_name
WHERE t.deleted_at IS NULL
AND t.state = 3
AND t.retry_count < m.retry_times
ORDER BY t.updated_at ASC
LIMIT ?`, public.TableNameTask, public.TableNameModel),
limit,
)
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// RequeueForRetryGlobal 将任务重新入队state=0并将 retry_count +1
// enqueueAt 用于控制重试任务在队列中的位置:
// - enqueueAt 越早越靠前ClaimPendingGlobal 按 enqueue_at ASC 抢占)
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask),
enqueueAt, id,
)
return err
}
// ListFailedExhaustedGlobal 获取失败(state=3)且超过重试次数的任务,用于硬删除
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 200
}
r, err := gfdb.DB(ctx).GetAll(ctx,
fmt.Sprintf(`
SELECT t.*
FROM %s t
JOIN %s m
ON t.tenant_id = m.tenant_id
AND t.model_name = m.model_name
WHERE t.deleted_at IS NULL
AND t.state = 3
AND t.retry_count >= m.retry_times
ORDER BY t.updated_at ASC
LIMIT ?`, public.TableNameTask, public.TableNameModel),
limit,
)
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// HardDeleteByIDGlobal 硬删除任务记录
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask),
id,
)
return err
}
// ListTimeoutTasksGlobal 根据模型配置 expected_seconds 判定超时任务:
// - state in (0,1)
// - 模型 expected_seconds > 0
// - now - created_at >= expected_seconds
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 200
}
r, err := gfdb.DB(ctx).GetAll(ctx,
fmt.Sprintf(`
SELECT t.*
FROM %s t
JOIN %s m
ON t.tenant_id = m.tenant_id
AND t.model_name = m.model_name
WHERE t.deleted_at IS NULL
AND t.state IN (0,1)
AND m.expected_seconds > 0
AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval)
LIMIT ?`, public.TableNameTask, public.TableNameModel),
limit,
)
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}