package prompt import ( "context" "fmt" "gitea.com/red-future/common/beans" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" "prompts-core/common/util" "prompts-core/dao" "prompts-core/model/dto" "prompts-core/model/entity" ) // SessionCallback 会话回调 func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { result, err := util.ParseOutput(req.Text) if err != nil { g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err) return nil, fmt.Errorf("解析模型输出失败: %w", err) } result["role"] = "assistant" if err = updateSessionResponse(ctx, req.EpicycleId, result); err != nil { return nil, err } session, err := getSessionById(ctx, req.EpicycleId) if err != nil { return nil, err } if err := saveSessionToRedis(ctx, session); err != nil { return nil, err } requestMessages := util.ConvertToMessages(session.RequestContent) responseMessages := util.ConvertToMessages(session.ResponseContent) g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d", session.SessionId, session.Id, len(requestMessages), len(responseMessages)) return &dto.SessionCallbackRes{}, nil } // updateSessionResponse 更新会话响应 func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error { _, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{ SQLBaseDO: beans.SQLBaseDO{Id: epicycleId}, ResponseContent: response, }) if err != nil { g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err) return fmt.Errorf("更新数据库失败: %w", err) } return nil } // getSessionById 根据ID获取会话 func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) { session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{ SQLBaseDO: beans.SQLBaseDO{Id: epicycleId}, }) if err != nil { g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err) return nil, fmt.Errorf("获取会话数据失败: %w", err) } return session, nil } // saveSessionToRedis 保存会话到Redis func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error { requestMessages := util.ConvertToMessages(session.RequestContent) responseMessages := util.ConvertToMessages(session.ResponseContent) if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil { g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v", session.SessionId, session.Id, err) return fmt.Errorf("Redis存储失败: %w", err) } return nil } // GetHistoryMessages 获取历史信息 func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) { maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() redisHistory, err := GetSessionHistoryForInference(ctx, sessionId) if err == nil && len(redisHistory) > 0 { return redisHistory, nil } return getHistoryFromDatabase(ctx, sessionId, maxRounds) } // getHistoryFromDatabase 从数据库获取历史记录 func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) { sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{ SessionId: sessionId, }, 1, maxRounds) if err != nil { return nil, fmt.Errorf("DB获取历史失败: %w", err) } messages := extractMessagesFromSessions(sessions) cacheSessionsToRedis(ctx, sessions) return messages, nil } // extractMessagesFromSessions 从会话列表中提取消息 func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any { var messages []map[string]any for _, session := range sessions { appendRequestMessages(session.RequestContent, &messages) appendResponseMessages(session.ResponseContent, &messages) } return messages } // appendRequestMessages 追加请求消息 func appendRequestMessages(requestContent any, messages *[]map[string]any) { reqMsgs := util.ConvertToMessages(requestContent) for _, m := range reqMsgs { role := gconv.String(m["role"]) if role == "user" || role == "assistant" { *messages = append(*messages, m) } } } // appendResponseMessages 追加响应消息 func appendResponseMessages(responseContent any, messages *[]map[string]any) { respMsgs := util.ConvertToMessages(responseContent) for _, m := range respMsgs { if m["role"] == nil { m["role"] = "assistant" } *messages = append(*messages, m) } } // cacheSessionsToRedis 将会话缓存到Redis func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) { for _, session := range sessions { reqMsgs := util.ConvertToMessages(session.RequestContent) respMsgs := util.ConvertToMessages(session.ResponseContent) for i := range respMsgs { if respMsgs[i]["role"] == nil { respMsgs[i]["role"] = "assistant" } } if len(reqMsgs) > 0 || len(respMsgs) > 0 { _ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs) } } }