Files
prompts-core/service/session/prompt_session_redis_service.go

170 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package session
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/model/dto"
"time"
"github.com/gogf/gf/v2/frame/g"
)
const (
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}
RedisKeySessionHistory = "session:history:%d:%s"
)
// formatRedisKey 格式化 Redis key
func formatRedisKey(tenantID uint64, sessionID string) string {
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID)
}
// ============================================
// 写操作
// ============================================
// SaveToRedis 保存一轮对话到 Redis ZSET
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *dto.HistoryRound) error {
key := formatRedisKey(tenantID, sessionID)
maxRounds := util.GetMaxRounds(ctx)
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
b, err := json.Marshal(round)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
score := float64(time.Now().UnixMilli())
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
return fmt.Errorf("ZADD失败: %w", err)
}
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
return fmt.Errorf("裁剪失败: %w", err)
}
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
return fmt.Errorf("设置过期失败: %w", err)
}
return nil
}
// DeleteSingleMessage 删除 Redis 中单条消息按消息ID
func DeleteSingleMessage(ctx context.Context, tenantID uint64, sessionID string, msgID int64) error {
key := formatRedisKey(tenantID, sessionID)
cursor := "0"
for {
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
if err != nil {
return fmt.Errorf("ZSCAN失败: %w", err)
}
parts := result.Strings()
if len(parts) < 2 {
break
}
cursor = parts[0]
members := parts[1:]
for _, member := range members {
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
g.Log().Warningf(ctx, "[会话Redis] ZREM单条失败 key=%s err=%v", key, err)
}
}
if cursor == "0" {
break
}
}
return nil
}
// DeleteSessionHistory 删除整个会话的 Redis 缓存
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
key := formatRedisKey(tenantID, sessionID)
_, err := g.Redis().Do(ctx, "DEL", key)
return err
}
// ============================================
// 读操作
// ============================================
// GetFromRedis 从 Redis ZSET 获取会话历史
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
key := formatRedisKey(tenantID, sessionID)
maxRounds := util.GetMaxRounds(ctx)
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
if err != nil {
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
}
if result == nil || result.IsNil() {
return []map[string]any{}, nil
}
return parseRedisRounds(ctx, result.Strings()), nil
}
// GetSessionHistoryForInference 获取扁平消息数组(给推理用)
func GetSessionHistoryForInference(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
rounds, err := GetFromRedis(ctx, tenantID, sessionID)
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
if len(rounds) == 0 {
return []map[string]any{}, nil
}
return flattenRounds(rounds), nil
}
// ============================================
// 解析
// ============================================
func parseRedisRounds(ctx context.Context, members []string) []map[string]any {
rounds := make([]map[string]any, 0, len(members))
for _, member := range members {
var data map[string]any
if err := json.Unmarshal([]byte(member), &data); err != nil {
g.Log().Warningf(ctx, "[会话Redis] 解析数据失败 err=%v", err)
continue
}
rounds = append(rounds, data)
}
return rounds
}
func flattenRounds(rounds []map[string]any) []map[string]any {
var messages []map[string]any
for i := len(rounds) - 1; i >= 0; i-- {
if user, ok := rounds[i]["user"].(map[string]any); ok && len(user) > 0 {
messages = append(messages, user)
}
if assistant, ok := rounds[i]["assistant"].(map[string]any); ok && len(assistant) > 0 {
messages = append(messages, assistant)
}
}
return messages
}
func appendFieldToMessages(data map[string]any, field string, messages *[]map[string]any) {
msgs, ok := data[field].([]any)
if !ok {
return
}
for _, m := range msgs {
if msg, ok := m.(map[string]any); ok {
*messages = append(*messages, msg)
}
}
}