refactor(task): 重构任务服务和数据结构
This commit is contained in:
@@ -26,14 +26,12 @@ type ComposeMessagesRes struct {
|
||||
}
|
||||
|
||||
type CallbackReq struct {
|
||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||
State int `json:"state" dc:"网关任务状态"`
|
||||
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
||||
FileType string `json:"file_type" dc:"结果文件类型"`
|
||||
Messages map[string]any `json:"messages" dc:"消息数组"`
|
||||
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||
State int `json:"state" dc:"网关任务状态"`
|
||||
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
||||
FileType string `json:"file_type" dc:"结果文件类型"`
|
||||
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
||||
}
|
||||
|
||||
type CallbackRes struct {
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/model/entity"
|
||||
"strings"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
@@ -147,11 +148,10 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
|
||||
|
||||
// SendCallbackReq 发送回调的请求体
|
||||
type SendCallbackReq struct {
|
||||
TaskId string `json:"taskId"`
|
||||
Status string `json:"status"`
|
||||
Messages map[string]any `json:"messages,omitempty"`
|
||||
EpicycleId int64 `json:"epicycleId"`
|
||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||
TaskId string `json:"taskId"`
|
||||
Status string `json:"status"`
|
||||
EpicycleId int64 `json:"epicycleId"`
|
||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||
}
|
||||
|
||||
// SendCallback 向业务方发送回调
|
||||
@@ -164,18 +164,32 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
|
||||
req := SendCallbackReq{
|
||||
TaskId: composeTask.TaskId,
|
||||
Status: composeTask.Status,
|
||||
Messages: composeTask.ResultJson,
|
||||
ErrorMsg: composeTask.ErrorMessage,
|
||||
EpicycleId: epicycleId,
|
||||
}
|
||||
// 3. 发送 POST 请求
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var resp struct{}
|
||||
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v",
|
||||
composeTask.TaskId, composeTask.CallbackUrl, gjson.New(req.Messages).String())
|
||||
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
|
||||
composeTask.TaskId, composeTask.CallbackUrl)
|
||||
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
|
||||
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DownloadFile 从 OSS 下载文件内容
|
||||
func DownloadFile(ossURL string) ([]byte, error) {
|
||||
resp, err := http.Get(ossURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
@@ -128,24 +129,43 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
||||
// Callback 回调处理
|
||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
|
||||
|
||||
// 1) 查询任务
|
||||
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
// 2) 处理失败
|
||||
|
||||
// 2) 读取 OSS 文件内容
|
||||
var ossContent []byte
|
||||
if req.OssFile != "" {
|
||||
ossContent, err = gateway.DownloadFile(req.OssFile)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 解析 OSS 内容为消息
|
||||
var messages map[string]any
|
||||
if len(ossContent) > 0 {
|
||||
messages, _ = gjson.New(ossContent).Map(), nil
|
||||
}
|
||||
|
||||
// 4) 处理失败
|
||||
if req.State == 3 {
|
||||
return handleCallbackFailed(ctx, req, composeTask)
|
||||
return handleCallbackFailed(ctx, req, composeTask, messages)
|
||||
}
|
||||
// 3) 处理成功
|
||||
|
||||
// 5) 处理成功
|
||||
if req.State == 2 {
|
||||
return handleCallbackSuccess(ctx, req, composeTask)
|
||||
return handleCallbackSuccess(ctx, req, composeTask, messages)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleCallbackFailed 处理回调失败
|
||||
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
@@ -153,7 +173,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultJson: req.Messages,
|
||||
ResultJson: messages,
|
||||
})
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusFailed
|
||||
@@ -164,7 +184,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||
// 1) 获取模型配置
|
||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||
@@ -198,7 +218,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
}
|
||||
|
||||
// 3.2 保存当前轮(先存,下次查询就能拿到)
|
||||
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
||||
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
@@ -208,7 +228,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
}
|
||||
|
||||
// 4) 合并附加结构
|
||||
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||
// 5) 注入历史
|
||||
if len(history) > 0 {
|
||||
messages = InjectHistory(messages, history, protocol)
|
||||
|
||||
Reference in New Issue
Block a user