154 lines
4.3 KiB
Go
154 lines
4.3 KiB
Go
package session
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"prompts-core/common/util"
|
||
"prompts-core/model/dto"
|
||
"time"
|
||
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
const (
|
||
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}
|
||
RedisKeySessionHistory = "session:history:%d:%s"
|
||
)
|
||
|
||
// formatRedisKey 格式化 Redis key
|
||
func formatRedisKey(tenantID uint64, sessionID string) string {
|
||
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID)
|
||
}
|
||
|
||
// ============================================
|
||
// 写操作
|
||
// ============================================
|
||
|
||
// SaveToRedis 保存一轮对话到 Redis ZSET
|
||
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *dto.HistoryRound) error {
|
||
key := formatRedisKey(tenantID, sessionID)
|
||
maxRounds := util.GetMaxRounds(ctx)
|
||
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
|
||
|
||
b, err := json.Marshal(round)
|
||
if err != nil {
|
||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||
}
|
||
|
||
score := float64(time.Now().UnixMilli())
|
||
|
||
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
|
||
return fmt.Errorf("ZADD失败: %w", err)
|
||
}
|
||
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
|
||
return fmt.Errorf("裁剪失败: %w", err)
|
||
}
|
||
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||
return fmt.Errorf("设置过期失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DeleteRedisMessages 批量删除 Redis 中多条消息(按消息ID列表)
|
||
func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID string, msgIDs []int64) error {
|
||
key := formatRedisKey(tenantID, sessionID)
|
||
|
||
for _, msgID := range msgIDs {
|
||
cursor := "0"
|
||
for {
|
||
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
|
||
if err != nil {
|
||
g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err)
|
||
break
|
||
}
|
||
|
||
parts := result.Strings()
|
||
if len(parts) < 2 {
|
||
break
|
||
}
|
||
|
||
cursor = parts[0]
|
||
for _, member := range parts[1:] {
|
||
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
|
||
g.Log().Warningf(ctx, "[会话Redis] ZREM失败 err=%v", err)
|
||
}
|
||
}
|
||
|
||
if cursor == "0" {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DeleteSessionHistory 删除整个会话的 Redis 缓存
|
||
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
|
||
key := formatRedisKey(tenantID, sessionID)
|
||
_, err := g.Redis().Do(ctx, "DEL", key)
|
||
return err
|
||
}
|
||
|
||
// ============================================
|
||
// 读操作
|
||
// ============================================
|
||
|
||
// GetFromRedis 从 Redis ZSET 获取会话历史,返回 HistoryRound 切片
|
||
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]dto.HistoryRound, error) {
|
||
key := formatRedisKey(tenantID, sessionID)
|
||
maxRounds := util.GetMaxRounds(ctx)
|
||
|
||
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
|
||
}
|
||
|
||
if result == nil || result.IsNil() {
|
||
return []dto.HistoryRound{}, nil
|
||
}
|
||
|
||
return parseRounds(result.Strings()), nil
|
||
}
|
||
|
||
// ============================================
|
||
// 解析
|
||
// ============================================
|
||
|
||
// parseRounds 解析 Redis ZSET members 为 HistoryRound 切片
|
||
func parseRounds(members []string) []dto.HistoryRound {
|
||
rounds := make([]dto.HistoryRound, 0, len(members))
|
||
for _, member := range members {
|
||
var round dto.HistoryRound
|
||
if err := json.Unmarshal([]byte(member), &round); err != nil {
|
||
continue
|
||
}
|
||
if round.User != nil || round.Assistant != nil {
|
||
rounds = append(rounds, round)
|
||
}
|
||
}
|
||
return rounds
|
||
}
|
||
|
||
func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage {
|
||
var messages []dto.FlatMessage
|
||
for i := len(rounds) - 1; i >= 0; i-- {
|
||
if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" {
|
||
messages = append(messages, dto.FlatMessage{
|
||
Role: gconv.String(rounds[i].User["role"]),
|
||
Content: gconv.String(rounds[i].User["content"]),
|
||
})
|
||
}
|
||
if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" {
|
||
messages = append(messages, dto.FlatMessage{
|
||
Role: gconv.String(rounds[i].Assistant["role"]),
|
||
Content: gconv.String(rounds[i].Assistant["content"]),
|
||
})
|
||
}
|
||
}
|
||
return messages
|
||
}
|