package task import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "model-gateway/common/util" "model-gateway/model/dto" "model-gateway/service/gateway" "model-gateway/service/queue" "net/http" "os" "path/filepath" "strings" "time" "unicode/utf8" "model-gateway/dao" "model-gateway/model/entity" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/grpool" ) var AsyncWorker = &asyncWorker{} type asyncWorker struct { } // RunOnce 由上层定时任务触发:一次性抢占并处理一批任务 // - batchSize: 本次抢占数量 // - goroutines: 本次并发数(协程池大小) func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) { if req.BatchSize <= 0 { req.BatchSize = 10 } if req.Goroutines <= 0 { req.Goroutines = 1 } tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize) if err != nil { return nil, err } if len(tasks) == 0 { return nil, errors.New("no task to run") } pool := grpool.New(req.Goroutines) defer pool.Close() claimed := len(tasks) done := make(chan struct{}, claimed) for _, t := range tasks { task := t _ = pool.AddWithRecover(ctx, func(ctx context.Context) { w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0}) done <- struct{}{} }, func(ctx context.Context, e error) { if e != nil { task.ErrorMsg = fmt.Sprintf("worker panic: %v", e) _ = dao.Task.UpdateFailedGlobal(ctx, task) queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) } done <- struct{}{} }) } for i := 0; i < claimed; i++ { <-done } return &dto.RunWorkRes{ Claimed: claimed, }, nil } // RunByTaskID 创建任务后立即异步尝试执行当前任务: // - 只定向抢占当前 taskId 对应的 pending 任务 // - 若任务已被其它 worker 抢走/已不在 pending,则直接返回 func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, req *dto.CreateTaskReq) error { task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID) if err != nil { return err } if task == nil { return nil } w.handleOne(ctx, task, req) return nil } // handleOne 执行一次完整的任务 func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *dto.CreateTaskReq) { payload := util.ParseStoredPayload(t.RequestPayload) maxRetry := 0 // 后面从 model 取 g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", t.TaskID, t.ModelName) // 1) 获取模型配置 model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) if err != nil || model == nil { w.failTask(ctx, t, "模型不存在或未启用") return } maxRetry = model.RetryTimes // 2) 分布式并发控制 semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName) maxC := queue.GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency) acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600) if err != nil { w.failTask(ctx, t, err.Error()) return } if !acquired { g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", t.TaskID) _ = w.rollbackToPending(ctx, t.Id) return } defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }() // 3) request_payload 校验 if payload == nil { w.failTask(ctx, t, "request_payload 为空") return } // 4) 调用模型(不重试,失败直接回调) textResult, err := w.callModel(ctx, t, model, payload) if err != nil { w.failTask(ctx, t, err.Error()) return } // 5) 上传 OSS(可重试) var oss *gateway.UploadFileResponse for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) } oss, err = w.uploadOSS(ctx, t) if err == nil { break } g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", t.TaskID, attempt, maxRetry, err) if attempt == maxRetry { _ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error()) w.failTask(ctx, t, fmt.Sprintf("OSS上传重试耗尽: %v", err)) return } } // 6) 解析校验(可重试,失败重新调模型) //if req.BuildType == 1 { // for attempt := 0; attempt <= maxRetry; attempt++ { // if attempt > 0 { // g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) // } // // 6.1) 校验数据 // err = util.ValidatePromptResult(textResult, model) // if err == nil { // break // } // g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", // t.TaskID, attempt, maxRetry, err) // if attempt == maxRetry { // w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err)) // return // } // // 6.2) 重新调模型 // newResult, modelErr := w.callModel(ctx, t, model, payload) // if modelErr != nil { // g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v", // t.TaskID, attempt, maxRetry, modelErr) // continue // } // textResult = newResult // } //} // 7) 成功回调 t.State = 2 t.OssFile = oss.FileAddressPrefix + oss.FileURL t.FileType = oss.FileFormat t.TextResult = textResult t.FileSize = int64(oss.FileSize) t.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, textResult)) if err = dao.Task.UpdateSuccessGlobal(ctx, t); err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", t.TaskID, err) return } queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) if req.EpicycleId != 0 { go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId) } g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s", t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL) _ = os.Remove(t.TmpFile) } // 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试) // callModel 调用模型 + 检测文件类型 + 保存临时文件 func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) { var data []byte var contentType, ext, textResult string var err error if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" { data, err = os.ReadFile(task.TmpFile) if err != nil || len(data) == 0 { data = nil } } if data == nil { _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) data, err = InvokeModel(ctx, model, payload, task.ModelKey) if err != nil { return nil, err } tmpPath, tmpErr := saveTmpResult(task.TaskID, data, ext) if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) } } contentType, ext = util.DetectFileType(data) if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { textResult = string(data) } return gjson.New(textResult).Map(), nil } // InvokeModel 调用模型服务,返回二进制结果 // modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key) func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[string]any, modelKey string) ([]byte, error) { // 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 //mappedPayload := util.ReverseMap(model.RequestMapping, payload) // 2)构建请求 URL 和超时 baseURL := strings.TrimRight(model.BaseURL, "/") timeout := time.Duration(model.TimeoutSeconds) * time.Second client := &http.Client{Timeout: timeout} method := strings.ToUpper(strings.TrimSpace(model.HttpMethod)) // 3)构建 HTTP 请求 var req *http.Request switch method { case http.MethodGet: q, err := util.PayloadToQuery(payload) if err != nil { return nil, err } if len(q) > 0 { if strings.Contains(baseURL, "?") { baseURL = baseURL + "&" + q.Encode() } else { baseURL = baseURL + "?" + q.Encode() } } req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) default: bodyBytes, err := json.Marshal(payload) if err != nil { return nil, err } req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) } // 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者) for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) { req.Header.Set(hk, hv) } for hk, hv := range util.ParseHeadMsgHeaders(modelKey) { req.Header.Set(hk, hv) } if method != http.MethodGet { req.Header.Set("Content-Type", "application/json") } // 5)发送请求 resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() // 6)读取响应体 b, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // 7)检查 HTTP 状态码 if resp.StatusCode < 200 || resp.StatusCode >= 300 { msg := string(b) return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) } // 8)响应参数映射 mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b) if err != nil { g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) return b, nil } return mappedResponse, nil } // // InvokeModel 调用模型服务,返回二进制结果 // // func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) { // if m == nil || m.BaseURL == "" { // return nil, fmt.Errorf("模型配置不完整") // } // // 请求参数映射 // mappedPayload, err := mapRequestPayload(m.RequestMapping, payload) // if err != nil { // return nil, fmt.Errorf("请求参数映射失败: %w", err) // } // // 合并请求头 // headers := util.ForwardHeaders(ctx) // for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) { // headers[hk] = hv // } // for hk, hv := range parseHeadMsgHeaders(modelKey) { // headers[hk] = hv // } // // // 设置超时 // timeout := time.Duration(m.TimeoutSeconds) * time.Second // if timeout <= 0 { // timeout = 600 * time.Second // } // ctx, cancel := context.WithTimeout(ctx, timeout) // defer cancel() // // invokeUrl := strings.TrimRight(m.BaseURL, "/") // method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) // if method == "" { // method = http.MethodPost // } // // var respBytes []byte // // switch method { // case http.MethodGet: // err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload) // default: // err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload) // } // if err != nil { // return nil, err // } // // 响应参数映射 // mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes) // if err != nil { // g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) // return respBytes, nil // } // return mappedResponse, nil // } // uploadOSS 从临时文件上传 OSS func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) { data, err := os.ReadFile(t.TmpFile) if err != nil { return nil, fmt.Errorf("读取临时文件失败: %w", err) } _, ext := util.DetectFileType(data) return gateway.UploadByTask(ctx, data, ext) } // failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg string) { t.State = 3 t.ErrorMsg = errMsg _ = dao.Task.UpdateFailedGlobal(ctx, t) queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) } // saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。 func saveTmpResult(taskID string, data []byte, ext string) (string, error) { dir := filepath.Join(os.TempDir(), "model-asynch") if err := os.MkdirAll(dir, 0o755); err != nil { return "", err } if ext == "" { ext = ".bin" } if ext[0] != '.' { ext = "." + ext } path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext)) if err := os.WriteFile(path, data, 0o644); err != nil { return "", err } return path, nil } func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { return dao.Task.RollbackToPendingGlobal(ctx, id) } // GetExpendTokens 根据映射路径从 result 中提取消耗 token 值 func GetExpendTokens(responseTokenField string, result map[string]any) int { val := gjson.New(result).Get(responseTokenField) if val.IsNil() { return 0 } return val.Int() }