From dd79643170a330b2f96a2d18a0f30e112e4be36b Mon Sep 17 00:00:00 2001 From: qhd <1766646056@qq.com> Date: Thu, 18 Jun 2026 10:06:49 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=9B=BF=E6=8D=A2=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E5=A4=B4=E8=BD=AC=E5=8F=91=E6=96=B9=E5=BC=8F=E5=B9=B6=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E7=A9=BA=E6=8C=87=E9=92=88=E6=8B=A6=E6=88=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 移除 util.ForwardHeaders,改为从 GoFrame 请求上下文直接提取原始请求头;新增 GetFileBytesFromURL 方法替代 DownloadFile 下载 OSS 文件;增加 composeTask 空指针校验防止异常;调整数据库连接池参数。 --- config.yml | 4 +- service/gateway/gateway_http_service.go | 76 ++++++++++++++++++++++-- service/prompt/prompt_compose_service.go | 12 ++-- 3 files changed, 80 insertions(+), 12 deletions(-) diff --git a/config.yml b/config.yml index a00b2ed..eb9a702 100644 --- a/config.yml +++ b/config.yml @@ -39,8 +39,8 @@ database: dryRun: false charset: "utf8" timezone: "Asia/Shanghai" - maxIdle: 5 - maxOpen: 20 + maxIdle: 15 + maxOpen: 60 maxLifetime: "30s" maxIdleConnTime: "30s" createdAt: "created_at" diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index f2641e9..29d2c2b 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -6,12 +6,12 @@ import ( "fmt" "io" "net/http" - "prompts-core/common/util" "prompts-core/model/entity" "strings" "gitea.redpowerfuture.com/red-future/common/beans" commonHttp "gitea.redpowerfuture.com/red-future/common/http" + "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" ) @@ -29,7 +29,15 @@ type CreateTaskReq struct { // CreateGatewayTask 创建网关异步任务 func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) { fullURL := "model-gateway/task/createTask" - headers := util.ForwardHeaders(ctx) + //headers := util.ForwardHeaders(ctx) + headers := make(map[string]string) + if r := g.RequestFromCtx(ctx); r != nil { + for k, v := range r.Request.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + } var req CreateTaskReq body, err := json.Marshal(payload) if err != nil { @@ -94,7 +102,17 @@ func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel, if len(params) > 0 { fullURL += "?" + strings.Join(params, "&") } - headers := util.ForwardHeaders(ctx) + //headers := util.ForwardHeaders(ctx) + + headers := make(map[string]string) + if r := g.RequestFromCtx(ctx); r != nil { + for k, v := range r.Request.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + } + var resp GetModelConfigResp if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil { return nil, fmt.Errorf("获取模型配置失败: %w", err) @@ -114,7 +132,15 @@ type GetTaskResultRes struct { // QueryGatewayTaskState 查询网关任务状态 func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) { fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID) - headers := util.ForwardHeaders(ctx) + //headers := util.ForwardHeaders(ctx) + headers := make(map[string]string) + if r := g.RequestFromCtx(ctx); r != nil { + for k, v := range r.Request.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + } var req GetTaskResultRes if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil { return 0, err @@ -137,7 +163,15 @@ type SkillUserVO struct { // GetSkillUser 获取技能用户信息 func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) { fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name) - headers := util.ForwardHeaders(ctx) + //headers := util.ForwardHeaders(ctx) + headers := make(map[string]string) + if r := g.RequestFromCtx(ctx); r != nil { + for k, v := range r.Request.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + } var resp SkillUserVO var req struct{} if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil { @@ -170,7 +204,15 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle EpicycleId: epicycleId, } // 3. 发送 POST 请求 - headers := util.ForwardHeaders(ctx) + //headers := util.ForwardHeaders(ctx) + headers := make(map[string]string) + if r := g.RequestFromCtx(ctx); r != nil { + for k, v := range r.Request.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + } var resp struct{} g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s", composeTask.TaskId, composeTask.CallbackUrl) @@ -195,3 +237,25 @@ func DownloadFile(ossURL string) ([]byte, error) { return io.ReadAll(resp.Body) } + +func GetFileBytesFromURL(ctx context.Context, fileUrl string) ([]byte, error) { + // 使用 GoFrame 客户端(自带超时、追踪、日志等能力) + resp, err := g.Client().Get(ctx, fileUrl) + if err != nil { + return nil, fmt.Errorf("下载OSS文件失败: %w", err) + } + defer resp.Close() + + // 校验状态码 + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode) + } + + // 读取全部内容 + allBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, gerror.Wrapf(err, "failed to read response body, url: %s", fileUrl) + } + + return allBytes, nil +} diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 42eaf1b..f2028f9 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -4,14 +4,13 @@ import ( "context" "errors" "fmt" - "prompts-core/service/session" - "prompts-core/common/util" "prompts-core/consts/public" "prompts-core/dao" "prompts-core/model/dto" "prompts-core/model/entity" "prompts-core/service/gateway" + "prompts-core/service/session" "gitea.redpowerfuture.com/red-future/common/beans" "gitea.redpowerfuture.com/red-future/common/utils" @@ -137,12 +136,17 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error { if err != nil { return fmt.Errorf("查询任务失败: %w", err) } - + // 新增空指针拦截 + if composeTask == nil { + g.Log().Infof(ctx, "[回调处理] composeTask 模型配置为空,无法查询模型配置 taskId=%s,req=%v", req.TaskId, req.State) + return fmt.Errorf("composeTask 任务对象为空,无法查询模型配置 taskId=%s,req=%v", req.TaskId, req.State) + } // 2) 读取 OSS 文件内容 var ossContent []byte if req.OssFile != "" { - ossContent, err = gateway.DownloadFile(req.OssFile) + ossContent, err = gateway.GetFileBytesFromURL(ctx, req.OssFile) if err != nil { + g.Log().Infof(ctx, "[回调处理] 读取OSS文件 taskId=%s,state=%v,ossFile=%v", req.TaskId, req.State, req.OssFile) g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err) } }