refactor(model-gateway): 重构代码结构并优化数据库查询

This commit is contained in:
2026-06-03 18:37:18 +08:00
parent 05cf1b9828
commit b2cad4cac2
10 changed files with 190 additions and 470 deletions

View File

@@ -5,11 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"prompts-core/service/session"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/consts/public"
@@ -17,28 +12,24 @@ import (
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
//1) 获取模型信息
// 1) 获取模型信息
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
//2) 校验用户表单
// 2) 校验用户表单
if err = validateUserForm(req, aiModel); err != nil {
return nil, err
}
//3) 处理不同类型
switch req.BuildType {
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
case public.BuildTypeNode:
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
default:
return nil, errors.New("BuildType 不支持")
}
return handleBuild(ctx, req, chatModel, aiModel)
}
// GetModelMessage 获取模型信息
@@ -51,24 +42,19 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: 1,
})
if err != nil {
return nil, nil, err
}
if chatModel == nil {
if err != nil || chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModels, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
aiModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil {
return nil, nil, err
}
if aiModels == nil {
if err != nil || aiModel == nil {
return nil, nil, errors.New("需要构建的模型不存在")
}
return chatModel, aiModels, nil
return chatModel, aiModel, nil
}
// validateUserForm 校验用户表单
@@ -89,103 +75,96 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := session.GetHistoryMessages(ctx, req.SessionId)
// handleBuild 通用构建处理
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 1) 处理表单分批
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
// 调用推理模型
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
if err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
// 2) 构建推理请求
ir := NewPromptIR()
var taskReq map[string]any
switch req.BuildType {
case public.BuildTypePrompt:
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir, totalBatches)
case public.BuildTypeNode:
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
case public.BuildTypeStruct:
taskReq, err = buildStructTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
if err != nil {
return nil, fmt.Errorf("构建推理请求失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
if err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (string, int64, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
if err != nil {
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
}
id := int64(0)
if req.SessionId != "" {
id, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
}
}
// 3) 调用网关创建任务
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
return nil, fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", 0, errors.New("网关未返回taskId")
return nil, errors.New("网关未返回taskId")
}
return taskID, id, nil
// 4) 保存任务记录
if _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
}); err != nil {
return nil, err
}
return &dto.ComposeMessagesRes{TaskId: taskID}, nil
}
// Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Messages))
// 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
// 1) 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
// 2) 处理失败
if req.State == 3 {
return handleCallbackFailed(ctx, req, composeTask)
}
// 3) 处理成功
if req.State == 2 {
return handleCallbackSuccess(ctx, req, composeTask)
}
return nil
}
// handleCallbackFailed 处理回调失败
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
})
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusFailed
composeTask.ErrorMessage = req.ErrorMsg
_ = gateway.SendCallback(ctx, composeTask)
}
return err
}
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
// 1) 查模型配置
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName,
@@ -193,75 +172,55 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
if err != nil {
return fmt.Errorf("查询模型失败: %w", err)
}
//处理失败
if req.State == 3 {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
})
// 用更新后的值发送回调
if composeTask.CallbackUrl != "" {
failedTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
CallbackUrl: composeTask.CallbackUrl,
Messages: composeTask.Messages,
}
gateway.SendCallback(ctx, failedTask)
// 2) 解析结果
var messages map[string]any
switch composeTask.BuildType {
case public.BuildTypePrompt, public.BuildTypeNode:
messages = ParseResult(req.Messages, model.ResponseBody)
case public.BuildTypeStruct:
messages = map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{req.Messages},
}
default:
messages = req.Messages
}
// 3) 合并附加结构
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
// 4) 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
})
if err != nil {
return err
}
//处理成功
if req.State == 2 {
// 1. 根据 BuildType 解析结果
var messages map[string]any
switch composeTask.BuildType {
case public.BuildTypePrompt: // 提示词构建解析
messages = ParsePromptResult(req.Messages)
case public.BuildTypeNode: // 节点构建解析
messages = ParseNodeResult(req.Messages)
default:
messages = req.Messages
}
// 2. 处理附加字段
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
// 3. 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
return err
}
// 4. 发送回调给业务方
if composeTask.CallbackUrl != "" {
successTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
CallbackUrl: composeTask.CallbackUrl,
}
gateway.SendCallback(ctx, successTask)
}
// 5) 存储历史结果
// 6) 回调业务方
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess
composeTask.Messages = messages
_ = gateway.SendCallback(ctx, composeTask)
}
return err
return nil
}
// ParsePromptResult 解析提示词构建结果
func ParsePromptResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
// ParseResult 解析回调结果
func ParseResult(raw map[string]any, responseBody string) map[string]any {
// responseBody 为空,直接返回原始数据
if responseBody == "" {
return raw
}
// 按 responseBody 路径取值
contentStr, ok := raw[responseBody].(string)
if !ok || contentStr == "" {
return raw
}