refactor(task): 重构异步任务处理流程

This commit is contained in:
2026-05-27 09:36:25 +08:00
parent a28fcbaee9
commit e487b4bb5e
9 changed files with 305 additions and 231 deletions

View File

@@ -35,7 +35,8 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
} else {
for _, t := range list {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "任务超时自动失败")
t.ErrorMsg = "任务超时自动失败"
_ = dao.Task.UpdateFailedGlobal(ctx, t)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
}
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"mime/multipart"
"model-gateway/common/util"
@@ -15,7 +16,7 @@ import (
"github.com/gogf/gf/v2/util/guid"
)
type uploadFileResponse struct {
type UploadFileResponse struct {
FileURL string `json:"fileURL"` // 文件 URL
FileSize int `json:"fileSize"` // 文件大小(字节)
FileName string `json:"fileName"` // 文件名
@@ -23,7 +24,7 @@ type uploadFileResponse struct {
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
}
func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *UploadFileResponse, err error) {
// multipart
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
@@ -39,41 +40,43 @@ func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileEx
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return "", err
return nil, err
}
if _, err := part.Write(data); err != nil {
return "", err
return nil, err
}
contentType := writer.FormDataContentType()
if err = writer.Close(); err != nil {
return "", err
return nil, err
}
headers := util.ForwardHeaders(ctx)
headers["Content-Type"] = contentType
//fullURL := "oss/file/uploadFile"
fullURL := "oss/file/uploadFile"
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
var resp uploadFileResponse
var resp UploadFileResponse
if err = commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
return "", err
return nil, err
}
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
return resp.FileURL, nil
if &resp == nil {
return nil, errors.New("[OSS] 上传文件失败")
}
g.Log().Infof(ctx, "[OSS] 上传成功 url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
return &resp, nil
}
// CallbackPayload 回调请求体
type CallbackPayload struct {
TaskId string `json:"task_id"`
State int `json:"state"`
OssFile string `json:"oss_file"`
FileType string `json:"file_type"`
Text string `json:"text"`
ErrorMsg string `json:"error_msg"`
TaskId string `json:"task_id"`
State int `json:"state"`
OssFile string `json:"oss_file"`
FileType string `json:"file_type"`
Messages map[string]any `json:"messages"`
ErrorMsg string `json:"error_msg"`
}
// TriggerCallback 任务成功后的回调
// TriggerCallback 任务的回调
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
headers := util.ForwardHeaders(ctx)
var resp struct{}
@@ -82,7 +85,7 @@ func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
State: t.State,
OssFile: t.OssFile,
FileType: t.FileType,
Text: t.TextResult,
Messages: t.TextResult,
ErrorMsg: t.ErrorMsg,
}
jsonData, err := json.Marshal(payload)
@@ -103,8 +106,8 @@ func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
// PromptsCallbackPayload 提示词回调请求体
type PromptsCallbackPayload struct {
EpicycleId int64 `json:"epicycleId"`
Text string `json:"text"`
EpicycleId int64 `json:"epicycleId"`
Messages map[string]any `json:"messages"`
}
// TriggerPromptsCallback 任务成功后的提示词回调
@@ -114,7 +117,7 @@ func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleI
var resp struct{}
payload := PromptsCallbackPayload{
EpicycleId: epicycleId,
Text: t.TextResult,
Messages: t.TextResult,
}
jsonData, err := json.Marshal(payload)
if err != nil {

View File

@@ -103,7 +103,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
// 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
// 一旦任务进入 running/success/failed/downloaded就停止轮询避免一直空转。
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req.EpicycleId)
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
@@ -112,7 +112,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
// - 这样不会无限轮询runWork 仍负责处理积压队列和未处理到的任务
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, epicycleId int64) {
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) {
if taskID == "" {
return
}
@@ -139,7 +139,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
}
switch t.State {
case 0:
if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil {
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)

View File

@@ -16,9 +16,9 @@ import (
"model-gateway/dao"
"model-gateway/model/entity"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/tidwall/gjson"
)
var AsyncWorker = &asyncWorker{}
@@ -50,11 +50,12 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt
for _, t := range tasks {
task := t
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task, 0)
w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0})
done <- struct{}{}
}, func(ctx context.Context, e error) {
if e != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", e))
task.ErrorMsg = fmt.Sprintf("worker panic: %v", e)
_ = dao.Task.UpdateFailedGlobal(ctx, task)
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
}
done <- struct{}{}
@@ -71,7 +72,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
// - 只定向抢占当前 taskId 对应的 pending 任务
// - 若任务已被其它 worker 抢走/已不在 pending则直接返回
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId int64) error {
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, req *dto.CreateTaskReq) error {
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return err
@@ -79,163 +80,175 @@ func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId
if task == nil {
return nil
}
w.handleOne(ctx, task, epicycleId)
w.handleOne(ctx, task, req)
return nil
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
// 从任务入库的 request_payload 里恢复 payload + headers
payload, headers := util.ParseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = util.SetTaskHeadersToCtx(ctx, headers)
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *dto.CreateTaskReq) {
payload := util.ParseStoredPayload(t.RequestPayload)
maxRetry := 0 // 后面从 model 取
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", t.TaskID, t.ModelName)
// 1) 取模型配置
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
errMsg := "模型不存在或未启用"
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, errMsg)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = errMsg
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ================================
// 1) 取模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil {
w.failTask(ctx, t, "模型不存在或未启用")
return
}
maxRetry = model.RetryTimes
// 2) 分布式并发
// 2) 分布式并发
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
leaseSeconds := int64(3600)
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency)
acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds)
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
acquired, err := acquireSemaphore(ctx, semKey, maxC, 3600)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ================================
w.failTask(ctx, t, err.Error())
return
}
if !acquired {
// 并发满了:放回排队,不回调(不是失败)
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", t.TaskID)
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() {
_ = releaseSemaphore(ctx, semKey)
}()
defer func() { _ = releaseSemaphore(ctx, semKey) }()
// 3) 调用模型服务
// 3) request_payload 校验
if payload == nil {
payload = map[string]any{
"taskId": t.TaskID,
"inputRef": t.InputRef,
w.failTask(ctx, t, "request_payload 为空")
return
}
// 4) 调用模型(不重试,失败直接回调)
textResult, err := w.callModel(ctx, t, model, payload)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
// 5) 上传 OSS可重试
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
}
oss, err = w.uploadOSS(ctx, t)
if err == nil {
break
}
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
w.failTask(ctx, t, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
}
}
var (
data []byte
contentType string
ext string
textResult string
)
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
// 6) 解析校验(可重试,失败重新调模型)
if req.BuildType == 1 {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
}
err = util.ValidatePromptResult(textResult, model.RequestMapping)
if err == nil {
break
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
return
}
// 重新调模型
newResult, modelErr := w.callModel(ctx, t, model, payload)
if modelErr != nil {
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, modelErr)
continue
}
textResult = newResult
}
}
// 7) 成功回调
t.State = 2
t.OssFile = oss.FileAddressPrefix + oss.FileURL
t.FileType = oss.FileFormat
t.TextResult = textResult
t.FileSize = int64(oss.FileSize)
t.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, textResult))
if err = dao.Task.UpdateSuccessGlobal(ctx, t); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", t.TaskID, err)
return
}
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
}
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL)
_ = os.Remove(t.TmpFile)
}
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
// callModel 调用模型 + 检测文件类型 + 保存临时文件
func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
var data []byte
var contentType, ext, textResult string
var err error
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = os.ReadFile(t.TmpFile)
if err == nil && len(data) > 0 {
contentType, ext = util.DetectFileType(data)
} else {
if err != nil || len(data) == 0 {
data = nil
}
}
if data == nil {
// 统计
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
// 核心调用
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
return nil, err
}
contentType, ext = util.DetectFileType(data)
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
if err == nil && tmpPath != "" {
tmpPath, tmpErr := saveTmpResult(t.TaskID, data, ext)
if tmpErr == nil && tmpPath != "" {
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
// 4) 存储 OSS
ossURL, err := gateway.UploadByTask(ctx, t, data, ext, contentType)
contentType, ext = util.DetectFileType(data)
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
return gjson.New(textResult).Map(), nil
}
// uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile)
if err != nil {
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ OSS失败不回调还会重试 ============
// 注意OSS失败保留临时文件下次重试所以这里不触发最终回调
// 如果已经重试多次还没成功,需要在任务超时或超过最大重试次数时才回调失败
return
return nil, fmt.Errorf("读取临时文件失败: %w", err)
}
_, ext := util.DetectFileType(data)
return gateway.UploadByTask(ctx, data, ext)
}
// 5) 更新任务状态成功
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
if err = dao.Task.UpdateSuccessGlobal(
ctx,
t.Id,
ossURL,
fileType,
textResult,
int64(len(data)),
nil,
GetExpendTokens(m.ResponseTokenField, textResult),
); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
// 成功/失败均不再占用 queue_limit
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg string) {
t.State = 3
t.ErrorMsg = errMsg
_ = dao.Task.UpdateFailedGlobal(ctx, t)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// 6) 成功回调
t.State = 2
t.OssFile = ossURL
t.FileType = fileType
t.TextResult = textResult
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ============ 如果有 epicycleId也触发业务回调 ============
if epicycleId != 0 {
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
}
// 成功后清理临时文件
_ = os.Remove(t.TmpFile)
}
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
@@ -261,11 +274,11 @@ func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}
// GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值
func GetExpendTokens(responseTokenField string, textResult string) int {
value := gjson.Get(textResult, responseTokenField)
if value.Exists() {
return int(value.Int())
// GetExpendTokens 根据映射路径从 result 中提取消耗 token 值
func GetExpendTokens(responseTokenField string, result map[string]any) int {
val := gjson.New(result).Get(responseTokenField)
if val.IsNil() {
return 0
}
return len(textResult)
return val.Int()
}