refactor(prompt): 重构提示词构建服务与数据模型

This commit is contained in:
2026-05-20 11:36:39 +08:00
parent c49144794d
commit 35bc3bd6ec
24 changed files with 1682 additions and 759 deletions

View File

@@ -4,65 +4,113 @@ import (
"context"
"errors"
"fmt"
"prompts-core/consts/public"
"strings"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto/prompt"
"prompts-core/model/dto"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// buildInferenceRequest 构建返回请求
func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, targetModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, targetModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
ir := NewPromptIR()
// 1. 统一 Prompt IR
switch req.BuildType {
case 1: //构建提示词请求
ir.AddSystem(promptBuild(ctx, req, model))
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType)))
case 2: //构建节点请求
ir.AddUser(NodeBuild(ctx, req))
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, targetModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, ir)
default:
return nil, errors.New("不支持的构建类型")
}
}
// 2. 获取协议配置
protocol, err := GetProtocolByProvider(ctx, "qwen")
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches)
ir.AddSystem(systemPrompt)
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType))
ir.AddUser(userPrompt)
if !checkOverallContent(ir, targetModel) {
availableWindow := util.GetAvailableWindow(targetModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
protocol, err := GetProtocolByProvider(ctx, req.ModelName)
if err != nil {
return nil, err
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
// 3. 编译为 Provider Request
providerReq, err := Compile(ir, protocol, chatModel)
providerReq, err := Compile(ir, protocol, nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("编译请求失败: %w", err)
}
// 4. 构建请求体
return map[string]any{
"modelName": chatModel.ModelName,
"modelName": req.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
}
// promptBuild 构建系统提示词
func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string {
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, model *entity.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, providerName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
providerReq, err := Compile(ir, protocol, model)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
fmt.Println("providerReq打印:", util.MustMarshal(providerReq))
return map[string]any{
"modelName": model.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
}
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: "qwen",
ProviderName: model.OperatorName,
Status: 1,
})
if err != nil || providerProtocol == nil {
@@ -70,43 +118,104 @@ func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *ent
}
outputJSON := util.JSONPretty(model.RequestMapping)
var userFormContent strings.Builder
for k, v := range req.UserForm {
userFormContent.WriteString(fmt.Sprintf("%s=%v", k, v))
}
userFormFullText := strings.TrimSuffix(userFormContent.String(), "")
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
availableWindow := util.GetAvailableWindow(model.TokenConfig)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormFullText)
`, util.FormToJSON(req.Form), userFormContent)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo)
inputInfo := fmt.Sprintf(`
目标模型: %s
%s
技能名称: %s
用户文件: %v
`, req.ModelName, formInfo, req.SkillName, req.UserFiles)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
req.ModelName,
maxWindowSize,
availableWindow,
totalRounds,
totalRounds,
totalRounds,
outputJSON,
inputInfo,
totalRounds,
)
}
// 构建用户提示词
func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string {
payload := map[string]any{
"model": req.ModelName, // 请求模型名称
"promptInfo": prompt, // 数据库提示信息
"form": req.Form, // 系统表单
"userForm": req.UserForm, // 用户表单
"userFiles": req.UserFiles, //文件url
"userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如xmljson,yaml
"skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
// buildUserFormContent 构建用户表单内容字符串
func buildUserFormContent(userForm []map[string]any) string {
var builder strings.Builder
for _, item := range userForm {
builder.WriteString(fmt.Sprintf("%v\n", item))
}
return builder.String()
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
// buildUserPrompt 构建用户提示词
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
userFormForPayload := prepareUserFormPayload(req.UserForm)
payload := map[string]any{
"model": req.ModelName,
"promptInfo": prompt,
"form": req.Form,
"userForm": userFormForPayload,
"userFiles": req.UserFiles,
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
"skills": SkillMdContent(ctx, req.SkillName),
}
return util.MustMarshal(payload)
}
// prepareUserFormPayload 准备用户表单载荷
func prepareUserFormPayload(userForm []map[string]any) any {
if len(userForm) == 0 {
return nil
}
if _, ok := userForm[0]["batch_index"]; ok {
return userForm
}
return mergeUserFormTexts(userForm)
}
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
func mergeUserFormTexts(userForm []map[string]any) string {
var builder strings.Builder
for i, item := range userForm {
text := getItemText(item)
if i > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(text)
}
return builder.String()
}
// NodeBuild 节点构建
func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx, req.BuildType)
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx)
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.FormToJSON(req.UserForm)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}