refactor(asynch): 重构异步模型配置和队列管理
This commit is contained in:
@@ -13,8 +13,8 @@ import (
|
||||
"model-gateway/service/queue"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -55,7 +55,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt
|
||||
for _, t := range tasks {
|
||||
task := t
|
||||
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
|
||||
w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0})
|
||||
//w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0})
|
||||
done <- struct{}{}
|
||||
}, func(ctx context.Context, e error) {
|
||||
if e != nil {
|
||||
@@ -74,185 +74,159 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
||||
// - 只定向抢占当前 taskId 对应的 pending 任务
|
||||
// - 若任务已被其它 worker 抢走/已不在 pending,则直接返回
|
||||
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, req *dto.CreateTaskReq) error {
|
||||
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
w.handleOne(ctx, task, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleOne 执行一次完整的任务
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *dto.CreateTaskReq) {
|
||||
payload := util.ParseStoredPayload(t.RequestPayload)
|
||||
maxRetry := 0 // 后面从 model 取
|
||||
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", t.TaskID, t.ModelName)
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) {
|
||||
body := util.GetModelBody(task.RequestPayload) //核心请求参数
|
||||
maxRetry := model.RetryTimes //重试次数
|
||||
|
||||
// 1) 获取模型配置
|
||||
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
||||
if err != nil || model == nil {
|
||||
w.failTask(ctx, t, "模型不存在或未启用")
|
||||
return
|
||||
}
|
||||
maxRetry = model.RetryTimes
|
||||
|
||||
// 2) 分布式并发控制
|
||||
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
|
||||
maxC := queue.GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
|
||||
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
|
||||
// 1) 分布式并发控制
|
||||
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
|
||||
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
|
||||
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
|
||||
if err != nil {
|
||||
w.failTask(ctx, t, err.Error())
|
||||
w.failTask(ctx, task, err.Error())
|
||||
return
|
||||
}
|
||||
if !acquired {
|
||||
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", t.TaskID)
|
||||
_ = w.rollbackToPending(ctx, t.Id)
|
||||
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
|
||||
_ = w.rollbackToPending(ctx, task.Id)
|
||||
return
|
||||
}
|
||||
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
|
||||
|
||||
// 3) request_payload 校验
|
||||
if payload == nil {
|
||||
w.failTask(ctx, t, "request_payload 为空")
|
||||
// 2) request_payload 校验
|
||||
if body == nil {
|
||||
w.failTask(ctx, task, "请求模型为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 4) 调用模型
|
||||
var textResult map[string]any
|
||||
if streamEnabled, _ := model.StreamConfig["enabled"].(bool); streamEnabled {
|
||||
rawBytes, modelErr := w.callModelRaw(ctx, t, model, payload)
|
||||
if modelErr != nil {
|
||||
w.failTask(ctx, t, modelErr.Error())
|
||||
// 3) 调用模型
|
||||
switch {
|
||||
case model.IsStream != nil && *model.IsStream == 1: // 流式调用
|
||||
rawBytes, err := w.callModelStream(ctx, task, model, body)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, err.Error())
|
||||
return
|
||||
}
|
||||
textResult, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||
// 解析流式结果
|
||||
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||
if err != nil {
|
||||
w.failTask(ctx, t, err.Error())
|
||||
w.failTask(ctx, task, err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
textResult, err = w.callModel(ctx, t, model, payload)
|
||||
case model.IsAsync != nil && *model.IsAsync == 1: // 异步调用:注入回调地址后提交,拿到 task_id 轮询
|
||||
// 异步调用:提交任务
|
||||
body, err = w.callModel(ctx, task, model, body)
|
||||
if err != nil {
|
||||
w.failTask(ctx, t, err.Error())
|
||||
w.failTask(ctx, task, err.Error())
|
||||
return
|
||||
}
|
||||
// 拿到 task_id,启动轮询
|
||||
taskID := gjson.New(body).Get(model.ResponseBody).String()
|
||||
body, err = util.PullTaskResult(ctx, taskID, model.QueryConfig)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, err.Error())
|
||||
return
|
||||
}
|
||||
default: // 同步调用
|
||||
body, err = w.callModel(ctx, task, model, body)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 5) 模型返回映射处理
|
||||
textResult, err = util.MapResponsePayload(model.ResponseMapping, textResult)
|
||||
// 5) 解析响应映射
|
||||
body, err = util.MapResponsePayload(model.ResponseMapping, body)
|
||||
if err != nil {
|
||||
w.failTask(ctx, t, err.Error())
|
||||
w.failTask(ctx, task, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 6) 保存临时文件(区分二进制音频和JSON文本)
|
||||
if audioData, ok := textResult["audio"].([]byte); ok {
|
||||
tmpPath, tmpErr := saveTmpResult(t.TaskID, audioData, ".mp3")
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
if t.TmpFile != "" {
|
||||
_ = os.Remove(t.TmpFile)
|
||||
}
|
||||
t.TmpFile = tmpPath
|
||||
t.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
||||
}
|
||||
} else {
|
||||
mappedBytes, _ := json.Marshal(textResult)
|
||||
if len(mappedBytes) > 0 {
|
||||
tmpPath, tmpErr := saveTmpResult(t.TaskID, mappedBytes, ".json")
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
if t.TmpFile != "" {
|
||||
_ = os.Remove(t.TmpFile)
|
||||
}
|
||||
t.TmpFile = tmpPath
|
||||
t.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
||||
}
|
||||
}
|
||||
// 5) 保存临时文件(通用工具方法)
|
||||
tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||||
}
|
||||
|
||||
// 7) 上传 OSS(可重试)
|
||||
// 6) 上传 OSS(可重试)
|
||||
var oss *gateway.UploadFileResponse
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
||||
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||
}
|
||||
oss, err = w.uploadOSS(ctx, t)
|
||||
oss, err = w.uploadOSS(ctx, task)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
|
||||
t.TaskID, attempt, maxRetry, err)
|
||||
task.TaskID, attempt, maxRetry, err)
|
||||
if attempt == maxRetry {
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||||
w.failTask(ctx, t, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error())
|
||||
w.failTask(ctx, task, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
//8) 解析校验(可重试,失败重新调模型)
|
||||
// 7) 解析校验(可重试,失败重新调模型)
|
||||
if req.BuildType == 1 {
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||
}
|
||||
// 6.1) 校验数据
|
||||
err = util.ValidatePromptResult(textResult, model)
|
||||
err = util.ValidatePromptResult(body, model)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
|
||||
t.TaskID, attempt, maxRetry, err)
|
||||
task.TaskID, attempt, maxRetry, err)
|
||||
if attempt == maxRetry {
|
||||
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
|
||||
w.failTask(ctx, task, fmt.Sprintf("JSON解析重试耗尽: %v", err))
|
||||
return
|
||||
}
|
||||
// 6.2) 重新调模型
|
||||
newResult, modelErr := w.callModel(ctx, t, model, payload)
|
||||
newResult, modelErr := w.callModel(ctx, task, model, body)
|
||||
if modelErr != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
|
||||
t.TaskID, attempt, maxRetry, modelErr)
|
||||
task.TaskID, attempt, maxRetry, modelErr)
|
||||
continue
|
||||
}
|
||||
textResult = newResult
|
||||
body = newResult
|
||||
}
|
||||
}
|
||||
|
||||
// 9) 成功回调
|
||||
t.State = 2
|
||||
t.OssFile = oss.FileAddressPrefix + oss.FileURL
|
||||
t.FileType = oss.FileFormat
|
||||
t.TextResult = textResult
|
||||
t.FileSize = int64(oss.FileSize)
|
||||
t.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, textResult))
|
||||
// 8) 成功回调
|
||||
task.State = 2
|
||||
task.OssFile = oss.FileAddressPrefix + oss.FileURL
|
||||
task.FileType = oss.FileFormat
|
||||
task.TextResult = body
|
||||
task.FileSize = int64(oss.FileSize)
|
||||
task.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, body))
|
||||
|
||||
if err = dao.Task.UpdateSuccessGlobal(ctx, t); err != nil {
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", t.TaskID, err)
|
||||
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil {
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||
return
|
||||
}
|
||||
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), task)
|
||||
if req.EpicycleId != 0 {
|
||||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
|
||||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
|
||||
t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL)
|
||||
task.TaskID, oss.FileFormat, len(body), task.CallbackURL)
|
||||
|
||||
// 10) 删除临时文件
|
||||
_ = os.Remove(t.TmpFile)
|
||||
// 9) 删除临时文件
|
||||
_ = os.Remove(task.TmpFile)
|
||||
}
|
||||
|
||||
// callModelRaw 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
||||
func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) ([]byte, error) {
|
||||
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
||||
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) {
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
@@ -265,11 +239,11 @@ func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask,
|
||||
|
||||
if data == nil {
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||||
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
|
||||
data, err = InvokeModel(ctx, model, body, task.ModelKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, "")
|
||||
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, "")
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
@@ -280,9 +254,61 @@ func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask,
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// asyncResult 异步任务结果
|
||||
type asyncResult struct {
|
||||
result map[string]any
|
||||
err error
|
||||
}
|
||||
|
||||
// asyncTaskChan 全局异步任务等待通道
|
||||
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
|
||||
|
||||
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
|
||||
// 1. 提交异步任务
|
||||
result, err := w.callModel(ctx, task, model, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 2. 拿到 task_id
|
||||
taskID := gjson.New(result).Get(model.ResponseBody).String()
|
||||
|
||||
// 3. 创建等待通道
|
||||
ch := make(chan asyncResult, 1)
|
||||
asyncTaskChan.Store(taskID, ch)
|
||||
defer func() {
|
||||
asyncTaskChan.Delete(taskID)
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
// 4. 阻塞等待回调或超时
|
||||
timeout := time.Duration(model.TimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
g.Log().Infof(ctx, "[异步任务] 开始等待结果 taskID=%s timeout=%v", taskID, timeout)
|
||||
|
||||
select {
|
||||
case res, ok := <-ch:
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("异步任务通道已关闭: taskID=%s", taskID)
|
||||
}
|
||||
g.Log().Infof(ctx, "[异步任务] 获取结果成功 taskID=%s", taskID)
|
||||
return res.result, res.err
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("异步任务超时: taskID=%s", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyAsyncResult 回调接口调用此方法通知结果
|
||||
func NotifyAsyncResult(taskID string, result map[string]any, err error) {
|
||||
if ch, ok := asyncTaskChan.Load(taskID); ok {
|
||||
ch.(chan asyncResult) <- asyncResult{result: result, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
|
||||
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
||||
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
|
||||
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
|
||||
var data []byte
|
||||
var contentType, ext, textResult string
|
||||
var err error
|
||||
@@ -296,11 +322,11 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
|
||||
|
||||
if data == nil {
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||||
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
|
||||
data, err = InvokeModel(ctx, model, body, task.ModelKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, ext)
|
||||
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext)
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
@@ -317,7 +343,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
|
||||
|
||||
// InvokeModel 调用模型服务,返回二进制结果
|
||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
||||
func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[string]any, modelKey string) ([]byte, error) {
|
||||
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) {
|
||||
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
|
||||
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
|
||||
|
||||
@@ -331,7 +357,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str
|
||||
var req *http.Request
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := util.PayloadToQuery(payload)
|
||||
q, err := util.BodyToQuery(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -344,7 +370,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -355,8 +381,8 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str
|
||||
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
for hk, hv := range util.ParseHeadMsgHeaders(modelKey) {
|
||||
req.Header.Set(hk, hv)
|
||||
if modelKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+modelKey)
|
||||
}
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -456,25 +482,7 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
}
|
||||
|
||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// rollbackToPending 恢复任务状态为 PENDING
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user