refactor(asynch): 重构异步模型配置和队列管理

This commit is contained in:
2026-06-02 20:26:45 +08:00
parent c7e9eb889b
commit 52124385a1
18 changed files with 726 additions and 1006 deletions

View File

@@ -3,6 +3,7 @@ package task
import (
"context"
"errors"
"fmt"
"model-gateway/common/util"
"model-gateway/service/queue"
"time"
@@ -11,9 +12,12 @@ import (
"model-gateway/model/dto"
"model-gateway/model/entity"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
"github.com/google/uuid"
)
@@ -25,22 +29,29 @@ type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
taskID := uuid.NewString()
// 1) 检查模型配置
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
// 1) 检查模型配置,并且获取模型
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId,
Creator: userInfo.UserName,
},
ModelName: req.ModelName,
})
if err != nil {
return nil, err
}
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
if model == nil || (model.Enabled != nil && *model.Enabled != 1) {
return nil, errors.New("模型不存在或未启用")
}
// 2) 排队上限严格控制Redis 原子闸门)
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
if limit > 0 {
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
if err != nil {
return nil, err
}
@@ -50,9 +61,13 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
}
// 3) 插入任务记录
if model.IsAsync != nil && *model.IsAsync == 1 {
// 异步调用:注入回调地址后提交,拿到 task_id 轮询
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
}
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": util.ForwardHeaders(ctx),
"headers": util.ParseHeadMsgHeaders(model.HeadMsg),
"body": req.RequestPayload,
}
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
ModelName: req.ModelName,
@@ -60,13 +75,12 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
State: 0,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: m.ApiKey,
ModelKey: model.ApiKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
EpicycleId: req.EpicycleId,
})
if err != nil {
// 入库失败:回滚闸门占位
if err != nil { // 入库失败:回滚闸门占位
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
}
@@ -100,75 +114,96 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
},
})
// 5) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
// 一旦任务进入 running/success/failed/downloaded就停止轮询避免一直空转。
go s.pollAndRunUntilPicked(util.AsyncCtx(ctx), taskID, req)
// 5) 获取任务信息
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return nil, err
}
if task == nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// pollAndRunUntilPicked 定向轮询执行刚创建的任务
// - 目标:尽快把刚创建的任务拉起来执行
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
// - 不会无限轮询runWork 仍负责处理积压队列和未处理到的任务
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) {
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds", 5).Int()
pollTimeout := g.Cfg().MustGet(ctx, "asynch.worker.pollTimeoutSeconds", 300).Int()
pollCtx, cancel := context.WithTimeout(ctx, time.Duration(pollTimeout)*time.Second)
defer cancel()
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
// 1. 查本地任务
task, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: req.TaskID,
})
if err != nil || task == nil {
return nil, fmt.Errorf("任务不存在: %s", req.TaskID)
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
// 2. 成功:取 video_url 和 usage
if req.Status == "succeeded" {
result := map[string]any{
"video_url": req.Content["video_url"],
"usage": req.Usage,
}
NotifyAsyncResult(req.TaskID, result, nil)
return &dto.ModelTaskCallbackRes{Success: true}, nil
}
g.Log().Infof(ctx, "[任务自动执行][开始] taskId=%s 轮询间隔=%ds 超时=%ds", taskID, interval, pollTimeout)
tryRun := func() bool {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
// 3. 失败/过期
if req.Status == "failed" || req.Status == "expired" {
NotifyAsyncResult(req.TaskID, nil, fmt.Errorf(req.Status))
return &dto.ModelTaskCallbackRes{Success: true}, nil
}
return &dto.ModelTaskCallbackRes{Success: true}, nil
}
// QueryPendingTasks 批量轮询进行中的异步任务
func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (*dto.QueryPendingTasksRes, error) {
limit := req.Limit
if limit <= 0 {
limit = g.Cfg().MustGet(ctx, "asynch.queryPending.limit", 10).Int()
}
// 1. 查 state=1执行中的异步任务
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit)
if err != nil {
return nil, err
}
// 2. 逐个查询
var results []dto.QueryTaskItem
for _, t := range tasks {
// 拿到模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil || model.QueryConfig == nil {
continue
}
result, err := util.PullTaskResult(ctx, t.TaskID, model.QueryConfig)
if err != nil {
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=查询失败 err=%v", taskID, err)
return true
}
if t == nil {
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=任务不存在", taskID)
return true
g.Log().Warningf(ctx, "[轮询] 查询失败 taskID=%s err=%v", t.TaskID, err)
continue
}
switch t.State {
case 0:
//RunByTaskID 尝试执行任务
if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil {
g.Log().Warningf(ctx, "[任务自动执行][重试] taskId=%s 状态=待处理 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[任务自动执行][已触发] taskId=%s 状态=待处理", taskID)
}
return false
case 1:
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=执行中", taskID)
return true
case 2, 3, 4:
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=终态 状态=%d", taskID, t.State)
return true
default:
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=未知状态 状态=%d", taskID, t.State)
return true
}
}
// 立即尝试一次
if stop := tryRun(); stop {
return
}
for {
select {
case <-pollCtx.Done():
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=轮询超时", taskID)
return
case <-ticker.C:
if stop := tryRun(); stop {
return
}
status := gconv.String(result["status"])
item := dto.QueryTaskItem{
TaskID: t.TaskID,
Status: status,
Content: result["content"].(map[string]any),
Usage: result["usage"].(map[string]any),
}
results = append(results, item)
// 如果任务完成,通知等待通道
if status == "succeeded" || status == "failed" || status == "expired" {
NotifyAsyncResult(t.TaskID, result["content"].(map[string]any), nil)
}
}
return &dto.QueryPendingTasksRes{
Total: len(results),
Results: results,
}, nil
}
// GetResult 获取任务结果

