refactor(files): 优化文件处理和任务服务逻辑

This commit is contained in:
2026-06-18 13:39:40 +08:00
parent b21d7a8dbf
commit ecaaa5bdbc
3 changed files with 94 additions and 194 deletions

View File

@@ -1,7 +1,6 @@
package util package util
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@@ -9,107 +8,75 @@ import (
"strings" "strings"
) )
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定) // DetectFileType 根据返回的二进制内容推断 contentType + 扩展名
func DetectFileType(data []byte) (contentType string, ext string) { func DetectFileType(data []byte) (contentType string, ext string) {
if len(data) == 0 { if len(data) == 0 {
return "application/octet-stream", "" return "application/octet-stream", ".bin"
} }
ct := http.DetectContentType(data) ct := http.DetectContentType(data)
// gateway.DetectContentType 可能带 charset 等参数text/plain; charset=utf-8
if idx := strings.Index(ct, ";"); idx > 0 { if idx := strings.Index(ct, ";"); idx > 0 {
ct = strings.TrimSpace(ct[:idx]) ct = strings.TrimSpace(ct[:idx])
} }
switch ct { switch ct {
case "audio/mpeg": case "audio/mpeg":
return ct, ".mp3" return ct, ".mp3"
case "audio/wave", "audio/wav", "audio/x-wav": case "audio/wave", "audio/wav", "audio/x-wav":
return ct, ".wav" return ct, ".wav"
case "audio/mp4", "audio/x-m4a":
return ct, ".m4a"
case "video/mp4": case "video/mp4":
return ct, ".mp4" return ct, ".mp4"
case "video/webm":
return ct, ".webm"
case "image/png": case "image/png":
return ct, ".png" return ct, ".png"
case "image/jpeg": case "image/jpeg":
return ct, ".jpg" return ct, ".jpg"
case "image/gif":
return ct, ".gif"
case "image/webp":
return ct, ".webp"
case "application/pdf": case "application/pdf":
return ct, ".pdf" return ct, ".pdf"
case "text/plain": case "text/plain":
return ct, ".txt" return ct, ".txt"
case "application/json": case "application/json":
return ct, ".json" return ct, ".json"
case "application/zip":
return ct, ".zip"
case "application/octet-stream":
return ct, ".bin"
default: default:
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json
if parts := strings.Split(ct, "/"); len(parts) == 2 { if parts := strings.Split(ct, "/"); len(parts) == 2 {
sub := parts[1] sub := parts[1]
// 避免出现 "plain; charset=utf-8" 之类的后缀
if idx := strings.Index(sub, ";"); idx > 0 { if idx := strings.Index(sub, ";"); idx > 0 {
sub = strings.TrimSpace(sub[:idx]) sub = strings.TrimSpace(sub[:idx])
} }
return ct, "." + sub return ct, "." + sub
} }
return ct, "" return ct, ".bin"
} }
} }
// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。 // SaveTmpResult 将二进制数据写入临时文件
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) { func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch") dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err return "", fmt.Errorf("创建临时目录失败: %w", err)
} }
if ext == "" { if ext == "" {
ext = ".bin" ext = ".bin"
} }
if ext[0] != '.' { if ext[0] != '.' {
ext = "." + ext ext = "." + ext
} }
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext)) path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil { if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err return "", fmt.Errorf("写入临时文件失败: %w", err)
} }
return path, nil return path, nil
} }
// SaveTempFileByType
// 根据传入的数据自动判断:
// 若是 []byte 且后缀为 .mp3 → 保存二进制音频
// 若是任意结构体/map → 自动转 JSON 保存
// 返回:新临时文件路径、错误
func SaveTempFileByType(taskID string, data any, oldTmpFile string) (string, error) {
// 1. 先清理旧临时文件(统一逻辑)
if oldTmpFile != "" {
_ = os.Remove(oldTmpFile)
}
var tmpPath string
var tmpErr error
// 2. 判断是否是二进制音频([]byte + .mp3
if audioData, ok := data.([]byte); ok {
tmpPath, tmpErr = saveTmpResult(taskID, audioData, ".mp3")
} else {
// 3. 其他类型 → 序列化为 JSON 保存
mappedBytes, err := json.Marshal(data)
if err != nil {
return "", err
}
if len(mappedBytes) == 0 {
return "", nil
}
tmpPath, tmpErr = saveTmpResult(taskID, mappedBytes, ".json")
}
if tmpErr != nil || tmpPath == "" {
return "", tmpErr
}
return tmpPath, nil
}
// saveTmpResult 你原有的底层保存文件方法(保留不动)
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
// 你原来实现,比如:
filename := taskID + ext
tmpPath := filepath.Join(os.TempDir(), filename)
err := os.WriteFile(tmpPath, data, 0644)
return tmpPath, err
}

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"model-gateway/common/util" "model-gateway/common/util"
"model-gateway/consts/public" "model-gateway/consts/public"
"model-gateway/service/queue"
"time" "time"
"model-gateway/dao" "model-gateway/dao"
@@ -28,12 +27,15 @@ type taskService struct{}
// Create 创建任务 // Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
taskID := uuid.NewString() taskID := uuid.NewString()
startAt := time.Now()
// 1) 检查模型配置,并且获取模型 // 1) 获取用户信息
userInfo, err := utils.GetUserInfo(ctx) userInfo, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 2) 检查模型配置
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{ SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId, TenantId: userInfo.TenantId,
@@ -48,86 +50,63 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
return nil, errors.New("模型不存在或未启用") return nil, errors.New("模型不存在或未启用")
} }
// 2) 排队上限严格控制Redis 原子闸门) // TODO: 排队控制暂时关闭,后续需要时取消注释
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2) // limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
if limit > 0 { // if limit > 0 {
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds) // ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
if err != nil { // if err != nil {
return nil, err // return nil, err
} // }
if !ok { // if !ok {
return nil, errors.New("任务排队已满,请稍后再试") // return nil, errors.New("任务排队已满,请稍后再试")
} // }
// }
// 3) 构建任务实体
task := &entity.ModelGatewayTask{
ModelName: model.ModelName,
TaskID: taskID,
State: public.TaskStatusRunning,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
RequestPayload: &entity.RequestPayload{
Body: req.RequestPayload,
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
},
EpicycleId: req.EpicycleId,
} }
// 3) 插入任务记录 // 4) 插入任务记录
requestPayload := entity.RequestPayload{
Body: req.RequestPayload,
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
}
task := new(entity.ModelGatewayTask)
task.ModelName = model.ModelName
task.TaskID = taskID
task.State = public.TaskStatusRunning
task.BizName = req.BizName
task.CallbackURL = req.CallbackUrl
task.RequestPayload = &requestPayload
task.EpicycleId = req.EpicycleId
id, err := dao.ModelGatewayTask.Insert(ctx, task) id, err := dao.ModelGatewayTask.Insert(ctx, task)
if err != nil { // 入库失败:回滚闸门占位 if err != nil {
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID) // TODO: 恢复排队逻辑后,此处需要回滚排队占位
// queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err return nil, err
} }
task.Id = id task.Id = id
// 4) 写操作日志(不影响主流程,失败忽略)
ip := "" // 5) 记录操作日志(非关键路径,失败不影响主流程)
ua := "" ip, ua := "", ""
apiPath := "/task/createTask"
httpMethod := "POST"
if r := g.RequestFromCtx(ctx); r != nil { if r := g.RequestFromCtx(ctx); r != nil {
ip = utils.GetLocalIP() ip = utils.GetLocalIP()
ua = r.UserAgent() ua = r.UserAgent()
apiPath = r.URL.Path
httpMethod = r.Method
} }
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{ _, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
IP: ip, IP: ip,
UserAgent: ua, UserAgent: ua,
APIPath: apiPath, APIPath: "/task/createTask",
HttpMethod: httpMethod, HttpMethod: "POST",
BizName: req.BizName, BizName: req.BizName,
ModelName: req.ModelName, ModelName: req.ModelName,
TaskID: taskID, TaskID: taskID,
OpType: "createTask", OpType: "createTask",
Success: 1, Success: 1,
CostMs: time.Since(time.Now()).Milliseconds(), CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: &requestPayload, RequestPayload: task.RequestPayload,
ResponsePayload: gdb.Map{ ResponsePayload: gdb.Map{"taskId": taskID},
"taskId": taskID,
},
}) })
//// 5) 抢占任务:改为执行中 // 6) 异步执行任务
//rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
// SQLBaseDO: beans.SQLBaseDO{Id: id},
// State: public.TaskStatusRunning,
//})
//if err != nil {
// return nil, err
//}
//if rows == 0 {
// return nil, fmt.Errorf("任务不存在: id=%d", id)
//}
// 6) 查询任务信息
//task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
// SQLBaseDO: beans.SQLBaseDO{Id: id},
//})
//if err != nil {
// return nil, err
//}
// 7) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil return &dto.CreateTaskRes{TaskID: taskID}, nil

