refactor(model-gateway): 重构代码结构并优化数据库查询
This commit is contained in:
@@ -2,10 +2,9 @@ package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/service/gateway"
|
||||
"prompts-core/service/session"
|
||||
"strings"
|
||||
|
||||
"prompts-core/common/util"
|
||||
@@ -13,6 +12,7 @@ import (
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
@@ -29,30 +29,13 @@ type UserPromptPayload struct {
|
||||
BuildType int `json:"buildType"`
|
||||
}
|
||||
|
||||
// buildInferenceRequest 构建推理请求
|
||||
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (map[string]any, error) {
|
||||
//1) 处理表单分批
|
||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||
}
|
||||
ir := NewPromptIR()
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt:
|
||||
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
|
||||
case public.BuildTypeNode:
|
||||
return buildNodeTypeRequest(ctx, req, chatModel, ir)
|
||||
default:
|
||||
return nil, errors.New("不支持的构建类型")
|
||||
}
|
||||
}
|
||||
|
||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||
//1) 构建系统提示词
|
||||
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
|
||||
ir.AddSystem(systemPrompt)
|
||||
//2) 构建历史对话
|
||||
history, _ := session.GetHistoryMessages(ctx, req.SessionId)
|
||||
for _, msg := range history {
|
||||
role := gconv.String(msg["role"])
|
||||
if role != "user" && role != "assistant" {
|
||||
@@ -75,24 +58,40 @@ func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chat
|
||||
return compileToProviderRequest(ctx, ir, chatModel)
|
||||
}
|
||||
|
||||
// compileToProviderRequest 编译为 Provider 请求
|
||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取协议配置失败: %w", err)
|
||||
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
||||
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||
// 提取 userForm 中的 prompt 作为自定义提示词
|
||||
var customPrompt string
|
||||
for _, item := range req.UserForm {
|
||||
if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" {
|
||||
customPrompt = gconv.String(prompt)
|
||||
break
|
||||
}
|
||||
}
|
||||
if protocol == nil {
|
||||
return nil, errors.New("协议配置不存在")
|
||||
// 用户消息
|
||||
ir.AddSystem(customPrompt)
|
||||
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
||||
return compileToProviderRequest(ctx, ir, chatModel, customPrompt)
|
||||
}
|
||||
|
||||
// compileToProviderRequest 编译为 Provider 请求
|
||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, customPrompt ...string) (map[string]any, error) {
|
||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||
if err != nil || protocol == nil {
|
||||
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
|
||||
}
|
||||
// 如果传了自定义提示词,替换掉协议模板
|
||||
if len(customPrompt) > 0 && customPrompt[0] != "" {
|
||||
protocol.SystemPromptTemplate = customPrompt[0]
|
||||
}
|
||||
providerReq, err := Compile(ir, protocol, chatModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("编译请求失败: %w", err)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"modelName": chatModel.ModelName,
|
||||
"bizName": util.GetServerName(ctx),
|
||||
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
|
||||
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
|
||||
"requestPayload": providerReq,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -5,11 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/service/session"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/consts/public"
|
||||
@@ -17,28 +12,24 @@ import (
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
"prompts-core/service/gateway"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// ComposeMessages 核心拼接提示词主流程
|
||||
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||
//1) 获取模型信息
|
||||
// 1) 获取模型信息
|
||||
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//2) 校验用户表单
|
||||
// 2) 校验用户表单
|
||||
if err = validateUserForm(req, aiModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//3) 处理不同类型
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt:
|
||||
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
|
||||
case public.BuildTypeNode:
|
||||
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
|
||||
default:
|
||||
return nil, errors.New("BuildType 不支持")
|
||||
}
|
||||
return handleBuild(ctx, req, chatModel, aiModel)
|
||||
}
|
||||
|
||||
// GetModelMessage 获取模型信息
|
||||
@@ -51,24 +42,19 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||
IsChatModel: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if chatModel == nil {
|
||||
if err != nil || chatModel == nil {
|
||||
return nil, nil, errors.New("当前没有对话模型,请添加")
|
||||
}
|
||||
aiModels, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
|
||||
aiModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if aiModels == nil {
|
||||
if err != nil || aiModel == nil {
|
||||
return nil, nil, errors.New("需要构建的模型不存在")
|
||||
}
|
||||
return chatModel, aiModels, nil
|
||||
|
||||
return chatModel, aiModel, nil
|
||||
}
|
||||
|
||||
// validateUserForm 校验用户表单
|
||||
@@ -89,103 +75,96 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e
|
||||
return nil
|
||||
}
|
||||
|
||||
// handlePromptBuild 处理提示词构建(BuildType=1)
|
||||
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
// 获取历史会话
|
||||
history, err := session.GetHistoryMessages(ctx, req.SessionId)
|
||||
// handleBuild 通用构建处理
|
||||
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
// 1) 处理表单分批
|
||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err)
|
||||
history = nil
|
||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||
}
|
||||
// 调用推理模型
|
||||
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
||||
}
|
||||
// 保存任务记录
|
||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
BuildType: req.BuildType,
|
||||
CallbackUrl: req.CallbackUrl,
|
||||
RequestPayload: util.MustMarshalToMap(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
||||
// 2) 构建推理请求
|
||||
ir := NewPromptIR()
|
||||
var taskReq map[string]any
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt:
|
||||
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir, totalBatches)
|
||||
case public.BuildTypeNode:
|
||||
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
|
||||
case public.BuildTypeStruct:
|
||||
taskReq, err = buildStructTypeRequest(ctx, req, chatModel, ir)
|
||||
default:
|
||||
return nil, errors.New("不支持的构建类型")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建推理请求失败: %w", err)
|
||||
}
|
||||
return &dto.ComposeMessagesRes{
|
||||
TaskId: taskID,
|
||||
EpicycleId: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleNodeBuild 处理节点构建(BuildType=2)
|
||||
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
||||
}
|
||||
// 保存任务记录
|
||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
BuildType: req.BuildType,
|
||||
CallbackUrl: req.CallbackUrl,
|
||||
RequestPayload: util.MustMarshalToMap(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
||||
}
|
||||
return &dto.ComposeMessagesRes{
|
||||
TaskId: taskID,
|
||||
EpicycleId: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// callInferenceModel 调用推理模型
|
||||
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (string, int64, error) {
|
||||
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
|
||||
}
|
||||
id := int64(0)
|
||||
if req.SessionId != "" {
|
||||
id, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
RequestContent: util.GetUserMessage(taskReq),
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
|
||||
}
|
||||
}
|
||||
// 3) 调用网关创建任务
|
||||
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
|
||||
return nil, fmt.Errorf("创建网关任务失败: %w", err)
|
||||
}
|
||||
|
||||
if taskID == "" {
|
||||
return "", 0, errors.New("网关未返回taskId")
|
||||
return nil, errors.New("网关未返回taskId")
|
||||
}
|
||||
|
||||
return taskID, id, nil
|
||||
// 4) 保存任务记录
|
||||
if _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
BuildType: req.BuildType,
|
||||
CallbackUrl: req.CallbackUrl,
|
||||
RequestPayload: util.MustMarshalToMap(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ComposeMessagesRes{TaskId: taskID}, nil
|
||||
}
|
||||
|
||||
// Callback 回调处理
|
||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Messages))
|
||||
// 查询任务
|
||||
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
})
|
||||
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
|
||||
// 1) 查询任务
|
||||
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
// 2) 处理失败
|
||||
if req.State == 3 {
|
||||
return handleCallbackFailed(ctx, req, composeTask)
|
||||
}
|
||||
// 3) 处理成功
|
||||
if req.State == 2 {
|
||||
return handleCallbackSuccess(ctx, req, composeTask)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleCallbackFailed 处理回调失败
|
||||
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: req.ErrorMsg,
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultText: req.Messages,
|
||||
})
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusFailed
|
||||
composeTask.ErrorMessage = req.ErrorMsg
|
||||
_ = gateway.SendCallback(ctx, composeTask)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
// 1) 查模型配置
|
||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||
ModelName: composeTask.ModelName,
|
||||
@@ -193,75 +172,55 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
//处理失败
|
||||
if req.State == 3 {
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: req.ErrorMsg,
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultText: req.Messages,
|
||||
})
|
||||
// 用更新后的值发送回调
|
||||
if composeTask.CallbackUrl != "" {
|
||||
failedTask := &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: req.ErrorMsg,
|
||||
CallbackUrl: composeTask.CallbackUrl,
|
||||
Messages: composeTask.Messages,
|
||||
}
|
||||
gateway.SendCallback(ctx, failedTask)
|
||||
|
||||
// 2) 解析结果
|
||||
var messages map[string]any
|
||||
switch composeTask.BuildType {
|
||||
case public.BuildTypePrompt, public.BuildTypeNode:
|
||||
messages = ParseResult(req.Messages, model.ResponseBody)
|
||||
case public.BuildTypeStruct:
|
||||
messages = map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{req.Messages},
|
||||
}
|
||||
default:
|
||||
messages = req.Messages
|
||||
}
|
||||
// 3) 合并附加结构
|
||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||
// 4) 更新数据库
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultText: req.Messages,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//处理成功
|
||||
if req.State == 2 {
|
||||
// 1. 根据 BuildType 解析结果
|
||||
var messages map[string]any
|
||||
switch composeTask.BuildType {
|
||||
case public.BuildTypePrompt: // 提示词构建解析
|
||||
messages = ParsePromptResult(req.Messages)
|
||||
case public.BuildTypeNode: // 节点构建解析
|
||||
messages = ParseNodeResult(req.Messages)
|
||||
default:
|
||||
messages = req.Messages
|
||||
}
|
||||
// 2. 处理附加字段
|
||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||
// 3. 更新数据库
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultText: req.Messages,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
|
||||
return err
|
||||
}
|
||||
// 4. 发送回调给业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
successTask := &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
CallbackUrl: composeTask.CallbackUrl,
|
||||
}
|
||||
gateway.SendCallback(ctx, successTask)
|
||||
}
|
||||
// 5) 存储历史结果
|
||||
|
||||
// 6) 回调业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusSuccess
|
||||
composeTask.Messages = messages
|
||||
_ = gateway.SendCallback(ctx, composeTask)
|
||||
}
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParsePromptResult 解析提示词构建结果
|
||||
func ParsePromptResult(raw map[string]any) map[string]any {
|
||||
contentStr, ok := raw["content"].(string)
|
||||
// ParseResult 解析回调结果
|
||||
func ParseResult(raw map[string]any, responseBody string) map[string]any {
|
||||
// responseBody 为空,直接返回原始数据
|
||||
if responseBody == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
// 按 responseBody 路径取值
|
||||
contentStr, ok := raw[responseBody].(string)
|
||||
if !ok || contentStr == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *g
|
||||
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
|
||||
|
||||
var result map[string]any
|
||||
json.Unmarshal([]byte(str), &result)
|
||||
_ = json.Unmarshal([]byte(str), &result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user