Files
prompts-core/service/prompt/prompt_build_service.go

184 lines
6.3 KiB
Go
Raw Normal View History

package prompt
import (
"context"
"fmt"
"prompts-core/service/gateway"
"strings"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
ir.AddSystem(systemPrompt)
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
ir.AddUser(userPrompt)
//2) 检查整体内容是否超出窗口
if !checkOverallContent(ir, aiModel) {
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
return compileToProviderRequest(ctx, ir, chatModel)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel)
}
// buildStructTypeRequest 构建结构体类型请求BuildType=3
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
// 提取 userForm 中的 prompt 作为自定义提示词
var customPrompt string
for _, item := range req.UserForm {
if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" {
customPrompt = gconv.String(prompt)
break
}
}
// 用户消息
ir.AddSystem(customPrompt)
ir.AddUser(buildUserPrompt(ctx, req, ""))
return compileToProviderRequest(ctx, ir, chatModel, customPrompt)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, customPrompt ...string) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil || protocol == nil {
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
}
// 如果传了自定义提示词,替换掉协议模板
if len(customPrompt) > 0 && customPrompt[0] != "" {
protocol.SystemPromptTemplate = customPrompt[0]
}
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": util.GetServerName(ctx),
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
}, nil
}
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: chatModel.OperatorName,
Status: 1,
})
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
outputJSON, //【输出结构】 %s
)
}
// buildUserFormContent 构建用户表单内容字符串
func buildUserFormContent(userForm []map[string]any) string {
var builder strings.Builder
for _, item := range userForm {
builder.WriteString(fmt.Sprintf("%v\n", item))
}
return builder.String()
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
// buildUserPrompt 构建用户提示词
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
var b strings.Builder
b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName))
if prompt != "" {
b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt))
}
if skills := SkillMdContent(ctx, req.SkillName); skills != "" {
b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills))
}
if formText := buildUserFormText(req.Form); formText != "" {
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
}
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
}
if len(req.Consult) > 0 {
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
}
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
}
return b.String()
}
// buildUserFormText 构建用户表单内容字符串
func buildUserFormText(form []map[string]any) string {
if len(form) == 0 {
return ""
}
var builder strings.Builder
for _, item := range form {
for k, v := range item {
switch val := v.(type) {
case []any:
// 数组类型:逐条列出
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, elem := range val {
if m, ok := elem.(map[string]any); ok {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
}
builder.WriteString("\n")
} else {
builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem))
}
}
case []map[string]any:
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, m := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
}
builder.WriteString("\n")
}
default:
builder.WriteString(fmt.Sprintf("%s%v\n", k, v))
}
}
}
return strings.TrimSpace(builder.String())
}
// NodeBuild 节点构建
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx)
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}