refactor(model-gateway): 重构代码结构并优化数据库查询

This commit is contained in:
2026-06-03 18:37:18 +08:00
parent 05cf1b9828
commit b2cad4cac2
10 changed files with 190 additions and 470 deletions

View File

@@ -2,10 +2,9 @@ package prompt
import (
"context"
"errors"
"fmt"
"prompts-core/consts/public"
"prompts-core/service/gateway"
"prompts-core/service/session"
"strings"
"prompts-core/common/util"
@@ -13,6 +12,7 @@ import (
"prompts-core/model/dto"
"prompts-core/model/entity"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -29,30 +29,13 @@ type UserPromptPayload struct {
BuildType int `json:"buildType"`
}
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (map[string]any, error) {
//1) 处理表单分批
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
ir := NewPromptIR()
switch req.BuildType {
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
ir.AddSystem(systemPrompt)
//2) 构建历史对话
history, _ := session.GetHistoryMessages(ctx, req.SessionId)
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
@@ -75,24 +58,40 @@ func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chat
return compileToProviderRequest(ctx, ir, chatModel)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
// buildStructTypeRequest 构建结构体类型请求BuildType=3
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
// 提取 userForm 中的 prompt 作为自定义提示词
var customPrompt string
for _, item := range req.UserForm {
if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" {
customPrompt = gconv.String(prompt)
break
}
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
// 用户消息
ir.AddSystem(customPrompt)
ir.AddUser(buildUserPrompt(ctx, req, ""))
return compileToProviderRequest(ctx, ir, chatModel, customPrompt)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, customPrompt ...string) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil || protocol == nil {
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
}
// 如果传了自定义提示词,替换掉协议模板
if len(customPrompt) > 0 && customPrompt[0] != "" {
protocol.SystemPromptTemplate = customPrompt[0]
}
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": util.GetServerName(ctx),
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
}, nil
}