fix: 替换请求头转发方式并修复空指针拦截

移除 util.ForwardHeaders,改为从 GoFrame 请求上下文直接提取原始请求头;新增 GetFileBytesFromURL 方法替代 DownloadFile 下载 OSS 文件;增加 composeTask 空指针校验防止异常;调整数据库连接池参数。
This commit is contained in:
2026-06-18 10:06:49 +08:00
parent eb28c2d1e0
commit dd79643170
3 changed files with 80 additions and 12 deletions

View File

@@ -39,8 +39,8 @@ database:
dryRun: false
charset: "utf8"
timezone: "Asia/Shanghai"
maxIdle: 5
maxOpen: 20
maxIdle: 15
maxOpen: 60
maxLifetime: "30s"
maxIdleConnTime: "30s"
createdAt: "created_at"

View File

@@ -6,12 +6,12 @@ import (
"fmt"
"io"
"net/http"
"prompts-core/common/util"
"prompts-core/model/entity"
"strings"
"gitea.redpowerfuture.com/red-future/common/beans"
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
)
@@ -29,7 +29,15 @@ type CreateTaskReq struct {
// CreateGatewayTask 创建网关异步任务
func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
fullURL := "model-gateway/task/createTask"
headers := util.ForwardHeaders(ctx)
//headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var req CreateTaskReq
body, err := json.Marshal(payload)
if err != nil {
@@ -94,7 +102,17 @@ func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel,
if len(params) > 0 {
fullURL += "?" + strings.Join(params, "&")
}
headers := util.ForwardHeaders(ctx)
//headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var resp GetModelConfigResp
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
return nil, fmt.Errorf("获取模型配置失败: %w", err)
@@ -114,7 +132,15 @@ type GetTaskResultRes struct {
// QueryGatewayTaskState 查询网关任务状态
func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
headers := util.ForwardHeaders(ctx)
//headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var req GetTaskResultRes
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
return 0, err
@@ -137,7 +163,15 @@ type SkillUserVO struct {
// GetSkillUser 获取技能用户信息
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
headers := util.ForwardHeaders(ctx)
//headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var resp SkillUserVO
var req struct{}
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {
@@ -170,7 +204,15 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
EpicycleId: epicycleId,
}
// 3. 发送 POST 请求
headers := util.ForwardHeaders(ctx)
//headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var resp struct{}
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
composeTask.TaskId, composeTask.CallbackUrl)
@@ -195,3 +237,25 @@ func DownloadFile(ossURL string) ([]byte, error) {
return io.ReadAll(resp.Body)
}
func GetFileBytesFromURL(ctx context.Context, fileUrl string) ([]byte, error) {
// 使用 GoFrame 客户端(自带超时、追踪、日志等能力)
resp, err := g.Client().Get(ctx, fileUrl)
if err != nil {
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
}
defer resp.Close()
// 校验状态码
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
}
// 读取全部内容
allBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, gerror.Wrapf(err, "failed to read response body, url: %s", fileUrl)
}
return allBytes, nil
}

View File

@@ -4,14 +4,13 @@ import (
"context"
"errors"
"fmt"
"prompts-core/service/session"
"prompts-core/common/util"
"prompts-core/consts/public"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
"prompts-core/service/session"
"gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/utils"
@@ -137,12 +136,17 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
// 新增空指针拦截
if composeTask == nil {
g.Log().Infof(ctx, "[回调处理] composeTask 模型配置为空,无法查询模型配置 taskId=%s,req=%v", req.TaskId, req.State)
return fmt.Errorf("composeTask 任务对象为空,无法查询模型配置 taskId=%s,req=%v", req.TaskId, req.State)
}
// 2) 读取 OSS 文件内容
var ossContent []byte
if req.OssFile != "" {
ossContent, err = gateway.DownloadFile(req.OssFile)
ossContent, err = gateway.GetFileBytesFromURL(ctx, req.OssFile)
if err != nil {
g.Log().Infof(ctx, "[回调处理] 读取OSS文件 taskId=%s,state=%v,ossFile=%v", req.TaskId, req.State, req.OssFile)
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
}
}