refactor(task): 重构任务服务和数据结构

This commit is contained in:
2026-06-12 15:29:06 +08:00
parent c22d578e1a
commit 0d52b631b9
3 changed files with 58 additions and 26 deletions

View File

@@ -26,14 +26,12 @@ type ComposeMessagesRes struct {
} }
type CallbackReq struct { type CallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"` g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"` TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
State int `json:"state" dc:"网关任务状态"` State int `json:"state" dc:"网关任务状态"`
OssFile string `json:"oss_file" dc:"结果文件地址"` OssFile string `json:"oss_file" dc:"结果文件地址"`
FileType string `json:"file_type" dc:"结果文件类型"` FileType string `json:"file_type" dc:"结果文件类型"`
Messages map[string]any `json:"messages" dc:"消息数组"` ErrorMsg string `json:"error_msg" dc:"错误信息"`
ErrorMsg string `json:"error_msg" dc:"错误信息"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
} }
type CallbackRes struct { type CallbackRes struct {

View File

@@ -4,13 +4,14 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"prompts-core/common/util" "prompts-core/common/util"
"prompts-core/model/entity" "prompts-core/model/entity"
"strings" "strings"
"gitea.redpowerfuture.com/red-future/common/beans" "gitea.redpowerfuture.com/red-future/common/beans"
commonHttp "gitea.redpowerfuture.com/red-future/common/http" commonHttp "gitea.redpowerfuture.com/red-future/common/http"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/os/gtime"
) )
@@ -147,11 +148,10 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
// SendCallbackReq 发送回调的请求体 // SendCallbackReq 发送回调的请求体
type SendCallbackReq struct { type SendCallbackReq struct {
TaskId string `json:"taskId"` TaskId string `json:"taskId"`
Status string `json:"status"` Status string `json:"status"`
Messages map[string]any `json:"messages,omitempty"` EpicycleId int64 `json:"epicycleId"`
EpicycleId int64 `json:"epicycleId"` ErrorMsg string `json:"errorMsg,omitempty"`
ErrorMsg string `json:"errorMsg,omitempty"`
} }
// SendCallback 向业务方发送回调 // SendCallback 向业务方发送回调
@@ -164,18 +164,32 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
req := SendCallbackReq{ req := SendCallbackReq{
TaskId: composeTask.TaskId, TaskId: composeTask.TaskId,
Status: composeTask.Status, Status: composeTask.Status,
Messages: composeTask.ResultJson,
ErrorMsg: composeTask.ErrorMessage, ErrorMsg: composeTask.ErrorMessage,
EpicycleId: epicycleId, EpicycleId: epicycleId,
} }
// 3. 发送 POST 请求 // 3. 发送 POST 请求
headers := util.ForwardHeaders(ctx) headers := util.ForwardHeaders(ctx)
var resp struct{} var resp struct{}
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v", g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
composeTask.TaskId, composeTask.CallbackUrl, gjson.New(req.Messages).String()) composeTask.TaskId, composeTask.CallbackUrl)
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil { if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err) return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
} }
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl) g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
return nil return nil
} }
// DownloadFile 从 OSS 下载文件内容
func DownloadFile(ossURL string) ([]byte, error) {
resp, err := http.Get(ossURL)
if err != nil {
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}

View File

@@ -15,6 +15,7 @@ import (
"gitea.redpowerfuture.com/red-future/common/beans" "gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/utils" "gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
@@ -128,24 +129,43 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
// Callback 回调处理 // Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error { func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State) g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
// 1) 查询任务 // 1) 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId}) composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
if err != nil { if err != nil {
return fmt.Errorf("查询任务失败: %w", err) return fmt.Errorf("查询任务失败: %w", err)
} }
// 2) 处理失败
// 2) 读取 OSS 文件内容
var ossContent []byte
if req.OssFile != "" {
ossContent, err = gateway.DownloadFile(req.OssFile)
if err != nil {
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
}
}
// 3) 解析 OSS 内容为消息
var messages map[string]any
if len(ossContent) > 0 {
messages, _ = gjson.New(ossContent).Map(), nil
}
// 4) 处理失败
if req.State == 3 { if req.State == 3 {
return handleCallbackFailed(ctx, req, composeTask) return handleCallbackFailed(ctx, req, composeTask, messages)
} }
// 3) 处理成功
// 5) 处理成功
if req.State == 2 { if req.State == 2 {
return handleCallbackSuccess(ctx, req, composeTask) return handleCallbackSuccess(ctx, req, composeTask, messages)
} }
return nil return nil
} }
// handleCallbackFailed 处理回调失败 // handleCallbackFailed 处理回调失败
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error { func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{ _, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId, TaskId: req.TaskId,
Status: public.ComposeStatusFailed, Status: public.ComposeStatusFailed,
@@ -153,7 +173,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
GatewayState: req.State, GatewayState: req.State,
OssFile: req.OssFile, OssFile: req.OssFile,
FileType: req.FileType, FileType: req.FileType,
ResultJson: req.Messages, ResultJson: messages,
}) })
if composeTask.CallbackUrl != "" { if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusFailed composeTask.Status = public.ComposeStatusFailed
@@ -164,7 +184,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
} }
// handleCallbackSuccess 处理回调成功 // handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error { func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
// 1) 获取模型配置 // 1) 获取模型配置
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator}, SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
@@ -198,7 +218,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
} }
// 3.2 保存当前轮(先存,下次查询就能拿到) // 3.2 保存当前轮(先存,下次查询就能拿到)
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil { if userMsg := util.ExtractUserText(messages); userMsg != nil {
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId, NodeId: nodeId,
SessionId: sessionId, SessionId: sessionId,
@@ -208,7 +228,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
} }
// 4) 合并附加结构 // 4) 合并附加结构
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping) messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
// 5) 注入历史 // 5) 注入历史
if len(history) > 0 { if len(history) > 0 {
messages = InjectHistory(messages, history, protocol) messages = InjectHistory(messages, history, protocol)