refactor(task): 重构异步任务处理流程

This commit is contained in:
2026-05-27 09:36:26 +08:00
parent 2548ffc7ac
commit d74559ae74
10 changed files with 162 additions and 212 deletions

View File

@@ -25,6 +25,7 @@ type UserPromptPayload struct {
Consult []dto.ConsultItem `json:"consult"`
UserFilesText map[string]string `json:"userFilesText"`
Skills string `json:"skills"`
BuildType int `json:"buildType"`
}
// buildInferenceRequest 构建推理请求
@@ -33,9 +34,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
ir := NewPromptIR()
switch req.BuildType {
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
@@ -65,11 +64,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
// 记录历史会话
_, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: ir.User,
})
return compileToProviderRequest(ctx, ir, chatModel)
}
@@ -168,6 +162,7 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
Consult: req.Consult,
UserFilesText: ExtractFileTexts(ctx, req.Consult),
Skills: SkillMdContent(ctx, req.SkillName),
BuildType: req.BuildType,
}
return gjson.New(payload).String()
}

View File

@@ -5,10 +5,10 @@ import (
"encoding/json"
"errors"
"fmt"
"prompts-core/service/session"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
@@ -44,17 +44,27 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
chatModel, err := getChatModel(ctx, userInfo.UserName)
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: new(1),
})
if err != nil {
return nil, nil, err
}
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil {
return nil, nil, err
}
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
if err != nil {
return nil, nil, err
if aiModel == nil {
return nil, nil, errors.New("需要构建的模型不存在")
}
return chatModel, aiModel, nil
}
@@ -73,51 +83,24 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow)
}
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := GetHistoryMessages(ctx, req.SessionId)
history, err := session.GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil
}
// 调用推理模型
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
if err = saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
if err := saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// saveComposeTask 保存组合任务记录
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
@@ -126,77 +109,70 @@ func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessage
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
return err
if err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// getChatModel 获取聊天模型
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
IsChatModel: new(1),
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
if err != nil {
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
if chatModel == nil {
return nil, errors.New("当前没有对话模型,请添加")
}
return chatModel, nil
}
// getAIModel 获取AI模型
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
ModelName: modelName,
})
if err != nil {
return nil, fmt.Errorf("查询AI模型失败: %w", err)
}
if aiModel == nil {
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
}
return aiModel, nil
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history)
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err)
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
}
id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
}
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err)
return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", errors.New("网关未返回taskId")
return "", 0, errors.New("网关未返回taskId")
}
return taskID, nil
}
// createDefaultResult 创建默认结果
func createDefaultResult(data map[string]any) map[string]any {
if data == nil {
data = make(map[string]any)
}
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{data},
}
return taskID, id, nil
}
// Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Messages))
// 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
@@ -220,7 +196,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
ResultText: req.Messages,
})
// 用更新后的值发送回调
if composeTask.CallbackUrl != "" {
@@ -241,11 +217,11 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
var messages map[string]any
switch composeTask.BuildType {
case public.BuildTypePrompt: // 提示词构建解析
messages = ParsePromptResult(req.Text)
messages = ParsePromptResult(req.Messages)
case public.BuildTypeNode: // 节点构建解析
messages = ParseNodeResult(req.Text)
messages = ParseNodeResult(req.Messages)
default:
messages = gjson.New(req.Text).Map()
messages = req.Messages
}
// 2. 处理附加字段
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
@@ -257,7 +233,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
ResultText: req.Messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
@@ -278,18 +254,12 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
}
// ParsePromptResult 解析提示词构建结果
func ParsePromptResult(raw string) map[string]any {
var wrapper map[string]any
if err := json.Unmarshal([]byte(raw), &wrapper); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
contentStr, ok := wrapper["content"].(string)
func ParsePromptResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
if !ok || contentStr == "" {
return createDefaultResult(wrapper)
return raw
}
// 先尝试解析为数组
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return map[string]any{
"total_rounds": len(roundsArray),
@@ -297,7 +267,6 @@ func ParsePromptResult(raw string) map[string]any {
}
}
// 再尝试解析为单个对象
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return map[string]any{
"total_rounds": 1,
@@ -305,7 +274,7 @@ func ParsePromptResult(raw string) map[string]any {
}
}
return createDefaultResult(map[string]any{"content": contentStr})
return map[string]any{"content": contentStr}
}
func tryParseAsMapArray(jsonStr string) []map[string]any {
@@ -330,22 +299,20 @@ func tryParseAsMap(jsonStr string) map[string]any {
return obj
}
// ParseNodeResult 解析节点构建结果
func ParseNodeResult(raw string) map[string]any {
var result map[string]any
if err := json.Unmarshal([]byte(raw), &result); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
if contentStr, ok := result["content"].(string); ok && contentStr != "" {
func ParseNodeResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
if ok && contentStr != "" {
var inner map[string]any
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
result = inner
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{inner},
}
}
}
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{result},
"rounds": []map[string]any{raw},
}
}

View File

@@ -1,9 +1,10 @@
package prompt
package session
import (
"context"
"encoding/json"
"fmt"
"prompts-core/model/entity"
"time"
"github.com/gogf/gf/v2/frame/g"
@@ -13,37 +14,33 @@ const (
redisKeyPrefix = "chat:session:%s"
)
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
key := formatRedisKey(sessionId)
// formatRedisKey 格式化Redis
func formatRedisKey(sessionId string) string {
return fmt.Sprintf(redisKeyPrefix, sessionId)
}
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, session *entity.ComposeSession) error {
key := formatRedisKey(session.SessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
data := map[string]any{
"sessionId": sessionId,
"requestContent": requestMessages,
"responseContent": responseMessages,
"sessionId": session.SessionId,
"requestContent": session.RequestContent,
"responseContent": session.ResponseContent,
"timestamp": time.Now().Unix(),
}
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
return err
}
return nil
}
// formatRedisKey 格式化Redis键
func formatRedisKey(sessionId string) string {
return fmt.Sprintf(redisKeyPrefix, sessionId)
}
// executeRedisCommands 执行Redis命令
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {

View File

@@ -1,4 +1,4 @@
package prompt
package session
import (
"context"
@@ -14,74 +14,36 @@ import (
"prompts-core/model/entity"
)
// SessionCallback 会话回调
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
result, err := util.ParseOutput(req.Text)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
result["role"] = "assistant"
if err = updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
return nil, err
}
session, err := getSessionById(ctx, req.EpicycleId)
if err != nil {
return nil, err
}
if err := saveSessionToRedis(ctx, session); err != nil {
return nil, err
}
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &dto.SessionCallbackRes{}, nil
}
// updateSessionResponse 更新会话响应
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
// Callback 会话回调
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
req.Messages["role"] = "assistant"
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
ResponseContent: response,
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: req.Messages,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err)
return fmt.Errorf("更新数据库失败: %w", err)
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("更新数据库失败: %w", err)
}
return nil
}
// getSessionById 根据ID获取会话
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
})
if session == nil {
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
}
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err)
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("获取会话数据失败: %w", err)
}
return session, nil
}
// saveSessionToRedis 保存会话到Redis
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error {
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return fmt.Errorf("Redis存储失败: %w", err)
if err = saveToRedis(ctx, session); err != nil {
return nil, fmt.Errorf("redis存储失败: %w", err)
}
return nil
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent))
return &dto.SessionCallbackRes{
Status: true,
SessionId: session.SessionId,
}, nil
}
// GetHistoryMessages 获取历史信息
@@ -159,7 +121,7 @@ func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession
}
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
_ = saveToRedis(ctx, session)
}
}
}