diff --git a/controller/prompt_session_controller.go b/controller/prompt_session_controller.go index 2dfc535..99bbb4c 100644 --- a/controller/prompt_session_controller.go +++ b/controller/prompt_session_controller.go @@ -15,17 +15,22 @@ type session struct{} var Session = new(session) -// SessionCallback 接收会话回调通知 +// SessionCallback 会话回调 func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) { return sessionService.Callback(ctx, req) } -// GetHistoryMessages 获取历史消息 -func (c *session) GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (res *dto.GetHistoryMessagesRes, err error) { - return sessionService.GetHistoryMessages(ctx, req) +// GetHistoryList 获取历史列表(前端列表) +func (c *session) GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (res *dto.GetHistoryListRes, err error) { + return sessionService.GetHistoryList(ctx, req) } -// DeleteSession 删除会话 +// DeleteMessages 批量删除消息 +func (c *session) DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (res *dto.DeleteMessagesRes, err error) { + return sessionService.DeleteMessages(ctx, req) +} + +// DeleteSession 删除整个会话 func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) { return sessionService.DeleteSession(ctx, req) } diff --git a/model/dto/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go index cfc6580..b1fb382 100644 --- a/model/dto/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -66,5 +66,5 @@ type GetPromptTextReq struct { } type GetPromptTextRes struct { - Messages any `json:"messages" dc:"最终消息数组"` + Messages any `json:"messages" dc:"历史消息"` } diff --git a/model/dto/prompt_session_dto.go b/model/dto/prompt_session_dto.go index 694bbe9..5b1ea87 100644 --- a/model/dto/prompt_session_dto.go +++ b/model/dto/prompt_session_dto.go @@ -2,11 +2,22 @@ package dto import "github.com/gogf/gf/v2/frame/g" +// HistoryRound 一轮对话 +type HistoryRound struct { + Id int64 `json:"id" dc:"记录ID"` + SessionId string `json:"sessionId" dc:"会话ID"` + NodeId string `json:"nodeId" dc:"节点ID"` + User map[string]any `json:"user" dc:"用户消息"` + Assistant map[string]any `json:"assistant" dc:"助手回复"` + CreatedAt string `json:"createdAt" dc:"创建时间"` + UpdatedAt string `json:"updatedAt" dc:"更新时间"` +} + // SessionCallbackReq 会话回调请求 type SessionCallbackReq struct { g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"` - Messages map[string]any `json:"messages" dc:"消息数组"` - EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` + Messages map[string]any `json:"messages" v:"required" dc:"消息数组"` + EpicycleId int64 `json:"epicycleId" v:"required" dc:"轮次ID"` } // SessionCallbackRes 会话回调响应 @@ -15,36 +26,55 @@ type SessionCallbackRes struct { SessionId string `json:"sessionId" dc:"会话ID"` } -// GetHistoryMessagesReq 获取历史消息请求 +// GetHistoryListReq 获取历史列表请求(前端) +type GetHistoryListReq struct { + g.Meta `path:"/historyList" method:"get" tags:"会话管理" summary:"获取历史列表"` + Page int `json:"page" d:"1" dc:"页码"` + Size int `json:"size" d:"10" dc:"每页条数"` +} + +// GetHistoryListRes 获取历史列表响应 +type GetHistoryListRes struct { + List []HistoryRound `json:"list" dc:"历史列表"` + Total int `json:"total" dc:"总数"` +} + +// GetHistoryMessagesReq 获取历史消息请求(提示词拼接) type GetHistoryMessagesReq struct { - g.Meta `path:"/history" method:"get" tags:"会话管理" summary:"获取历史消息"` + g.Meta `path:"/historyMessages" method:"get" tags:"会话管理" summary:"获取历史消息"` SessionId string `json:"sessionId" v:"required" dc:"会话ID"` NodeId string `json:"nodeId" dc:"节点ID"` } // GetHistoryMessagesRes 获取历史消息响应 type GetHistoryMessagesRes struct { - Messages []HistoryRound `json:"messages" dc:"历史消息列表"` + Messages []FlatMessage `json:"messages"` } -// HistoryRound 一轮对话 -type HistoryRound struct { - Id int64 `json:"id" dc:"记录ID"` - User map[string]any `json:"user" dc:"用户消息"` - Assistant map[string]any `json:"assistant" dc:"助手回复"` - CreatedAt string `json:"createdAt" dc:"创建时间"` +type FlatMessage struct { + Role string `json:"role"` + Content string `json:"content"` } -// DeleteSessionReq 删除会话请求 -type DeleteSessionReq struct { - g.Meta `path:"/delete" method:"post" tags:"会话管理" summary:"删除会话"` - TenantId uint64 `json:"tenantId" dc:"租户ID"` +// DeleteMessagesReq 批量删除消息请求 +type DeleteMessagesReq struct { + g.Meta `path:"/deleteMessages" method:"post" tags:"会话管理" summary:"批量删除消息"` SessionId string `json:"sessionId" v:"required" dc:"会话ID"` - NodeId string `json:"nodeId" dc:"节点ID"` - MsgIds []int64 `json:"msgIds" dc:"消息ID列表,传则删单条,不传删整个会话"` + MsgIds []int64 `json:"msgIds" v:"required" dc:"消息ID列表"` } -// DeleteSessionRes 删除会话响应 +// DeleteMessagesRes 批量删除消息响应 +type DeleteMessagesRes struct { + Ok bool `json:"ok" dc:"是否成功"` +} + +// DeleteSessionReq 删除整个会话请求 +type DeleteSessionReq struct { + g.Meta `path:"/deleteSession" method:"post" tags:"会话管理" summary:"删除整个会话"` + SessionId string `json:"sessionId" v:"required" dc:"会话ID"` +} + +// DeleteSessionRes 删除整个会话响应 type DeleteSessionRes struct { Ok bool `json:"ok" dc:"是否成功"` } diff --git a/model/entity/prompts_compose_task.go b/model/entity/prompts_compose_task.go index 55318cc..90ccce7 100644 --- a/model/entity/prompts_compose_task.go +++ b/model/entity/prompts_compose_task.go @@ -11,8 +11,7 @@ type ComposeTask struct { CallbackUrl string `orm:"callback_url" json:"callbackUrl"` GatewayState int `orm:"gateway_state" json:"gatewayState"` RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"` - ResultText map[string]any `orm:"result_text" json:"resultText"` - Messages map[string]any `orm:"messages" json:"messages"` + ResultJson map[string]any `orm:"result_json" json:"resultJson"` Status string `orm:"status" json:"status"` ErrorMessage string `orm:"error_message" json:"errorMessage"` OssFile string `orm:"oss_file" json:"ossFile"` @@ -28,8 +27,7 @@ type composeTaskCol struct { CallbackUrl string GatewayState string RequestPayload string - ResultText string - Messages string + ResultJson string Status string ErrorMessage string OssFile string @@ -45,8 +43,7 @@ var ComposeTaskCol = composeTaskCol{ CallbackUrl: "callback_url", GatewayState: "gateway_state", RequestPayload: "request_payload", - ResultText: "result_text", - Messages: "messages", + ResultJson: "result_json", Status: "status", ErrorMessage: "error_message", OssFile: "oss_file", diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index e466ba6..0a78b5a 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -164,7 +164,7 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle req := SendCallbackReq{ TaskId: composeTask.TaskId, Status: composeTask.Status, - Messages: composeTask.Messages, + Messages: composeTask.ResultJson, ErrorMsg: composeTask.ErrorMessage, EpicycleId: epicycleId, } diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index a064eea..791d239 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -154,7 +154,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask GatewayState: req.State, OssFile: req.OssFile, FileType: req.FileType, - ResultText: req.Messages, + ResultJson: req.Messages, }) if composeTask.CallbackUrl != "" { composeTask.Status = public.ComposeStatusFailed @@ -181,11 +181,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ TaskId: req.TaskId, Status: public.ComposeStatusSuccess, - Messages: messages, GatewayState: req.State, OssFile: req.OssFile, FileType: req.FileType, - ResultText: req.Messages, + ResultJson: messages, }) if err != nil { return err @@ -214,7 +213,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas // 6) 回调业务方 if composeTask.CallbackUrl != "" { composeTask.Status = public.ComposeStatusSuccess - composeTask.Messages = messages + composeTask.ResultJson = messages _ = gateway.SendCallback(ctx, composeTask, epicycleId) } return nil @@ -232,7 +231,7 @@ func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID) } - messages := parseMessagesForResponse(record.Messages) + messages := parseMessagesForResponse(record.ResultJson) return &dto.GetComposeTaskRes{ TaskId: record.TaskId, diff --git a/service/session/prompt_session_redis_service.go b/service/session/prompt_session_redis_service.go index 28a5310..886c98b 100644 --- a/service/session/prompt_session_redis_service.go +++ b/service/session/prompt_session_redis_service.go @@ -9,6 +9,7 @@ import ( "time" "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" ) const ( @@ -51,33 +52,34 @@ func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round * return nil } -// DeleteSingleMessage 删除 Redis 中单条消息(按消息ID) -func DeleteSingleMessage(ctx context.Context, tenantID uint64, sessionID string, msgID int64) error { +// DeleteRedisMessages 批量删除 Redis 中多条消息(按消息ID列表) +func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID string, msgIDs []int64) error { key := formatRedisKey(tenantID, sessionID) - cursor := "0" - for { - result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10) - if err != nil { - return fmt.Errorf("ZSCAN失败: %w", err) - } - - parts := result.Strings() - if len(parts) < 2 { - break - } - - cursor = parts[0] - members := parts[1:] - - for _, member := range members { - if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil { - g.Log().Warningf(ctx, "[会话Redis] ZREM单条失败 key=%s err=%v", key, err) + for _, msgID := range msgIDs { + cursor := "0" + for { + result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10) + if err != nil { + g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err) + break } - } - if cursor == "0" { - break + parts := result.Strings() + if len(parts) < 2 { + break + } + + cursor = parts[0] + for _, member := range parts[1:] { + if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil { + g.Log().Warningf(ctx, "[会话Redis] ZREM失败 err=%v", err) + } + } + + if cursor == "0" { + break + } } } @@ -95,8 +97,8 @@ func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string // 读操作 // ============================================ -// GetFromRedis 从 Redis ZSET 获取会话历史 -func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) { +// GetFromRedis 从 Redis ZSET 获取会话历史,返回 HistoryRound 切片 +func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]dto.HistoryRound, error) { key := formatRedisKey(tenantID, sessionID) maxRounds := util.GetMaxRounds(ctx) @@ -106,64 +108,46 @@ func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map } if result == nil || result.IsNil() { - return []map[string]any{}, nil + return []dto.HistoryRound{}, nil } - return parseRedisRounds(ctx, result.Strings()), nil -} - -// GetSessionHistoryForInference 获取扁平消息数组(给推理用) -func GetSessionHistoryForInference(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) { - rounds, err := GetFromRedis(ctx, tenantID, sessionID) - if err != nil { - return nil, fmt.Errorf("获取历史会话失败: %w", err) - } - - if len(rounds) == 0 { - return []map[string]any{}, nil - } - - return flattenRounds(rounds), nil + return parseRounds(result.Strings()), nil } // ============================================ // 解析 // ============================================ -func parseRedisRounds(ctx context.Context, members []string) []map[string]any { - rounds := make([]map[string]any, 0, len(members)) +// parseRounds 解析 Redis ZSET members 为 HistoryRound 切片 +func parseRounds(members []string) []dto.HistoryRound { + rounds := make([]dto.HistoryRound, 0, len(members)) for _, member := range members { - var data map[string]any - if err := json.Unmarshal([]byte(member), &data); err != nil { - g.Log().Warningf(ctx, "[会话Redis] 解析数据失败 err=%v", err) + var round dto.HistoryRound + if err := json.Unmarshal([]byte(member), &round); err != nil { continue } - rounds = append(rounds, data) + if round.User != nil || round.Assistant != nil { + rounds = append(rounds, round) + } } return rounds } -func flattenRounds(rounds []map[string]any) []map[string]any { - var messages []map[string]any +func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage { + var messages []dto.FlatMessage for i := len(rounds) - 1; i >= 0; i-- { - if user, ok := rounds[i]["user"].(map[string]any); ok && len(user) > 0 { - messages = append(messages, user) + if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" { + messages = append(messages, dto.FlatMessage{ + Role: gconv.String(rounds[i].User["role"]), + Content: gconv.String(rounds[i].User["content"]), + }) } - if assistant, ok := rounds[i]["assistant"].(map[string]any); ok && len(assistant) > 0 { - messages = append(messages, assistant) + if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" { + messages = append(messages, dto.FlatMessage{ + Role: gconv.String(rounds[i].Assistant["role"]), + Content: gconv.String(rounds[i].Assistant["content"]), + }) } } return messages } - -func appendFieldToMessages(data map[string]any, field string, messages *[]map[string]any) { - msgs, ok := data[field].([]any) - if !ok { - return - } - for _, m := range msgs { - if msg, ok := m.(map[string]any); ok { - *messages = append(*messages, msg) - } - } -} diff --git a/service/session/prompt_session_service.go b/service/session/prompt_session_service.go index 7414936..f60bf3b 100644 --- a/service/session/prompt_session_service.go +++ b/service/session/prompt_session_service.go @@ -15,9 +15,14 @@ import ( "prompts-core/model/entity" ) +// ============================================ +// 回调存储 +// ============================================ + // Callback 会话回调 func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { req.Messages["role"] = "assistant" + // 1) 更新 DB _, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{ SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, @@ -36,22 +41,42 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId) } - // 3) 写入 Redis - if err = SaveToRedis(ctx, session.TenantId, session.SessionId, &dto.HistoryRound{ - Id: session.Id, - User: session.RequestContent, - Assistant: req.Messages, - CreatedAt: gconv.String(session.CreatedAt), - }); err != nil { + // 3) entity → HistoryRound → 写入 Redis + round := entityToHistoryRound(session) + round.Assistant = req.Messages + if err = SaveToRedis(ctx, session.TenantId, session.SessionId, round); err != nil { return nil, fmt.Errorf("redis存储失败: %w", err) } - // 4) 返回 g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id) return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil } -// GetHistoryMessages 获取历史消息 +// ============================================ +// 场景1:前端历史列表(按 creator) +// ============================================ + +// GetHistoryList 获取历史列表 +func GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (*dto.GetHistoryListRes, error) { + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + sessions, total, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{ + SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, + }, req.Page, req.Size) + if err != nil { + return nil, fmt.Errorf("DB获取历史列表失败: %w", err) + } + rounds := sessionsToHistoryRounds(sessions) + return &dto.GetHistoryListRes{List: rounds, Total: total}, nil +} + +// ============================================ +// 场景2:提示词拼接(按 sessionId + nodeId) +// ============================================ + +// GetHistoryMessages 获取历史消息(Redis → DB → 异步回种) func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) { user, err := utils.GetUserInfo(ctx) if err != nil { @@ -59,10 +84,9 @@ func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*d } // 1) Redis - redisRounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId) - if err == nil && len(redisRounds) > 0 { - g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(redisRounds)) - return &dto.GetHistoryMessagesRes{Messages: parseHistoryRounds(redisRounds)}, nil + if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId); err == nil && len(rounds) > 0 { + g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(rounds)) + return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil } // 2) DB @@ -70,129 +94,108 @@ func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*d sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{ SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, SessionId: req.SessionId, + NodeId: req.NodeId, }, 1, maxRounds) if err != nil { return nil, fmt.Errorf("DB获取历史失败: %w", err) } if len(sessions) == 0 { - return &dto.GetHistoryMessagesRes{Messages: []dto.HistoryRound{}}, nil + return &dto.GetHistoryMessagesRes{Messages: []dto.FlatMessage{}}, nil } // 3) 转换 + 异步回种 rounds := sessionsToHistoryRounds(sessions) - go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, sessions) + go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, rounds) - return &dto.GetHistoryMessagesRes{Messages: rounds}, nil + return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil } -// parseHistoryRounds Redis 数据转为 HistoryRound -func parseHistoryRounds(redisRounds []map[string]any) []dto.HistoryRound { - rounds := make([]dto.HistoryRound, 0, len(redisRounds)) - for _, r := range redisRounds { - round := dto.HistoryRound{ - Id: gconv.Int64(r["id"]), - CreatedAt: gconv.String(r["createdAt"]), - } - if user, ok := r["user"].(map[string]any); ok { - round.User = user - } - if assistant, ok := r["assistant"].(map[string]any); ok { - round.Assistant = assistant - } - rounds = append(rounds, round) +// ============================================ +// 删除 +// ============================================ + +// DeleteMessages 批量删除消息 +func DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (*dto.DeleteMessagesRes, error) { + if len(req.MsgIds) == 0 { + return &dto.DeleteMessagesRes{Ok: false}, fmt.Errorf("msgIds不能为空") } - return rounds -} -// sessionsToHistoryRounds DB 数据转为 HistoryRound -func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound { - rounds := make([]dto.HistoryRound, 0, len(sessions)) - for _, s := range sessions { - reqMsgs := util.ConvertToMessages(s.RequestContent) - respMsgs := util.ConvertToMessages(s.ResponseContent) - - round := dto.HistoryRound{ - Id: s.Id, - CreatedAt: gconv.String(s.CreatedAt), - } - if len(reqMsgs) > 0 { - round.User = reqMsgs[0] - } - if len(respMsgs) > 0 { - if respMsgs[0]["role"] == nil { - respMsgs[0]["role"] = "assistant" - } - round.Assistant = respMsgs[0] - } - rounds = append(rounds, round) + // 1) 删 DB + for _, id := range req.MsgIds { + _, _ = dao.ComposeSession.Delete(ctx, &entity.ComposeSession{ + SQLBaseDO: beans.SQLBaseDO{Id: id}, + }) } - return rounds + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + // 2) 删 Redis + _ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, req.MsgIds) + + return &dto.DeleteMessagesRes{Ok: true}, nil } -// DeleteSession 删除会话 +// DeleteSession 删除整个会话 func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) { - hasMsgID := len(req.MsgIds) > 0 && req.MsgIds[0] > 0 - - deleteReq := &entity.ComposeSession{ + // 1) 删 DB + if _, err := dao.ComposeSession.Delete(ctx, &entity.ComposeSession{ SessionId: req.SessionId, - NodeId: req.NodeId, - } - if hasMsgID { - deleteReq.Id = req.MsgIds[0] - } - - if _, err := dao.ComposeSession.Delete(ctx, deleteReq); err != nil { + }); err != nil { return nil, fmt.Errorf("DB删除失败: %w", err) } - if hasMsgID { - if err := DeleteSingleMessage(ctx, req.TenantId, req.SessionId, req.MsgIds[0]); err != nil { - g.Log().Warningf(ctx, "[删除会话] Redis删除单条失败 msgID=%d err=%v", req.MsgIds[0], err) - } - } else { - if err := DeleteSessionHistory(ctx, req.TenantId, req.SessionId); err != nil { - g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err) - } + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + // 2) 删 Redis + if err := DeleteSessionHistory(ctx, user.TenantId, req.SessionId); err != nil { + g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err) } return &dto.DeleteSessionRes{Ok: true}, nil } // ============================================ -// 内部方法 +// 转换方法(entity ↔ dto,集中管理) // ============================================ -func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any { - var messages []map[string]any - for i := len(sessions) - 1; i >= 0; i-- { - appendRoleMessages(sessions[i].RequestContent, "user", &messages) - appendRoleMessages(sessions[i].ResponseContent, "assistant", &messages) +// entityToHistoryRound entity → HistoryRound +func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound { + reqMsgs := util.ConvertToMessages(s.RequestContent) + respMsgs := util.ConvertToMessages(s.ResponseContent) + + round := &dto.HistoryRound{ + Id: s.Id, + SessionId: s.SessionId, + NodeId: s.NodeId, + CreatedAt: gconv.String(s.CreatedAt), + UpdatedAt: gconv.String(s.UpdatedAt), } - return messages + if len(reqMsgs) > 0 { + round.User = reqMsgs[0] + } + if len(respMsgs) > 0 { + round.Assistant = respMsgs[0] + } + return round } -func appendRoleMessages(content any, defaultRole string, messages *[]map[string]any) { - msgs := util.ConvertToMessages(content) - for _, m := range msgs { - if m["role"] == nil || gconv.String(m["role"]) == "" { - m["role"] = defaultRole - } - *messages = append(*messages, m) - } -} - -// asyncCacheToRedis 异步缓存会话数据到 Redis -func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, sessions []*entity.ComposeSession) { +// sessionsToHistoryRounds 批量转换 +func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound { + rounds := make([]dto.HistoryRound, 0, len(sessions)) for _, s := range sessions { - reqMsgs := util.ConvertToMessages(s.RequestContent) - respMsgs := util.ConvertToMessages(s.ResponseContent) - if len(reqMsgs) > 0 || len(respMsgs) > 0 { - _ = SaveToRedis(ctx, tenantID, sessionID, &dto.HistoryRound{ - Id: s.Id, - User: s.RequestContent, - Assistant: s.ResponseContent, - CreatedAt: gconv.String(s.CreatedAt), - }) + rounds = append(rounds, *entityToHistoryRound(s)) + } + return rounds +} + +// asyncCacheToRedis 异步缓存到 Redis +func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, rounds []dto.HistoryRound) { + for i := range rounds { + if rounds[i].User != nil || rounds[i].Assistant != nil { + _ = SaveToRedis(ctx, tenantID, sessionID, &rounds[i]) } } }