2026-05-27 09:36:26 +08:00
|
|
|
|
package session
|
2026-05-12 13:59:15 +08:00
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"encoding/json"
|
|
|
|
|
|
"fmt"
|
2026-06-09 14:00:01 +08:00
|
|
|
|
"prompts-core/common/util"
|
|
|
|
|
|
"prompts-core/model/dto"
|
2026-05-12 13:59:15 +08:00
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-05-20 11:36:39 +08:00
|
|
|
|
const (
|
2026-06-09 14:00:01 +08:00
|
|
|
|
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}
|
|
|
|
|
|
RedisKeySessionHistory = "session:history:%d:%s"
|
2026-05-20 11:36:39 +08:00
|
|
|
|
)
|
2026-05-12 13:59:15 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
// formatRedisKey 格式化 Redis key
|
|
|
|
|
|
func formatRedisKey(tenantID uint64, sessionID string) string {
|
|
|
|
|
|
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID)
|
2026-05-27 09:36:26 +08:00
|
|
|
|
}
|
2026-05-12 13:59:15 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
// ============================================
|
|
|
|
|
|
// 写操作
|
|
|
|
|
|
// ============================================
|
|
|
|
|
|
|
|
|
|
|
|
// SaveToRedis 保存一轮对话到 Redis ZSET
|
|
|
|
|
|
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *dto.HistoryRound) error {
|
|
|
|
|
|
key := formatRedisKey(tenantID, sessionID)
|
|
|
|
|
|
maxRounds := util.GetMaxRounds(ctx)
|
|
|
|
|
|
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
|
|
|
|
|
|
|
|
|
|
|
|
b, err := json.Marshal(round)
|
2026-05-12 13:59:15 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return fmt.Errorf("序列化会话数据失败: %w", err)
|
|
|
|
|
|
}
|
2026-05-20 11:36:39 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
score := float64(time.Now().UnixMilli())
|
2026-05-12 13:59:15 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
|
|
|
|
|
|
return fmt.Errorf("ZADD失败: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
|
|
|
|
|
|
return fmt.Errorf("裁剪失败: %w", err)
|
2026-05-12 13:59:15 +08:00
|
|
|
|
}
|
2026-06-09 14:00:01 +08:00
|
|
|
|
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
|
|
|
|
|
return fmt.Errorf("设置过期失败: %w", err)
|
2026-05-12 13:59:15 +08:00
|
|
|
|
}
|
2026-06-09 14:00:01 +08:00
|
|
|
|
|
2026-05-12 13:59:15 +08:00
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
// DeleteSingleMessage 删除 Redis 中单条消息(按消息ID)
|
|
|
|
|
|
func DeleteSingleMessage(ctx context.Context, tenantID uint64, sessionID string, msgID int64) error {
|
|
|
|
|
|
key := formatRedisKey(tenantID, sessionID)
|
2026-05-12 13:59:15 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
cursor := "0"
|
|
|
|
|
|
for {
|
|
|
|
|
|
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return fmt.Errorf("ZSCAN失败: %w", err)
|
|
|
|
|
|
}
|
2026-05-20 11:36:39 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
parts := result.Strings()
|
|
|
|
|
|
if len(parts) < 2 {
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
2026-05-20 11:36:39 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
cursor = parts[0]
|
|
|
|
|
|
members := parts[1:]
|
2026-05-20 11:36:39 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
for _, member := range members {
|
|
|
|
|
|
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
|
|
|
|
|
|
g.Log().Warningf(ctx, "[会话Redis] ZREM单条失败 key=%s err=%v", key, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-05-20 11:36:39 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
if cursor == "0" {
|
|
|
|
|
|
break
|
2026-05-12 13:59:15 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// DeleteSessionHistory 删除整个会话的 Redis 缓存
|
|
|
|
|
|
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
|
|
|
|
|
|
key := formatRedisKey(tenantID, sessionID)
|
|
|
|
|
|
_, err := g.Redis().Do(ctx, "DEL", key)
|
|
|
|
|
|
return err
|
2026-05-20 11:36:39 +08:00
|
|
|
|
}
|
2026-05-12 13:59:15 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
// ============================================
|
|
|
|
|
|
// 读操作
|
|
|
|
|
|
// ============================================
|
|
|
|
|
|
|
|
|
|
|
|
// GetFromRedis 从 Redis ZSET 获取会话历史
|
|
|
|
|
|
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
|
|
|
|
|
|
key := formatRedisKey(tenantID, sessionID)
|
|
|
|
|
|
maxRounds := util.GetMaxRounds(ctx)
|
|
|
|
|
|
|
|
|
|
|
|
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
|
2026-05-20 11:36:39 +08:00
|
|
|
|
}
|
2026-06-09 14:00:01 +08:00
|
|
|
|
|
|
|
|
|
|
if result == nil || result.IsNil() {
|
|
|
|
|
|
return []map[string]any{}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return parseRedisRounds(ctx, result.Strings()), nil
|
2026-05-12 13:59:15 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
// GetSessionHistoryForInference 获取扁平消息数组(给推理用)
|
|
|
|
|
|
func GetSessionHistoryForInference(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
|
|
|
|
|
|
rounds, err := GetFromRedis(ctx, tenantID, sessionID)
|
2026-05-12 13:59:15 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
if len(rounds) == 0 {
|
2026-05-15 09:45:51 +08:00
|
|
|
|
return []map[string]any{}, nil
|
2026-05-12 13:59:15 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
return flattenRounds(rounds), nil
|
2026-05-20 11:36:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
// ============================================
|
|
|
|
|
|
// 解析
|
|
|
|
|
|
// ============================================
|
2026-05-20 11:36:39 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
func parseRedisRounds(ctx context.Context, members []string) []map[string]any {
|
|
|
|
|
|
rounds := make([]map[string]any, 0, len(members))
|
|
|
|
|
|
for _, member := range members {
|
|
|
|
|
|
var data map[string]any
|
|
|
|
|
|
if err := json.Unmarshal([]byte(member), &data); err != nil {
|
|
|
|
|
|
g.Log().Warningf(ctx, "[会话Redis] 解析数据失败 err=%v", err)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
rounds = append(rounds, data)
|
2026-05-20 11:36:39 +08:00
|
|
|
|
}
|
2026-06-09 14:00:01 +08:00
|
|
|
|
return rounds
|
|
|
|
|
|
}
|
2026-05-20 11:36:39 +08:00
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
func flattenRounds(rounds []map[string]any) []map[string]any {
|
|
|
|
|
|
var messages []map[string]any
|
|
|
|
|
|
for i := len(rounds) - 1; i >= 0; i-- {
|
|
|
|
|
|
if user, ok := rounds[i]["user"].(map[string]any); ok && len(user) > 0 {
|
|
|
|
|
|
messages = append(messages, user)
|
|
|
|
|
|
}
|
|
|
|
|
|
if assistant, ok := rounds[i]["assistant"].(map[string]any); ok && len(assistant) > 0 {
|
|
|
|
|
|
messages = append(messages, assistant)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-05-20 11:36:39 +08:00
|
|
|
|
return messages
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-09 14:00:01 +08:00
|
|
|
|
func appendFieldToMessages(data map[string]any, field string, messages *[]map[string]any) {
|
|
|
|
|
|
msgs, ok := data[field].([]any)
|
2026-05-20 11:36:39 +08:00
|
|
|
|
if !ok {
|
|
|
|
|
|
return
|
2026-05-12 13:59:15 +08:00
|
|
|
|
}
|
2026-05-20 11:36:39 +08:00
|
|
|
|
for _, m := range msgs {
|
2026-06-09 14:00:01 +08:00
|
|
|
|
if msg, ok := m.(map[string]any); ok {
|
2026-05-20 11:36:39 +08:00
|
|
|
|
*messages = append(*messages, msg)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-05-12 13:59:15 +08:00
|
|
|
|
}
|