package session import ( "context" "encoding/json" "fmt" "prompts-core/common/util" "prompts-core/model/dto" "time" "github.com/gogf/gf/v2/frame/g" ) const ( // RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId} RedisKeySessionHistory = "session:history:%d:%s" ) // formatRedisKey 格式化 Redis key func formatRedisKey(tenantID uint64, sessionID string) string { return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID) } // ============================================ // 写操作 // ============================================ // SaveToRedis 保存一轮对话到 Redis ZSET func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *dto.HistoryRound) error { key := formatRedisKey(tenantID, sessionID) maxRounds := util.GetMaxRounds(ctx) expireSeconds := int64(util.GetExpireMinutes(ctx) * 60) b, err := json.Marshal(round) if err != nil { return fmt.Errorf("序列化会话数据失败: %w", err) } score := float64(time.Now().UnixMilli()) if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil { return fmt.Errorf("ZADD失败: %w", err) } if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil { return fmt.Errorf("裁剪失败: %w", err) } if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil { return fmt.Errorf("设置过期失败: %w", err) } return nil } // DeleteSingleMessage 删除 Redis 中单条消息(按消息ID) func DeleteSingleMessage(ctx context.Context, tenantID uint64, sessionID string, msgID int64) error { key := formatRedisKey(tenantID, sessionID) cursor := "0" for { result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10) if err != nil { return fmt.Errorf("ZSCAN失败: %w", err) } parts := result.Strings() if len(parts) < 2 { break } cursor = parts[0] members := parts[1:] for _, member := range members { if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil { g.Log().Warningf(ctx, "[会话Redis] ZREM单条失败 key=%s err=%v", key, err) } } if cursor == "0" { break } } return nil } // DeleteSessionHistory 删除整个会话的 Redis 缓存 func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error { key := formatRedisKey(tenantID, sessionID) _, err := g.Redis().Do(ctx, "DEL", key) return err } // ============================================ // 读操作 // ============================================ // GetFromRedis 从 Redis ZSET 获取会话历史 func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) { key := formatRedisKey(tenantID, sessionID) maxRounds := util.GetMaxRounds(ctx) result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1) if err != nil { return nil, fmt.Errorf("ZREVRANGE失败: %w", err) } if result == nil || result.IsNil() { return []map[string]any{}, nil } return parseRedisRounds(ctx, result.Strings()), nil } // GetSessionHistoryForInference 获取扁平消息数组(给推理用) func GetSessionHistoryForInference(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) { rounds, err := GetFromRedis(ctx, tenantID, sessionID) if err != nil { return nil, fmt.Errorf("获取历史会话失败: %w", err) } if len(rounds) == 0 { return []map[string]any{}, nil } return flattenRounds(rounds), nil } // ============================================ // 解析 // ============================================ func parseRedisRounds(ctx context.Context, members []string) []map[string]any { rounds := make([]map[string]any, 0, len(members)) for _, member := range members { var data map[string]any if err := json.Unmarshal([]byte(member), &data); err != nil { g.Log().Warningf(ctx, "[会话Redis] 解析数据失败 err=%v", err) continue } rounds = append(rounds, data) } return rounds } func flattenRounds(rounds []map[string]any) []map[string]any { var messages []map[string]any for i := len(rounds) - 1; i >= 0; i-- { if user, ok := rounds[i]["user"].(map[string]any); ok && len(user) > 0 { messages = append(messages, user) } if assistant, ok := rounds[i]["assistant"].(map[string]any); ok && len(assistant) > 0 { messages = append(messages, assistant) } } return messages } func appendFieldToMessages(data map[string]any, field string, messages *[]map[string]any) { msgs, ok := data[field].([]any) if !ok { return } for _, m := range msgs { if msg, ok := m.(map[string]any); ok { *messages = append(*messages, msg) } } }