package session import ( "context" "encoding/json" "fmt" "prompts-core/model/entity" "time" "github.com/gogf/gf/v2/frame/g" ) const ( redisKeyPrefix = "chat:session:%s" ) // formatRedisKey 格式化Redis键 func formatRedisKey(sessionId string) string { return fmt.Sprintf(redisKeyPrefix, sessionId) } // saveToRedis 保存会话数据到Redis func saveToRedis(ctx context.Context, session *entity.ComposeSession) error { key := formatRedisKey(session.SessionId) maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64() data := map[string]any{ "sessionId": session.SessionId, "requestContent": session.RequestContent, "responseContent": session.ResponseContent, "timestamp": time.Now().Unix(), } b, err := json.Marshal(data) if err != nil { return fmt.Errorf("序列化会话数据失败: %w", err) } if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil { return err } return nil } // executeRedisCommands 执行Redis命令 func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error { if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil { return fmt.Errorf("写入Redis失败: %w", err) } if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil { return fmt.Errorf("裁剪Redis列表失败: %w", err) } if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil { return fmt.Errorf("设置过期时间失败: %w", err) } return nil } // getFromRedis 从Redis获取会话历史 func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) { key := formatRedisKey(sessionId) result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1) if err != nil { return nil, fmt.Errorf("从Redis获取数据失败: %w", err) } if result == nil || result.IsNil() { return []map[string]any{}, nil } sessions := parseRedisSessions(ctx, result.Strings()) reverseSlice(sessions) return sessions, nil } // parseRedisSessions 解析Redis会话数据 func parseRedisSessions(ctx context.Context, values []string) []map[string]any { var sessions []map[string]any for _, str := range values { var data map[string]any if err := json.Unmarshal([]byte(str), &data); err != nil { g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err) continue } sessions = append(sessions, data) } return sessions } // reverseSlice 反转切片 func reverseSlice(s []map[string]any) { for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { s[i], s[j] = s[j], s[i] } } // GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用) func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) { historyData, err := getFromRedis(ctx, sessionId) if err != nil { return nil, fmt.Errorf("获取历史会话失败: %w", err) } if len(historyData) == 0 { return []map[string]any{}, nil } return flattenHistoryMessages(historyData), nil } // flattenHistoryMessages 扁平化历史消息 func flattenHistoryMessages(historyData []map[string]any) []map[string]any { var messages []map[string]any for _, round := range historyData { appendMessagesFromField(round, "requestContent", &messages) appendMessagesFromField(round, "responseContent", &messages) } return messages } // appendMessagesFromField 从指定字段追加消息 func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) { msgs, ok := data[field].([]interface{}) if !ok { return } for _, m := range msgs { if msg, ok := m.(map[string]interface{}); ok { *messages = append(*messages, msg) } } }