fix: 修复请求头转发与任务状态流转问题

移除 util.ForwardHeaders,改为从原始请求精确提取 Authorization 或全部请求头;
任务创建时直接设为 Running 状态,避免二次更新与查询;
模型调用使用独立超时上下文,防止外层取消影响回调;
增加 OSS 上传耗时日志,调整数据库连接池参数。
This commit is contained in:
2026-06-18 10:08:36 +08:00
parent fddaf36f48
commit b21d7a8dbf
4 changed files with 90 additions and 45 deletions

View File

@@ -39,8 +39,8 @@ database:
dryRun: false dryRun: false
charset: "utf8" charset: "utf8"
timezone: "Asia/Shanghai" timezone: "Asia/Shanghai"
maxIdle: 5 maxIdle: 15
maxOpen: 20 maxOpen: 60
maxLifetime: "30s" maxLifetime: "30s"
maxIdleConnTime: "30s" maxIdleConnTime: "30s"
createdAt: "created_at" createdAt: "created_at"

View File

@@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"mime/multipart" "mime/multipart"
"model-gateway/common/util"
"model-gateway/model/entity" "model-gateway/model/entity"
"time" "time"
@@ -43,16 +42,25 @@ func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *Upload
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, err := part.Write(data); err != nil { if _, err = part.Write(data); err != nil {
return nil, err return nil, err
} }
contentType := writer.FormDataContentType() //contentType := writer.FormDataContentType()
if err = writer.Close(); err != nil { if err = writer.Close(); err != nil {
return nil, err return nil, err
} }
headers := util.ForwardHeaders(ctx) //headers := util.ForwardHeaders(ctx)
headers["Content-Type"] = contentType //headers["Content-Type"] = contentType
headers := make(map[string]string)
headers["Content-Type"] = writer.FormDataContentType()
if r := g.RequestFromCtx(ctx); r != nil {
if auth := r.Header.Get("Authorization"); auth != "" {
headers["Authorization"] = auth
}
}
fullURL := "oss/file/uploadFile" fullURL := "oss/file/uploadFile"
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data)) g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
@@ -78,15 +86,25 @@ type CallbackPayload struct {
// TriggerCallback 任务的回调 // TriggerCallback 任务的回调
func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) { func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) {
headers := util.ForwardHeaders(ctx) //headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var resp struct{} var resp struct{}
payload := CallbackPayload{ payload := CallbackPayload{
TaskId: t.TaskID, TaskId: t.TaskID,
State: t.State, State: t.State,
OssFile: t.ResultFile.OssFile,
FileType: t.ResultFile.FileType,
ErrorMsg: t.ErrorMsg, ErrorMsg: t.ErrorMsg,
} }
if !g.IsEmpty(t.ResultFile) {
payload.OssFile = t.ResultFile.OssFile
payload.FileType = t.ResultFile.FileType
}
jsonData, err := json.Marshal(payload) jsonData, err := json.Marshal(payload)
if err != nil { if err != nil {
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err) g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
@@ -112,7 +130,15 @@ type PromptsCallbackPayload struct {
// TriggerPromptsCallback 任务成功后的提示词回调 // TriggerPromptsCallback 任务成功后的提示词回调
func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) { func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) {
callbackURL := "prompts-core/session/callback" callbackURL := "prompts-core/session/callback"
headers := util.ForwardHeaders(ctx) //headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var resp struct{} var resp struct{}
payload := PromptsCallbackPayload{ payload := PromptsCallbackPayload{
EpicycleId: epicycleId, EpicycleId: epicycleId,
@@ -136,7 +162,15 @@ func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epi
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员 // IsSuperAdmin 调用admin-go服务检查是否是超级管理员
func IsSuperAdmin(ctx context.Context) (res bool, err error) { func IsSuperAdmin(ctx context.Context) (res bool, err error) {
headers := util.ForwardHeaders(ctx) //headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var r = make(map[string]bool) var r = make(map[string]bool)
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil { if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
return false, err return false, err

View File

@@ -65,20 +65,20 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
Body: req.RequestPayload, Body: req.RequestPayload,
Headers: util.ParseHeadMsgHeaders(model.HeadMsg), Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
} }
id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{ task := new(entity.ModelGatewayTask)
ModelName: req.ModelName, task.ModelName = model.ModelName
TaskID: taskID, task.TaskID = taskID
State: public.TaskStatusPending, task.State = public.TaskStatusRunning
BizName: req.BizName, task.BizName = req.BizName
CallbackURL: req.CallbackUrl, task.CallbackURL = req.CallbackUrl
RequestPayload: &requestPayload, task.RequestPayload = &requestPayload
EpicycleId: req.EpicycleId, task.EpicycleId = req.EpicycleId
}) id, err := dao.ModelGatewayTask.Insert(ctx, task)
if err != nil { // 入库失败:回滚闸门占位 if err != nil { // 入库失败:回滚闸门占位
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID) queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err return nil, err
} }
task.Id = id
// 4) 写操作日志(不影响主流程,失败忽略) // 4) 写操作日志(不影响主流程,失败忽略)
ip := "" ip := ""
ua := "" ua := ""
@@ -107,25 +107,25 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
}, },
}) })
// 5) 抢占任务:改为执行中 //// 5) 抢占任务:改为执行中
rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ //rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id}, // SQLBaseDO: beans.SQLBaseDO{Id: id},
State: public.TaskStatusRunning, // State: public.TaskStatusRunning,
}) //})
if err != nil { //if err != nil {
return nil, err // return nil, err
} //}
if rows == 0 { //if rows == 0 {
return nil, fmt.Errorf("任务不存在: id=%d", id) // return nil, fmt.Errorf("任务不存在: id=%d", id)
} //}
// 6) 查询任务信息 // 6) 查询任务信息
task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{ //task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id}, // SQLBaseDO: beans.SQLBaseDO{Id: id},
}) //})
if err != nil { //if err != nil {
return nil, err // return nil, err
} //}
// 7) 创建成功后立即异步尝试执行当前任务 // 7) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)

