Files
prompts-core/service/session/prompt_session_redis_service.go

154 lines
4.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package session
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/model/dto"
"time"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
const (
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}
RedisKeySessionHistory = "session:history:%d:%s"
)
// formatRedisKey 格式化 Redis key
func formatRedisKey(tenantID uint64, sessionID string) string {
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID)
}
// ============================================
// 写操作
// ============================================
// SaveToRedis 保存一轮对话到 Redis ZSET
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *dto.HistoryRound) error {
key := formatRedisKey(tenantID, sessionID)
maxRounds := util.GetMaxRounds(ctx)
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
b, err := json.Marshal(round)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
score := float64(time.Now().UnixMilli())
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
return fmt.Errorf("ZADD失败: %w", err)
}
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
return fmt.Errorf("裁剪失败: %w", err)
}
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
return fmt.Errorf("设置过期失败: %w", err)
}
return nil
}
// DeleteRedisMessages 批量删除 Redis 中多条消息按消息ID列表
func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID string, msgIDs []int64) error {
key := formatRedisKey(tenantID, sessionID)
for _, msgID := range msgIDs {
cursor := "0"
for {
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
if err != nil {
g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err)
break
}
parts := result.Strings()
if len(parts) < 2 {
break
}
cursor = parts[0]
for _, member := range parts[1:] {
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
g.Log().Warningf(ctx, "[会话Redis] ZREM失败 err=%v", err)
}
}
if cursor == "0" {
break
}
}
}
return nil
}
// DeleteSessionHistory 删除整个会话的 Redis 缓存
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
key := formatRedisKey(tenantID, sessionID)
_, err := g.Redis().Do(ctx, "DEL", key)
return err
}
// ============================================
// 读操作
// ============================================
// GetFromRedis 从 Redis ZSET 获取会话历史,返回 HistoryRound 切片
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]dto.HistoryRound, error) {
key := formatRedisKey(tenantID, sessionID)
maxRounds := util.GetMaxRounds(ctx)
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
if err != nil {
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
}
if result == nil || result.IsNil() {
return []dto.HistoryRound{}, nil
}
return parseRounds(result.Strings()), nil
}
// ============================================
// 解析
// ============================================
// parseRounds 解析 Redis ZSET members 为 HistoryRound 切片
func parseRounds(members []string) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(members))
for _, member := range members {
var round dto.HistoryRound
if err := json.Unmarshal([]byte(member), &round); err != nil {
continue
}
if round.User != nil || round.Assistant != nil {
rounds = append(rounds, round)
}
}
return rounds
}
func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage {
var messages []dto.FlatMessage
for i := len(rounds) - 1; i >= 0; i-- {
if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" {
messages = append(messages, dto.FlatMessage{
Role: gconv.String(rounds[i].User["role"]),
Content: gconv.String(rounds[i].User["content"]),
})
}
if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" {
messages = append(messages, dto.FlatMessage{
Role: gconv.String(rounds[i].Assistant["role"]),
Content: gconv.String(rounds[i].Assistant["content"]),
})
}
}
return messages
}