feat(session): 重构会话管理和消息存储功能

This commit is contained in:
2026-06-09 15:46:09 +08:00
parent 9410199fbe
commit 78114f99c7
8 changed files with 221 additions and 203 deletions

View File

@@ -164,7 +164,7 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
req := SendCallbackReq{
TaskId: composeTask.TaskId,
Status: composeTask.Status,
Messages: composeTask.Messages,
Messages: composeTask.ResultJson,
ErrorMsg: composeTask.ErrorMessage,
EpicycleId: epicycleId,
}

View File

@@ -154,7 +154,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
ResultJson: req.Messages,
})
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusFailed
@@ -181,11 +181,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
ResultJson: messages,
})
if err != nil {
return err
@@ -214,7 +213,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
// 6) 回调业务方
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess
composeTask.Messages = messages
composeTask.ResultJson = messages
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
}
return nil
@@ -232,7 +231,7 @@ func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes,
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
messages := parseMessagesForResponse(record.Messages)
messages := parseMessagesForResponse(record.ResultJson)
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
const (
@@ -51,33 +52,34 @@ func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *
return nil
}
// DeleteSingleMessage 删除 Redis 中条消息按消息ID
func DeleteSingleMessage(ctx context.Context, tenantID uint64, sessionID string, msgID int64) error {
// DeleteRedisMessages 批量删除 Redis 中条消息按消息ID列表
func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID string, msgIDs []int64) error {
key := formatRedisKey(tenantID, sessionID)
cursor := "0"
for {
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
if err != nil {
return fmt.Errorf("ZSCAN失败: %w", err)
}
parts := result.Strings()
if len(parts) < 2 {
break
}
cursor = parts[0]
members := parts[1:]
for _, member := range members {
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
g.Log().Warningf(ctx, "[会话Redis] ZREM单条失败 key=%s err=%v", key, err)
for _, msgID := range msgIDs {
cursor := "0"
for {
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
if err != nil {
g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err)
break
}
}
if cursor == "0" {
break
parts := result.Strings()
if len(parts) < 2 {
break
}
cursor = parts[0]
for _, member := range parts[1:] {
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
g.Log().Warningf(ctx, "[会话Redis] ZREM失败 err=%v", err)
}
}
if cursor == "0" {
break
}
}
}
@@ -95,8 +97,8 @@ func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string
// 读操作
// ============================================
// GetFromRedis 从 Redis ZSET 获取会话历史
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
// GetFromRedis 从 Redis ZSET 获取会话历史,返回 HistoryRound 切片
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]dto.HistoryRound, error) {
key := formatRedisKey(tenantID, sessionID)
maxRounds := util.GetMaxRounds(ctx)
@@ -106,64 +108,46 @@ func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map
}
if result == nil || result.IsNil() {
return []map[string]any{}, nil
return []dto.HistoryRound{}, nil
}
return parseRedisRounds(ctx, result.Strings()), nil
}
// GetSessionHistoryForInference 获取扁平消息数组(给推理用)
func GetSessionHistoryForInference(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
rounds, err := GetFromRedis(ctx, tenantID, sessionID)
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
if len(rounds) == 0 {
return []map[string]any{}, nil
}
return flattenRounds(rounds), nil
return parseRounds(result.Strings()), nil
}
// ============================================
// 解析
// ============================================
func parseRedisRounds(ctx context.Context, members []string) []map[string]any {
rounds := make([]map[string]any, 0, len(members))
// parseRounds 解析 Redis ZSET members 为 HistoryRound 切片
func parseRounds(members []string) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(members))
for _, member := range members {
var data map[string]any
if err := json.Unmarshal([]byte(member), &data); err != nil {
g.Log().Warningf(ctx, "[会话Redis] 解析数据失败 err=%v", err)
var round dto.HistoryRound
if err := json.Unmarshal([]byte(member), &round); err != nil {
continue
}
rounds = append(rounds, data)
if round.User != nil || round.Assistant != nil {
rounds = append(rounds, round)
}
}
return rounds
}
func flattenRounds(rounds []map[string]any) []map[string]any {
var messages []map[string]any
func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage {
var messages []dto.FlatMessage
for i := len(rounds) - 1; i >= 0; i-- {
if user, ok := rounds[i]["user"].(map[string]any); ok && len(user) > 0 {
messages = append(messages, user)
if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" {
messages = append(messages, dto.FlatMessage{
Role: gconv.String(rounds[i].User["role"]),
Content: gconv.String(rounds[i].User["content"]),
})
}
if assistant, ok := rounds[i]["assistant"].(map[string]any); ok && len(assistant) > 0 {
messages = append(messages, assistant)
if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" {
messages = append(messages, dto.FlatMessage{
Role: gconv.String(rounds[i].Assistant["role"]),
Content: gconv.String(rounds[i].Assistant["content"]),
})
}
}
return messages
}
func appendFieldToMessages(data map[string]any, field string, messages *[]map[string]any) {
msgs, ok := data[field].([]any)
if !ok {
return
}
for _, m := range msgs {
if msg, ok := m.(map[string]any); ok {
*messages = append(*messages, msg)
}
}
}

View File

@@ -15,9 +15,14 @@ import (
"prompts-core/model/entity"
)
// ============================================
// 回调存储
// ============================================
// Callback 会话回调
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
req.Messages["role"] = "assistant"
// 1) 更新 DB
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
@@ -36,22 +41,42 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
}
// 3) 写入 Redis
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, &dto.HistoryRound{
Id: session.Id,
User: session.RequestContent,
Assistant: req.Messages,
CreatedAt: gconv.String(session.CreatedAt),
}); err != nil {
// 3) entity → HistoryRound → 写入 Redis
round := entityToHistoryRound(session)
round.Assistant = req.Messages
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, round); err != nil {
return nil, fmt.Errorf("redis存储失败: %w", err)
}
// 4) 返回
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
}
// GetHistoryMessages 获取历史消息
// ============================================
// 场景1前端历史列表按 creator
// ============================================
// GetHistoryList 获取历史列表
func GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (*dto.GetHistoryListRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
sessions, total, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
}, req.Page, req.Size)
if err != nil {
return nil, fmt.Errorf("DB获取历史列表失败: %w", err)
}
rounds := sessionsToHistoryRounds(sessions)
return &dto.GetHistoryListRes{List: rounds, Total: total}, nil
}
// ============================================
// 场景2提示词拼接按 sessionId + nodeId
// ============================================
// GetHistoryMessages 获取历史消息Redis → DB → 异步回种)
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
@@ -59,10 +84,9 @@ func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*d
}
// 1) Redis
redisRounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId)
if err == nil && len(redisRounds) > 0 {
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(redisRounds))
return &dto.GetHistoryMessagesRes{Messages: parseHistoryRounds(redisRounds)}, nil
if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId); err == nil && len(rounds) > 0 {
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(rounds))
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
}
// 2) DB
@@ -70,129 +94,108 @@ func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*d
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
SessionId: req.SessionId,
NodeId: req.NodeId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
if len(sessions) == 0 {
return &dto.GetHistoryMessagesRes{Messages: []dto.HistoryRound{}}, nil
return &dto.GetHistoryMessagesRes{Messages: []dto.FlatMessage{}}, nil
}
// 3) 转换 + 异步回种
rounds := sessionsToHistoryRounds(sessions)
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, sessions)
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, rounds)
return &dto.GetHistoryMessagesRes{Messages: rounds}, nil
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
}
// parseHistoryRounds Redis 数据转为 HistoryRound
func parseHistoryRounds(redisRounds []map[string]any) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(redisRounds))
for _, r := range redisRounds {
round := dto.HistoryRound{
Id: gconv.Int64(r["id"]),
CreatedAt: gconv.String(r["createdAt"]),
}
if user, ok := r["user"].(map[string]any); ok {
round.User = user
}
if assistant, ok := r["assistant"].(map[string]any); ok {
round.Assistant = assistant
}
rounds = append(rounds, round)
// ============================================
// 删除
// ============================================
// DeleteMessages 批量删除消息
func DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (*dto.DeleteMessagesRes, error) {
if len(req.MsgIds) == 0 {
return &dto.DeleteMessagesRes{Ok: false}, fmt.Errorf("msgIds不能为空")
}
return rounds
}
// sessionsToHistoryRounds DB 数据转为 HistoryRound
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(sessions))
for _, s := range sessions {
reqMsgs := util.ConvertToMessages(s.RequestContent)
respMsgs := util.ConvertToMessages(s.ResponseContent)
round := dto.HistoryRound{
Id: s.Id,
CreatedAt: gconv.String(s.CreatedAt),
}
if len(reqMsgs) > 0 {
round.User = reqMsgs[0]
}
if len(respMsgs) > 0 {
if respMsgs[0]["role"] == nil {
respMsgs[0]["role"] = "assistant"
}
round.Assistant = respMsgs[0]
}
rounds = append(rounds, round)
// 1) 删 DB
for _, id := range req.MsgIds {
_, _ = dao.ComposeSession.Delete(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: id},
})
}
return rounds
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
// 2) 删 Redis
_ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, req.MsgIds)
return &dto.DeleteMessagesRes{Ok: true}, nil
}
// DeleteSession 删除会话
// DeleteSession 删除整个会话
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
hasMsgID := len(req.MsgIds) > 0 && req.MsgIds[0] > 0
deleteReq := &entity.ComposeSession{
// 1) 删 DB
if _, err := dao.ComposeSession.Delete(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
NodeId: req.NodeId,
}
if hasMsgID {
deleteReq.Id = req.MsgIds[0]
}
if _, err := dao.ComposeSession.Delete(ctx, deleteReq); err != nil {
}); err != nil {
return nil, fmt.Errorf("DB删除失败: %w", err)
}
if hasMsgID {
if err := DeleteSingleMessage(ctx, req.TenantId, req.SessionId, req.MsgIds[0]); err != nil {
g.Log().Warningf(ctx, "[删除会话] Redis删除单条失败 msgID=%d err=%v", req.MsgIds[0], err)
}
} else {
if err := DeleteSessionHistory(ctx, req.TenantId, req.SessionId); err != nil {
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
// 2) 删 Redis
if err := DeleteSessionHistory(ctx, user.TenantId, req.SessionId); err != nil {
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
}
return &dto.DeleteSessionRes{Ok: true}, nil
}
// ============================================
// 内部方法
// 转换方法entity ↔ dto集中管理
// ============================================
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
var messages []map[string]any
for i := len(sessions) - 1; i >= 0; i-- {
appendRoleMessages(sessions[i].RequestContent, "user", &messages)
appendRoleMessages(sessions[i].ResponseContent, "assistant", &messages)
// entityToHistoryRound entity → HistoryRound
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
reqMsgs := util.ConvertToMessages(s.RequestContent)
respMsgs := util.ConvertToMessages(s.ResponseContent)
round := &dto.HistoryRound{
Id: s.Id,
SessionId: s.SessionId,
NodeId: s.NodeId,
CreatedAt: gconv.String(s.CreatedAt),
UpdatedAt: gconv.String(s.UpdatedAt),
}
return messages
if len(reqMsgs) > 0 {
round.User = reqMsgs[0]
}
if len(respMsgs) > 0 {
round.Assistant = respMsgs[0]
}
return round
}
func appendRoleMessages(content any, defaultRole string, messages *[]map[string]any) {
msgs := util.ConvertToMessages(content)
for _, m := range msgs {
if m["role"] == nil || gconv.String(m["role"]) == "" {
m["role"] = defaultRole
}
*messages = append(*messages, m)
}
}
// asyncCacheToRedis 异步缓存会话数据到 Redis
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, sessions []*entity.ComposeSession) {
// sessionsToHistoryRounds 批量转换
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(sessions))
for _, s := range sessions {
reqMsgs := util.ConvertToMessages(s.RequestContent)
respMsgs := util.ConvertToMessages(s.ResponseContent)
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = SaveToRedis(ctx, tenantID, sessionID, &dto.HistoryRound{
Id: s.Id,
User: s.RequestContent,
Assistant: s.ResponseContent,
CreatedAt: gconv.String(s.CreatedAt),
})
rounds = append(rounds, *entityToHistoryRound(s))
}
return rounds
}
// asyncCacheToRedis 异步缓存到 Redis
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, rounds []dto.HistoryRound) {
for i := range rounds {
if rounds[i].User != nil || rounds[i].Assistant != nil {
_ = SaveToRedis(ctx, tenantID, sessionID, &rounds[i])
}
}
}