refactor(asynch): 重构异步模型配置和队列管理

This commit is contained in:
2026-06-02 20:26:45 +08:00
parent c7e9eb889b
commit 52124385a1
18 changed files with 726 additions and 1006 deletions

View File

@@ -3,6 +3,7 @@ package task
import (
"context"
"errors"
"fmt"
"model-gateway/common/util"
"model-gateway/service/queue"
"time"
@@ -11,9 +12,12 @@ import (
"model-gateway/model/dto"
"model-gateway/model/entity"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
"github.com/google/uuid"
)
@@ -25,22 +29,29 @@ type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
taskID := uuid.NewString()
// 1) 检查模型配置
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
// 1) 检查模型配置,并且获取模型
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId,
Creator: userInfo.UserName,
},
ModelName: req.ModelName,
})
if err != nil {
return nil, err
}
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
if model == nil || (model.Enabled != nil && *model.Enabled != 1) {
return nil, errors.New("模型不存在或未启用")
}
// 2) 排队上限严格控制Redis 原子闸门)
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
if limit > 0 {
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
if err != nil {
return nil, err
}
@@ -50,9 +61,13 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
}
// 3) 插入任务记录
if model.IsAsync != nil && *model.IsAsync == 1 {
// 异步调用:注入回调地址后提交,拿到 task_id 轮询
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
}
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": util.ForwardHeaders(ctx),
"headers": util.ParseHeadMsgHeaders(model.HeadMsg),
"body": req.RequestPayload,
}
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
ModelName: req.ModelName,
@@ -60,13 +75,12 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
State: 0,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: m.ApiKey,
ModelKey: model.ApiKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
EpicycleId: req.EpicycleId,
})
if err != nil {
// 入库失败:回滚闸门占位
if err != nil { // 入库失败:回滚闸门占位
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
}
@@ -100,75 +114,96 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
},
})
// 5) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
// 一旦任务进入 running/success/failed/downloaded就停止轮询避免一直空转。
go s.pollAndRunUntilPicked(util.AsyncCtx(ctx), taskID, req)
// 5) 获取任务信息
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return nil, err
}
if task == nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// pollAndRunUntilPicked 定向轮询执行刚创建的任务
// - 目标:尽快把刚创建的任务拉起来执行
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
// - 不会无限轮询runWork 仍负责处理积压队列和未处理到的任务
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) {
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds", 5).Int()
pollTimeout := g.Cfg().MustGet(ctx, "asynch.worker.pollTimeoutSeconds", 300).Int()
pollCtx, cancel := context.WithTimeout(ctx, time.Duration(pollTimeout)*time.Second)
defer cancel()
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
// 1. 查本地任务
task, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: req.TaskID,
})
if err != nil || task == nil {
return nil, fmt.Errorf("任务不存在: %s", req.TaskID)
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
// 2. 成功:取 video_url 和 usage
if req.Status == "succeeded" {
result := map[string]any{
"video_url": req.Content["video_url"],
"usage": req.Usage,
}
NotifyAsyncResult(req.TaskID, result, nil)
return &dto.ModelTaskCallbackRes{Success: true}, nil
}
g.Log().Infof(ctx, "[任务自动执行][开始] taskId=%s 轮询间隔=%ds 超时=%ds", taskID, interval, pollTimeout)
tryRun := func() bool {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
// 3. 失败/过期
if req.Status == "failed" || req.Status == "expired" {
NotifyAsyncResult(req.TaskID, nil, fmt.Errorf(req.Status))
return &dto.ModelTaskCallbackRes{Success: true}, nil
}
return &dto.ModelTaskCallbackRes{Success: true}, nil
}
// QueryPendingTasks 批量轮询进行中的异步任务
func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (*dto.QueryPendingTasksRes, error) {
limit := req.Limit
if limit <= 0 {
limit = g.Cfg().MustGet(ctx, "asynch.queryPending.limit", 10).Int()
}
// 1. 查 state=1执行中的异步任务
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit)
if err != nil {
return nil, err
}
// 2. 逐个查询
var results []dto.QueryTaskItem
for _, t := range tasks {
// 拿到模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil || model.QueryConfig == nil {
continue
}
result, err := util.PullTaskResult(ctx, t.TaskID, model.QueryConfig)
if err != nil {
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=查询失败 err=%v", taskID, err)
return true
}
if t == nil {
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=任务不存在", taskID)
return true
g.Log().Warningf(ctx, "[轮询] 查询失败 taskID=%s err=%v", t.TaskID, err)
continue
}
switch t.State {
case 0:
//RunByTaskID 尝试执行任务
if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil {
g.Log().Warningf(ctx, "[任务自动执行][重试] taskId=%s 状态=待处理 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[任务自动执行][已触发] taskId=%s 状态=待处理", taskID)
}
return false
case 1:
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=执行中", taskID)
return true
case 2, 3, 4:
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=终态 状态=%d", taskID, t.State)
return true
default:
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=未知状态 状态=%d", taskID, t.State)
return true
}
}
// 立即尝试一次
if stop := tryRun(); stop {
return
}
for {
select {
case <-pollCtx.Done():
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=轮询超时", taskID)
return
case <-ticker.C:
if stop := tryRun(); stop {
return
}
status := gconv.String(result["status"])
item := dto.QueryTaskItem{
TaskID: t.TaskID,
Status: status,
Content: result["content"].(map[string]any),
Usage: result["usage"].(map[string]any),
}
results = append(results, item)
// 如果任务完成,通知等待通道
if status == "succeeded" || status == "failed" || status == "expired" {
NotifyAsyncResult(t.TaskID, result["content"].(map[string]any), nil)
}
}
return &dto.QueryPendingTasksRes{
Total: len(results),
Results: results,
}, nil
}
// GetResult 获取任务结果