diff --git a/common/util/files.go b/common/util/files.go index 45e28fc..a5f57c4 100644 --- a/common/util/files.go +++ b/common/util/files.go @@ -1,7 +1,6 @@ package util import ( - "encoding/json" "fmt" "net/http" "os" @@ -9,107 +8,75 @@ import ( "strings" ) -// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定) +// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名 func DetectFileType(data []byte) (contentType string, ext string) { if len(data) == 0 { - return "application/octet-stream", "" + return "application/octet-stream", ".bin" } + ct := http.DetectContentType(data) - // gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8 if idx := strings.Index(ct, ";"); idx > 0 { ct = strings.TrimSpace(ct[:idx]) } + switch ct { case "audio/mpeg": return ct, ".mp3" case "audio/wave", "audio/wav", "audio/x-wav": return ct, ".wav" + case "audio/mp4", "audio/x-m4a": + return ct, ".m4a" case "video/mp4": return ct, ".mp4" + case "video/webm": + return ct, ".webm" case "image/png": return ct, ".png" case "image/jpeg": return ct, ".jpg" + case "image/gif": + return ct, ".gif" + case "image/webp": + return ct, ".webp" case "application/pdf": return ct, ".pdf" case "text/plain": return ct, ".txt" case "application/json": return ct, ".json" + case "application/zip": + return ct, ".zip" + case "application/octet-stream": + return ct, ".bin" default: - // 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json) if parts := strings.Split(ct, "/"); len(parts) == 2 { sub := parts[1] - // 避免出现 "plain; charset=utf-8" 之类的后缀 if idx := strings.Index(sub, ";"); idx > 0 { sub = strings.TrimSpace(sub[:idx]) } return ct, "." + sub } - return ct, "" + return ct, ".bin" } } -// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。 +// SaveTmpResult 将二进制数据写入临时文件 func SaveTmpResult(taskID string, data []byte, ext string) (string, error) { dir := filepath.Join(os.TempDir(), "model-asynch") if err := os.MkdirAll(dir, 0o755); err != nil { - return "", err + return "", fmt.Errorf("创建临时目录失败: %w", err) } + if ext == "" { ext = ".bin" } if ext[0] != '.' { ext = "." + ext } + path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext)) if err := os.WriteFile(path, data, 0o644); err != nil { - return "", err + return "", fmt.Errorf("写入临时文件失败: %w", err) } return path, nil } - -// SaveTempFileByType -// 根据传入的数据自动判断: -// 若是 []byte 且后缀为 .mp3 → 保存二进制音频 -// 若是任意结构体/map → 自动转 JSON 保存 -// 返回:新临时文件路径、错误 -func SaveTempFileByType(taskID string, data any, oldTmpFile string) (string, error) { - // 1. 先清理旧临时文件(统一逻辑) - if oldTmpFile != "" { - _ = os.Remove(oldTmpFile) - } - - var tmpPath string - var tmpErr error - - // 2. 判断是否是二进制音频([]byte + .mp3) - if audioData, ok := data.([]byte); ok { - tmpPath, tmpErr = saveTmpResult(taskID, audioData, ".mp3") - } else { - // 3. 其他类型 → 序列化为 JSON 保存 - mappedBytes, err := json.Marshal(data) - if err != nil { - return "", err - } - if len(mappedBytes) == 0 { - return "", nil - } - tmpPath, tmpErr = saveTmpResult(taskID, mappedBytes, ".json") - } - - if tmpErr != nil || tmpPath == "" { - return "", tmpErr - } - - return tmpPath, nil -} - -// saveTmpResult 你原有的底层保存文件方法(保留不动) -func saveTmpResult(taskID string, data []byte, ext string) (string, error) { - // 你原来实现,比如: - filename := taskID + ext - tmpPath := filepath.Join(os.TempDir(), filename) - err := os.WriteFile(tmpPath, data, 0644) - return tmpPath, err -} diff --git a/service/task/task_service.go b/service/task/task_service.go index b1d39f5..f8905b0 100644 --- a/service/task/task_service.go +++ b/service/task/task_service.go @@ -6,7 +6,6 @@ import ( "fmt" "model-gateway/common/util" "model-gateway/consts/public" - "model-gateway/service/queue" "time" "model-gateway/dao" @@ -28,12 +27,15 @@ type taskService struct{} // Create 创建任务 func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { taskID := uuid.NewString() + startAt := time.Now() - // 1) 检查模型配置,并且获取模型 + // 1) 获取用户信息 userInfo, err := utils.GetUserInfo(ctx) if err != nil { return nil, err } + + // 2) 检查模型配置 model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{ TenantId: userInfo.TenantId, @@ -48,86 +50,63 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * return nil, errors.New("模型不存在或未启用") } - // 2) 排队上限(严格控制:Redis 原子闸门) - limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2) - if limit > 0 { - ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds) - if err != nil { - return nil, err - } - if !ok { - return nil, errors.New("任务排队已满,请稍后再试") - } + // TODO: 排队控制暂时关闭,后续需要时取消注释 + // limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2) + // if limit > 0 { + // ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds) + // if err != nil { + // return nil, err + // } + // if !ok { + // return nil, errors.New("任务排队已满,请稍后再试") + // } + // } + + // 3) 构建任务实体 + task := &entity.ModelGatewayTask{ + ModelName: model.ModelName, + TaskID: taskID, + State: public.TaskStatusRunning, + BizName: req.BizName, + CallbackURL: req.CallbackUrl, + RequestPayload: &entity.RequestPayload{ + Body: req.RequestPayload, + Headers: util.ParseHeadMsgHeaders(model.HeadMsg), + }, + EpicycleId: req.EpicycleId, } - // 3) 插入任务记录 - requestPayload := entity.RequestPayload{ - Body: req.RequestPayload, - Headers: util.ParseHeadMsgHeaders(model.HeadMsg), - } - task := new(entity.ModelGatewayTask) - task.ModelName = model.ModelName - task.TaskID = taskID - task.State = public.TaskStatusRunning - task.BizName = req.BizName - task.CallbackURL = req.CallbackUrl - task.RequestPayload = &requestPayload - task.EpicycleId = req.EpicycleId + // 4) 插入任务记录 id, err := dao.ModelGatewayTask.Insert(ctx, task) - if err != nil { // 入库失败:回滚闸门占位 - queue.ReleaseQueueSlot(ctx, req.ModelName, taskID) + if err != nil { + // TODO: 恢复排队逻辑后,此处需要回滚排队占位 + // queue.ReleaseQueueSlot(ctx, req.ModelName, taskID) return nil, err } task.Id = id - // 4) 写操作日志(不影响主流程,失败忽略) - ip := "" - ua := "" - apiPath := "/task/createTask" - httpMethod := "POST" + + // 5) 记录操作日志(非关键路径,失败不影响主流程) + ip, ua := "", "" if r := g.RequestFromCtx(ctx); r != nil { ip = utils.GetLocalIP() ua = r.UserAgent() - apiPath = r.URL.Path - httpMethod = r.Method } _, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{ - IP: ip, - UserAgent: ua, - APIPath: apiPath, - HttpMethod: httpMethod, - BizName: req.BizName, - ModelName: req.ModelName, - TaskID: taskID, - OpType: "createTask", - Success: 1, - CostMs: time.Since(time.Now()).Milliseconds(), - RequestPayload: &requestPayload, - ResponsePayload: gdb.Map{ - "taskId": taskID, - }, + IP: ip, + UserAgent: ua, + APIPath: "/task/createTask", + HttpMethod: "POST", + BizName: req.BizName, + ModelName: req.ModelName, + TaskID: taskID, + OpType: "createTask", + Success: 1, + CostMs: time.Since(startAt).Milliseconds(), + RequestPayload: task.RequestPayload, + ResponsePayload: gdb.Map{"taskId": taskID}, }) - //// 5) 抢占任务:改为执行中 - //rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ - // SQLBaseDO: beans.SQLBaseDO{Id: id}, - // State: public.TaskStatusRunning, - //}) - //if err != nil { - // return nil, err - //} - //if rows == 0 { - // return nil, fmt.Errorf("任务不存在: id=%d", id) - //} - - // 6) 查询任务信息 - //task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{ - // SQLBaseDO: beans.SQLBaseDO{Id: id}, - //}) - //if err != nil { - // return nil, err - //} - - // 7) 创建成功后立即异步尝试执行当前任务 + // 6) 异步执行任务 go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) return &dto.CreateTaskRes{TaskID: taskID}, nil diff --git a/service/task/worker.go b/service/task/worker.go index b74b91e..b18f399 100644 --- a/service/task/worker.go +++ b/service/task/worker.go @@ -19,7 +19,6 @@ import ( "model-gateway/model/dto" "model-gateway/model/entity" "model-gateway/service/gateway" - "model-gateway/service/queue" "gitea.redpowerfuture.com/red-future/common/beans" "github.com/gogf/gf/v2/encoding/gjson" @@ -38,50 +37,28 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa body = task.RequestPayload.Body maxRetry = model.RetryTimes startTime = time.Now() + rawBytes []byte result map[string]any err error ) - g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName) + + g.Log().Infof(ctx, "[handleOne] 开始 taskId=%s model=%s", task.TaskID, task.ModelName) // ============================================ - // 1) 分布式并发控制 - // ============================================ - semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName) - maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency) - acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600) - if err != nil { - task.DurationSeconds = int64(time.Since(startTime).Seconds()) - w.failTask(ctx, task, startTime, err.Error()) - return - } - if !acquired { - _, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ - SQLBaseDO: beans.SQLBaseDO{Id: task.Id}, - State: public.TaskStatusPending, - }) - g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID) - return - } - defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }() - - // ============================================ - // 2) 调用模型 + // 1) 调用模型 // ============================================ for attempt := 0; ; attempt++ { if attempt > 0 { - g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID) + g.Log().Infof(ctx, "[handleOne] 调模型重试 第%d次 taskId=%s", attempt, task.TaskID) time.Sleep(time.Duration(attempt) * time.Second) } switch { case model.CallMode != nil && *model.CallMode == public.CallModeStream: - rawBytes, streamErr := w.callModelStream(ctx, task, model, body) - if streamErr != nil { - err = streamErr - g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err) - continue + rawBytes, err = w.callModelStream(ctx, task, model, body) + if err == nil { + result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) } - result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) case model.CallMode != nil && *model.CallMode == public.CallModeAsync: result, err = w.callModel(ctx, task, model, body) if err == nil { @@ -95,24 +72,17 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa break } - if !strings.Contains(err.Error(), "Timeout") { + if !strings.Contains(err.Error(), "Timeout") && + !strings.Contains(err.Error(), "InternalServiceError") { w.failTask(ctx, task, startTime, err.Error()) return } - g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err) + + g.Log().Warningf(ctx, "[handleOne] 调模型失败 taskId=%s attempt=%d err=%v", task.TaskID, attempt, err) } // ============================================ - // 3) 缓存临时文件 - // ============================================ - if tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, result, task.TmpFile); tmpErr == nil && tmpPath != "" { - task.TmpFile = tmpPath - task.Phase = 1 - _, _ = dao.ModelGatewayTask.Update(ctx, task) - } - - // ============================================ - // 4) 解析校验 + 响应映射(可重试) + // 2) 解析校验 + 响应映射(可重试) // ============================================ result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime) if err != nil { @@ -122,34 +92,26 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa } // ============================================ - // 5) 上传 OSS(可重试) + // 3) 上传 OSS(可重试) // ============================================ var oss *gateway.UploadFileResponse for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { - g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) + g.Log().Infof(ctx, "[handleOne] OSS上传重试 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) } - startUpload := time.Now() oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json") if err == nil { break } - cost := time.Since(startUpload) - g.Log().Infof(ctx, "本次上传耗时:%s", cost) - - g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) + g.Log().Errorf(ctx, "[handleOne] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { - task.State = public.TaskStatusFailed - task.ErrorMsg = err.Error() - task.Phase = 1 - _, _ = dao.ModelGatewayTask.Update(ctx, task) w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err)) return } } // ============================================ - // 6) 成功收尾 + // 4) 成功收尾 // ============================================ task.State = public.TaskStatusSuccess task.DurationSeconds = int64(time.Since(startTime).Seconds()) @@ -159,21 +121,19 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa FileSize: int64(oss.FileSize), } task.TextResult = result + if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil { - g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) + g.Log().Errorf(ctx, "[handleOne] 更新DB失败 taskId=%s err=%v", task.TaskID, err) return } - queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) go gateway.TriggerCallback(util.AsyncCtx(ctx), task) if req.EpicycleId != 0 { go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId) } - g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s", + g.Log().Infof(ctx, "[handleOne] 成功 taskId=%s duration=%ds fileType=%s", task.TaskID, task.DurationSeconds, oss.FileFormat) - - _ = os.Remove(task.TmpFile) } // callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出) @@ -313,12 +273,10 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa // parseAndRetry 解析模型返回结果,并重试 func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { var lastErr error - for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) } - // 1) 响应映射 mapped, err := util.MapResponsePayload(model.ResponseMapping, body) if err != nil { @@ -372,7 +330,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta } var rawResp map[string]any - if err := json.Unmarshal(rawData, &rawResp); err != nil { + if err = json.Unmarshal(rawData, &rawResp); err != nil { g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) continue } @@ -553,15 +511,11 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[ // return mappedResponse, nil // } -// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 +// failTask 任务失败统一处理 func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) { t.State = 3 t.ErrorMsg = errMsg t.DurationSeconds = int64(time.Since(startTime).Seconds()) - _, err := dao.ModelGatewayTask.Update(ctx, t) - if err != nil { - g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err) - } - queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - go gateway.TriggerCallback(util.AsyncCtx(ctx), t) + _, _ = dao.ModelGatewayTask.Update(ctx, t) // 更新任务状态 + go gateway.TriggerCallback(util.AsyncCtx(ctx), t) // 触发回调 }