feat: 重构异步模型字段并更新依赖

This commit is contained in:
2026-06-08 18:01:53 +08:00
parent 9049e0d2e8
commit 0cf8948cd2
11 changed files with 512 additions and 441 deletions

View File

@@ -4,13 +4,8 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"model-gateway/common/util"
"model-gateway/model/dto"
"model-gateway/service/gateway"
"model-gateway/service/queue"
"net/http"
"os"
"strings"
@@ -18,12 +13,17 @@ import (
"time"
"unicode/utf8"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"model-gateway/service/queue"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/util/gconv"
)
var AsyncWorker = &asyncWorker{}
@@ -31,61 +31,21 @@ var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
}
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
// - batchSize: 本次抢占数量
// - goroutines: 本次并发数(协程池大小)
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
if req.BatchSize <= 0 {
req.BatchSize = 10
}
if req.Goroutines <= 0 {
req.Goroutines = 1
}
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
if err != nil {
return nil, err
}
if len(tasks) == 0 {
return nil, errors.New("no task to run")
}
pool := grpool.New(req.Goroutines)
defer pool.Close()
claimed := len(tasks)
done := make(chan struct{}, claimed)
for _, t := range tasks {
task := t
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
//w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0})
done <- struct{}{}
}, func(ctx context.Context, e error) {
if e != nil {
task.ErrorMsg = fmt.Sprintf("worker panic: %v", e)
_ = dao.Task.UpdateFailedGlobal(ctx, task)
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
}
done <- struct{}{}
})
}
for i := 0; i < claimed; i++ {
<-done
}
return &dto.RunWorkRes{
Claimed: claimed,
}, nil
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) {
body := util.GetModelBody(task.RequestPayload) //核心请求参数
maxRetry := model.RetryTimes //重试次数
body := util.GetModelBody(task.RequestPayload) // 核心请求参数
maxRetry := model.RetryTimes // 重试次数
startTime := time.Now()
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
// 1) 分布式并发控制
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
if err != nil {
w.failTask(ctx, task, err.Error())
task.DurationSeconds = int64(time.Since(startTime).Seconds())
w.failTask(ctx, task, startTime, err.Error())
return
}
if !acquired {
@@ -95,62 +55,55 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
}
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// 2) request_payload 校验
if body == nil {
w.failTask(ctx, task, "请求模型为空")
return
}
// 3) 调用模型
// 2) 调用模型
switch {
case model.IsStream != nil && *model.IsStream == 1: // 流式调用
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, err := w.callModelStream(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, err.Error())
w.failTask(ctx, task, startTime, err.Error())
return
}
// 解析流式结果
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil {
w.failTask(ctx, task, err.Error())
w.failTask(ctx, task, startTime, err.Error())
return
}
case model.IsAsync != nil && *model.IsAsync == 1: // 异步调用:注入回调地址后提交,拿到 task_id 轮询
// 异步调用:提交任务
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
body, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, err.Error())
w.failTask(ctx, task, startTime, err.Error())
return
}
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg)
if err != nil {
w.failTask(ctx, task, err.Error())
w.failTask(ctx, task, startTime, err.Error())
return
}
default: // 同步调用
default:
body, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, err.Error())
w.failTask(ctx, task, startTime, err.Error())
return
}
}
// 5) 解析响应映射
body, err = util.MapResponsePayload(model.ResponseMapping, body)
if err != nil {
w.failTask(ctx, task, err.Error())
return
}
// 5) 保存临时文件(通用工具方法)
tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
if tmpErr == nil && tmpPath != "" {
// 3) 保存临时文件
tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
if err == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
// 6) 上传 OSS可重试
// 4) 解析校验 + 响应映射(可重试,失败重新调模型
body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime)
if err != nil {
task.TextResult = body
w.failTask(ctx, task, startTime, err.Error())
return
}
// 5) 上传 OSS可重试
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
@@ -164,46 +117,18 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error())
w.failTask(ctx, task, fmt.Sprintf("OSS上传重试耗尽: %v", err))
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
}
}
// 7) 解析校验(可重试,失败重新调模型)
if req.BuildType == 1 {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
// 6.1) 校验数据
err = util.ValidatePromptResult(body, model)
if err == nil {
break
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
w.failTask(ctx, task, fmt.Sprintf("JSON解析重试耗尽: %v", err))
return
}
// 6.2) 重新调模型
newResult, modelErr := w.callModel(ctx, task, model, body)
if modelErr != nil {
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
task.TaskID, attempt, maxRetry, modelErr)
continue
}
body = newResult
}
}
// 8) 成功回调
// 6) 成功回调
task.State = 2
task.DurationSeconds = int64(time.Since(startTime).Seconds())
task.OssFile = oss.FileAddressPrefix + oss.FileURL
task.FileType = oss.FileFormat
task.TextResult = body
task.FileSize = int64(oss.FileSize)
task.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, body))
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
@@ -216,10 +141,10 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
}
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
task.TaskID, oss.FileFormat, len(body), task.CallbackURL)
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s textLen=%d callbackUrl=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat, len(body), task.CallbackURL)
// 9) 删除临时文件
// 7) 删除临时文件
_ = os.Remove(task.TmpFile)
}
@@ -263,12 +188,12 @@ var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
// 1. 提交异步任务
result, err := w.callModel(ctx, task, model, body)
body, err := w.callModel(ctx, task, model, body)
if err != nil {
return nil, err
}
// 2. 拿到 task_id
taskID := gjson.New(result).Get(model.ResponseBody).String()
taskID := gjson.New(body).Get(model.ResponseBody).String()
// 3. 创建等待通道
ch := make(chan asyncResult, 1)
@@ -304,26 +229,31 @@ func NotifyAsyncResult(taskID string, result map[string]any, err error) {
}
}
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
// callModel 调用模型 + 检测文件类型 + 保存临时文件
// 返回: 解析后的响应体, error
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
var data []byte
var contentType, ext, textResult string
var err error
// 1) 如果已有临时文件且 phase=1直接读取
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
data, err = os.ReadFile(task.TmpFile)
if err != nil || len(data) == 0 {
g.Log().Warningf(ctx, "[callModel] 读取临时文件失败,重新调用模型 taskId=%s err=%v", task.TaskID, err)
data = nil
}
}
// 2) 没有可用数据,调用模型
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
if err != nil {
return nil, err
}
// 3) 检测文件类型,保存临时文件
_, ext := util.DetectFileType(data)
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext)
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
@@ -332,17 +262,94 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
}
}
contentType, ext = util.DetectFileType(data)
// 4) 检测文件类型,提取文本结果
contentType, _ := util.DetectFileType(data)
var textResult string
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
// 5) 非文本内容,返回错误
if textResult == "" {
return nil, fmt.Errorf("模型返回非文本内容contentType=%s", contentType)
}
// 6) 解析并返回
return gjson.New(textResult).Map(), nil
}
// parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
// 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil {
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
}
continue
}
// 2) 先存 token 到数据库,防止后续失败丢失
if tokens, ok := mapped[model.ResponseTokenField]; ok {
task.ExpendTokens = gconv.Int64(tokens)
_ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{
ExpendTokens: gconv.Int64(body[model.ResponseTokenField]),
})
}
// 3) 解析 + 校验
var parsed map[string]any
switch req.BuildType {
case public.BuildTypePrompt, public.BuildTypeNode:
parsed, err = util.ParseAndValidate(mapped, model)
if err == nil {
return parsed, nil
}
case public.BuildTypeStruct:
parsed = util.ParseStructResult(mapped, model.ResponseBody)
return parsed, nil
default:
return mapped, nil
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err)
}
// 4) 重新调模型(直接调,不走缓存)
_ = dao.Task.IncRetryCountGlobal(ctx, task.Id)
reqBody := util.GetModelBody(task.RequestPayload)
rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey)
if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue
}
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any
if err := json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue
}
body = rawResp
}
return body, nil
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) {
// 1请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 2构建请求 URL 和超时
@@ -472,9 +479,10 @@ func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gat
}
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg string) {
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) {
t.State = 3
t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds())
_ = dao.Task.UpdateFailedGlobal(ctx, t)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
@@ -484,12 +492,3 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}
// GetExpendTokens 根据映射路径从 result 中提取消耗 token 值
func GetExpendTokens(responseTokenField string, result map[string]any) int {
val := gjson.New(result).Get(responseTokenField)
if val.IsNil() {
return 0
}
return val.Int()
}