feat(session): 重构会话管理和Redis缓存机制

This commit is contained in:
2026-06-09 14:00:01 +08:00
parent 1f9a2b9b5f
commit 9410199fbe
8 changed files with 324 additions and 196 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
@@ -17,6 +18,7 @@ import (
// Callback 会话回调
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
req.Messages["role"] = "assistant"
// 1) 更新 DB
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: req.Messages,
@@ -25,121 +27,172 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("更新数据库失败: %w", err)
}
// 2) 查询完整记录
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
})
if session == nil {
if err != nil || session == nil {
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
}
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("获取会话数据失败: %w", err)
}
if err = saveToRedis(ctx, session); err != nil {
// 3) 写入 Redis
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, &dto.HistoryRound{
Id: session.Id,
User: session.RequestContent,
Assistant: req.Messages,
CreatedAt: gconv.String(session.CreatedAt),
}); err != nil {
return nil, fmt.Errorf("redis存储失败: %w", err)
}
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent))
return &dto.SessionCallbackRes{
Status: true,
SessionId: session.SessionId,
}, nil
// 4) 返回
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
}
// GetHistoryMessages 获取历史
func GetHistoryMessages(ctx context.Context, sessionId string, nodeId string) ([]map[string]any, error) {
// 1) 获取最大轮次
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
// 2) 从 Redis 获取历史记录
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
if err == nil && len(redisHistory) > 0 {
return redisHistory, nil
// GetHistoryMessages 获取历史
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
// 3) Redis 没有,从数据库查最新 maxRounds 条
// 1) Redis
redisRounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId)
if err == nil && len(redisRounds) > 0 {
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(redisRounds))
return &dto.GetHistoryMessagesRes{Messages: parseHistoryRounds(redisRounds)}, nil
}
// 2) DB
maxRounds := util.GetMaxRounds(ctx)
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
NodeId: nodeId,
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
SessionId: req.SessionId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
// 4) 为空返回报错
if len(sessions) == 0 {
return nil, fmt.Errorf("会话不存在: sessionId=%s nodeId=%s", sessionId, nodeId)
}
// 5) 提取为统一格式
messages := extractMessagesFromSessions(sessions)
// 6) 缓存 Redis 半小时
//_ = CacheSessionHistoryForInference(ctx, sessionId, messages, 30*time.Minute)
return messages, nil
}
// getHistoryFromDatabase 从数据库获取历史记录
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
return &dto.GetHistoryMessagesRes{Messages: []dto.HistoryRound{}}, nil
}
messages := extractMessagesFromSessions(sessions)
// 3) 转换 + 异步回种
rounds := sessionsToHistoryRounds(sessions)
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, sessions)
cacheSessionsToRedis(ctx, sessions)
return messages, nil
return &dto.GetHistoryMessagesRes{Messages: rounds}, nil
}
// extractMessagesFromSessions 从会话列表中提取消息
// parseHistoryRounds Redis 数据转为 HistoryRound
func parseHistoryRounds(redisRounds []map[string]any) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(redisRounds))
for _, r := range redisRounds {
round := dto.HistoryRound{
Id: gconv.Int64(r["id"]),
CreatedAt: gconv.String(r["createdAt"]),
}
if user, ok := r["user"].(map[string]any); ok {
round.User = user
}
if assistant, ok := r["assistant"].(map[string]any); ok {
round.Assistant = assistant
}
rounds = append(rounds, round)
}
return rounds
}
// sessionsToHistoryRounds DB 数据转为 HistoryRound
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(sessions))
for _, s := range sessions {
reqMsgs := util.ConvertToMessages(s.RequestContent)
respMsgs := util.ConvertToMessages(s.ResponseContent)
round := dto.HistoryRound{
Id: s.Id,
CreatedAt: gconv.String(s.CreatedAt),
}
if len(reqMsgs) > 0 {
round.User = reqMsgs[0]
}
if len(respMsgs) > 0 {
if respMsgs[0]["role"] == nil {
respMsgs[0]["role"] = "assistant"
}
round.Assistant = respMsgs[0]
}
rounds = append(rounds, round)
}
return rounds
}
// DeleteSession 删除会话
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
hasMsgID := len(req.MsgIds) > 0 && req.MsgIds[0] > 0
deleteReq := &entity.ComposeSession{
SessionId: req.SessionId,
NodeId: req.NodeId,
}
if hasMsgID {
deleteReq.Id = req.MsgIds[0]
}
if _, err := dao.ComposeSession.Delete(ctx, deleteReq); err != nil {
return nil, fmt.Errorf("DB删除失败: %w", err)
}
if hasMsgID {
if err := DeleteSingleMessage(ctx, req.TenantId, req.SessionId, req.MsgIds[0]); err != nil {
g.Log().Warningf(ctx, "[删除会话] Redis删除单条失败 msgID=%d err=%v", req.MsgIds[0], err)
}
} else {
if err := DeleteSessionHistory(ctx, req.TenantId, req.SessionId); err != nil {
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
}
}
return &dto.DeleteSessionRes{Ok: true}, nil
}
// ============================================
// 内部方法
// ============================================
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
var messages []map[string]any
for _, session := range sessions {
appendRequestMessages(session.RequestContent, &messages)
appendResponseMessages(session.ResponseContent, &messages)
for i := len(sessions) - 1; i >= 0; i-- {
appendRoleMessages(sessions[i].RequestContent, "user", &messages)
appendRoleMessages(sessions[i].ResponseContent, "assistant", &messages)
}
return messages
}
// appendRequestMessages 追加请求消息
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
reqMsgs := util.ConvertToMessages(requestContent)
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
*messages = append(*messages, m)
}
}
}
// appendResponseMessages 追加响应消息
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
respMsgs := util.ConvertToMessages(responseContent)
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
func appendRoleMessages(content any, defaultRole string, messages *[]map[string]any) {
msgs := util.ConvertToMessages(content)
for _, m := range msgs {
if m["role"] == nil || gconv.String(m["role"]) == "" {
m["role"] = defaultRole
}
*messages = append(*messages, m)
}
}
// cacheSessionsToRedis 将会话缓存到Redis
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
for _, session := range sessions {
reqMsgs := util.ConvertToMessages(session.RequestContent)
respMsgs := util.ConvertToMessages(session.ResponseContent)
for i := range respMsgs {
if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant"
}
}
// asyncCacheToRedis 异步缓存会话数据到 Redis
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, sessions []*entity.ComposeSession) {
for _, s := range sessions {
reqMsgs := util.ConvertToMessages(s.RequestContent)
respMsgs := util.ConvertToMessages(s.ResponseContent)
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session)
_ = SaveToRedis(ctx, tenantID, sessionID, &dto.HistoryRound{
Id: s.Id,
User: s.RequestContent,
Assistant: s.ResponseContent,
CreatedAt: gconv.String(s.CreatedAt),
})
}
}
}