refactor(task): 重构异步任务处理流程
This commit is contained in:
@@ -25,6 +25,7 @@ type UserPromptPayload struct {
|
||||
Consult []dto.ConsultItem `json:"consult"`
|
||||
UserFilesText map[string]string `json:"userFilesText"`
|
||||
Skills string `json:"skills"`
|
||||
BuildType int `json:"buildType"`
|
||||
}
|
||||
|
||||
// buildInferenceRequest 构建推理请求
|
||||
@@ -33,9 +34,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||
}
|
||||
|
||||
ir := NewPromptIR()
|
||||
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt:
|
||||
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
|
||||
@@ -65,11 +64,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
||||
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
||||
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||||
}
|
||||
// 记录历史会话
|
||||
_, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
RequestContent: ir.User,
|
||||
})
|
||||
return compileToProviderRequest(ctx, ir, chatModel)
|
||||
}
|
||||
|
||||
@@ -168,6 +162,7 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
|
||||
Consult: req.Consult,
|
||||
UserFilesText: ExtractFileTexts(ctx, req.Consult),
|
||||
Skills: SkillMdContent(ctx, req.SkillName),
|
||||
BuildType: req.BuildType,
|
||||
}
|
||||
return gjson.New(payload).String()
|
||||
}
|
||||
|
||||
@@ -5,10 +5,10 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/service/session"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
"prompts-core/common/util"
|
||||
@@ -44,17 +44,27 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
|
||||
chatModel, err := getChatModel(ctx, userInfo.UserName)
|
||||
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if chatModel == nil {
|
||||
return nil, nil, errors.New("当前没有对话模型,请添加")
|
||||
}
|
||||
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
if aiModel == nil {
|
||||
return nil, nil, errors.New("需要构建的模型不存在")
|
||||
}
|
||||
|
||||
return chatModel, aiModel, nil
|
||||
}
|
||||
|
||||
@@ -73,51 +83,24 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er
|
||||
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试",
|
||||
exceedTokens, availableWindow)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handlePromptBuild 处理提示词构建(BuildType=1)
|
||||
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
// 获取历史会话
|
||||
history, err := GetHistoryMessages(ctx, req.SessionId)
|
||||
history, err := session.GetHistoryMessages(ctx, req.SessionId)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err)
|
||||
history = nil
|
||||
}
|
||||
// 调用推理模型
|
||||
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
|
||||
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
||||
}
|
||||
// 保存任务记录
|
||||
if err = saveComposeTask(ctx, taskID, req); err != nil {
|
||||
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
||||
}
|
||||
return &dto.ComposeMessagesRes{
|
||||
TaskId: taskID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleNodeBuild 处理节点构建(BuildType=2)
|
||||
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
||||
}
|
||||
|
||||
if err := saveComposeTask(ctx, taskID, req); err != nil {
|
||||
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
||||
}
|
||||
|
||||
return &dto.ComposeMessagesRes{
|
||||
TaskId: taskID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// saveComposeTask 保存组合任务记录
|
||||
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
|
||||
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
@@ -126,77 +109,70 @@ func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessage
|
||||
RequestPayload: util.MustMarshalToMap(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
})
|
||||
return err
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
||||
}
|
||||
return &dto.ComposeMessagesRes{
|
||||
TaskId: taskID,
|
||||
EpicycleId: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getChatModel 获取聊天模型
|
||||
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
|
||||
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
|
||||
IsChatModel: new(1),
|
||||
// handleNodeBuild 处理节点构建(BuildType=2)
|
||||
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
||||
}
|
||||
// 保存任务记录
|
||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
BuildType: req.BuildType,
|
||||
CallbackUrl: req.CallbackUrl,
|
||||
RequestPayload: util.MustMarshalToMap(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
|
||||
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
||||
}
|
||||
|
||||
if chatModel == nil {
|
||||
return nil, errors.New("当前没有对话模型,请添加")
|
||||
}
|
||||
|
||||
return chatModel, nil
|
||||
}
|
||||
|
||||
// getAIModel 获取AI模型
|
||||
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
|
||||
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
|
||||
ModelName: modelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询AI模型失败: %w", err)
|
||||
}
|
||||
|
||||
if aiModel == nil {
|
||||
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
|
||||
}
|
||||
|
||||
return aiModel, nil
|
||||
return &dto.ComposeMessagesRes{
|
||||
TaskId: taskID,
|
||||
EpicycleId: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// callInferenceModel 调用推理模型
|
||||
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) {
|
||||
taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history)
|
||||
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) {
|
||||
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("构建推理请求失败: %w", err)
|
||||
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
|
||||
}
|
||||
id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
RequestContent: util.GetUserMessage(taskReq),
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
|
||||
}
|
||||
|
||||
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建网关任务失败: %w", err)
|
||||
return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
|
||||
}
|
||||
|
||||
if taskID == "" {
|
||||
return "", errors.New("网关未返回taskId")
|
||||
return "", 0, errors.New("网关未返回taskId")
|
||||
}
|
||||
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
// createDefaultResult 创建默认结果
|
||||
func createDefaultResult(data map[string]any) map[string]any {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{data},
|
||||
}
|
||||
return taskID, id, nil
|
||||
}
|
||||
|
||||
// Callback 回调处理
|
||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
|
||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Messages))
|
||||
// 查询任务
|
||||
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
@@ -220,7 +196,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultText: req.Text,
|
||||
ResultText: req.Messages,
|
||||
})
|
||||
// 用更新后的值发送回调
|
||||
if composeTask.CallbackUrl != "" {
|
||||
@@ -241,11 +217,11 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
var messages map[string]any
|
||||
switch composeTask.BuildType {
|
||||
case public.BuildTypePrompt: // 提示词构建解析
|
||||
messages = ParsePromptResult(req.Text)
|
||||
messages = ParsePromptResult(req.Messages)
|
||||
case public.BuildTypeNode: // 节点构建解析
|
||||
messages = ParseNodeResult(req.Text)
|
||||
messages = ParseNodeResult(req.Messages)
|
||||
default:
|
||||
messages = gjson.New(req.Text).Map()
|
||||
messages = req.Messages
|
||||
}
|
||||
// 2. 处理附加字段
|
||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||
@@ -257,7 +233,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultText: req.Text,
|
||||
ResultText: req.Messages,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
|
||||
@@ -278,18 +254,12 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
}
|
||||
|
||||
// ParsePromptResult 解析提示词构建结果
|
||||
func ParsePromptResult(raw string) map[string]any {
|
||||
var wrapper map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &wrapper); err != nil {
|
||||
return createDefaultResult(map[string]any{"raw": raw})
|
||||
}
|
||||
|
||||
contentStr, ok := wrapper["content"].(string)
|
||||
func ParsePromptResult(raw map[string]any) map[string]any {
|
||||
contentStr, ok := raw["content"].(string)
|
||||
if !ok || contentStr == "" {
|
||||
return createDefaultResult(wrapper)
|
||||
return raw
|
||||
}
|
||||
|
||||
// 先尝试解析为数组
|
||||
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
|
||||
return map[string]any{
|
||||
"total_rounds": len(roundsArray),
|
||||
@@ -297,7 +267,6 @@ func ParsePromptResult(raw string) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
// 再尝试解析为单个对象
|
||||
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
@@ -305,7 +274,7 @@ func ParsePromptResult(raw string) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
return createDefaultResult(map[string]any{"content": contentStr})
|
||||
return map[string]any{"content": contentStr}
|
||||
}
|
||||
|
||||
func tryParseAsMapArray(jsonStr string) []map[string]any {
|
||||
@@ -330,22 +299,20 @@ func tryParseAsMap(jsonStr string) map[string]any {
|
||||
return obj
|
||||
}
|
||||
|
||||
// ParseNodeResult 解析节点构建结果
|
||||
func ParseNodeResult(raw string) map[string]any {
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &result); err != nil {
|
||||
return createDefaultResult(map[string]any{"raw": raw})
|
||||
}
|
||||
|
||||
if contentStr, ok := result["content"].(string); ok && contentStr != "" {
|
||||
func ParseNodeResult(raw map[string]any) map[string]any {
|
||||
contentStr, ok := raw["content"].(string)
|
||||
if ok && contentStr != "" {
|
||||
var inner map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
|
||||
result = inner
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{inner},
|
||||
}
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{result},
|
||||
"rounds": []map[string]any{raw},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package prompt
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"prompts-core/model/entity"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
@@ -13,37 +14,33 @@ const (
|
||||
redisKeyPrefix = "chat:session:%s"
|
||||
)
|
||||
|
||||
// saveToRedis 保存会话数据到Redis
|
||||
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
|
||||
key := formatRedisKey(sessionId)
|
||||
// formatRedisKey 格式化Redis键
|
||||
func formatRedisKey(sessionId string) string {
|
||||
return fmt.Sprintf(redisKeyPrefix, sessionId)
|
||||
}
|
||||
|
||||
// saveToRedis 保存会话数据到Redis
|
||||
func saveToRedis(ctx context.Context, session *entity.ComposeSession) error {
|
||||
key := formatRedisKey(session.SessionId)
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
||||
|
||||
data := map[string]any{
|
||||
"sessionId": sessionId,
|
||||
"requestContent": requestMessages,
|
||||
"responseContent": responseMessages,
|
||||
"sessionId": session.SessionId,
|
||||
"requestContent": session.RequestContent,
|
||||
"responseContent": session.ResponseContent,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
||||
if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatRedisKey 格式化Redis键
|
||||
func formatRedisKey(sessionId string) string {
|
||||
return fmt.Sprintf(redisKeyPrefix, sessionId)
|
||||
}
|
||||
|
||||
// executeRedisCommands 执行Redis命令
|
||||
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
|
||||
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
|
||||
@@ -1,4 +1,4 @@
|
||||
package prompt
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,74 +14,36 @@ import (
|
||||
"prompts-core/model/entity"
|
||||
)
|
||||
|
||||
// SessionCallback 会话回调
|
||||
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||
result, err := util.ParseOutput(req.Text)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
||||
}
|
||||
|
||||
result["role"] = "assistant"
|
||||
if err = updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := getSessionById(ctx, req.EpicycleId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := saveSessionToRedis(ctx, session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestMessages := util.ConvertToMessages(session.RequestContent)
|
||||
responseMessages := util.ConvertToMessages(session.ResponseContent)
|
||||
|
||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
||||
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
|
||||
|
||||
return &dto.SessionCallbackRes{}, nil
|
||||
}
|
||||
|
||||
// updateSessionResponse 更新会话响应
|
||||
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
|
||||
// Callback 会话回调
|
||||
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||
req.Messages["role"] = "assistant"
|
||||
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
|
||||
ResponseContent: response,
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
ResponseContent: req.Messages,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err)
|
||||
return fmt.Errorf("更新数据库失败: %w", err)
|
||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, fmt.Errorf("更新数据库失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSessionById 根据ID获取会话
|
||||
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
|
||||
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
})
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
|
||||
}
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err)
|
||||
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, fmt.Errorf("获取会话数据失败: %w", err)
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// saveSessionToRedis 保存会话到Redis
|
||||
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error {
|
||||
requestMessages := util.ConvertToMessages(session.RequestContent)
|
||||
responseMessages := util.ConvertToMessages(session.ResponseContent)
|
||||
|
||||
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
|
||||
session.SessionId, session.Id, err)
|
||||
return fmt.Errorf("Redis存储失败: %w", err)
|
||||
if err = saveToRedis(ctx, session); err != nil {
|
||||
return nil, fmt.Errorf("redis存储失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
||||
session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent))
|
||||
return &dto.SessionCallbackRes{
|
||||
Status: true,
|
||||
SessionId: session.SessionId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetHistoryMessages 获取历史信息
|
||||
@@ -159,7 +121,7 @@ func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession
|
||||
}
|
||||
|
||||
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
||||
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
|
||||
_ = saveToRedis(ctx, session)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user