diff --git a/model/dto/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go index bb0a2b8..731fe24 100644 --- a/model/dto/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -26,14 +26,12 @@ type ComposeMessagesRes struct { } type CallbackReq struct { - g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"` - TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"` - State int `json:"state" dc:"网关任务状态"` - OssFile string `json:"oss_file" dc:"结果文件地址"` - FileType string `json:"file_type" dc:"结果文件类型"` - Messages map[string]any `json:"messages" dc:"消息数组"` - ErrorMsg string `json:"error_msg" dc:"错误信息"` - EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` + g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"` + TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"` + State int `json:"state" dc:"网关任务状态"` + OssFile string `json:"oss_file" dc:"结果文件地址"` + FileType string `json:"file_type" dc:"结果文件类型"` + ErrorMsg string `json:"error_msg" dc:"错误信息"` } type CallbackRes struct { diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index 9e07855..6f19e2f 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -4,13 +4,14 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" "prompts-core/common/util" "prompts-core/model/entity" "strings" "gitea.redpowerfuture.com/red-future/common/beans" commonHttp "gitea.redpowerfuture.com/red-future/common/http" - "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" ) @@ -147,11 +148,10 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) { // SendCallbackReq 发送回调的请求体 type SendCallbackReq struct { - TaskId string `json:"taskId"` - Status string `json:"status"` - Messages map[string]any `json:"messages,omitempty"` - EpicycleId int64 `json:"epicycleId"` - ErrorMsg string `json:"errorMsg,omitempty"` + TaskId string `json:"taskId"` + Status string `json:"status"` + EpicycleId int64 `json:"epicycleId"` + ErrorMsg string `json:"errorMsg,omitempty"` } // SendCallback 向业务方发送回调 @@ -164,18 +164,32 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle req := SendCallbackReq{ TaskId: composeTask.TaskId, Status: composeTask.Status, - Messages: composeTask.ResultJson, ErrorMsg: composeTask.ErrorMessage, EpicycleId: epicycleId, } // 3. 发送 POST 请求 headers := util.ForwardHeaders(ctx) var resp struct{} - g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v", - composeTask.TaskId, composeTask.CallbackUrl, gjson.New(req.Messages).String()) + g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s", + composeTask.TaskId, composeTask.CallbackUrl) if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil { return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err) } g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl) return nil } + +// DownloadFile 从 OSS 下载文件内容 +func DownloadFile(ossURL string) ([]byte, error) { + resp, err := http.Get(ossURL) + if err != nil { + return nil, fmt.Errorf("下载OSS文件失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 3bd8d5e..b23310c 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -15,6 +15,7 @@ import ( "gitea.redpowerfuture.com/red-future/common/beans" "gitea.redpowerfuture.com/red-future/common/utils" + "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) @@ -128,24 +129,43 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai // Callback 回调处理 func Callback(ctx context.Context, req *dto.CallbackReq) error { g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State) + // 1) 查询任务 composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId}) if err != nil { return fmt.Errorf("查询任务失败: %w", err) } - // 2) 处理失败 + + // 2) 读取 OSS 文件内容 + var ossContent []byte + if req.OssFile != "" { + ossContent, err = gateway.DownloadFile(req.OssFile) + if err != nil { + g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err) + } + } + + // 3) 解析 OSS 内容为消息 + var messages map[string]any + if len(ossContent) > 0 { + messages, _ = gjson.New(ossContent).Map(), nil + } + + // 4) 处理失败 if req.State == 3 { - return handleCallbackFailed(ctx, req, composeTask) + return handleCallbackFailed(ctx, req, composeTask, messages) } - // 3) 处理成功 + + // 5) 处理成功 if req.State == 2 { - return handleCallbackSuccess(ctx, req, composeTask) + return handleCallbackSuccess(ctx, req, composeTask, messages) } + return nil } // handleCallbackFailed 处理回调失败 -func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error { +func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error { _, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{ TaskId: req.TaskId, Status: public.ComposeStatusFailed, @@ -153,7 +173,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask GatewayState: req.State, OssFile: req.OssFile, FileType: req.FileType, - ResultJson: req.Messages, + ResultJson: messages, }) if composeTask.CallbackUrl != "" { composeTask.Status = public.ComposeStatusFailed @@ -164,7 +184,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask } // handleCallbackSuccess 处理回调成功 -func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error { +func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error { // 1) 获取模型配置 model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator}, @@ -198,7 +218,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas } // 3.2 保存当前轮(先存,下次查询就能拿到) - if userMsg := util.ExtractUserText(req.Messages); userMsg != nil { + if userMsg := util.ExtractUserText(messages); userMsg != nil { epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ NodeId: nodeId, SessionId: sessionId, @@ -208,7 +228,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas } // 4) 合并附加结构 - messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping) + messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping) // 5) 注入历史 if len(history) > 0 { messages = InjectHistory(messages, history, protocol)