feat(session): 重构会话管理和Redis缓存机制
This commit is contained in:
@@ -4,134 +4,165 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"prompts-core/model/entity"
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/model/dto"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
const (
|
||||
redisKeyPrefix = "chat:session:%s"
|
||||
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}
|
||||
RedisKeySessionHistory = "session:history:%d:%s"
|
||||
)
|
||||
|
||||
// formatRedisKey 格式化Redis键
|
||||
func formatRedisKey(sessionId string) string {
|
||||
return fmt.Sprintf(redisKeyPrefix, sessionId)
|
||||
// formatRedisKey 格式化 Redis key
|
||||
func formatRedisKey(tenantID uint64, sessionID string) string {
|
||||
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID)
|
||||
}
|
||||
|
||||
// saveToRedis 保存会话数据到Redis
|
||||
func saveToRedis(ctx context.Context, session *entity.ComposeSession) error {
|
||||
key := formatRedisKey(session.SessionId)
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
||||
data := map[string]any{
|
||||
"sessionId": session.SessionId,
|
||||
"requestContent": session.RequestContent,
|
||||
"responseContent": session.ResponseContent,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
b, err := json.Marshal(data)
|
||||
// ============================================
|
||||
// 写操作
|
||||
// ============================================
|
||||
|
||||
// SaveToRedis 保存一轮对话到 Redis ZSET
|
||||
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *dto.HistoryRound) error {
|
||||
key := formatRedisKey(tenantID, sessionID)
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
|
||||
|
||||
b, err := json.Marshal(round)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
||||
return err
|
||||
|
||||
score := float64(time.Now().UnixMilli())
|
||||
|
||||
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
|
||||
return fmt.Errorf("ZADD失败: %w", err)
|
||||
}
|
||||
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
|
||||
return fmt.Errorf("裁剪失败: %w", err)
|
||||
}
|
||||
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||
return fmt.Errorf("设置过期失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeRedisCommands 执行Redis命令
|
||||
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
|
||||
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
|
||||
return fmt.Errorf("写入Redis失败: %w", err)
|
||||
// DeleteSingleMessage 删除 Redis 中单条消息(按消息ID)
|
||||
func DeleteSingleMessage(ctx context.Context, tenantID uint64, sessionID string, msgID int64) error {
|
||||
key := formatRedisKey(tenantID, sessionID)
|
||||
|
||||
cursor := "0"
|
||||
for {
|
||||
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ZSCAN失败: %w", err)
|
||||
}
|
||||
|
||||
parts := result.Strings()
|
||||
if len(parts) < 2 {
|
||||
break
|
||||
}
|
||||
|
||||
cursor = parts[0]
|
||||
members := parts[1:]
|
||||
|
||||
for _, member := range members {
|
||||
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
|
||||
g.Log().Warningf(ctx, "[会话Redis] ZREM单条失败 key=%s err=%v", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
if cursor == "0" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
|
||||
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
||||
}
|
||||
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||
return fmt.Errorf("设置过期时间失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getFromRedis 从Redis获取会话历史
|
||||
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
key := formatRedisKey(sessionId)
|
||||
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
||||
// DeleteSessionHistory 删除整个会话的 Redis 缓存
|
||||
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
|
||||
key := formatRedisKey(tenantID, sessionID)
|
||||
_, err := g.Redis().Do(ctx, "DEL", key)
|
||||
return err
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 读操作
|
||||
// ============================================
|
||||
|
||||
// GetFromRedis 从 Redis ZSET 获取会话历史
|
||||
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
|
||||
key := formatRedisKey(tenantID, sessionID)
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
|
||||
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
|
||||
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
|
||||
}
|
||||
|
||||
if result == nil || result.IsNil() {
|
||||
return []map[string]any{}, nil
|
||||
}
|
||||
|
||||
sessions := parseRedisSessions(ctx, result.Strings())
|
||||
|
||||
reverseSlice(sessions)
|
||||
|
||||
return sessions, nil
|
||||
return parseRedisRounds(ctx, result.Strings()), nil
|
||||
}
|
||||
|
||||
// parseRedisSessions 解析Redis会话数据
|
||||
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
|
||||
var sessions []map[string]any
|
||||
|
||||
for _, str := range values {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
||||
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, data)
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
// reverseSlice 反转切片
|
||||
func reverseSlice(s []map[string]any) {
|
||||
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
|
||||
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
|
||||
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
historyData, err := getFromRedis(ctx, sessionId)
|
||||
// GetSessionHistoryForInference 获取扁平消息数组(给推理用)
|
||||
func GetSessionHistoryForInference(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
|
||||
rounds, err := GetFromRedis(ctx, tenantID, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
||||
}
|
||||
|
||||
if len(historyData) == 0 {
|
||||
if len(rounds) == 0 {
|
||||
return []map[string]any{}, nil
|
||||
}
|
||||
|
||||
return flattenHistoryMessages(historyData), nil
|
||||
return flattenRounds(rounds), nil
|
||||
}
|
||||
|
||||
// flattenHistoryMessages 扁平化历史消息
|
||||
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
|
||||
var messages []map[string]any
|
||||
// ============================================
|
||||
// 解析
|
||||
// ============================================
|
||||
|
||||
for _, round := range historyData {
|
||||
appendMessagesFromField(round, "requestContent", &messages)
|
||||
appendMessagesFromField(round, "responseContent", &messages)
|
||||
func parseRedisRounds(ctx context.Context, members []string) []map[string]any {
|
||||
rounds := make([]map[string]any, 0, len(members))
|
||||
for _, member := range members {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(member), &data); err != nil {
|
||||
g.Log().Warningf(ctx, "[会话Redis] 解析数据失败 err=%v", err)
|
||||
continue
|
||||
}
|
||||
rounds = append(rounds, data)
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
func flattenRounds(rounds []map[string]any) []map[string]any {
|
||||
var messages []map[string]any
|
||||
for i := len(rounds) - 1; i >= 0; i-- {
|
||||
if user, ok := rounds[i]["user"].(map[string]any); ok && len(user) > 0 {
|
||||
messages = append(messages, user)
|
||||
}
|
||||
if assistant, ok := rounds[i]["assistant"].(map[string]any); ok && len(assistant) > 0 {
|
||||
messages = append(messages, assistant)
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// appendMessagesFromField 从指定字段追加消息
|
||||
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
|
||||
msgs, ok := data[field].([]interface{})
|
||||
func appendFieldToMessages(data map[string]any, field string, messages *[]map[string]any) {
|
||||
msgs, ok := data[field].([]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg, ok := m.(map[string]any); ok {
|
||||
*messages = append(*messages, msg)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user