From df2632983601ece38c4b20b89be1176c2bc26553 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Wed, 10 Jun 2026 16:48:35 +0800 Subject: [PATCH] =?UTF-8?q?feat(session):=20=E9=87=8D=E6=9E=84=E4=BC=9A?= =?UTF-8?q?=E8=AF=9D=E6=9C=8D=E5=8A=A1=E6=94=AF=E6=8C=81=E8=8A=82=E7=82=B9?= =?UTF-8?q?=E7=BB=B4=E5=BA=A6=E7=9A=84Redis=E7=BC=93=E5=AD=98=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/compose_session_dao.go | 33 ++++++++++++ .../session/prompt_session_redis_service.go | 52 +++++++++---------- service/session/prompt_session_service.go | 36 ++++++------- 3 files changed, 75 insertions(+), 46 deletions(-) diff --git a/dao/compose_session_dao.go b/dao/compose_session_dao.go index c890f20..d541b52 100644 --- a/dao/compose_session_dao.go +++ b/dao/compose_session_dao.go @@ -89,3 +89,36 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi } return r.RowsAffected() } + +// ListByIds 根据 ID 列表批量查询 +func (d *composeSessionDao) ListByIds(ctx context.Context, ids []int64, creator, sessionId string) (list []*entity.ComposeSession, err error) { + if len(ids) == 0 { + return nil, nil + } + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). + WhereIn(entity.ComposeSessionCol.Id, ids). + Where(entity.ComposeSessionCol.Creator, creator). + Where(entity.ComposeSessionCol.SessionId, sessionId). + All() + if err != nil { + return nil, err + } + err = r.Structs(&list) + return +} + +// DeleteByIds 批量删除编排会话 +func (d *composeSessionDao) DeleteByIds(ctx context.Context, ids []int64, creator, sessionId string) (int64, error) { + if len(ids) == 0 { + return 0, nil + } + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). + WhereIn(entity.ComposeSessionCol.Id, ids). + Where(entity.ComposeSessionCol.Creator, creator). + Where(entity.ComposeSessionCol.SessionId, sessionId). + Delete() + if err != nil { + return 0, err + } + return r.RowsAffected() +} diff --git a/service/session/prompt_session_redis_service.go b/service/session/prompt_session_redis_service.go index 886c98b..0925774 100644 --- a/service/session/prompt_session_redis_service.go +++ b/service/session/prompt_session_redis_service.go @@ -13,13 +13,13 @@ import ( ) const ( - // RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId} - RedisKeySessionHistory = "session:history:%d:%s" + // RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}:{nodeId} + RedisKeySessionHistory = "session:history:%d:%s:%s" ) // formatRedisKey 格式化 Redis key -func formatRedisKey(tenantID uint64, sessionID string) string { - return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID) +func formatRedisKey(tenantID uint64, sessionID, nodeID string) string { + return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, nodeID) } // ============================================ @@ -27,8 +27,8 @@ func formatRedisKey(tenantID uint64, sessionID string) string { // ============================================ // SaveToRedis 保存一轮对话到 Redis ZSET -func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *dto.HistoryRound) error { - key := formatRedisKey(tenantID, sessionID) +func SaveToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, round *dto.HistoryRound) error { + key := formatRedisKey(tenantID, sessionID, nodeID) maxRounds := util.GetMaxRounds(ctx) expireSeconds := int64(util.GetExpireMinutes(ctx) * 60) @@ -52,10 +52,22 @@ func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round * return nil } -// DeleteRedisMessages 批量删除 Redis 中多条消息(按消息ID列表) -func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID string, msgIDs []int64) error { - key := formatRedisKey(tenantID, sessionID) +// DeleteSessionHistory 删除整个 session 下所有 node 的缓存 +func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error { + pattern := fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, "*") + keys, err := g.Redis().Do(ctx, "KEYS", pattern) + if err != nil { + return err + } + for _, key := range keys.Strings() { + _, _ = g.Redis().Do(ctx, "DEL", key) + } + return nil +} +// DeleteRedisMessages 批量删除指定 node 下的消息 +func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID, nodeID string, msgIDs []int64) error { + key := formatRedisKey(tenantID, sessionID, nodeID) for _, msgID := range msgIDs { cursor := "0" for { @@ -64,42 +76,29 @@ func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID string, g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err) break } - parts := result.Strings() if len(parts) < 2 { break } - cursor = parts[0] for _, member := range parts[1:] { - if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil { - g.Log().Warningf(ctx, "[会话Redis] ZREM失败 err=%v", err) - } + _, _ = g.Redis().Do(ctx, "ZREM", key, member) } - if cursor == "0" { break } } } - return nil } -// DeleteSessionHistory 删除整个会话的 Redis 缓存 -func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error { - key := formatRedisKey(tenantID, sessionID) - _, err := g.Redis().Do(ctx, "DEL", key) - return err -} - // ============================================ // 读操作 // ============================================ -// GetFromRedis 从 Redis ZSET 获取会话历史,返回 HistoryRound 切片 -func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]dto.HistoryRound, error) { - key := formatRedisKey(tenantID, sessionID) +// GetFromRedis 从 Redis ZSET 获取会话历史 +func GetFromRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string) ([]dto.HistoryRound, error) { + key := formatRedisKey(tenantID, sessionID, nodeID) maxRounds := util.GetMaxRounds(ctx) result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1) @@ -118,7 +117,6 @@ func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]dto // 解析 // ============================================ -// parseRounds 解析 Redis ZSET members 为 HistoryRound 切片 func parseRounds(members []string) []dto.HistoryRound { rounds := make([]dto.HistoryRound, 0, len(members)) for _, member := range members { diff --git a/service/session/prompt_session_service.go b/service/session/prompt_session_service.go index 18c3f8f..4271aff 100644 --- a/service/session/prompt_session_service.go +++ b/service/session/prompt_session_service.go @@ -21,7 +21,6 @@ import ( // Callback 会话回调 func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { - fmt.Println("打印会话回调", req) req.Messages["role"] = "assistant" // 1) 更新 DB _, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{ @@ -44,7 +43,7 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal // 3) entity → HistoryRound → 写入 Redis round := entityToHistoryRound(session) round.Assistant = req.Messages - if err = SaveToRedis(ctx, session.TenantId, session.SessionId, round); err != nil { + if err = SaveToRedis(ctx, session.TenantId, session.SessionId, session.NodeId, round); err != nil { return nil, fmt.Errorf("redis存储失败: %w", err) } @@ -84,7 +83,7 @@ func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*d } // 1) Redis - if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId); err == nil && len(rounds) > 0 { + if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId, req.NodeId); err == nil && len(rounds) > 0 { g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(rounds)) return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil } @@ -105,7 +104,7 @@ func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*d // 3) 转换 + 异步回种 rounds := sessionsToHistoryRounds(sessions) - go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, rounds) + go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, req.NodeId, rounds) return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil } @@ -114,25 +113,24 @@ func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*d // 删除 // ============================================ -// DeleteMessages 批量删除消息 +// DeleteMessages 删除消息 func DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (*dto.DeleteMessagesRes, error) { if len(req.MsgIds) == 0 { return &dto.DeleteMessagesRes{Ok: false}, fmt.Errorf("msgIds不能为空") } - // 1) 删 DB - for _, id := range req.MsgIds { - _, _ = dao.ComposeSession.Delete(ctx, &entity.ComposeSession{ - SQLBaseDO: beans.SQLBaseDO{Id: id}, - }) - } - user, err := utils.GetUserInfo(ctx) - if err != nil { - return nil, err - } - // 2) 删 Redis - _ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, req.MsgIds) + user, _ := utils.GetUserInfo(ctx) + // 1) 批量查询 + sessions, _ := dao.ComposeSession.ListByIds(ctx, req.MsgIds, user.UserName, req.SessionId) + + // 2) 批量删 DB + _, _ = dao.ComposeSession.DeleteByIds(ctx, req.MsgIds, user.UserName, req.SessionId) + + // 3) 按 nodeId 分组删 Redis + for _, s := range sessions { + _ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, s.NodeId, req.MsgIds) + } return &dto.DeleteMessagesRes{Ok: true}, nil } @@ -184,10 +182,10 @@ func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRou } // asyncCacheToRedis 异步缓存到 Redis -func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, rounds []dto.HistoryRound) { +func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, rounds []dto.HistoryRound) { for i := range rounds { if rounds[i].User != nil || rounds[i].Assistant != nil { - _ = SaveToRedis(ctx, tenantID, sessionID, &rounds[i]) + _ = SaveToRedis(ctx, tenantID, sessionID, nodeID, &rounds[i]) } } }