package task import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "os" "strings" "sync" "time" "unicode/utf8" "model-gateway/common/util" "model-gateway/consts/public" "model-gateway/dao" "model-gateway/model/dto" "model-gateway/model/entity" "model-gateway/service/gateway" "model-gateway/service/queue" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) var AsyncWorker = &asyncWorker{} type asyncWorker struct { } // handleOne 执行一次完整的任务 func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) { body := util.GetModelBody(task.RequestPayload) // 核心请求参数 maxRetry := model.RetryTimes // 重试次数 startTime := time.Now() g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName) // 1) 分布式并发控制 semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName) maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency) acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600) if err != nil { task.DurationSeconds = int64(time.Since(startTime).Seconds()) w.failTask(ctx, task, startTime, err.Error()) return } if !acquired { g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID) _ = w.rollbackToPending(ctx, task.Id) return } defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }() // 2) 调用模型 switch { case model.CallMode != nil && *model.CallMode == public.CallModeStream: rawBytes, err := w.callModelStream(ctx, task, model, body) if err != nil { w.failTask(ctx, task, startTime, err.Error()) return } body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) if err != nil { w.failTask(ctx, task, startTime, err.Error()) return } case model.CallMode != nil && *model.CallMode == public.CallModeAsync: body, err = w.callModel(ctx, task, model, body) if err != nil { w.failTask(ctx, task, startTime, err.Error()) return } body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg) if err != nil { w.failTask(ctx, task, startTime, err.Error()) return } default: body, err = w.callModel(ctx, task, model, body) if err != nil { w.failTask(ctx, task, startTime, err.Error()) return } } // 3) 保存临时文件 tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile) if err == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) } // 4) 解析校验 + 响应映射(可重试,失败重新调模型) body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime) if err != nil { task.TextResult = body w.failTask(ctx, task, startTime, err.Error()) return } // 5) 上传 OSS(可重试) var oss *gateway.UploadFileResponse for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) } oss, err = w.uploadOSS(ctx, task) if err == nil { break } g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { _ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error()) w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err)) return } } // 6) 成功回调 task.State = 2 task.DurationSeconds = int64(time.Since(startTime).Seconds()) task.OssFile = oss.FileAddressPrefix + oss.FileURL task.FileType = oss.FileFormat task.TextResult = body task.FileSize = int64(oss.FileSize) if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) return } queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), task) if req.EpicycleId != 0 { go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId) } g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s textLen=%d callbackUrl=%s", task.TaskID, task.DurationSeconds, oss.FileFormat, len(body), task.CallbackURL) // 7) 删除临时文件 _ = os.Remove(task.TmpFile) } // callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出) func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) { var data []byte var err error if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" { data, err = os.ReadFile(task.TmpFile) if err != nil || len(data) == 0 { data = nil } } if data == nil { _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) data, err = InvokeModel(ctx, model, body, task.ModelKey) if err != nil { return nil, err } tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, "") if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) } } return data, nil } // asyncResult 异步任务结果 type asyncResult struct { result map[string]any err error } // asyncTaskChan 全局异步任务等待通道 var asyncTaskChan = sync.Map{} // taskID → chan asyncResult func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { // 1. 提交异步任务 body, err := w.callModel(ctx, task, model, body) if err != nil { return nil, err } // 2. 拿到 task_id taskID := gjson.New(body).Get(model.ResponseBody).String() // 3. 创建等待通道 ch := make(chan asyncResult, 1) asyncTaskChan.Store(taskID, ch) defer func() { asyncTaskChan.Delete(taskID) close(ch) }() // 4. 阻塞等待回调或超时 timeout := time.Duration(model.TimeoutSeconds) * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() g.Log().Infof(ctx, "[异步任务] 开始等待结果 taskID=%s timeout=%v", taskID, timeout) select { case res, ok := <-ch: if !ok { return nil, fmt.Errorf("异步任务通道已关闭: taskID=%s", taskID) } g.Log().Infof(ctx, "[异步任务] 获取结果成功 taskID=%s", taskID) return res.result, res.err case <-ctx.Done(): return nil, fmt.Errorf("异步任务超时: taskID=%s", taskID) } } // NotifyAsyncResult 回调接口调用此方法通知结果 func NotifyAsyncResult(taskID string, result map[string]any, err error) { if ch, ok := asyncTaskChan.Load(taskID); ok { ch.(chan asyncResult) <- asyncResult{result: result, err: err} } } // callModel 调用模型 + 检测文件类型 + 保存临时文件 // 返回: 解析后的响应体, error func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { var data []byte var err error // 1) 如果已有临时文件且 phase=1,直接读取 if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" { data, err = os.ReadFile(task.TmpFile) if err != nil || len(data) == 0 { g.Log().Warningf(ctx, "[callModel] 读取临时文件失败,重新调用模型 taskId=%s err=%v", task.TaskID, err) data = nil } } // 2) 没有可用数据,调用模型 if data == nil { _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) data, err = InvokeModel(ctx, model, body, task.ModelKey) if err != nil { return nil, err } // 3) 检测文件类型,保存临时文件 _, ext := util.DetectFileType(data) tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext) if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) } } // 4) 检测文件类型,提取文本结果 contentType, _ := util.DetectFileType(data) var textResult string if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { textResult = string(data) } // 5) 非文本内容,返回错误 if textResult == "" { return nil, fmt.Errorf("模型返回非文本内容,contentType=%s", contentType) } // 6) 解析并返回 return gjson.New(textResult).Map(), nil } // parseAndRetry 解析模型返回结果,并重试 func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) } // 1) 响应映射 mapped, err := util.MapResponsePayload(model.ResponseMapping, body) if err != nil { g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { return nil, fmt.Errorf("响应映射重试耗尽: %w", err) } continue } // 2) 先存 token 到数据库,防止后续失败丢失 if tokens, ok := mapped[model.ResponseTokenField]; ok { task.ExpendTokens = gconv.Int64(tokens) _ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{ ExpendTokens: gconv.Int64(body[model.ResponseTokenField]), }) } // 3) 解析 + 校验 var parsed map[string]any switch req.BuildType { case public.BuildTypePrompt, public.BuildTypeNode: parsed, err = util.ParseAndValidate(mapped, model) if err == nil { return parsed, nil } case public.BuildTypeStruct: parsed = util.ParseStructResult(mapped, model.ResponseBody) return parsed, nil default: return mapped, nil } g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { return nil, fmt.Errorf("JSON解析重试耗尽: %w", err) } // 4) 重新调模型(直接调,不走缓存) _ = dao.Task.IncRetryCountGlobal(ctx, task.Id) reqBody := util.GetModelBody(task.RequestPayload) rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey) if callErr != nil { g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr) continue } // 5) 解析原始响应,覆盖 body 进入下一轮 var rawResp map[string]any if err := json.Unmarshal(rawData, &rawResp); err != nil { g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) continue } body = rawResp } return body, nil } // InvokeModel 调用模型服务,返回二进制结果 // modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key) func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) { // 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 //—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射 //mappedPayload := util.ReverseMap(model.RequestMapping, payload) // 2)构建请求 URL 和超时 baseURL := strings.TrimRight(model.BaseURL, "/") timeout := time.Duration(model.TimeoutSeconds) * time.Second client := &http.Client{Timeout: timeout} method := strings.ToUpper(strings.TrimSpace(model.HttpMethod)) // 3)构建 HTTP 请求 var req *http.Request switch method { case http.MethodGet: q, err := util.BodyToQuery(body) if err != nil { return nil, err } if len(q) > 0 { if strings.Contains(baseURL, "?") { baseURL = baseURL + "&" + q.Encode() } else { baseURL = baseURL + "?" + q.Encode() } } req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) default: bodyBytes, err := json.Marshal(body) if err != nil { return nil, err } req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) } // 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者) for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) { req.Header.Set(hk, hv) } if modelKey != "" { req.Header.Set("Authorization", "Bearer "+modelKey) } if method != http.MethodGet { req.Header.Set("Content-Type", "application/json") } // 5)发送请求 resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() // 6)读取响应体 b, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // 7)检查 HTTP 状态码 if resp.StatusCode < 200 || resp.StatusCode >= 300 { msg := string(b) return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) } return b, nil } // // InvokeModel 调用模型服务,返回二进制结果 // // func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) { // if m == nil || m.BaseURL == "" { // return nil, fmt.Errorf("模型配置不完整") // } // // 请求参数映射 // mappedPayload, err := mapRequestPayload(m.RequestMapping, payload) // if err != nil { // return nil, fmt.Errorf("请求参数映射失败: %w", err) // } // // 合并请求头 // headers := util.ForwardHeaders(ctx) // for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) { // headers[hk] = hv // } // for hk, hv := range parseHeadMsgHeaders(modelKey) { // headers[hk] = hv // } // // // 设置超时 // timeout := time.Duration(m.TimeoutSeconds) * time.Second // if timeout <= 0 { // timeout = 600 * time.Second // } // ctx, cancel := context.WithTimeout(ctx, timeout) // defer cancel() // // invokeUrl := strings.TrimRight(m.BaseURL, "/") // method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) // if method == "" { // method = http.MethodPost // } // // var respBytes []byte // // switch method { // case http.MethodGet: // err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload) // default: // err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload) // } // if err != nil { // return nil, err // } // // 响应参数映射 // mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes) // if err != nil { // g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) // return respBytes, nil // } // return mappedResponse, nil // } // uploadOSS 从临时文件上传 OSS func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) { data, err := os.ReadFile(t.TmpFile) if err != nil { return nil, fmt.Errorf("读取临时文件失败: %w", err) } _, ext := util.DetectFileType(data) return gateway.UploadByTask(ctx, data, ext) } // failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) { t.State = 3 t.ErrorMsg = errMsg t.DurationSeconds = int64(time.Since(startTime).Seconds()) _ = dao.Task.UpdateFailedGlobal(ctx, t) queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) } // rollbackToPending 恢复任务状态为 PENDING func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { return dao.Task.RollbackToPendingGlobal(ctx, id) }