fix: 替换请求头转发方式并修复空指针拦截
移除 util.ForwardHeaders,改为从 GoFrame 请求上下文直接提取原始请求头;新增 GetFileBytesFromURL 方法替代 DownloadFile 下载 OSS 文件;增加 composeTask 空指针校验防止异常;调整数据库连接池参数。
This commit is contained in:
@@ -39,8 +39,8 @@ database:
|
|||||||
dryRun: false
|
dryRun: false
|
||||||
charset: "utf8"
|
charset: "utf8"
|
||||||
timezone: "Asia/Shanghai"
|
timezone: "Asia/Shanghai"
|
||||||
maxIdle: 5
|
maxIdle: 15
|
||||||
maxOpen: 20
|
maxOpen: 60
|
||||||
maxLifetime: "30s"
|
maxLifetime: "30s"
|
||||||
maxIdleConnTime: "30s"
|
maxIdleConnTime: "30s"
|
||||||
createdAt: "created_at"
|
createdAt: "created_at"
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"prompts-core/common/util"
|
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/os/gtime"
|
"github.com/gogf/gf/v2/os/gtime"
|
||||||
)
|
)
|
||||||
@@ -29,7 +29,15 @@ type CreateTaskReq struct {
|
|||||||
// CreateGatewayTask 创建网关异步任务
|
// CreateGatewayTask 创建网关异步任务
|
||||||
func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
|
func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
|
||||||
fullURL := "model-gateway/task/createTask"
|
fullURL := "model-gateway/task/createTask"
|
||||||
headers := util.ForwardHeaders(ctx)
|
//headers := util.ForwardHeaders(ctx)
|
||||||
|
headers := make(map[string]string)
|
||||||
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
|
for k, v := range r.Request.Header {
|
||||||
|
if len(v) > 0 {
|
||||||
|
headers[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
var req CreateTaskReq
|
var req CreateTaskReq
|
||||||
body, err := json.Marshal(payload)
|
body, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -94,7 +102,17 @@ func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel,
|
|||||||
if len(params) > 0 {
|
if len(params) > 0 {
|
||||||
fullURL += "?" + strings.Join(params, "&")
|
fullURL += "?" + strings.Join(params, "&")
|
||||||
}
|
}
|
||||||
headers := util.ForwardHeaders(ctx)
|
//headers := util.ForwardHeaders(ctx)
|
||||||
|
|
||||||
|
headers := make(map[string]string)
|
||||||
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
|
for k, v := range r.Request.Header {
|
||||||
|
if len(v) > 0 {
|
||||||
|
headers[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var resp GetModelConfigResp
|
var resp GetModelConfigResp
|
||||||
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
|
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
|
||||||
return nil, fmt.Errorf("获取模型配置失败: %w", err)
|
return nil, fmt.Errorf("获取模型配置失败: %w", err)
|
||||||
@@ -114,7 +132,15 @@ type GetTaskResultRes struct {
|
|||||||
// QueryGatewayTaskState 查询网关任务状态
|
// QueryGatewayTaskState 查询网关任务状态
|
||||||
func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
|
func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
|
||||||
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
|
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
|
||||||
headers := util.ForwardHeaders(ctx)
|
//headers := util.ForwardHeaders(ctx)
|
||||||
|
headers := make(map[string]string)
|
||||||
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
|
for k, v := range r.Request.Header {
|
||||||
|
if len(v) > 0 {
|
||||||
|
headers[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
var req GetTaskResultRes
|
var req GetTaskResultRes
|
||||||
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
|
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -137,7 +163,15 @@ type SkillUserVO struct {
|
|||||||
// GetSkillUser 获取技能用户信息
|
// GetSkillUser 获取技能用户信息
|
||||||
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
|
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
|
||||||
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
|
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
|
||||||
headers := util.ForwardHeaders(ctx)
|
//headers := util.ForwardHeaders(ctx)
|
||||||
|
headers := make(map[string]string)
|
||||||
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
|
for k, v := range r.Request.Header {
|
||||||
|
if len(v) > 0 {
|
||||||
|
headers[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
var resp SkillUserVO
|
var resp SkillUserVO
|
||||||
var req struct{}
|
var req struct{}
|
||||||
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {
|
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {
|
||||||
@@ -170,7 +204,15 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
|
|||||||
EpicycleId: epicycleId,
|
EpicycleId: epicycleId,
|
||||||
}
|
}
|
||||||
// 3. 发送 POST 请求
|
// 3. 发送 POST 请求
|
||||||
headers := util.ForwardHeaders(ctx)
|
//headers := util.ForwardHeaders(ctx)
|
||||||
|
headers := make(map[string]string)
|
||||||
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
|
for k, v := range r.Request.Header {
|
||||||
|
if len(v) > 0 {
|
||||||
|
headers[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
var resp struct{}
|
var resp struct{}
|
||||||
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
|
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
|
||||||
composeTask.TaskId, composeTask.CallbackUrl)
|
composeTask.TaskId, composeTask.CallbackUrl)
|
||||||
@@ -195,3 +237,25 @@ func DownloadFile(ossURL string) ([]byte, error) {
|
|||||||
|
|
||||||
return io.ReadAll(resp.Body)
|
return io.ReadAll(resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetFileBytesFromURL(ctx context.Context, fileUrl string) ([]byte, error) {
|
||||||
|
// 使用 GoFrame 客户端(自带超时、追踪、日志等能力)
|
||||||
|
resp, err := g.Client().Get(ctx, fileUrl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
// 校验状态码
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取全部内容
|
||||||
|
allBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, gerror.Wrapf(err, "failed to read response body, url: %s", fileUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
return allBytes, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,14 +4,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"prompts-core/service/session"
|
|
||||||
|
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
"prompts-core/consts/public"
|
"prompts-core/consts/public"
|
||||||
"prompts-core/dao"
|
"prompts-core/dao"
|
||||||
"prompts-core/model/dto"
|
"prompts-core/model/dto"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
"prompts-core/service/gateway"
|
"prompts-core/service/gateway"
|
||||||
|
"prompts-core/service/session"
|
||||||
|
|
||||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||||
@@ -137,12 +136,17 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("查询任务失败: %w", err)
|
return fmt.Errorf("查询任务失败: %w", err)
|
||||||
}
|
}
|
||||||
|
// 新增空指针拦截
|
||||||
|
if composeTask == nil {
|
||||||
|
g.Log().Infof(ctx, "[回调处理] composeTask 模型配置为空,无法查询模型配置 taskId=%s,req=%v", req.TaskId, req.State)
|
||||||
|
return fmt.Errorf("composeTask 任务对象为空,无法查询模型配置 taskId=%s,req=%v", req.TaskId, req.State)
|
||||||
|
}
|
||||||
// 2) 读取 OSS 文件内容
|
// 2) 读取 OSS 文件内容
|
||||||
var ossContent []byte
|
var ossContent []byte
|
||||||
if req.OssFile != "" {
|
if req.OssFile != "" {
|
||||||
ossContent, err = gateway.DownloadFile(req.OssFile)
|
ossContent, err = gateway.GetFileBytesFromURL(ctx, req.OssFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
g.Log().Infof(ctx, "[回调处理] 读取OSS文件 taskId=%s,state=%v,ossFile=%v", req.TaskId, req.State, req.OssFile)
|
||||||
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
|
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user