From b21d7a8dbffddcffa94521b6c697d548e368c54c Mon Sep 17 00:00:00 2001 From: qhd <1766646056@qq.com> Date: Thu, 18 Jun 2026 10:08:36 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E5=A4=B4=E8=BD=AC=E5=8F=91=E4=B8=8E=E4=BB=BB=E5=8A=A1=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E6=B5=81=E8=BD=AC=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 移除 util.ForwardHeaders,改为从原始请求精确提取 Authorization 或全部请求头; 任务创建时直接设为 Running 状态,避免二次更新与查询; 模型调用使用独立超时上下文,防止外层取消影响回调; 增加 OSS 上传耗时日志,调整数据库连接池参数。 --- config.yml | 4 +- service/gateway/gateway_http_service.go | 54 ++++++++++++++++++++----- service/task/task_service.go | 54 ++++++++++++------------- service/task/worker.go | 23 ++++++++--- 4 files changed, 90 insertions(+), 45 deletions(-) diff --git a/config.yml b/config.yml index e7d2193..ef1c857 100644 --- a/config.yml +++ b/config.yml @@ -39,8 +39,8 @@ database: dryRun: false charset: "utf8" timezone: "Asia/Shanghai" - maxIdle: 5 - maxOpen: 20 + maxIdle: 15 + maxOpen: 60 maxLifetime: "30s" maxIdleConnTime: "30s" createdAt: "created_at" diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index efc3c35..3782545 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "mime/multipart" - "model-gateway/common/util" "model-gateway/model/entity" "time" @@ -43,16 +42,25 @@ func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *Upload if err != nil { return nil, err } - if _, err := part.Write(data); err != nil { + if _, err = part.Write(data); err != nil { return nil, err } - contentType := writer.FormDataContentType() + //contentType := writer.FormDataContentType() if err = writer.Close(); err != nil { return nil, err } - headers := util.ForwardHeaders(ctx) - headers["Content-Type"] = contentType + //headers := util.ForwardHeaders(ctx) + //headers["Content-Type"] = contentType + + headers := make(map[string]string) + headers["Content-Type"] = writer.FormDataContentType() + if r := g.RequestFromCtx(ctx); r != nil { + if auth := r.Header.Get("Authorization"); auth != "" { + headers["Authorization"] = auth + } + } + fullURL := "oss/file/uploadFile" g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data)) @@ -78,15 +86,25 @@ type CallbackPayload struct { // TriggerCallback 任务的回调 func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) { - headers := util.ForwardHeaders(ctx) + //headers := util.ForwardHeaders(ctx) + headers := make(map[string]string) + if r := g.RequestFromCtx(ctx); r != nil { + for k, v := range r.Request.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + } var resp struct{} payload := CallbackPayload{ TaskId: t.TaskID, State: t.State, - OssFile: t.ResultFile.OssFile, - FileType: t.ResultFile.FileType, ErrorMsg: t.ErrorMsg, } + if !g.IsEmpty(t.ResultFile) { + payload.OssFile = t.ResultFile.OssFile + payload.FileType = t.ResultFile.FileType + } jsonData, err := json.Marshal(payload) if err != nil { g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err) @@ -112,7 +130,15 @@ type PromptsCallbackPayload struct { // TriggerPromptsCallback 任务成功后的提示词回调 func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) { callbackURL := "prompts-core/session/callback" - headers := util.ForwardHeaders(ctx) + //headers := util.ForwardHeaders(ctx) + headers := make(map[string]string) + if r := g.RequestFromCtx(ctx); r != nil { + for k, v := range r.Request.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + } var resp struct{} payload := PromptsCallbackPayload{ EpicycleId: epicycleId, @@ -136,7 +162,15 @@ func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epi // IsSuperAdmin 调用admin-go服务检查是否是超级管理员 func IsSuperAdmin(ctx context.Context) (res bool, err error) { - headers := util.ForwardHeaders(ctx) + //headers := util.ForwardHeaders(ctx) + headers := make(map[string]string) + if r := g.RequestFromCtx(ctx); r != nil { + for k, v := range r.Request.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + } var r = make(map[string]bool) if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil { return false, err diff --git a/service/task/task_service.go b/service/task/task_service.go index 6e5f710..b1d39f5 100644 --- a/service/task/task_service.go +++ b/service/task/task_service.go @@ -65,20 +65,20 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * Body: req.RequestPayload, Headers: util.ParseHeadMsgHeaders(model.HeadMsg), } - id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{ - ModelName: req.ModelName, - TaskID: taskID, - State: public.TaskStatusPending, - BizName: req.BizName, - CallbackURL: req.CallbackUrl, - RequestPayload: &requestPayload, - EpicycleId: req.EpicycleId, - }) + task := new(entity.ModelGatewayTask) + task.ModelName = model.ModelName + task.TaskID = taskID + task.State = public.TaskStatusRunning + task.BizName = req.BizName + task.CallbackURL = req.CallbackUrl + task.RequestPayload = &requestPayload + task.EpicycleId = req.EpicycleId + id, err := dao.ModelGatewayTask.Insert(ctx, task) if err != nil { // 入库失败:回滚闸门占位 queue.ReleaseQueueSlot(ctx, req.ModelName, taskID) return nil, err } - + task.Id = id // 4) 写操作日志(不影响主流程,失败忽略) ip := "" ua := "" @@ -107,25 +107,25 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * }, }) - // 5) 抢占任务:改为执行中 - rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ - SQLBaseDO: beans.SQLBaseDO{Id: id}, - State: public.TaskStatusRunning, - }) - if err != nil { - return nil, err - } - if rows == 0 { - return nil, fmt.Errorf("任务不存在: id=%d", id) - } + //// 5) 抢占任务:改为执行中 + //rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ + // SQLBaseDO: beans.SQLBaseDO{Id: id}, + // State: public.TaskStatusRunning, + //}) + //if err != nil { + // return nil, err + //} + //if rows == 0 { + // return nil, fmt.Errorf("任务不存在: id=%d", id) + //} // 6) 查询任务信息 - task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{ - SQLBaseDO: beans.SQLBaseDO{Id: id}, - }) - if err != nil { - return nil, err - } + //task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{ + // SQLBaseDO: beans.SQLBaseDO{Id: id}, + //}) + //if err != nil { + // return nil, err + //} // 7) 创建成功后立即异步尝试执行当前任务 go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) diff --git a/service/task/worker.go b/service/task/worker.go index 2236c36..b74b91e 100644 --- a/service/task/worker.go +++ b/service/task/worker.go @@ -129,10 +129,14 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) } + startUpload := time.Now() oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json") if err == nil { break } + cost := time.Since(startUpload) + g.Log().Infof(ctx, "本次上传耗时:%s", cost) + g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { task.State = public.TaskStatusFailed @@ -161,9 +165,9 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa } queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) - go gateway.TriggerCallback(context.WithoutCancel(ctx), task) + go gateway.TriggerCallback(util.AsyncCtx(ctx), task) if req.EpicycleId != 0 { - go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId) + go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId) } g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s", @@ -420,7 +424,7 @@ func injectErrorMessage(payload map[string]any, err error) map[string]any { // modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key) func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) { // 1) 记录模型调用次数 - _ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName) + //_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName) // 2)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 //—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射 @@ -447,13 +451,20 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[ baseURL = baseURL + "?" + q.Encode() } } - req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) + // 改用独立超时ctx,隔绝外层截止 + reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout) + defer reqCancel() + req, err = http.NewRequestWithContext(reqCtx, http.MethodGet, baseURL, nil) + //req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) default: bodyBytes, err := json.Marshal(body) if err != nil { return nil, err } - req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) + reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout) + defer reqCancel() + req, err = http.NewRequestWithContext(reqCtx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) + //req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) } // 5)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者) @@ -552,5 +563,5 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err) } queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - go gateway.TriggerCallback(context.WithoutCancel(ctx), t) + go gateway.TriggerCallback(util.AsyncCtx(ctx), t) }