package service import ( "context" "fmt" "strings" "time" "unicode/utf8" "model-gateway/dao" "model-gateway/model/entity" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/grpool" "github.com/tidwall/gjson" ) var AsyncWorker = &asyncWorker{} type asyncWorker struct { } // RunOnce 由上层定时任务触发:一次性抢占并处理一批任务 // - batchSize: 本次抢占数量 // - goroutines: 本次并发数(协程池大小) func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) { if batchSize <= 0 { batchSize = 10 } if goroutines <= 0 { goroutines = 1 } tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize) if err != nil { return 0, err } if len(tasks) == 0 { return 0, nil } pool := grpool.New(goroutines) defer pool.Close() claimed = len(tasks) done := make(chan struct{}, claimed) for _, t := range tasks { task := t _ = pool.AddWithRecover(ctx, func(ctx context.Context) { w.handleOne(ctx, task, 0) done <- struct{}{} }, func(ctx context.Context, e error) { if e != nil { _ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", e)) ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) } done <- struct{}{} }) } for i := 0; i < claimed; i++ { <-done } return claimed, nil } // RunByTaskID 创建任务后立即异步尝试执行当前任务: // - 只定向抢占当前 taskId 对应的 pending 任务 // - 若任务已被其它 worker 抢走/已不在 pending,则直接返回 func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId int64) error { task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID) if err != nil { return err } if task == nil { return nil } w.handleOne(ctx, task, epicycleId) return nil } func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) { // 从任务入库的 request_payload 里恢复 payload + headers payload, headers := parseStoredPayload(t.RequestPayload) if len(headers) > 0 { ctx = setTaskHeadersToCtx(ctx, headers) } // 1) 拉取模型配置 m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) if err != nil { _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) // ============ 失败回调 ============ t.State = 3 t.ErrorMsg = err.Error() go triggerCallback(context.WithoutCancel(ctx), t) // ================================ return } if m == nil || (m.Enabled != nil && *m.Enabled != 1) { errMsg := "模型不存在或未启用" _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, errMsg) ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) // ============ 失败回调 ============ t.State = 3 t.ErrorMsg = errMsg go triggerCallback(context.WithoutCancel(ctx), t) // ================================ return } // 2) 分布式并发限制 semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName) leaseSeconds := int64(3600) maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency) acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds) if err != nil { _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) // ============ 失败回调 ============ t.State = 3 t.ErrorMsg = err.Error() go triggerCallback(context.WithoutCancel(ctx), t) // ================================ return } if !acquired { // 并发满了:放回排队,不回调(不是失败) _ = w.rollbackToPending(ctx, t.Id) return } defer func() { _ = releaseSemaphore(ctx, semKey) }() // 3) 调用模型服务 if payload == nil { payload = map[string]any{ "taskId": t.TaskID, "inputRef": t.InputRef, } } var ( data []byte contentType string ext string textResult string ) // phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载 if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" { data, err = loadTmpResult(t.TmpFile) if err == nil && len(data) > 0 { contentType, ext = DetectFileType(data) } else { data = nil } } if data == nil { // 统计 _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName) // 核心调用 data, err = InvokeModel(ctx, m, payload, t.ModelKey) if err != nil { _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) // ============ 失败回调 ============ t.State = 3 t.ErrorMsg = err.Error() go triggerCallback(context.WithoutCancel(ctx), t) // ================================ return } contentType, ext = DetectFileType(data) if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { textResult = string(data) } tmpPath, err := saveTmpResult(t.TaskID, data, ext) if err == nil && tmpPath != "" { t.TmpFile = tmpPath t.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath) } } // 4) 存储 OSS ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType) if err != nil { // OSS 阶段失败:保留临时文件,下一轮仅重试 OSS _ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error()) ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) // ============ OSS失败不回调(还会重试) ============ // 注意:OSS失败保留临时文件,下次重试,所以这里不触发最终回调 // 如果已经重试多次还没成功,需要在任务超时或超过最大重试次数时才回调失败 return } // 5) 更新任务状态成功 fileType := strings.TrimPrefix(ext, ".") if fileType == "" { fileType = contentType } if err := dao.Task.UpdateSuccessGlobal( ctx, t.Id, ossURL, fileType, textResult, int64(len(data)), nil, GetExpendTokens(m.TokenMapping, textResult), ); err != nil { g.Log().Errorf(ctx, "[worker] update success failed: %v", err) return } // 成功/失败均不再占用 queue_limit ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) // 6) 成功回调 t.State = 2 t.OssFile = ossURL t.FileType = fileType t.TextResult = textResult g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL) go triggerCallback(context.WithoutCancel(ctx), t) // ============ 如果有 epicycleId,也触发业务回调 ============ if epicycleId != 0 { go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId) } // 成功后清理临时文件 deleteTmpResult(t.TmpFile) } func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { return dao.Task.RollbackToPendingGlobal(ctx, id) } // GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值 func GetExpendTokens(tokenMapping string, textResult string) int { value := gjson.Get(textResult, tokenMapping) if value.Exists() { return int(value.Int()) } else { return len(textResult) } }