refactor(files): 优化文件处理和任务服务逻辑
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -9,107 +8,75 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名
|
||||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return "application/octet-stream", ""
|
return "application/octet-stream", ".bin"
|
||||||
}
|
}
|
||||||
|
|
||||||
ct := http.DetectContentType(data)
|
ct := http.DetectContentType(data)
|
||||||
// gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
|
||||||
if idx := strings.Index(ct, ";"); idx > 0 {
|
if idx := strings.Index(ct, ";"); idx > 0 {
|
||||||
ct = strings.TrimSpace(ct[:idx])
|
ct = strings.TrimSpace(ct[:idx])
|
||||||
}
|
}
|
||||||
|
|
||||||
switch ct {
|
switch ct {
|
||||||
case "audio/mpeg":
|
case "audio/mpeg":
|
||||||
return ct, ".mp3"
|
return ct, ".mp3"
|
||||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||||
return ct, ".wav"
|
return ct, ".wav"
|
||||||
|
case "audio/mp4", "audio/x-m4a":
|
||||||
|
return ct, ".m4a"
|
||||||
case "video/mp4":
|
case "video/mp4":
|
||||||
return ct, ".mp4"
|
return ct, ".mp4"
|
||||||
|
case "video/webm":
|
||||||
|
return ct, ".webm"
|
||||||
case "image/png":
|
case "image/png":
|
||||||
return ct, ".png"
|
return ct, ".png"
|
||||||
case "image/jpeg":
|
case "image/jpeg":
|
||||||
return ct, ".jpg"
|
return ct, ".jpg"
|
||||||
|
case "image/gif":
|
||||||
|
return ct, ".gif"
|
||||||
|
case "image/webp":
|
||||||
|
return ct, ".webp"
|
||||||
case "application/pdf":
|
case "application/pdf":
|
||||||
return ct, ".pdf"
|
return ct, ".pdf"
|
||||||
case "text/plain":
|
case "text/plain":
|
||||||
return ct, ".txt"
|
return ct, ".txt"
|
||||||
case "application/json":
|
case "application/json":
|
||||||
return ct, ".json"
|
return ct, ".json"
|
||||||
|
case "application/zip":
|
||||||
|
return ct, ".zip"
|
||||||
|
case "application/octet-stream":
|
||||||
|
return ct, ".bin"
|
||||||
default:
|
default:
|
||||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
|
||||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||||
sub := parts[1]
|
sub := parts[1]
|
||||||
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
|
||||||
if idx := strings.Index(sub, ";"); idx > 0 {
|
if idx := strings.Index(sub, ";"); idx > 0 {
|
||||||
sub = strings.TrimSpace(sub[:idx])
|
sub = strings.TrimSpace(sub[:idx])
|
||||||
}
|
}
|
||||||
return ct, "." + sub
|
return ct, "." + sub
|
||||||
}
|
}
|
||||||
return ct, ""
|
return ct, ".bin"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
// SaveTmpResult 将二进制数据写入临时文件
|
||||||
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("创建临时目录失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ext == "" {
|
if ext == "" {
|
||||||
ext = ".bin"
|
ext = ".bin"
|
||||||
}
|
}
|
||||||
if ext[0] != '.' {
|
if ext[0] != '.' {
|
||||||
ext = "." + ext
|
ext = "." + ext
|
||||||
}
|
}
|
||||||
|
|
||||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("写入临时文件失败: %w", err)
|
||||||
}
|
}
|
||||||
return path, nil
|
return path, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTempFileByType
|
|
||||||
// 根据传入的数据自动判断:
|
|
||||||
// 若是 []byte 且后缀为 .mp3 → 保存二进制音频
|
|
||||||
// 若是任意结构体/map → 自动转 JSON 保存
|
|
||||||
// 返回:新临时文件路径、错误
|
|
||||||
func SaveTempFileByType(taskID string, data any, oldTmpFile string) (string, error) {
|
|
||||||
// 1. 先清理旧临时文件(统一逻辑)
|
|
||||||
if oldTmpFile != "" {
|
|
||||||
_ = os.Remove(oldTmpFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
var tmpPath string
|
|
||||||
var tmpErr error
|
|
||||||
|
|
||||||
// 2. 判断是否是二进制音频([]byte + .mp3)
|
|
||||||
if audioData, ok := data.([]byte); ok {
|
|
||||||
tmpPath, tmpErr = saveTmpResult(taskID, audioData, ".mp3")
|
|
||||||
} else {
|
|
||||||
// 3. 其他类型 → 序列化为 JSON 保存
|
|
||||||
mappedBytes, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if len(mappedBytes) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
tmpPath, tmpErr = saveTmpResult(taskID, mappedBytes, ".json")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tmpErr != nil || tmpPath == "" {
|
|
||||||
return "", tmpErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return tmpPath, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// saveTmpResult 你原有的底层保存文件方法(保留不动)
|
|
||||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
|
||||||
// 你原来实现,比如:
|
|
||||||
filename := taskID + ext
|
|
||||||
tmpPath := filepath.Join(os.TempDir(), filename)
|
|
||||||
err := os.WriteFile(tmpPath, data, 0644)
|
|
||||||
return tmpPath, err
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"model-gateway/common/util"
|
"model-gateway/common/util"
|
||||||
"model-gateway/consts/public"
|
"model-gateway/consts/public"
|
||||||
"model-gateway/service/queue"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"model-gateway/dao"
|
"model-gateway/dao"
|
||||||
@@ -28,12 +27,15 @@ type taskService struct{}
|
|||||||
// Create 创建任务
|
// Create 创建任务
|
||||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||||
taskID := uuid.NewString()
|
taskID := uuid.NewString()
|
||||||
|
startAt := time.Now()
|
||||||
|
|
||||||
// 1) 检查模型配置,并且获取模型
|
// 1) 获取用户信息
|
||||||
userInfo, err := utils.GetUserInfo(ctx)
|
userInfo, err := utils.GetUserInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 2) 检查模型配置
|
||||||
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
|
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{
|
SQLBaseDO: beans.SQLBaseDO{
|
||||||
TenantId: userInfo.TenantId,
|
TenantId: userInfo.TenantId,
|
||||||
@@ -48,86 +50,63 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
|||||||
return nil, errors.New("模型不存在或未启用")
|
return nil, errors.New("模型不存在或未启用")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) 排队上限(严格控制:Redis 原子闸门)
|
// TODO: 排队控制暂时关闭,后续需要时取消注释
|
||||||
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
|
// limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
|
||||||
if limit > 0 {
|
// if limit > 0 {
|
||||||
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
|
// ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return nil, err
|
// return nil, err
|
||||||
}
|
// }
|
||||||
if !ok {
|
// if !ok {
|
||||||
return nil, errors.New("任务排队已满,请稍后再试")
|
// return nil, errors.New("任务排队已满,请稍后再试")
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
// 3) 插入任务记录
|
// 3) 构建任务实体
|
||||||
requestPayload := entity.RequestPayload{
|
task := &entity.ModelGatewayTask{
|
||||||
|
ModelName: model.ModelName,
|
||||||
|
TaskID: taskID,
|
||||||
|
State: public.TaskStatusRunning,
|
||||||
|
BizName: req.BizName,
|
||||||
|
CallbackURL: req.CallbackUrl,
|
||||||
|
RequestPayload: &entity.RequestPayload{
|
||||||
Body: req.RequestPayload,
|
Body: req.RequestPayload,
|
||||||
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
|
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
|
||||||
|
},
|
||||||
|
EpicycleId: req.EpicycleId,
|
||||||
}
|
}
|
||||||
task := new(entity.ModelGatewayTask)
|
|
||||||
task.ModelName = model.ModelName
|
// 4) 插入任务记录
|
||||||
task.TaskID = taskID
|
|
||||||
task.State = public.TaskStatusRunning
|
|
||||||
task.BizName = req.BizName
|
|
||||||
task.CallbackURL = req.CallbackUrl
|
|
||||||
task.RequestPayload = &requestPayload
|
|
||||||
task.EpicycleId = req.EpicycleId
|
|
||||||
id, err := dao.ModelGatewayTask.Insert(ctx, task)
|
id, err := dao.ModelGatewayTask.Insert(ctx, task)
|
||||||
if err != nil { // 入库失败:回滚闸门占位
|
if err != nil {
|
||||||
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
// TODO: 恢复排队逻辑后,此处需要回滚排队占位
|
||||||
|
// queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
task.Id = id
|
task.Id = id
|
||||||
// 4) 写操作日志(不影响主流程,失败忽略)
|
|
||||||
ip := ""
|
// 5) 记录操作日志(非关键路径,失败不影响主流程)
|
||||||
ua := ""
|
ip, ua := "", ""
|
||||||
apiPath := "/task/createTask"
|
|
||||||
httpMethod := "POST"
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
ip = utils.GetLocalIP()
|
ip = utils.GetLocalIP()
|
||||||
ua = r.UserAgent()
|
ua = r.UserAgent()
|
||||||
apiPath = r.URL.Path
|
|
||||||
httpMethod = r.Method
|
|
||||||
}
|
}
|
||||||
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
|
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
APIPath: apiPath,
|
APIPath: "/task/createTask",
|
||||||
HttpMethod: httpMethod,
|
HttpMethod: "POST",
|
||||||
BizName: req.BizName,
|
BizName: req.BizName,
|
||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
TaskID: taskID,
|
TaskID: taskID,
|
||||||
OpType: "createTask",
|
OpType: "createTask",
|
||||||
Success: 1,
|
Success: 1,
|
||||||
CostMs: time.Since(time.Now()).Milliseconds(),
|
CostMs: time.Since(startAt).Milliseconds(),
|
||||||
RequestPayload: &requestPayload,
|
RequestPayload: task.RequestPayload,
|
||||||
ResponsePayload: gdb.Map{
|
ResponsePayload: gdb.Map{"taskId": taskID},
|
||||||
"taskId": taskID,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
|
|
||||||
//// 5) 抢占任务:改为执行中
|
// 6) 异步执行任务
|
||||||
//rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
|
||||||
// SQLBaseDO: beans.SQLBaseDO{Id: id},
|
|
||||||
// State: public.TaskStatusRunning,
|
|
||||||
//})
|
|
||||||
//if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
//}
|
|
||||||
//if rows == 0 {
|
|
||||||
// return nil, fmt.Errorf("任务不存在: id=%d", id)
|
|
||||||
//}
|
|
||||||
|
|
||||||
// 6) 查询任务信息
|
|
||||||
//task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
|
|
||||||
// SQLBaseDO: beans.SQLBaseDO{Id: id},
|
|
||||||
//})
|
|
||||||
//if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
//}
|
|
||||||
|
|
||||||
// 7) 创建成功后立即异步尝试执行当前任务
|
|
||||||
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
|
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
|
||||||
|
|
||||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
"model-gateway/service/gateway"
|
"model-gateway/service/gateway"
|
||||||
"model-gateway/service/queue"
|
|
||||||
|
|
||||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
@@ -38,50 +37,28 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
body = task.RequestPayload.Body
|
body = task.RequestPayload.Body
|
||||||
maxRetry = model.RetryTimes
|
maxRetry = model.RetryTimes
|
||||||
startTime = time.Now()
|
startTime = time.Now()
|
||||||
|
rawBytes []byte
|
||||||
result map[string]any
|
result map[string]any
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
|
|
||||||
|
g.Log().Infof(ctx, "[handleOne] 开始 taskId=%s model=%s", task.TaskID, task.ModelName)
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 1) 分布式并发控制
|
// 1) 调用模型
|
||||||
// ============================================
|
|
||||||
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
|
|
||||||
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
|
|
||||||
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
|
|
||||||
if err != nil {
|
|
||||||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
|
||||||
w.failTask(ctx, task, startTime, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !acquired {
|
|
||||||
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
|
|
||||||
State: public.TaskStatusPending,
|
|
||||||
})
|
|
||||||
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
|
|
||||||
|
|
||||||
// ============================================
|
|
||||||
// 2) 调用模型
|
|
||||||
// ============================================
|
// ============================================
|
||||||
for attempt := 0; ; attempt++ {
|
for attempt := 0; ; attempt++ {
|
||||||
if attempt > 0 {
|
if attempt > 0 {
|
||||||
g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID)
|
g.Log().Infof(ctx, "[handleOne] 调模型重试 第%d次 taskId=%s", attempt, task.TaskID)
|
||||||
time.Sleep(time.Duration(attempt) * time.Second)
|
time.Sleep(time.Duration(attempt) * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
|
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
|
||||||
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
|
rawBytes, err = w.callModelStream(ctx, task, model, body)
|
||||||
if streamErr != nil {
|
if err == nil {
|
||||||
err = streamErr
|
|
||||||
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||||
|
}
|
||||||
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
||||||
result, err = w.callModel(ctx, task, model, body)
|
result, err = w.callModel(ctx, task, model, body)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -95,24 +72,17 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(err.Error(), "Timeout") {
|
if !strings.Contains(err.Error(), "Timeout") &&
|
||||||
|
!strings.Contains(err.Error(), "InternalServiceError") {
|
||||||
w.failTask(ctx, task, startTime, err.Error())
|
w.failTask(ctx, task, startTime, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
|
|
||||||
|
g.Log().Warningf(ctx, "[handleOne] 调模型失败 taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 3) 缓存临时文件
|
// 2) 解析校验 + 响应映射(可重试)
|
||||||
// ============================================
|
|
||||||
if tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, result, task.TmpFile); tmpErr == nil && tmpPath != "" {
|
|
||||||
task.TmpFile = tmpPath
|
|
||||||
task.Phase = 1
|
|
||||||
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================
|
|
||||||
// 4) 解析校验 + 响应映射(可重试)
|
|
||||||
// ============================================
|
// ============================================
|
||||||
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
|
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -122,34 +92,26 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 5) 上传 OSS(可重试)
|
// 3) 上传 OSS(可重试)
|
||||||
// ============================================
|
// ============================================
|
||||||
var oss *gateway.UploadFileResponse
|
var oss *gateway.UploadFileResponse
|
||||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||||
if attempt > 0 {
|
if attempt > 0 {
|
||||||
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
g.Log().Infof(ctx, "[handleOne] OSS上传重试 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||||
}
|
}
|
||||||
startUpload := time.Now()
|
|
||||||
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
|
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
cost := time.Since(startUpload)
|
g.Log().Errorf(ctx, "[handleOne] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||||
g.Log().Infof(ctx, "本次上传耗时:%s", cost)
|
|
||||||
|
|
||||||
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
|
||||||
if attempt == maxRetry {
|
if attempt == maxRetry {
|
||||||
task.State = public.TaskStatusFailed
|
|
||||||
task.ErrorMsg = err.Error()
|
|
||||||
task.Phase = 1
|
|
||||||
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
|
||||||
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 6) 成功收尾
|
// 4) 成功收尾
|
||||||
// ============================================
|
// ============================================
|
||||||
task.State = public.TaskStatusSuccess
|
task.State = public.TaskStatusSuccess
|
||||||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||||
@@ -159,21 +121,19 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
FileSize: int64(oss.FileSize),
|
FileSize: int64(oss.FileSize),
|
||||||
}
|
}
|
||||||
task.TextResult = result
|
task.TextResult = result
|
||||||
|
|
||||||
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
|
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
|
||||||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
g.Log().Errorf(ctx, "[handleOne] 更新DB失败 taskId=%s err=%v", task.TaskID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
|
||||||
go gateway.TriggerCallback(util.AsyncCtx(ctx), task)
|
go gateway.TriggerCallback(util.AsyncCtx(ctx), task)
|
||||||
if req.EpicycleId != 0 {
|
if req.EpicycleId != 0 {
|
||||||
go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId)
|
go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s",
|
g.Log().Infof(ctx, "[handleOne] 成功 taskId=%s duration=%ds fileType=%s",
|
||||||
task.TaskID, task.DurationSeconds, oss.FileFormat)
|
task.TaskID, task.DurationSeconds, oss.FileFormat)
|
||||||
|
|
||||||
_ = os.Remove(task.TmpFile)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
||||||
@@ -313,12 +273,10 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
// parseAndRetry 解析模型返回结果,并重试
|
// parseAndRetry 解析模型返回结果,并重试
|
||||||
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||||
if attempt > 0 {
|
if attempt > 0 {
|
||||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1) 响应映射
|
// 1) 响应映射
|
||||||
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
|
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -372,7 +330,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
}
|
}
|
||||||
|
|
||||||
var rawResp map[string]any
|
var rawResp map[string]any
|
||||||
if err := json.Unmarshal(rawData, &rawResp); err != nil {
|
if err = json.Unmarshal(rawData, &rawResp); err != nil {
|
||||||
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -553,15 +511,11 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[
|
|||||||
// return mappedResponse, nil
|
// return mappedResponse, nil
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
|
// failTask 任务失败统一处理
|
||||||
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
|
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
|
||||||
t.State = 3
|
t.State = 3
|
||||||
t.ErrorMsg = errMsg
|
t.ErrorMsg = errMsg
|
||||||
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||||
_, err := dao.ModelGatewayTask.Update(ctx, t)
|
_, _ = dao.ModelGatewayTask.Update(ctx, t) // 更新任务状态
|
||||||
if err != nil {
|
go gateway.TriggerCallback(util.AsyncCtx(ctx), t) // 触发回调
|
||||||
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
|
|
||||||
}
|
|
||||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
|
||||||
go gateway.TriggerCallback(util.AsyncCtx(ctx), t)
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user