2026-04-29 15:54:14 +08:00
package dao
import (
"context"
"fmt"
"time"
2026-05-15 14:56:26 +08:00
"model-gateway/consts/public"
"model-gateway/model/entity"
2026-04-29 15:54:14 +08:00
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
// ClaimPendingGlobal 后台任务使用:全局抢占 pending 任务(不加 tenant 过滤)
func ( d * taskDao ) ClaimPendingGlobal ( ctx context . Context , batchSize int ) ( tasks [ ] * entity . AsynchTask , err error ) {
if batchSize <= 0 {
batchSize = 1
}
err = gfdb . DB ( ctx ) . Transaction ( ctx , func ( ctx context . Context , tx gdb . TX ) error {
sql := fmt . Sprintf (
2026-05-27 09:36:25 +08:00
` SELECT id , tenant_id , creator , model_name , task_id , biz_name , callback_url , model_key , retry_count , input_ref , request_payload , phase , tmp_file
2026-04-29 15:54:14 +08:00
FROM % s
WHERE deleted_at IS NULL AND state = 0
ORDER BY enqueue_at ASC
LIMIT % d
FOR UPDATE SKIP LOCKED ` ,
public . TableNameTask ,
batchSize ,
)
r , err := tx . GetAll ( sql )
if err != nil {
return err
}
if r . IsEmpty ( ) {
tasks = nil
return nil
}
if err := r . Structs ( & tasks ) ; err != nil {
return err
}
now := time . Now ( )
for _ , t := range tasks {
_ , err = tx . Exec (
fmt . Sprintf ( ` UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=? ` , public . TableNameTask ) ,
now , now , t . Id ,
)
if err != nil {
return err
}
}
return nil
} )
return
}
2026-05-12 13:45:08 +08:00
// ClaimPendingByTaskIDGlobal 按 task_id 定向抢占单个 pending 任务(不加 tenant 过滤)
// 用于 createTask 创建成功后立即异步尝试执行当前任务,避免只依赖后续 runWork 扫描队列。
func ( d * taskDao ) ClaimPendingByTaskIDGlobal ( ctx context . Context , taskID string ) ( task * entity . AsynchTask , err error ) {
if taskID == "" {
return nil , nil
}
err = gfdb . DB ( ctx ) . Transaction ( ctx , func ( ctx context . Context , tx gdb . TX ) error {
sql := fmt . Sprintf (
2026-05-27 09:36:25 +08:00
` SELECT id , tenant_id , creator , model_name , task_id , biz_name , callback_url , model_key , retry_count , input_ref , request_payload , phase , tmp_file
2026-05-12 13:45:08 +08:00
FROM % s
WHERE deleted_at IS NULL AND state = 0 AND task_id = ?
LIMIT 1
FOR UPDATE SKIP LOCKED ` ,
public . TableNameTask ,
)
r , err := tx . GetOne ( sql , taskID )
if err != nil {
return err
}
if r . IsEmpty ( ) {
task = nil
return nil
}
if err := r . Struct ( & task ) ; err != nil {
return err
}
now := time . Now ( )
_ , err = tx . Exec (
fmt . Sprintf ( ` UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=? ` , public . TableNameTask ) ,
now , now , task . Id ,
)
return err
} )
return
}
2026-05-27 09:36:25 +08:00
// UpdateSuccessGlobal 更新任务成功
func ( d * taskDao ) UpdateSuccessGlobal ( ctx context . Context , t * entity . AsynchTask ) error {
2026-04-29 15:54:14 +08:00
now := gtime . Now ( )
2026-05-27 09:36:25 +08:00
_ , err := gfdb . DB ( ctx ) . Model ( ctx , public . TableNameTask ) . OmitEmpty ( ) .
Where ( entity . AsynchTaskCol . Id , t . Id ) .
Data ( entity . AsynchTask {
State : 2 ,
OssFile : t . OssFile ,
FileType : t . FileType ,
TextResult : t . TextResult ,
FileSize : t . FileSize ,
ErrorMsg : "" ,
FinishedAt : now ,
Phase : 0 ,
TmpFile : "" ,
ExpendTokens : t . ExpendTokens ,
} ) .
Update ( )
2026-04-29 15:54:14 +08:00
return err
}
2026-05-27 09:36:25 +08:00
// UpdateFailedGlobal 模型调用失败
func ( d * taskDao ) UpdateFailedGlobal ( ctx context . Context , t * entity . AsynchTask ) error {
2026-04-29 15:54:14 +08:00
now := gtime . Now ( )
2026-05-27 09:36:25 +08:00
_ , err := gfdb . DB ( ctx ) . Model ( ctx , public . TableNameTask ) . OmitEmpty ( ) .
Where ( entity . AsynchTaskCol . Id , t . Id ) .
Data ( entity . AsynchTask {
State : 3 ,
ErrorMsg : t . ErrorMsg ,
FinishedAt : now ,
Phase : 0 ,
TmpFile : "" ,
} ) .
Update ( )
2026-04-29 15:54:14 +08:00
return err
}
// UpdateFailedKeepTmpGlobal OSS 上传失败:保留 phase/tmp_file, 下一轮仅重试 OSS 上传
func ( d * taskDao ) UpdateFailedKeepTmpGlobal ( ctx context . Context , id int64 , errorMsg string ) error {
now := gtime . Now ( )
_ , err := gfdb . DB ( ctx ) . Exec ( ctx ,
fmt . Sprintf ( ` UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=? ` , public . TableNameTask ) ,
errorMsg , now , now , id ,
)
return err
}
// UpdateTmpAfterModelGlobal 模型调用成功后,写入临时文件路径并标记 phase=1
func ( d * taskDao ) UpdateTmpAfterModelGlobal ( ctx context . Context , id int64 , tmpFile string ) error {
_ , err := gfdb . DB ( ctx ) . Exec ( ctx ,
fmt . Sprintf ( ` UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=? ` , public . TableNameTask ) ,
tmpFile , id ,
)
return err
}
func ( d * taskDao ) RollbackToPendingGlobal ( ctx context . Context , id int64 ) error {
_ , err := gfdb . DB ( ctx ) . Exec ( ctx ,
fmt . Sprintf ( ` UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1 ` , public . TableNameTask ) ,
id ,
)
return err
}
// ListExpiredDownloadedGlobal 获取已下载(state=4)且过期的任务,用于清理
func ( d * taskDao ) ListExpiredDownloadedGlobal ( ctx context . Context , limit int ) ( list [ ] * entity . AsynchTask , err error ) {
if limit <= 0 {
limit = 200
}
r , err := gfdb . DB ( ctx ) . GetAll ( ctx ,
fmt . Sprintf ( ` SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ? ` , public . TableNameTask ) ,
gtime . Now ( ) , limit ,
)
if err != nil {
return nil , err
}
err = r . Structs ( & list )
return
}
// ListFailedRetryableGlobal 获取失败(state=3)且仍可重试的任务
// retry_count 不含首次执行; retry_times 表示失败后最多再重试 N 次
func ( d * taskDao ) ListFailedRetryableGlobal ( ctx context . Context , limit int ) ( list [ ] * entity . AsynchTask , err error ) {
if limit <= 0 {
limit = 200
}
r , err := gfdb . DB ( ctx ) . GetAll ( ctx ,
fmt . Sprintf ( `
SELECT t . * ,
m . retry_queue_max_seconds AS retry_queue_max_seconds
FROM % s t
JOIN % s m
ON t . tenant_id = m . tenant_id
AND t . model_name = m . model_name
WHERE t . deleted_at IS NULL
AND t . state = 3
AND t . retry_count < m . retry_times
ORDER BY t . updated_at ASC
LIMIT ? ` , public . TableNameTask , public . TableNameModel ) ,
limit ,
)
if err != nil {
return nil , err
}
err = r . Structs ( & list )
return
}
// RequeueForRetryGlobal 将任务重新入队( state=0) , 并将 retry_count +1
// enqueueAt 用于控制重试任务在队列中的位置:
// - enqueueAt 越早, 越靠前( ClaimPendingGlobal 按 enqueue_at ASC 抢占)
func ( d * taskDao ) RequeueForRetryGlobal ( ctx context . Context , id int64 , enqueueAt time . Time ) error {
_ , err := gfdb . DB ( ctx ) . Exec ( ctx ,
fmt . Sprintf ( ` UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL ` , public . TableNameTask ) ,
enqueueAt , id ,
)
return err
}
// ListFailedExhaustedGlobal 获取失败(state=3)且超过重试次数的任务,用于硬删除
func ( d * taskDao ) ListFailedExhaustedGlobal ( ctx context . Context , limit int ) ( list [ ] * entity . AsynchTask , err error ) {
if limit <= 0 {
limit = 200
}
r , err := gfdb . DB ( ctx ) . GetAll ( ctx ,
fmt . Sprintf ( `
SELECT t . *
FROM % s t
JOIN % s m
ON t . tenant_id = m . tenant_id
AND t . model_name = m . model_name
WHERE t . deleted_at IS NULL
AND t . state = 3
AND t . retry_count >= m . retry_times
ORDER BY t . updated_at ASC
LIMIT ? ` , public . TableNameTask , public . TableNameModel ) ,
limit ,
)
if err != nil {
return nil , err
}
err = r . Structs ( & list )
return
}
// HardDeleteByIDGlobal 硬删除任务记录
func ( d * taskDao ) HardDeleteByIDGlobal ( ctx context . Context , id int64 ) error {
_ , err := gfdb . DB ( ctx ) . Exec ( ctx ,
fmt . Sprintf ( ` DELETE FROM %s WHERE id=? ` , public . TableNameTask ) ,
id ,
)
return err
}
// ListTimeoutTasksGlobal 根据模型配置 expected_seconds 判定超时任务:
// - state in (0,1)
// - 模型 expected_seconds > 0
// - now - created_at >= expected_seconds
func ( d * taskDao ) ListTimeoutTasksGlobal ( ctx context . Context , limit int ) ( list [ ] * entity . AsynchTask , err error ) {
if limit <= 0 {
limit = 200
}
r , err := gfdb . DB ( ctx ) . GetAll ( ctx ,
fmt . Sprintf ( `
SELECT t . *
FROM % s t
JOIN % s m
ON t . tenant_id = m . tenant_id
AND t . model_name = m . model_name
WHERE t . deleted_at IS NULL
AND t . state IN ( 0 , 1 )
AND m . expected_seconds > 0
AND t . created_at < ( NOW ( ) - ( m . expected_seconds || ' seconds ' ) : : interval )
LIMIT ? ` , public . TableNameTask , public . TableNameModel ) ,
limit ,
)
if err != nil {
return nil , err
}
err = r . Structs ( & list )
return
}