View File

@@ -129,10 +129,14 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
} }
startUpload := time.Now()
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json") oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
if err == nil { if err == nil {
break break
} }
cost := time.Since(startUpload)
g.Log().Infof(ctx, "本次上传耗时:%s", cost)
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
task.State = public.TaskStatusFailed task.State = public.TaskStatusFailed
@@ -161,9 +165,9 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
} }
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), task) go gateway.TriggerCallback(util.AsyncCtx(ctx), task)
if req.EpicycleId != 0 { if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId) go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId)
} }
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s", g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s",
@@ -420,7 +424,7 @@ func injectErrorMessage(payload map[string]any, err error) map[string]any {
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key // modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) { func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
// 1) 记录模型调用次数 // 1) 记录模型调用次数
_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName) //_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
// 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 // 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射 //—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
@@ -447,13 +451,20 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[
baseURL = baseURL + "?" + q.Encode() baseURL = baseURL + "?" + q.Encode()
} }
} }
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) // 改用独立超时ctx隔绝外层截止
reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
defer reqCancel()
req, err = http.NewRequestWithContext(reqCtx, http.MethodGet, baseURL, nil)
//req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
default: default:
bodyBytes, err := json.Marshal(body) bodyBytes, err := json.Marshal(body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
defer reqCancel()
req, err = http.NewRequestWithContext(reqCtx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
//req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
} }
// 5注入请求头先模型静态配置再动态 modelKey后者可覆盖前者 // 5注入请求头先模型静态配置再动态 modelKey后者可覆盖前者
@@ -552,5 +563,5 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask,
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err) g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
} }
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t) go gateway.TriggerCallback(util.AsyncCtx(ctx), t)
} }