113 lines
3.5 KiB
Go
113 lines
3.5 KiB
Go
package prompt
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"prompts-core/common/util"
|
||
"prompts-core/dao"
|
||
"prompts-core/model/dto/prompt"
|
||
"prompts-core/model/entity"
|
||
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
// buildInferenceRequest 构建返回请求
|
||
func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
||
ir := NewPromptIR()
|
||
// 1. 统一 Prompt IR
|
||
switch req.BuildType {
|
||
case 1: //构建提示词请求
|
||
ir.AddSystem(promptBuild(ctx, req, model))
|
||
for _, msg := range history {
|
||
role := gconv.String(msg["role"])
|
||
if role != "user" && role != "assistant" {
|
||
continue
|
||
}
|
||
ir.AddHistory(role, gconv.String(msg["content"]))
|
||
}
|
||
ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType)))
|
||
case 2: //构建节点请求
|
||
ir.AddUser(NodeBuild(ctx, req))
|
||
default:
|
||
return nil, errors.New("不支持的构建类型")
|
||
}
|
||
|
||
// 2. 获取协议配置
|
||
protocol, err := GetProtocolByProvider(ctx, "qwen")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if protocol == nil {
|
||
return nil, errors.New("协议配置不存在")
|
||
}
|
||
|
||
// 3. 编译为 Provider Request
|
||
providerReq, err := Compile(ir, protocol, chatModel)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 4. 构建请求体
|
||
return map[string]any{
|
||
"modelName": chatModel.ModelName,
|
||
"bizName": "prompts-core",
|
||
"callbackUrl": "/prompt/callback",
|
||
"requestPayload": providerReq,
|
||
}, nil
|
||
}
|
||
|
||
// promptBuild 构建系统提示词
|
||
func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string {
|
||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||
ProviderName: "qwen",
|
||
Status: 1,
|
||
})
|
||
if err != nil || providerProtocol == nil {
|
||
return ""
|
||
}
|
||
|
||
outputJSON := util.JSONPretty(model.RequestMapping)
|
||
var userFormContent strings.Builder
|
||
for k, v := range req.UserForm {
|
||
userFormContent.WriteString(fmt.Sprintf("%s=%v;", k, v))
|
||
}
|
||
userFormFullText := strings.TrimSuffix(userFormContent.String(), ";")
|
||
|
||
formInfo := fmt.Sprintf(`
|
||
【系统表单(系统提示词/参数)】
|
||
%s
|
||
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
||
%s
|
||
`, util.FormToJSON(req.Form), userFormFullText)
|
||
|
||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo)
|
||
}
|
||
|
||
// 构建用户提示词
|
||
func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string {
|
||
payload := map[string]any{
|
||
"model": req.ModelName, // 请求模型名称
|
||
"promptInfo": prompt, // 数据库提示信息
|
||
"form": req.Form, // 系统表单
|
||
"userForm": req.UserForm, // 用户表单
|
||
"userFiles": req.UserFiles, //文件url
|
||
"userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如:xml,json,yaml)
|
||
"skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
|
||
}
|
||
return util.MustMarshal(payload)
|
||
}
|
||
|
||
// NodeBuild 节点构建
|
||
func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string {
|
||
promptTpl := util.GetBuildPrompt(ctx, req.BuildType)
|
||
if promptTpl == "" {
|
||
return ""
|
||
}
|
||
formStr := util.FormToJSON(req.Form)
|
||
userFormStr := util.FormToJSON(req.UserForm)
|
||
return fmt.Sprintf(promptTpl, formStr, userFormStr)
|
||
}
|