View File

@@ -19,7 +19,6 @@ import (
"model-gateway/model/dto" "model-gateway/model/dto"
"model-gateway/model/entity" "model-gateway/model/entity"
"model-gateway/service/gateway" "model-gateway/service/gateway"
"model-gateway/service/queue"
"gitea.redpowerfuture.com/red-future/common/beans" "gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gjson"
@@ -38,50 +37,28 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
body = task.RequestPayload.Body body = task.RequestPayload.Body
maxRetry = model.RetryTimes maxRetry = model.RetryTimes
startTime = time.Now() startTime = time.Now()
rawBytes []byte
result map[string]any result map[string]any
err error err error
) )
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
g.Log().Infof(ctx, "[handleOne] 开始 taskId=%s model=%s", task.TaskID, task.ModelName)
// ============================================ // ============================================
// 1) 分布式并发控制 // 1) 调用模型
// ============================================
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
if err != nil {
task.DurationSeconds = int64(time.Since(startTime).Seconds())
w.failTask(ctx, task, startTime, err.Error())
return
}
if !acquired {
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
State: public.TaskStatusPending,
})
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
return
}
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// ============================================
// 2) 调用模型
// ============================================ // ============================================
for attempt := 0; ; attempt++ { for attempt := 0; ; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] 调模型 第%d次 taskId=%s", attempt, task.TaskID) g.Log().Infof(ctx, "[handleOne] 调模型重试 第%d次 taskId=%s", attempt, task.TaskID)
time.Sleep(time.Duration(attempt) * time.Second) time.Sleep(time.Duration(attempt) * time.Second)
} }
switch { switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream: case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body) rawBytes, err = w.callModelStream(ctx, task, model, body)
if streamErr != nil { if err == nil {
err = streamErr result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
continue
} }
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync: case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body) result, err = w.callModel(ctx, task, model, body)
if err == nil { if err == nil {
@@ -95,24 +72,17 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
break break
} }
if !strings.Contains(err.Error(), "Timeout") { if !strings.Contains(err.Error(), "Timeout") &&
!strings.Contains(err.Error(), "InternalServiceError") {
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
g.Log().Warningf(ctx, "[handleOne] 调模型失败 taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
} }
// ============================================ // ============================================
// 3) 缓存临时文件 // 2) 解析校验 + 响应映射(可重试)
// ============================================
if tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, result, task.TmpFile); tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, _ = dao.ModelGatewayTask.Update(ctx, task)
}
// ============================================
// 4) 解析校验 + 响应映射(可重试)
// ============================================ // ============================================
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime) result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
if err != nil { if err != nil {
@@ -122,34 +92,26 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
} }
// ============================================ // ============================================
// 5) 上传 OSS可重试 // 3) 上传 OSS可重试
// ============================================ // ============================================
var oss *gateway.UploadFileResponse var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ { for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[handleOne] OSS上传重试 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
} }
startUpload := time.Now()
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json") oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
if err == nil { if err == nil {
break break
} }
cost := time.Since(startUpload) g.Log().Errorf(ctx, "[handleOne] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
g.Log().Infof(ctx, "本次上传耗时:%s", cost)
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
task.State = public.TaskStatusFailed
task.ErrorMsg = err.Error()
task.Phase = 1
_, _ = dao.ModelGatewayTask.Update(ctx, task)
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err)) w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return return
} }
} }
// ============================================ // ============================================
// 6) 成功收尾 // 4) 成功收尾
// ============================================ // ============================================
task.State = public.TaskStatusSuccess task.State = public.TaskStatusSuccess
task.DurationSeconds = int64(time.Since(startTime).Seconds()) task.DurationSeconds = int64(time.Since(startTime).Seconds())
@@ -159,21 +121,19 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
FileSize: int64(oss.FileSize), FileSize: int64(oss.FileSize),
} }
task.TextResult = result task.TextResult = result
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil { if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) g.Log().Errorf(ctx, "[handleOne] 更新DB失败 taskId=%s err=%v", task.TaskID, err)
return return
} }
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
go gateway.TriggerCallback(util.AsyncCtx(ctx), task) go gateway.TriggerCallback(util.AsyncCtx(ctx), task)
if req.EpicycleId != 0 { if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId) go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId)
} }
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s", g.Log().Infof(ctx, "[handleOne] 成功 taskId=%s duration=%ds fileType=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat) task.TaskID, task.DurationSeconds, oss.FileFormat)
_ = os.Remove(task.TmpFile)
} }
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出) // callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
@@ -313,12 +273,10 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
// parseAndRetry 解析模型返回结果,并重试 // parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
var lastErr error var lastErr error
for attempt := 0; attempt <= maxRetry; attempt++ { for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
} }
// 1) 响应映射 // 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body) mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil { if err != nil {
@@ -372,7 +330,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
} }
var rawResp map[string]any var rawResp map[string]any
if err := json.Unmarshal(rawData, &rawResp); err != nil { if err = json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue continue
} }
@@ -553,15 +511,11 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[
// return mappedResponse, nil // return mappedResponse, nil
// } // }
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 // failTask 任务失败统一处理
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) { func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
t.State = 3 t.State = 3
t.ErrorMsg = errMsg t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds()) t.DurationSeconds = int64(time.Since(startTime).Seconds())
_, err := dao.ModelGatewayTask.Update(ctx, t) _, _ = dao.ModelGatewayTask.Update(ctx, t) // 更新任务状态
if err != nil { go gateway.TriggerCallback(util.AsyncCtx(ctx), t) // 触发回调
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
}
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(util.AsyncCtx(ctx), t)
} }