package service import ( "context" "customer-server/consts/account" "customer-server/consts/public" "customer-server/consts/scriptedSpeech" "customer-server/dao" "customer-server/model/dto" "customer-server/model/entity" "encoding/json" "fmt" "slices" "strings" "time" "gitea.com/red-future/common/beans" "gitea.com/red-future/common/http" "gitea.com/red-future/common/jaeger" "gitea.com/red-future/common/utils" gmq "github.com/bjang03/gmq/core/gmq" "github.com/bjang03/gmq/mq" "github.com/bjang03/gmq/types" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) var SessionToolService = new(sessionToolService) type sessionToolService struct{} func (s *sessionToolService) PushOpeningRemark(ctx context.Context, userId string, accountInfo *dto.AccountVO, headers map[string]string) (content string, err error) { content = "" var sceneType = scriptedSpeech.SceneTypeOpeningRemark var key = fmt.Sprintf(public.AccountMsgKey, accountInfo.AccountCode, account.GetDescByCode(accountInfo.Platform), userId) get, err := g.Redis().Get(ctx, key) if err != nil { return } if g.IsEmpty(get) { // 构建开场白内容 if len(accountInfo.DatasetIds) > 1 { var datasetInfo *dto.RagListDatasetRes datasetInfo, err = SessionToolService.GetDatasetInfo(ctx, accountInfo.DatasetIds, headers) if err != nil { return } if g.IsEmpty(datasetInfo) { err = fmt.Errorf("数据集不存在") return } var datasetDescriptions [][]string for _, dataset := range datasetInfo.List { datasetDescriptions = append(datasetDescriptions, []string{dataset.Name, gconv.String(dataset.Id)}) } content, err = SessionToolService.BuildGreeting(ctx, userId, accountInfo.Greeting, datasetDescriptions, len(accountInfo.DatasetIds)) } else { var datasetDescriptions [][]string for _, keyword := range accountInfo.KeywordOption { datasetDescriptions = append(datasetDescriptions, []string{keyword, gconv.String(accountInfo.DatasetIds[0])}) } content, err = SessionToolService.BuildGreeting(ctx, userId, accountInfo.Greeting, datasetDescriptions, len(accountInfo.DatasetIds)) } err = s.pushDelayMsg(ctx, key, sceneType.Code(), sceneType.Desc(), accountInfo.DatasetIds) if err != nil { return } } return } func (s *sessionToolService) PushDialog(ctx context.Context, userId string, questionContent string, accountInfo *dto.AccountVO, headers map[string]string) (content string, err error) { content = "" // 删除延迟消息 if err = s.DeleteDelayMsg(ctx); err != nil { return } var key = fmt.Sprintf(public.AccountMsgKey, accountInfo.AccountCode, account.GetDescByCode(accountInfo.Platform), userId) get, err := g.Redis().Get(ctx, key) if err != nil { return } if !g.IsEmpty(get) { sceneType := scriptedSpeech.SceneTypeDialog var datasetIds []int64 var optionsMap *gvar.Var optionsMap, err = g.Redis().Get(ctx, fmt.Sprintf(public.AccountGreetingOptionsKey, userId)) if err != nil { return } jsonStr := gconv.String(optionsMap) var data map[string]interface{} if err = gconv.Scan(jsonStr, &data); err != nil { return } for i, item := range data { // 把每一项转成 map if i == questionContent { m := gconv.Map(item) questionContent = gconv.String(m["datasetName"]) datasetIds = gconv.Int64s(m["datasetId"]) } } if g.IsEmpty(datasetIds) { var datasetRes []int64 datasetRes, err = s.getDatasetIdsByKeywords(ctx, questionContent, headers) if err != nil { return } if len(datasetRes) > 0 { datasetIds = datasetRes } else { datasetIds = accountInfo.DatasetIds } } // 获取用户对话上下文 var history []*dto.Message history, err = SessionToolService.GetUserHistory(ctx, userId) if err != nil { err = fmt.Errorf("获取用户对话上下文失败: %w", err) return } // 获取用户对话记录 var accountUserDialog *entity.AccountUserDialog accountUserDialog, err = dao.AccountUserDialog.Get(ctx, &dto.GetAccountUserDialogReq{ AccountId: accountInfo.Id, UserId: userId, }) if err != nil { err = fmt.Errorf("获取用户对话记录失败: %w", err) return } if g.IsEmpty(accountUserDialog.Id) { // 保存用户对话记录 if _, err = dao.AccountUserDialog.Insert(ctx, &dto.AddAccountUserDialogReq{ AccountId: accountInfo.Id, UserId: userId, DialogCount: 1, }); err != nil { err = fmt.Errorf("保存用户对话记录失败: %w", err) return } } else { if accountUserDialog.DialogCount >= g.Cfg().MustGet(ctx, "card.triggerCount").Int64() { // TODO 替换为实际卡片发送逻辑 content = "请加一下卡片的联系方式,进行更专业的咨询" sceneType = scriptedSpeech.SceneTypeCardSend if _, err = SessionToolService.ClearUserHistory(ctx, userId); err != nil { err = fmt.Errorf("清除用户对话上下文失败: %w", err) return } } else { // 更新用户对话记录 if _, err = dao.AccountUserDialog.Update(ctx, &dto.UpdateAccountUserDialogReq{ Id: accountUserDialog.Id, DialogCount: 1, }); err != nil { return } } } if *sceneType.Code() != *scriptedSpeech.SceneTypeCardSend.Code() { // 通过HTTP调用rag服务的RAG查询接口 var ragQuery *dto.RagQueryRes ragQuery, err = SessionToolService.GetRagQuery(ctx, questionContent, datasetIds, history, headers) if err != nil { err = fmt.Errorf("调用rag服务的RAG查询接口失败: %w", err) return } content = ragQuery.Answer // 保存用户对话上下文 err = SessionToolService.SaveUserHistory(ctx, userId, []*dto.Message{ {Role: "user", Content: questionContent}, {Role: "assistant", Content: content}, }) if err != nil { err = fmt.Errorf("保存用户对话上下文失败: %w", err) return } } err = s.pushDelayMsg(ctx, key, sceneType.Code(), sceneType.Desc(), datasetIds) if err != nil { return } } return } func (s *sessionToolService) pushDelayMsg(ctx context.Context, key string, sceneTypeCode scriptedSpeech.SceneType, sceneTypeDesc string, datasetIds []int64) (err error) { err = g.Redis().SetEX(ctx, key, sceneTypeDesc, gconv.Int64(public.DialogTimeout*time.Second)) if err != nil { return err } // 获取追问话术内容 var msg string if len(datasetIds) == 1 { scriptedSpeechInfo, err := SessionToolService.GetScriptedSpeechContent(ctx, datasetIds[0], sceneTypeCode) if err != nil { return fmt.Errorf("获取追问话术内容失败: %w", err) } msg = scriptedSpeechInfo.QuestionContent } if g.IsEmpty(msg) { if *sceneTypeCode == *scriptedSpeech.SceneTypeOpeningRemark.Code() { msg = public.SceneOpeningRemark } else if *sceneTypeCode == *scriptedSpeech.SceneTypeDialog.Code() { msg = public.SceneDialog } else if *sceneTypeCode == *scriptedSpeech.SceneTypeCardSend.Code() { msg = public.SceneCardSend } } var msgMap = map[string]string{ "key": key, "data": msg, } err = gmq.GetGmq(public.GmqMsgPluginsName).GmqPublishDelay(ctx, &mq.NatsPubDelayMessage{ PubDelayMessage: types.PubDelayMessage{ PubMessage: types.PubMessage{ Topic: public.AccountFollowupTopic, Data: msgMap, }, DelaySeconds: 60, }, }) return } func (s *sessionToolService) DeleteDelayMsg(ctx context.Context) (err error) { return gmq.GetGmq(public.GmqMsgPluginsName).GmqDeleteDelay(ctx, &mq.NatsDelMessage{ DelMessage: types.DelMessage{ Topic: public.AccountFollowupTopic, }, }) } // GetAccountInfo 获取客服账号信息 func (s *sessionToolService) GetAccountInfo(ctx context.Context, accountCode string) (res *dto.AccountVO, err error) { r, err := dao.Account.GetByAccountCode(ctx, &dto.GetByAccountCodeReq{ AccountCode: accountCode, }) if err != nil { return nil, fmt.Errorf("获取客服账号信息失败: %w", err) } err = gconv.Struct(r, &res) return } // SetUserInfo 设置用户信息 func (s *sessionToolService) SetUserInfo(ctx context.Context, creator string, tenantId uint64) (headers map[string]string, err error) { // 创建完整的用户信息 userInfo := &beans.User{ UserName: creator, TenantId: tenantId, } ctx = context.WithValue(ctx, "user", *userInfo) // 提取并保存请求头(在连接升级前) headers = make(map[string]string) // 提取其他headers if r := g.RequestFromCtx(ctx); r != nil { for k, v := range r.Request.Header { if len(v) > 0 { headers[k] = v[0] } } } // 将完整用户信息序列化为JSON,放到X-User-Info请求头 userInfoJson, err := gjson.Encode(userInfo) if err != nil { return nil, fmt.Errorf("用户信息序列化失败: %w", err) } headers["X-User-Info"] = string(userInfoJson) return } // GetDatasetInfo 获取数据集信息 func (s *sessionToolService) GetDatasetInfo(ctx context.Context, datasetIds []int64, headers map[string]string) (res *dto.RagListDatasetRes, err error) { // 通过HTTP调用rag服务的关键词查询接口 res = &dto.RagListDatasetRes{} if err = http.Get(ctx, "rag/dataset/list", headers, &res, &dto.RagListDatasetReq{ Ids: datasetIds, }); err != nil { return nil, fmt.Errorf("获取数据集信息失败: %w", err) } return } // BuildGreeting 构建问候语 func (s *sessionToolService) BuildGreeting(ctx context.Context, userId, greeting string, options [][]string, datasetCount int) (content string, err error) { var sb strings.Builder // 问候语 if datasetCount > 1 || greeting == "" { greeting = public.GreetingBegin } sb.WriteString(greeting) sb.WriteByte('\n') // 拼接选项 1、xx 2、xx... var optionsMap = make(map[string]map[string]string, len(options)) for i, opt := range options { optionsMap[gconv.String(i+1)] = map[string]string{ "datasetId": opt[1], "datasetName": opt[0], } sb.WriteString(fmt.Sprintf("%d、%s\n", i+1, opt)) if i == len(options)-1 { sb.WriteString(fmt.Sprintf("%s\n", public.GreetingBetween)) } } // 固定结尾 sb.WriteString(public.GreetingEnd) content = sb.String() err = g.Redis().SetEX(ctx, fmt.Sprintf(public.AccountGreetingOptionsKey, userId), optionsMap, gconv.Int64(public.DialogTimeout*time.Second)) return } // GetScriptedSpeechContent 获取话术内容 func (s *sessionToolService) GetScriptedSpeechContent(ctx context.Context, datasetId int64, sceneType scriptedSpeech.SceneType) (res *dto.ScriptedSpeechVO, err error) { r, err := dao.ScriptedSpeech.GetByDatasetIdAndSceneType(ctx, &dto.ListScriptedSpeechReq{ DatasetId: datasetId, SceneType: sceneType, }) if err != nil { return } err = gconv.Struct(r, &res) return } // GetRagQuery 获取rag查询结果 func (s *sessionToolService) GetRagQuery(ctx context.Context, questionContent string, datasetIds []int64, history []*dto.Message, headers map[string]string) (res *dto.RagQueryRes, err error) { resp := new(dto.RagQueryRes) if err = http.Post(ctx, "rag/document/vector/ragQuery", headers, &resp, &dto.RagQueryReq{ Content: questionContent, DatasetIds: datasetIds, History: history, TopK: 5, }); err != nil { return } return resp, nil } // SaveUserHistory 保存用户对话历史到Redis func (s *sessionToolService) SaveUserHistory(ctx context.Context, userKey string, newMessages []*dto.Message) (err error) { key := fmt.Sprintf(public.AccountDialogHistoryKey, userKey) // 1. 先读旧历史 var oldMessages []*dto.Message oldMessages, err = s.GetUserHistory(ctx, key) if err != nil { return err } // 2. 合并 allMessages := append(oldMessages, newMessages...) // 3. 限制长度(保留最新 N 轮) maxMsgCount := 2 * g.Cfg().MustGet(ctx, "history.contextLimit", 5).Int() if len(allMessages) > maxMsgCount { allMessages = allMessages[len(allMessages)-maxMsgCount:] } // 4. 存回Redis data, err := json.Marshal(allMessages) if err != nil { return err } return g.Redis().SetEX(ctx, key, data, gconv.Int64(public.DialogTimeout*time.Second)) } // GetUserHistory 从Redis获取用户历史 func (s *sessionToolService) GetUserHistory(ctx context.Context, key string) ([]*dto.Message, error) { data, err := g.Redis().Get(ctx, key) if err != nil || data.IsEmpty() { return []*dto.Message{}, nil } var messages []*dto.Message if err = json.Unmarshal(data.Bytes(), &messages); err != nil { return []*dto.Message{}, err } return messages, nil } // ClearUserHistory 清空历史(可选) func (s *sessionToolService) ClearUserHistory(ctx context.Context, userKey string) (int64, error) { key := fmt.Sprintf(public.AccountDialogHistoryKey, userKey) return g.Redis().Del(ctx, key) } // getDatasetIdsByKeywords 通过关键词查询数据集ID func (s *sessionToolService) getDatasetIdsByKeywords(ctx context.Context, questionContent string, headers map[string]string) (res []int64, err error) { // 1. 提取关键词 keywords := s.extractKeywords(questionContent) g.Log().Infof(ctx, "提取关键词: %v", keywords) // 通过HTTP调用rag服务的关键词查询接口 respKeyword := &dto.RAGListKeywordRes{} if err = http.Get(ctx, "rag/keyword/list", headers, &respKeyword, &dto.RAGListKeywordReq{ Words: keywords, }); err != nil { jaeger.RecordError(ctx, err, "RAG查询关键词失败") g.Log().Errorf(ctx, "RAG查询关键词失败: %v", err) return } var datasetIds []int64 for _, v := range respKeyword.List { if !slices.Contains(datasetIds, v.DatasetId) { datasetIds = append(datasetIds, v.DatasetId) } } return datasetIds, nil } // extractKeywords 提取关键词 func (s *sessionToolService) extractKeywords(text string) []string { if text == "" { return []string{} } // 使用gse分词工具提取关键词 keywords := utils.GseTool.Extract(text, 5) words := make([]string, 0, len(keywords)) for _, kw := range keywords { if kw.Word != "" { words = append(words, kw.Word) } } // 如果没有提取到关键词,使用分词结果 if len(words) == 0 { words = utils.GseTool.Cut(text) } return words }