View File

@@ -13,8 +13,8 @@ import (
"model-gateway/service/queue"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"unicode/utf8"
@@ -55,7 +55,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt
for _, t := range tasks {
task := t
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0})
//w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0})
done <- struct{}{}
}, func(ctx context.Context, e error) {
if e != nil {
@@ -74,185 +74,159 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt
}, nil
}
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
// - 只定向抢占当前 taskId 对应的 pending 任务
// - 若任务已被其它 worker 抢走/已不在 pending则直接返回
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, req *dto.CreateTaskReq) error {
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return err
}
if task == nil {
return nil
}
w.handleOne(ctx, task, req)
return nil
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *dto.CreateTaskReq) {
payload := util.ParseStoredPayload(t.RequestPayload)
maxRetry := 0 // 后面从 model 取
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", t.TaskID, t.ModelName)
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) {
body := util.GetModelBody(task.RequestPayload) //核心请求参数
maxRetry := model.RetryTimes //重试次数
// 1) 获取模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil {
w.failTask(ctx, t, "模型不存在或未启用")
return
}
maxRetry = model.RetryTimes
// 2) 分布式并发控制
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
maxC := queue.GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
// 1) 分布式并发控制
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
if err != nil {
w.failTask(ctx, t, err.Error())
w.failTask(ctx, task, err.Error())
return
}
if !acquired {
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", t.TaskID)
_ = w.rollbackToPending(ctx, t.Id)
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
_ = w.rollbackToPending(ctx, task.Id)
return
}
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// 3) request_payload 校验
if payload == nil {
w.failTask(ctx, t, "request_payload 为空")
// 2) request_payload 校验
if body == nil {
w.failTask(ctx, task, "请求模型为空")
return
}
// 4) 调用模型
var textResult map[string]any
if streamEnabled, _ := model.StreamConfig["enabled"].(bool); streamEnabled {
rawBytes, modelErr := w.callModelRaw(ctx, t, model, payload)
if modelErr != nil {
w.failTask(ctx, t, modelErr.Error())
// 3) 调用模型
switch {
case model.IsStream != nil && *model.IsStream == 1: // 流式调用
rawBytes, err := w.callModelStream(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, err.Error())
return
}
textResult, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
// 解析流式结果
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil {
w.failTask(ctx, t, err.Error())
w.failTask(ctx, task, err.Error())
return
}
} else {
textResult, err = w.callModel(ctx, t, model, payload)
case model.IsAsync != nil && *model.IsAsync == 1: // 异步调用:注入回调地址后提交,拿到 task_id 轮询
// 异步调用:提交任务
body, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, t, err.Error())
w.failTask(ctx, task, err.Error())
return
}
// 拿到 task_id启动轮询
taskID := gjson.New(body).Get(model.ResponseBody).String()
body, err = util.PullTaskResult(ctx, taskID, model.QueryConfig)
if err != nil {
w.failTask(ctx, task, err.Error())
return
}
default: // 同步调用
body, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, err.Error())
return
}
}
// 5) 模型返回映射处理
textResult, err = util.MapResponsePayload(model.ResponseMapping, textResult)
// 5) 解析响应映射
body, err = util.MapResponsePayload(model.ResponseMapping, body)
if err != nil {
w.failTask(ctx, t, err.Error())
w.failTask(ctx, task, err.Error())
return
}
// 6) 保存临时文件(区分二进制音频和JSON文本
if audioData, ok := textResult["audio"].([]byte); ok {
tmpPath, tmpErr := saveTmpResult(t.TaskID, audioData, ".mp3")
if tmpErr == nil && tmpPath != "" {
if t.TmpFile != "" {
_ = os.Remove(t.TmpFile)
}
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
} else {
mappedBytes, _ := json.Marshal(textResult)
if len(mappedBytes) > 0 {
tmpPath, tmpErr := saveTmpResult(t.TaskID, mappedBytes, ".json")
if tmpErr == nil && tmpPath != "" {
if t.TmpFile != "" {
_ = os.Remove(t.TmpFile)
}
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
// 5) 保存临时文件(通用工具方法
tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
// 7) 上传 OSS可重试
// 6) 上传 OSS可重试
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
oss, err = w.uploadOSS(ctx, t)
oss, err = w.uploadOSS(ctx, task)
if err == nil {
break
}
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
w.failTask(ctx, t, fmt.Sprintf("OSS上传重试耗尽: %v", err))
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error())
w.failTask(ctx, task, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
}
}
//8) 解析校验(可重试,失败重新调模型)
// 7) 解析校验(可重试,失败重新调模型)
if req.BuildType == 1 {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
// 6.1) 校验数据
err = util.ValidatePromptResult(textResult, model)
err = util.ValidatePromptResult(body, model)
if err == nil {
break
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
w.failTask(ctx, task, fmt.Sprintf("JSON解析重试耗尽: %v", err))
return
}
// 6.2) 重新调模型
newResult, modelErr := w.callModel(ctx, t, model, payload)
newResult, modelErr := w.callModel(ctx, task, model, body)
if modelErr != nil {
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, modelErr)
task.TaskID, attempt, maxRetry, modelErr)
continue
}
textResult = newResult
body = newResult
}
}
// 9) 成功回调
t.State = 2
t.OssFile = oss.FileAddressPrefix + oss.FileURL
t.FileType = oss.FileFormat
t.TextResult = textResult
t.FileSize = int64(oss.FileSize)
t.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, textResult))
// 8) 成功回调
task.State = 2
task.OssFile = oss.FileAddressPrefix + oss.FileURL
task.FileType = oss.FileFormat
task.TextResult = body
task.FileSize = int64(oss.FileSize)
task.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, body))
if err = dao.Task.UpdateSuccessGlobal(ctx, t); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", t.TaskID, err)
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
return
}
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), task)
if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
}
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL)
task.TaskID, oss.FileFormat, len(body), task.CallbackURL)
// 10) 删除临时文件
_ = os.Remove(t.TmpFile)
// 9) 删除临时文件
_ = os.Remove(task.TmpFile)
}
// callModelRaw 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) ([]byte, error) {
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) {
var data []byte
var err error
@@ -265,11 +239,11 @@ func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask,
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
if err != nil {
return nil, err
}
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, "")
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, "")
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
@@ -280,9 +254,61 @@ func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask,
return data, nil
}
// asyncResult 异步任务结果
type asyncResult struct {
result map[string]any
err error
}
// asyncTaskChan 全局异步任务等待通道
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
// 1. 提交异步任务
result, err := w.callModel(ctx, task, model, body)
if err != nil {
return nil, err
}
// 2. 拿到 task_id
taskID := gjson.New(result).Get(model.ResponseBody).String()
// 3. 创建等待通道
ch := make(chan asyncResult, 1)
asyncTaskChan.Store(taskID, ch)
defer func() {
asyncTaskChan.Delete(taskID)
close(ch)
}()
// 4. 阻塞等待回调或超时
timeout := time.Duration(model.TimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
g.Log().Infof(ctx, "[异步任务] 开始等待结果 taskID=%s timeout=%v", taskID, timeout)
select {
case res, ok := <-ch:
if !ok {
return nil, fmt.Errorf("异步任务通道已关闭: taskID=%s", taskID)
}
g.Log().Infof(ctx, "[异步任务] 获取结果成功 taskID=%s", taskID)
return res.result, res.err
case <-ctx.Done():
return nil, fmt.Errorf("异步任务超时: taskID=%s", taskID)
}
}
// NotifyAsyncResult 回调接口调用此方法通知结果
func NotifyAsyncResult(taskID string, result map[string]any, err error) {
if ch, ok := asyncTaskChan.Load(taskID); ok {
ch.(chan asyncResult) <- asyncResult{result: result, err: err}
}
}
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
// callModel 调用模型 + 检测文件类型 + 保存临时文件
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
var data []byte
var contentType, ext, textResult string
var err error
@@ -296,11 +322,11 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
if err != nil {
return nil, err
}
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, ext)
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext)
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
@@ -317,7 +343,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[string]any, modelKey string) ([]byte, error) {
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) {
// 1请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
@@ -331,7 +357,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str
var req *http.Request
switch method {
case http.MethodGet:
q, err := util.PayloadToQuery(payload)
q, err := util.BodyToQuery(body)
if err != nil {
return nil, err
}
@@ -344,7 +370,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
default:
bodyBytes, err := json.Marshal(payload)
bodyBytes, err := json.Marshal(body)
if err != nil {
return nil, err
}
@@ -355,8 +381,8 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv)
}
for hk, hv := range util.ParseHeadMsgHeaders(modelKey) {
req.Header.Set(hk, hv)
if modelKey != "" {
req.Header.Set("Authorization", "Bearer "+modelKey)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
@@ -456,25 +482,7 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
}
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
// rollbackToPending 恢复任务状态为 PENDING
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}