146 lines
3.7 KiB
Go
146 lines
3.7 KiB
Go
package prompt
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
)
|
|
|
|
const (
|
|
redisKeyPrefix = "chat:session:%s"
|
|
)
|
|
|
|
// saveToRedis 保存会话数据到Redis
|
|
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
|
|
key := formatRedisKey(sessionId)
|
|
|
|
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
|
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
|
|
|
data := map[string]any{
|
|
"sessionId": sessionId,
|
|
"requestContent": requestMessages,
|
|
"responseContent": responseMessages,
|
|
"timestamp": time.Now().Unix(),
|
|
}
|
|
|
|
b, err := json.Marshal(data)
|
|
if err != nil {
|
|
return fmt.Errorf("序列化会话数据失败: %w", err)
|
|
}
|
|
|
|
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// formatRedisKey 格式化Redis键
|
|
func formatRedisKey(sessionId string) string {
|
|
return fmt.Sprintf(redisKeyPrefix, sessionId)
|
|
}
|
|
|
|
// executeRedisCommands 执行Redis命令
|
|
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
|
|
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
|
|
return fmt.Errorf("写入Redis失败: %w", err)
|
|
}
|
|
|
|
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
|
|
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
|
}
|
|
|
|
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
|
return fmt.Errorf("设置过期时间失败: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getFromRedis 从Redis获取会话历史
|
|
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
|
key := formatRedisKey(sessionId)
|
|
|
|
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
|
|
}
|
|
|
|
if result == nil || result.IsNil() {
|
|
return []map[string]any{}, nil
|
|
}
|
|
|
|
sessions := parseRedisSessions(ctx, result.Strings())
|
|
|
|
reverseSlice(sessions)
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
// parseRedisSessions 解析Redis会话数据
|
|
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
|
|
var sessions []map[string]any
|
|
|
|
for _, str := range values {
|
|
var data map[string]any
|
|
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
|
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
|
|
continue
|
|
}
|
|
sessions = append(sessions, data)
|
|
}
|
|
|
|
return sessions
|
|
}
|
|
|
|
// reverseSlice 反转切片
|
|
func reverseSlice(s []map[string]any) {
|
|
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
|
|
s[i], s[j] = s[j], s[i]
|
|
}
|
|
}
|
|
|
|
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
|
|
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
|
historyData, err := getFromRedis(ctx, sessionId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
|
}
|
|
|
|
if len(historyData) == 0 {
|
|
return []map[string]any{}, nil
|
|
}
|
|
|
|
return flattenHistoryMessages(historyData), nil
|
|
}
|
|
|
|
// flattenHistoryMessages 扁平化历史消息
|
|
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
|
|
var messages []map[string]any
|
|
|
|
for _, round := range historyData {
|
|
appendMessagesFromField(round, "requestContent", &messages)
|
|
appendMessagesFromField(round, "responseContent", &messages)
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
// appendMessagesFromField 从指定字段追加消息
|
|
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
|
|
msgs, ok := data[field].([]interface{})
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
for _, m := range msgs {
|
|
if msg, ok := m.(map[string]interface{}); ok {
|
|
*messages = append(*messages, msg)
|
|
}
|
|
}
|
|
}
|