feat(session): 重构会话管理和Redis缓存机制
This commit is contained in:
@@ -2,7 +2,6 @@ package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
@@ -13,16 +12,6 @@ func GetServerName(ctx context.Context) string {
|
||||
return g.Cfg().MustGet(ctx, "server.name", "").String()
|
||||
}
|
||||
|
||||
// GetServerPort 从配置获取服务端口
|
||||
func GetServerPort(ctx context.Context) string {
|
||||
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
|
||||
// address 格式如 ":3009",去掉冒号
|
||||
if strings.HasPrefix(address, ":") {
|
||||
return address[1:]
|
||||
}
|
||||
return "8080"
|
||||
}
|
||||
|
||||
// GetModelPrompt 获取请求模型的提示词
|
||||
func GetModelPrompt(ctx context.Context, modelType int) string {
|
||||
key := "modelPrompts.types." + gconv.String(modelType)
|
||||
@@ -33,3 +22,13 @@ func GetModelPrompt(ctx context.Context, modelType int) string {
|
||||
func GetBuildPrompt(ctx context.Context) string {
|
||||
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
|
||||
}
|
||||
|
||||
// GetMaxRounds 获取最大轮数配置
|
||||
func GetMaxRounds(ctx context.Context) int {
|
||||
return g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
}
|
||||
|
||||
// GetExpireMinutes 获取过期时间配置
|
||||
func GetExpireMinutes(ctx context.Context) int {
|
||||
return g.Cfg().MustGet(ctx, "session.expireMinutes", 30).Int()
|
||||
}
|
||||
|
||||
@@ -50,14 +50,14 @@ database:
|
||||
|
||||
redis:
|
||||
default:
|
||||
address: 116.204.74.41:6379
|
||||
address: 192.168.3.30:6379
|
||||
db: 0
|
||||
|
||||
consul:
|
||||
address: 116.204.74.41:8500
|
||||
address: 192.168.3.30:8500
|
||||
|
||||
jaeger:
|
||||
addr: 116.204.74.41:4318
|
||||
addr: 192.168.3.30:4318
|
||||
|
||||
task:
|
||||
waitTimeoutSeconds: 600 # /composeMessages 同步等待最终结果的最长时间(秒)
|
||||
|
||||
@@ -1,18 +1,31 @@
|
||||
// ============================================
|
||||
// controller/session.go
|
||||
// ============================================
|
||||
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"prompts-core/model/dto"
|
||||
|
||||
"prompts-core/model/dto"
|
||||
sessionService "prompts-core/service/session"
|
||||
)
|
||||
|
||||
type session struct{}
|
||||
|
||||
// Session 提示词会话控制器
|
||||
var Session = new(session)
|
||||
|
||||
// SessionCallback 会话回调
|
||||
// SessionCallback 接收会话回调通知
|
||||
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
||||
return sessionService.Callback(ctx, req)
|
||||
}
|
||||
|
||||
// GetHistoryMessages 获取历史消息
|
||||
func (c *session) GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (res *dto.GetHistoryMessagesRes, err error) {
|
||||
return sessionService.GetHistoryMessages(ctx, req)
|
||||
}
|
||||
|
||||
// DeleteSession 删除会话
|
||||
func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) {
|
||||
return sessionService.DeleteSession(ctx, req)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var ComposeSession = &composeSessionDao{}
|
||||
@@ -15,13 +14,8 @@ type composeSessionDao struct{}
|
||||
|
||||
// Insert 插入
|
||||
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
|
||||
var m = new(entity.ComposeSession)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
Insert(m)
|
||||
Insert(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -54,7 +48,6 @@ func (d *composeSessionDao) List(ctx context.Context, req *entity.ComposeSession
|
||||
OmitEmpty()
|
||||
model.Where(entity.ComposeSessionCol.Creator, req.Creator)
|
||||
model.Where(entity.ComposeSessionCol.SessionId, req.SessionId)
|
||||
model.Where(entity.ComposeSessionCol.NodeId, req.NodeId)
|
||||
model.OrderDesc(entity.ComposeSessionCol.CreatedAt)
|
||||
model.Page(page, size)
|
||||
r, total, err := model.AllAndCount(false)
|
||||
@@ -70,6 +63,7 @@ func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession,
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
OmitEmpty().
|
||||
Where(entity.ComposeSessionCol.Id, req.Id).
|
||||
Where(entity.ComposeSessionCol.Creator, req.Creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
@@ -87,6 +81,7 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
OmitEmpty().
|
||||
Where(entity.ComposeSessionCol.Id, req.Id).
|
||||
Where(entity.ComposeSessionCol.Creator, req.Creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
||||
Delete()
|
||||
if err != nil {
|
||||
|
||||
@@ -2,13 +2,49 @@ package dto
|
||||
|
||||
import "github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
// SessionCallbackReq 会话回调请求
|
||||
type SessionCallbackReq struct {
|
||||
g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"`
|
||||
g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"`
|
||||
Messages map[string]any `json:"messages" dc:"消息数组"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
}
|
||||
|
||||
// SessionCallbackRes 会话回调响应
|
||||
type SessionCallbackRes struct {
|
||||
Status bool `json:"status" dc:"状态"`
|
||||
SessionId string `json:"sessionId" dc:"会话ID"`
|
||||
}
|
||||
|
||||
// GetHistoryMessagesReq 获取历史消息请求
|
||||
type GetHistoryMessagesReq struct {
|
||||
g.Meta `path:"/history" method:"get" tags:"会话管理" summary:"获取历史消息"`
|
||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||
}
|
||||
|
||||
// GetHistoryMessagesRes 获取历史消息响应
|
||||
type GetHistoryMessagesRes struct {
|
||||
Messages []HistoryRound `json:"messages" dc:"历史消息列表"`
|
||||
}
|
||||
|
||||
// HistoryRound 一轮对话
|
||||
type HistoryRound struct {
|
||||
Id int64 `json:"id" dc:"记录ID"`
|
||||
User map[string]any `json:"user" dc:"用户消息"`
|
||||
Assistant map[string]any `json:"assistant" dc:"助手回复"`
|
||||
CreatedAt string `json:"createdAt" dc:"创建时间"`
|
||||
}
|
||||
|
||||
// DeleteSessionReq 删除会话请求
|
||||
type DeleteSessionReq struct {
|
||||
g.Meta `path:"/delete" method:"post" tags:"会话管理" summary:"删除会话"`
|
||||
TenantId uint64 `json:"tenantId" dc:"租户ID"`
|
||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||
MsgIds []int64 `json:"msgIds" dc:"消息ID列表,传则删单条,不传删整个会话"`
|
||||
}
|
||||
|
||||
// DeleteSessionRes 删除会话响应
|
||||
type DeleteSessionRes struct {
|
||||
Ok bool `json:"ok" dc:"是否成功"`
|
||||
}
|
||||
|
||||
@@ -198,13 +198,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
buildType := gconv.Int(payload["buildType"])
|
||||
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
|
||||
// 4) 获取历史内容并拼接
|
||||
history, _ := session.GetHistoryMessages(ctx, sessionId, nodeId)
|
||||
for _, msg := range history {
|
||||
role := gconv.String(msg["role"])
|
||||
if role != "user" && role != "assistant" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
_, _ = session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
})
|
||||
// 5) 存储提示词结果作为历史请求
|
||||
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
@@ -261,11 +258,15 @@ func parseMessagesForResponse(messages any) any {
|
||||
}
|
||||
|
||||
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
|
||||
// 1) 获取基础数据
|
||||
|
||||
// 4) 模拟历史拼接
|
||||
history, _ := session.GetHistoryMessages(ctx, "88888888", "node1")
|
||||
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: "88888888",
|
||||
NodeId: "node1",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.GetPromptTextRes{
|
||||
Messages: history,
|
||||
Messages: history.Messages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -4,134 +4,165 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"prompts-core/model/entity"
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/model/dto"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
const (
|
||||
redisKeyPrefix = "chat:session:%s"
|
||||
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}
|
||||
RedisKeySessionHistory = "session:history:%d:%s"
|
||||
)
|
||||
|
||||
// formatRedisKey 格式化Redis键
|
||||
func formatRedisKey(sessionId string) string {
|
||||
return fmt.Sprintf(redisKeyPrefix, sessionId)
|
||||
// formatRedisKey 格式化 Redis key
|
||||
func formatRedisKey(tenantID uint64, sessionID string) string {
|
||||
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID)
|
||||
}
|
||||
|
||||
// saveToRedis 保存会话数据到Redis
|
||||
func saveToRedis(ctx context.Context, session *entity.ComposeSession) error {
|
||||
key := formatRedisKey(session.SessionId)
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
||||
data := map[string]any{
|
||||
"sessionId": session.SessionId,
|
||||
"requestContent": session.RequestContent,
|
||||
"responseContent": session.ResponseContent,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
b, err := json.Marshal(data)
|
||||
// ============================================
|
||||
// 写操作
|
||||
// ============================================
|
||||
|
||||
// SaveToRedis 保存一轮对话到 Redis ZSET
|
||||
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *dto.HistoryRound) error {
|
||||
key := formatRedisKey(tenantID, sessionID)
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
|
||||
|
||||
b, err := json.Marshal(round)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
||||
|
||||
score := float64(time.Now().UnixMilli())
|
||||
|
||||
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
|
||||
return fmt.Errorf("ZADD失败: %w", err)
|
||||
}
|
||||
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
|
||||
return fmt.Errorf("裁剪失败: %w", err)
|
||||
}
|
||||
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||
return fmt.Errorf("设置过期失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSingleMessage 删除 Redis 中单条消息(按消息ID)
|
||||
func DeleteSingleMessage(ctx context.Context, tenantID uint64, sessionID string, msgID int64) error {
|
||||
key := formatRedisKey(tenantID, sessionID)
|
||||
|
||||
cursor := "0"
|
||||
for {
|
||||
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ZSCAN失败: %w", err)
|
||||
}
|
||||
|
||||
parts := result.Strings()
|
||||
if len(parts) < 2 {
|
||||
break
|
||||
}
|
||||
|
||||
cursor = parts[0]
|
||||
members := parts[1:]
|
||||
|
||||
for _, member := range members {
|
||||
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
|
||||
g.Log().Warningf(ctx, "[会话Redis] ZREM单条失败 key=%s err=%v", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
if cursor == "0" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSessionHistory 删除整个会话的 Redis 缓存
|
||||
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
|
||||
key := formatRedisKey(tenantID, sessionID)
|
||||
_, err := g.Redis().Do(ctx, "DEL", key)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeRedisCommands 执行Redis命令
|
||||
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
|
||||
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
|
||||
return fmt.Errorf("写入Redis失败: %w", err)
|
||||
}
|
||||
// ============================================
|
||||
// 读操作
|
||||
// ============================================
|
||||
|
||||
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
|
||||
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
||||
}
|
||||
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||
return fmt.Errorf("设置过期时间失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// GetFromRedis 从 Redis ZSET 获取会话历史
|
||||
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
|
||||
key := formatRedisKey(tenantID, sessionID)
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
|
||||
// getFromRedis 从Redis获取会话历史
|
||||
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
key := formatRedisKey(sessionId)
|
||||
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
||||
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
|
||||
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
|
||||
}
|
||||
|
||||
if result == nil || result.IsNil() {
|
||||
return []map[string]any{}, nil
|
||||
}
|
||||
|
||||
sessions := parseRedisSessions(ctx, result.Strings())
|
||||
|
||||
reverseSlice(sessions)
|
||||
|
||||
return sessions, nil
|
||||
return parseRedisRounds(ctx, result.Strings()), nil
|
||||
}
|
||||
|
||||
// parseRedisSessions 解析Redis会话数据
|
||||
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
|
||||
var sessions []map[string]any
|
||||
|
||||
for _, str := range values {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
||||
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, data)
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
// reverseSlice 反转切片
|
||||
func reverseSlice(s []map[string]any) {
|
||||
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
|
||||
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
|
||||
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
historyData, err := getFromRedis(ctx, sessionId)
|
||||
// GetSessionHistoryForInference 获取扁平消息数组(给推理用)
|
||||
func GetSessionHistoryForInference(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
|
||||
rounds, err := GetFromRedis(ctx, tenantID, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
||||
}
|
||||
|
||||
if len(historyData) == 0 {
|
||||
if len(rounds) == 0 {
|
||||
return []map[string]any{}, nil
|
||||
}
|
||||
|
||||
return flattenHistoryMessages(historyData), nil
|
||||
return flattenRounds(rounds), nil
|
||||
}
|
||||
|
||||
// flattenHistoryMessages 扁平化历史消息
|
||||
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
|
||||
// ============================================
|
||||
// 解析
|
||||
// ============================================
|
||||
|
||||
func parseRedisRounds(ctx context.Context, members []string) []map[string]any {
|
||||
rounds := make([]map[string]any, 0, len(members))
|
||||
for _, member := range members {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(member), &data); err != nil {
|
||||
g.Log().Warningf(ctx, "[会话Redis] 解析数据失败 err=%v", err)
|
||||
continue
|
||||
}
|
||||
rounds = append(rounds, data)
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
func flattenRounds(rounds []map[string]any) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, round := range historyData {
|
||||
appendMessagesFromField(round, "requestContent", &messages)
|
||||
appendMessagesFromField(round, "responseContent", &messages)
|
||||
for i := len(rounds) - 1; i >= 0; i-- {
|
||||
if user, ok := rounds[i]["user"].(map[string]any); ok && len(user) > 0 {
|
||||
messages = append(messages, user)
|
||||
}
|
||||
if assistant, ok := rounds[i]["assistant"].(map[string]any); ok && len(assistant) > 0 {
|
||||
messages = append(messages, assistant)
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// appendMessagesFromField 从指定字段追加消息
|
||||
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
|
||||
msgs, ok := data[field].([]interface{})
|
||||
func appendFieldToMessages(data map[string]any, field string, messages *[]map[string]any) {
|
||||
msgs, ok := data[field].([]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg, ok := m.(map[string]any); ok {
|
||||
*messages = append(*messages, msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
// Callback 会话回调
|
||||
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||
req.Messages["role"] = "assistant"
|
||||
// 1) 更新 DB
|
||||
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
ResponseContent: req.Messages,
|
||||
@@ -25,121 +27,172 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal
|
||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, fmt.Errorf("更新数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 2) 查询完整记录
|
||||
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
})
|
||||
if session == nil {
|
||||
if err != nil || session == nil {
|
||||
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
|
||||
}
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, fmt.Errorf("获取会话数据失败: %w", err)
|
||||
}
|
||||
if err = saveToRedis(ctx, session); err != nil {
|
||||
|
||||
// 3) 写入 Redis
|
||||
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, &dto.HistoryRound{
|
||||
Id: session.Id,
|
||||
User: session.RequestContent,
|
||||
Assistant: req.Messages,
|
||||
CreatedAt: gconv.String(session.CreatedAt),
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("redis存储失败: %w", err)
|
||||
}
|
||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
||||
session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent))
|
||||
return &dto.SessionCallbackRes{
|
||||
Status: true,
|
||||
SessionId: session.SessionId,
|
||||
}, nil
|
||||
|
||||
// 4) 返回
|
||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
|
||||
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
|
||||
}
|
||||
|
||||
// GetHistoryMessages 获取历史信息
|
||||
func GetHistoryMessages(ctx context.Context, sessionId string, nodeId string) ([]map[string]any, error) {
|
||||
// 1) 获取最大轮次
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
|
||||
// 2) 从 Redis 获取历史记录
|
||||
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
|
||||
if err == nil && len(redisHistory) > 0 {
|
||||
return redisHistory, nil
|
||||
// GetHistoryMessages 获取历史消息
|
||||
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3) Redis 没有,从数据库查最新 maxRounds 条
|
||||
// 1) Redis
|
||||
redisRounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId)
|
||||
if err == nil && len(redisRounds) > 0 {
|
||||
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(redisRounds))
|
||||
return &dto.GetHistoryMessagesRes{Messages: parseHistoryRounds(redisRounds)}, nil
|
||||
}
|
||||
|
||||
// 2) DB
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
SessionId: req.SessionId,
|
||||
}, 1, maxRounds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||
}
|
||||
// 4) 为空返回报错
|
||||
if len(sessions) == 0 {
|
||||
return nil, fmt.Errorf("会话不存在: sessionId=%s nodeId=%s", sessionId, nodeId)
|
||||
}
|
||||
// 5) 提取为统一格式
|
||||
messages := extractMessagesFromSessions(sessions)
|
||||
|
||||
// 6) 缓存 Redis 半小时
|
||||
//_ = CacheSessionHistoryForInference(ctx, sessionId, messages, 30*time.Minute)
|
||||
|
||||
return messages, nil
|
||||
return &dto.GetHistoryMessagesRes{Messages: []dto.HistoryRound{}}, nil
|
||||
}
|
||||
|
||||
// getHistoryFromDatabase 从数据库获取历史记录
|
||||
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
|
||||
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||
SessionId: sessionId,
|
||||
}, 1, maxRounds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||
// 3) 转换 + 异步回种
|
||||
rounds := sessionsToHistoryRounds(sessions)
|
||||
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, sessions)
|
||||
|
||||
return &dto.GetHistoryMessagesRes{Messages: rounds}, nil
|
||||
}
|
||||
|
||||
messages := extractMessagesFromSessions(sessions)
|
||||
|
||||
cacheSessionsToRedis(ctx, sessions)
|
||||
|
||||
return messages, nil
|
||||
// parseHistoryRounds Redis 数据转为 HistoryRound
|
||||
func parseHistoryRounds(redisRounds []map[string]any) []dto.HistoryRound {
|
||||
rounds := make([]dto.HistoryRound, 0, len(redisRounds))
|
||||
for _, r := range redisRounds {
|
||||
round := dto.HistoryRound{
|
||||
Id: gconv.Int64(r["id"]),
|
||||
CreatedAt: gconv.String(r["createdAt"]),
|
||||
}
|
||||
if user, ok := r["user"].(map[string]any); ok {
|
||||
round.User = user
|
||||
}
|
||||
if assistant, ok := r["assistant"].(map[string]any); ok {
|
||||
round.Assistant = assistant
|
||||
}
|
||||
rounds = append(rounds, round)
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
// extractMessagesFromSessions 从会话列表中提取消息
|
||||
// sessionsToHistoryRounds DB 数据转为 HistoryRound
|
||||
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
|
||||
rounds := make([]dto.HistoryRound, 0, len(sessions))
|
||||
for _, s := range sessions {
|
||||
reqMsgs := util.ConvertToMessages(s.RequestContent)
|
||||
respMsgs := util.ConvertToMessages(s.ResponseContent)
|
||||
|
||||
round := dto.HistoryRound{
|
||||
Id: s.Id,
|
||||
CreatedAt: gconv.String(s.CreatedAt),
|
||||
}
|
||||
if len(reqMsgs) > 0 {
|
||||
round.User = reqMsgs[0]
|
||||
}
|
||||
if len(respMsgs) > 0 {
|
||||
if respMsgs[0]["role"] == nil {
|
||||
respMsgs[0]["role"] = "assistant"
|
||||
}
|
||||
round.Assistant = respMsgs[0]
|
||||
}
|
||||
rounds = append(rounds, round)
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
// DeleteSession 删除会话
|
||||
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
|
||||
hasMsgID := len(req.MsgIds) > 0 && req.MsgIds[0] > 0
|
||||
|
||||
deleteReq := &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
NodeId: req.NodeId,
|
||||
}
|
||||
if hasMsgID {
|
||||
deleteReq.Id = req.MsgIds[0]
|
||||
}
|
||||
|
||||
if _, err := dao.ComposeSession.Delete(ctx, deleteReq); err != nil {
|
||||
return nil, fmt.Errorf("DB删除失败: %w", err)
|
||||
}
|
||||
|
||||
if hasMsgID {
|
||||
if err := DeleteSingleMessage(ctx, req.TenantId, req.SessionId, req.MsgIds[0]); err != nil {
|
||||
g.Log().Warningf(ctx, "[删除会话] Redis删除单条失败 msgID=%d err=%v", req.MsgIds[0], err)
|
||||
}
|
||||
} else {
|
||||
if err := DeleteSessionHistory(ctx, req.TenantId, req.SessionId); err != nil {
|
||||
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.DeleteSessionRes{Ok: true}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 内部方法
|
||||
// ============================================
|
||||
|
||||
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
|
||||
var messages []map[string]any
|
||||
for _, session := range sessions {
|
||||
appendRequestMessages(session.RequestContent, &messages)
|
||||
appendResponseMessages(session.ResponseContent, &messages)
|
||||
for i := len(sessions) - 1; i >= 0; i-- {
|
||||
appendRoleMessages(sessions[i].RequestContent, "user", &messages)
|
||||
appendRoleMessages(sessions[i].ResponseContent, "assistant", &messages)
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// appendRequestMessages 追加请求消息
|
||||
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
|
||||
reqMsgs := util.ConvertToMessages(requestContent)
|
||||
for _, m := range reqMsgs {
|
||||
role := gconv.String(m["role"])
|
||||
if role == "user" || role == "assistant" {
|
||||
*messages = append(*messages, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// appendResponseMessages 追加响应消息
|
||||
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
|
||||
respMsgs := util.ConvertToMessages(responseContent)
|
||||
for _, m := range respMsgs {
|
||||
if m["role"] == nil {
|
||||
m["role"] = "assistant"
|
||||
func appendRoleMessages(content any, defaultRole string, messages *[]map[string]any) {
|
||||
msgs := util.ConvertToMessages(content)
|
||||
for _, m := range msgs {
|
||||
if m["role"] == nil || gconv.String(m["role"]) == "" {
|
||||
m["role"] = defaultRole
|
||||
}
|
||||
*messages = append(*messages, m)
|
||||
}
|
||||
}
|
||||
|
||||
// cacheSessionsToRedis 将会话缓存到Redis
|
||||
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
|
||||
for _, session := range sessions {
|
||||
reqMsgs := util.ConvertToMessages(session.RequestContent)
|
||||
respMsgs := util.ConvertToMessages(session.ResponseContent)
|
||||
|
||||
for i := range respMsgs {
|
||||
if respMsgs[i]["role"] == nil {
|
||||
respMsgs[i]["role"] = "assistant"
|
||||
}
|
||||
}
|
||||
|
||||
// asyncCacheToRedis 异步缓存会话数据到 Redis
|
||||
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, sessions []*entity.ComposeSession) {
|
||||
for _, s := range sessions {
|
||||
reqMsgs := util.ConvertToMessages(s.RequestContent)
|
||||
respMsgs := util.ConvertToMessages(s.ResponseContent)
|
||||
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
||||
_ = saveToRedis(ctx, session)
|
||||
_ = SaveToRedis(ctx, tenantID, sessionID, &dto.HistoryRound{
|
||||
Id: s.Id,
|
||||
User: s.RequestContent,
|
||||
Assistant: s.ResponseContent,
|
||||
CreatedAt: gconv.String(s.CreatedAt),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user