182 lines
5.3 KiB
Go
182 lines
5.3 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/gogf/gf/v2/frame/g"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// Message 消息结构(content 支持 string 或 []string)
|
|||
|
|
type Message struct {
|
|||
|
|
Role string `json:"role"` // user / assistant / system
|
|||
|
|
Content any `json:"content"` // 内容:string 或 []string
|
|||
|
|
Type string `json:"type,omitempty"` // text / file(可选扩展)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetContentString 获取 Content 的字符串形式
|
|||
|
|
func (m Message) GetContentString() string {
|
|||
|
|
switch v := m.Content.(type) {
|
|||
|
|
case string:
|
|||
|
|
return v
|
|||
|
|
case []interface{}:
|
|||
|
|
var parts []string
|
|||
|
|
for _, item := range v {
|
|||
|
|
if s, ok := item.(string); ok {
|
|||
|
|
parts = append(parts, s)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return strings.Join(parts, "\n")
|
|||
|
|
default:
|
|||
|
|
b, _ := json.Marshal(m.Content)
|
|||
|
|
return string(b)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SessionRoundData Redis存储的单轮会话数据
|
|||
|
|
type SessionRoundData struct {
|
|||
|
|
SessionId string `json:"sessionId"` // 会话ID
|
|||
|
|
RequestContent []Message `json:"requestContent"` // 用户请求会话
|
|||
|
|
ResponseContent []Message `json:"responseContent"` // AI回调会话
|
|||
|
|
Timestamp int64 `json:"timestamp"` // 存入时间戳
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetSessionHistory 获取多轮会话历史(供推理时使用)
|
|||
|
|
func (s *sessionService) GetSessionHistory(ctx context.Context, sessionId string) ([]SessionRoundData, error) {
|
|||
|
|
return s.getFromRedis(ctx, sessionId)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// BuildMessages 根据Redis历史构建完整的Messages数组
|
|||
|
|
func (s *sessionService) BuildMessages(ctx context.Context, sessionId string, currentMessages []Message) ([]Message, error) {
|
|||
|
|
// 获取历史会话
|
|||
|
|
history, err := s.getFromRedis(ctx, sessionId)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var allMessages []Message
|
|||
|
|
|
|||
|
|
// 按时间顺序拼接历史消息
|
|||
|
|
for _, round := range history {
|
|||
|
|
allMessages = append(allMessages, round.RequestContent...)
|
|||
|
|
allMessages = append(allMessages, round.ResponseContent...)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 添加当前轮次的请求消息
|
|||
|
|
allMessages = append(allMessages, currentMessages...)
|
|||
|
|
|
|||
|
|
return allMessages, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ==================== Redis 操作 ====================
|
|||
|
|
|
|||
|
|
// saveToRedis 保存会话数据到Redis
|
|||
|
|
// sessionId: 会话ID作为key
|
|||
|
|
// 最大10轮,超出替换最早的,过期时间30分钟
|
|||
|
|
func (s *sessionService) saveToRedis(ctx context.Context, sessionId string, requestMessages []Message, responseMessages []Message) error {
|
|||
|
|
key := fmt.Sprintf("chat:session:%s", sessionId)
|
|||
|
|
|
|||
|
|
// 从配置读取,提供默认值
|
|||
|
|
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
|||
|
|
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
|||
|
|
expireTime := time.Duration(expireSeconds) * time.Second
|
|||
|
|
|
|||
|
|
// 构造存储数据
|
|||
|
|
data := SessionRoundData{
|
|||
|
|
SessionId: sessionId,
|
|||
|
|
RequestContent: requestMessages,
|
|||
|
|
ResponseContent: responseMessages,
|
|||
|
|
Timestamp: time.Now().Unix(),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 序列化
|
|||
|
|
b, err := json.Marshal(data)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("序列化会话数据失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 写入 Redis(LPUSH 添加到最前面,新的在前)
|
|||
|
|
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("写入Redis失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 裁剪到最新10轮(保留前10条)
|
|||
|
|
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 重置过期时间
|
|||
|
|
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("设置过期时间失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getFromRedis 从Redis获取会话历史
|
|||
|
|
func (s *sessionService) getFromRedis(ctx context.Context, sessionId string) ([]SessionRoundData, error) {
|
|||
|
|
key := fmt.Sprintf("chat:session:%s", sessionId)
|
|||
|
|
|
|||
|
|
// 获取列表中所有数据(最多10条)
|
|||
|
|
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if result == nil || result.IsNil() {
|
|||
|
|
return []SessionRoundData{}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 解析数据
|
|||
|
|
var sessions []SessionRoundData
|
|||
|
|
|
|||
|
|
// 将结果转换为字符串数组
|
|||
|
|
values := result.Strings()
|
|||
|
|
for _, str := range values {
|
|||
|
|
var data SessionRoundData
|
|||
|
|
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
|||
|
|
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
sessions = append(sessions, data)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 反转顺序(Redis存储最新在前,使用时按时间正序)
|
|||
|
|
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
|
|||
|
|
sessions[i], sessions[j] = sessions[j], sessions[i]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return sessions, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetSessionHistoryForInference 获取历史会话,直接返回Message数组(给推理用)
|
|||
|
|
func (s *sessionService) GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]Message, error) {
|
|||
|
|
// 从Redis获取历史会话数据
|
|||
|
|
historyData, err := s.getFromRedis(ctx, sessionId)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 如果没有任何历史数据,返回空
|
|||
|
|
if len(historyData) == 0 {
|
|||
|
|
return []Message{}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 把SessionRoundData转换成扁平的Message数组
|
|||
|
|
var messages []Message
|
|||
|
|
for _, round := range historyData {
|
|||
|
|
// 先加用户的请求
|
|||
|
|
messages = append(messages, round.RequestContent...)
|
|||
|
|
// 再加AI的回答
|
|||
|
|
messages = append(messages, round.ResponseContent...)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return messages, nil
|
|||
|
|
}
|