Files
prompts-core/service/prompt/prompt_session_redis_service.go

115 lines
3.0 KiB
Go
Raw Normal View History

package prompt
2026-05-12 13:59:15 +08:00
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/gogf/gf/v2/frame/g"
)
// ==================== Redis 操作 ====================
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
2026-05-12 13:59:15 +08:00
key := fmt.Sprintf("chat:session:%s", sessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
expireTime := time.Duration(expireSeconds) * time.Second
2026-05-15 09:45:51 +08:00
data := map[string]any{
"sessionId": sessionId,
"requestContent": requestMessages,
"responseContent": responseMessages,
"timestamp": time.Now().Unix(),
2026-05-12 13:59:15 +08:00
}
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
if err != nil {
return fmt.Errorf("写入Redis失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
if err != nil {
return fmt.Errorf("裁剪Redis列表失败: %w", err)
}
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
if err != nil {
return fmt.Errorf("设置过期时间失败: %w", err)
}
return nil
}
// getFromRedis 从Redis获取会话历史
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
2026-05-12 13:59:15 +08:00
key := fmt.Sprintf("chat:session:%s", sessionId)
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
if err != nil {
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
}
if result == nil || result.IsNil() {
2026-05-15 09:45:51 +08:00
return []map[string]any{}, nil
2026-05-12 13:59:15 +08:00
}
2026-05-15 09:45:51 +08:00
var sessions []map[string]any
2026-05-12 13:59:15 +08:00
values := result.Strings()
for _, str := range values {
2026-05-15 09:45:51 +08:00
var data map[string]any
2026-05-12 13:59:15 +08:00
if err := json.Unmarshal([]byte(str), &data); err != nil {
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
continue
}
sessions = append(sessions, data)
}
2026-05-15 09:45:51 +08:00
// 反转Redis 最新在前 → 时间正序)
2026-05-12 13:59:15 +08:00
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
sessions[i], sessions[j] = sessions[j], sessions[i]
}
return sessions, nil
}
2026-05-15 09:45:51 +08:00
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
historyData, err := getFromRedis(ctx, sessionId)
2026-05-12 13:59:15 +08:00
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
if len(historyData) == 0 {
2026-05-15 09:45:51 +08:00
return []map[string]any{}, nil
2026-05-12 13:59:15 +08:00
}
2026-05-15 09:45:51 +08:00
var messages []map[string]any
2026-05-12 13:59:15 +08:00
for _, round := range historyData {
2026-05-15 09:45:51 +08:00
if reqMsgs, ok := round["requestContent"].([]interface{}); ok {
for _, m := range reqMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
if respMsgs, ok := round["responseContent"].([]interface{}); ok {
for _, m := range respMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
2026-05-12 13:59:15 +08:00
}
return messages, nil
}