Files
prompts-core/service/prompt/prompt_session_service.go

115 lines
3.4 KiB
Go
Raw Normal View History

package prompt
2026-05-12 13:59:15 +08:00
import (
"context"
2026-05-15 09:45:51 +08:00
"fmt"
sessionDao "prompts-core/dao"
2026-05-12 13:59:15 +08:00
"prompts-core/model/entity"
"prompts-core/common/util"
sessionDto "prompts-core/model/dto/prompt"
2026-05-12 13:59:15 +08:00
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
2026-05-15 09:45:51 +08:00
"github.com/gogf/gf/v2/util/gconv"
2026-05-12 13:59:15 +08:00
)
func SessionCallback(ctx context.Context, req *sessionDto.SessionCallbackReq) (res *sessionDto.SessionCallbackRes, err error) {
2026-05-12 13:59:15 +08:00
// 1. 解析AI返回的文本
result, err := util.ParseOutput(req.Text)
2026-05-12 13:59:15 +08:00
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
2026-05-15 09:45:51 +08:00
// 2. 更新数据库
result["role"] = "assistant"
_, err = sessionDao.ComposeSession.Update(ctx, &entity.ComposeSession{
2026-05-15 09:45:51 +08:00
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: result,
2026-05-12 13:59:15 +08:00
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
2026-05-15 09:45:51 +08:00
// 3. 获取当前轮次完整数据
session, err := sessionDao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
})
2026-05-12 13:59:15 +08:00
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
2026-05-15 09:45:51 +08:00
// 4. 转换 json 并存入 Redis
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
2026-05-15 09:45:51 +08:00
if err = saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
2026-05-12 13:59:15 +08:00
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return nil, err
}
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &sessionDto.SessionCallbackRes{}, nil
2026-05-12 13:59:15 +08:00
}
2026-05-15 09:45:51 +08:00
// GetHistoryMessages 获取历史信息
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
2026-05-15 09:45:51 +08:00
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
// 1. 先从 Redis 拿
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
2026-05-15 09:45:51 +08:00
if err == nil && len(redisHistory) > 0 {
return redisHistory, nil
}
// 2. Redis 没有 → fallback DB
sessions, _, err := sessionDao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
}, 1, maxRounds)
2026-05-15 09:45:51 +08:00
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
var messages []map[string]any
for _, session := range sessions {
// request
reqMsgs := util.ConvertToMessages(session.RequestContent)
2026-05-15 09:45:51 +08:00
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
messages = append(messages, m)
}
}
// response
respMsgs := util.ConvertToMessages(session.ResponseContent)
2026-05-15 09:45:51 +08:00
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
}
messages = append(messages, m)
}
}
// 3. 回写 Redis
for _, session := range sessions {
reqMsgs := util.ConvertToMessages(session.RequestContent)
respMsgs := util.ConvertToMessages(session.ResponseContent)
2026-05-15 09:45:51 +08:00
for i := range respMsgs {
if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant"
}
}
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
2026-05-15 09:45:51 +08:00
}
}
return messages, nil
}