2026-04-29 15:54:14 +08:00
package dao
import (
"context"
"fmt"
"time"
2026-05-15 14:56:26 +08:00
"model-gateway/consts/public"
"model-gateway/model/entity"
2026-04-29 15:54:14 +08:00
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
2026-06-08 18:01:53 +08:00
// ======================== 查询辅助 ========================
// taskColumns 查询用的公共字段
const taskColumns = ` id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file `
// ======================== 事务抢占 ========================
// claimTasks 事务内抢占任务并更新 state=1
func claimTasks ( ctx context . Context , where string , args ... any ) ( [ ] * entity . AsynchTask , error ) {
var tasks [ ] * entity . AsynchTask
err := gfdb . DB ( ctx ) . Transaction ( ctx , func ( ctx context . Context , tx gdb . TX ) error {
sql := fmt . Sprintf ( ` SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED ` , taskColumns , public . TableNameTask , where )
r , err := tx . GetOne ( sql , args ... )
if err != nil {
return err
}
if r . IsEmpty ( ) {
return nil
}
var task entity . AsynchTask
if err := r . Struct ( & task ) ; err != nil {
return err
}
now := time . Now ( )
_ , err = tx . Exec ( fmt . Sprintf ( ` UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=? ` , public . TableNameTask ) , now , now , task . Id )
if err != nil {
return err
}
tasks = [ ] * entity . AsynchTask { & task }
return nil
} )
return tasks , err
}
// ClaimPendingGlobal 批量抢占 pending 任务
func ( d * taskDao ) ClaimPendingGlobal ( ctx context . Context , batchSize int ) ( [ ] * entity . AsynchTask , error ) {
2026-04-29 15:54:14 +08:00
if batchSize <= 0 {
batchSize = 1
}
2026-06-08 18:01:53 +08:00
var tasks [ ] * entity . AsynchTask
err := gfdb . DB ( ctx ) . Transaction ( ctx , func ( ctx context . Context , tx gdb . TX ) error {
sql := fmt . Sprintf ( ` SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED ` , taskColumns , public . TableNameTask , batchSize )
2026-04-29 15:54:14 +08:00
r , err := tx . GetAll ( sql )
if err != nil {
return err
}
if r . IsEmpty ( ) {
return nil
}
if err := r . Structs ( & tasks ) ; err != nil {
return err
}
now := time . Now ( )
for _ , t := range tasks {
2026-06-08 18:01:53 +08:00
_ , err = tx . Exec ( fmt . Sprintf ( ` UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=? ` , public . TableNameTask ) , now , now , t . Id )
2026-04-29 15:54:14 +08:00
if err != nil {
return err
}
}
return nil
} )
2026-06-08 18:01:53 +08:00
return tasks , err
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
func ( d * taskDao ) ClaimPendingByTaskIDGlobal ( ctx context . Context , taskID string ) ( * entity . AsynchTask , error ) {
2026-05-12 13:45:08 +08:00
if taskID == "" {
return nil , nil
}
2026-06-08 18:01:53 +08:00
tasks , err := claimTasks ( ctx , "AND task_id = ?" , taskID )
if err != nil || len ( tasks ) == 0 {
return nil , err
}
return tasks [ 0 ] , nil
2026-05-12 13:45:08 +08:00
}
2026-06-08 18:01:53 +08:00
// ======================== 更新辅助 ========================
func execSQL ( ctx context . Context , sql string , args ... any ) error {
_ , err := gfdb . DB ( ctx ) . Exec ( ctx , sql , args ... )
return err
}
// updateTask 通用更新
func updateTask ( ctx context . Context , id int64 , data entity . AsynchTask ) error {
2026-05-27 09:36:25 +08:00
_ , err := gfdb . DB ( ctx ) . Model ( ctx , public . TableNameTask ) . OmitEmpty ( ) .
2026-06-08 18:01:53 +08:00
Where ( entity . AsynchTaskCol . Id , id ) . Data ( data ) . Update ( )
2026-04-29 15:54:14 +08:00
return err
}
2026-06-08 18:01:53 +08:00
// UpdateSuccessGlobal 更新任务成功
func ( d * taskDao ) UpdateSuccessGlobal ( ctx context . Context , t * entity . AsynchTask ) error {
return updateTask ( ctx , t . Id , entity . AsynchTask {
State : 2 ,
OssFile : t . OssFile ,
FileType : t . FileType ,
TextResult : t . TextResult ,
FileSize : t . FileSize ,
ErrorMsg : "" ,
FinishedAt : gtime . Now ( ) ,
Phase : 0 ,
TmpFile : "" ,
ExpendTokens : t . ExpendTokens ,
DurationSeconds : t . DurationSeconds ,
} )
}
2026-05-27 09:36:25 +08:00
// UpdateFailedGlobal 模型调用失败
func ( d * taskDao ) UpdateFailedGlobal ( ctx context . Context , t * entity . AsynchTask ) error {
2026-06-08 18:01:53 +08:00
return updateTask ( ctx , t . Id , entity . AsynchTask {
State : 3 ,
ErrorMsg : t . ErrorMsg ,
FinishedAt : gtime . Now ( ) ,
Phase : 0 ,
TmpFile : "" ,
TextResult : t . TextResult ,
DurationSeconds : t . DurationSeconds ,
} )
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// UpdateFailedKeepTmpGlobal OSS 上传失败
2026-04-29 15:54:14 +08:00
func ( d * taskDao ) UpdateFailedKeepTmpGlobal ( ctx context . Context , id int64 , errorMsg string ) error {
2026-06-08 18:01:53 +08:00
return execSQL ( ctx , fmt . Sprintf ( ` UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=? ` , public . TableNameTask ) , errorMsg , gtime . Now ( ) , gtime . Now ( ) , id )
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// UpdateTmpAfterModelGlobal 写临时文件
2026-04-29 15:54:14 +08:00
func ( d * taskDao ) UpdateTmpAfterModelGlobal ( ctx context . Context , id int64 , tmpFile string ) error {
2026-06-08 18:01:53 +08:00
return execSQL ( ctx , fmt . Sprintf ( ` UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=? ` , public . TableNameTask ) , tmpFile , id )
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// RollbackToPendingGlobal 回滚
2026-04-29 15:54:14 +08:00
func ( d * taskDao ) RollbackToPendingGlobal ( ctx context . Context , id int64 ) error {
2026-06-08 18:01:53 +08:00
return execSQL ( ctx , fmt . Sprintf ( ` UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1 ` , public . TableNameTask ) , id )
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// IncRetryCountGlobal 重试计数+1
func ( d * taskDao ) IncRetryCountGlobal ( ctx context . Context , id int64 ) error {
return execSQL ( ctx , fmt . Sprintf ( ` UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=? ` , public . TableNameTask ) , id )
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// RequeueForRetryGlobal 重新入队
func ( d * taskDao ) RequeueForRetryGlobal ( ctx context . Context , id int64 , enqueueAt time . Time ) error {
return execSQL ( ctx , fmt . Sprintf ( ` UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL ` , public . TableNameTask ) , enqueueAt , id )
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// ======================== 列表查询 ========================
// ListExpiredDownloadedGlobal
func ( d * taskDao ) ListExpiredDownloadedGlobal ( ctx context . Context , limit int ) ( [ ] * entity . AsynchTask , error ) {
return queryTasks ( ctx , fmt . Sprintf ( ` SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ? ` , public . TableNameTask ) , gtime . Now ( ) , clampLimit ( limit , 200 ) )
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// ListFailedRetryableGlobal
func ( d * taskDao ) ListFailedRetryableGlobal ( ctx context . Context , limit int ) ( [ ] * entity . AsynchTask , error ) {
return queryTasks ( ctx , fmt . Sprintf ( ` SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ? ` , public . TableNameTask , public . TableNameModel ) , clampLimit ( limit , 200 ) )
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// ListFailedExhaustedGlobal
func ( d * taskDao ) ListFailedExhaustedGlobal ( ctx context . Context , limit int ) ( [ ] * entity . AsynchTask , error ) {
return queryTasks ( ctx , fmt . Sprintf ( ` SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ? ` , public . TableNameTask , public . TableNameModel ) , clampLimit ( limit , 200 ) )
}
// ListTimeoutTasksGlobal
func ( d * taskDao ) ListTimeoutTasksGlobal ( ctx context . Context , limit int ) ( [ ] * entity . AsynchTask , error ) {
return queryTasks ( ctx , fmt . Sprintf ( ` SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ? ` , public . TableNameTask , public . TableNameModel ) , clampLimit ( limit , 200 ) )
}
// HardDeleteByIDGlobal
2026-04-29 15:54:14 +08:00
func ( d * taskDao ) HardDeleteByIDGlobal ( ctx context . Context , id int64 ) error {
2026-06-08 18:01:53 +08:00
return execSQL ( ctx , fmt . Sprintf ( ` DELETE FROM %s WHERE id=? ` , public . TableNameTask ) , id )
2026-04-29 15:54:14 +08:00
}
2026-06-08 18:01:53 +08:00
// ======================== 内部辅助 ========================
func queryTasks ( ctx context . Context , sql string , args ... any ) ( [ ] * entity . AsynchTask , error ) {
r , err := gfdb . DB ( ctx ) . GetAll ( ctx , sql , args ... )
2026-04-29 15:54:14 +08:00
if err != nil {
return nil , err
}
2026-06-08 18:01:53 +08:00
var list [ ] * entity . AsynchTask
2026-04-29 15:54:14 +08:00
err = r . Structs ( & list )
2026-06-08 18:01:53 +08:00
return list , err
}
func clampLimit ( limit , defaultVal int ) int {
if limit <= 0 {
return defaultVal
}
return limit
}
// UpdateColumns 更新指定字段(结构体版)
func ( d * taskDao ) UpdateColumns ( ctx context . Context , id int64 , data entity . AsynchTask ) error {
_ , err := gfdb . DB ( ctx ) . Model ( ctx , public . TableNameTask ) . OmitEmpty ( ) .
Where ( entity . AsynchTaskCol . Id , id ) .
Data ( data ) .
Update ( )
return err
2026-04-29 15:54:14 +08:00
}