refactor(prompt): 重构提示词构建服务和回调处理

This commit is contained in:
2026-06-05 11:00:05 +08:00
parent b2cad4cac2
commit de70d33115
5 changed files with 166 additions and 157 deletions

View File

@@ -23,7 +23,6 @@ type ConsultItem struct {
} }
type ComposeMessagesRes struct { type ComposeMessagesRes struct {
TaskId string `json:"taskId" dc:"任务ID"` TaskId string `json:"taskId" dc:"任务ID"`
EpicycleId int64 `json:"epicycle_id" dc:"轮次ID"`
} }
// MultiRoundResult 多轮返回结果 // MultiRoundResult 多轮返回结果

View File

@@ -156,7 +156,7 @@ type SendCallbackReq struct {
} }
// SendCallback 向业务方发送回调 // SendCallback 向业务方发送回调
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error { func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycleId int64) error {
// 1. 检查回调地址 // 1. 检查回调地址
if composeTask.CallbackUrl == "" { if composeTask.CallbackUrl == "" {
return fmt.Errorf("回调地址为空taskId=%s", composeTask.TaskId) return fmt.Errorf("回调地址为空taskId=%s", composeTask.TaskId)
@@ -167,6 +167,7 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error {
Status: composeTask.Status, Status: composeTask.Status,
Messages: composeTask.Messages, Messages: composeTask.Messages,
ErrorMsg: composeTask.ErrorMessage, ErrorMsg: composeTask.ErrorMessage,
EpicycleId: epicycleId,
} }
// 3. 发送 POST 请求 // 3. 发送 POST 请求
headers := util.ForwardHeaders(ctx) headers := util.ForwardHeaders(ctx)

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"prompts-core/service/gateway" "prompts-core/service/gateway"
"prompts-core/service/session"
"strings" "strings"
"prompts-core/common/util" "prompts-core/common/util"
@@ -17,34 +16,14 @@ import (
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
// UserPromptPayload 用户提示词请求体
type UserPromptPayload struct {
Model string `json:"model"`
PromptInfo string `json:"promptInfo"`
Form any `json:"form"`
UserForm any `json:"userForm"`
Consult []dto.ConsultItem `json:"consult"`
UserFilesText map[string]string `json:"userFilesText"`
Skills string `json:"skills"`
BuildType int `json:"buildType"`
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1 // buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) { func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
//1) 构建系统提示词 //1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches) systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
ir.AddSystem(systemPrompt) ir.AddSystem(systemPrompt)
//2) 构建历史对话
history, _ := session.GetHistoryMessages(ctx, req.SessionId)
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType)) userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
ir.AddUser(userPrompt) ir.AddUser(userPrompt)
//2) 检查整体内容是否超出窗口
if !checkOverallContent(ir, aiModel) { if !checkOverallContent(ir, aiModel) {
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig) availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow) return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
@@ -96,8 +75,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
}, nil }, nil
} }
// promptBuildWithRounds 构建系统提示词 func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, batches int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: chatModel.OperatorName, ProviderName: chatModel.OperatorName,
Status: 1, Status: 1,
@@ -105,32 +83,9 @@ func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, cha
if err != nil || providerProtocol == nil { if err != nil || providerProtocol == nil {
return "" return ""
} }
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{})) outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
maxWindowSize := util.GetMaxWindowSize(chatModel.TokenConfig)
availableWindow := util.GetAvailableWindow(chatModel.TokenConfig)
formContent := buildUserFormContent(req.Form)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, formContent, userFormContent)
inputInfo := fmt.Sprintf(`
目标模型: %s
%s
技能名称: %s
用户文件: %v
`, req.ModelName, formInfo, req.SkillName, req.Consult)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
req.ModelName, // %s 目标模型名称 outputJSON, //【输出结构】 %s
maxWindowSize, // %d 最大窗口
availableWindow, // %d 可用窗口
outputJSON, // %s 输出结构
inputInfo, // %s 完整输入信息
) )
} }
@@ -151,43 +106,67 @@ func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
// buildUserPrompt 构建用户提示词 // buildUserPrompt 构建用户提示词
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string { func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
payload := UserPromptPayload{ var b strings.Builder
Model: req.ModelName, b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName))
PromptInfo: prompt, if prompt != "" {
Form: prepareUserFormPayload(req.Form), b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt))
UserForm: prepareUserFormPayload(req.UserForm),
Consult: req.Consult,
UserFilesText: ExtractFileTexts(ctx, req.Consult),
Skills: SkillMdContent(ctx, req.SkillName),
BuildType: req.BuildType,
} }
return gjson.New(payload).String() if skills := SkillMdContent(ctx, req.SkillName); skills != "" {
b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills))
}
if formText := buildUserFormText(req.Form); formText != "" {
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
}
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
}
if len(req.Consult) > 0 {
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
}
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
}
return b.String()
} }
// prepareUserFormPayload 准备用户表单载荷 // buildUserFormText 构建用户表单内容字符串
func prepareUserFormPayload(userForm []map[string]any) any { func buildUserFormText(form []map[string]any) string {
if len(userForm) == 0 { if len(form) == 0 {
return nil return ""
} }
if _, ok := userForm[0]["batch_index"]; ok {
return userForm
}
return mergeUserFormTexts(userForm)
}
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
func mergeUserFormTexts(userForm []map[string]any) string {
var builder strings.Builder var builder strings.Builder
for i, item := range userForm { for _, item := range form {
text := getItemText(item) for k, v := range item {
if i > 0 { switch val := v.(type) {
builder.WriteString("\n\n") case []any:
// 数组类型:逐条列出
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, elem := range val {
if m, ok := elem.(map[string]any); ok {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
} }
builder.WriteString(text) builder.WriteString("\n")
} else {
builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem))
} }
return builder.String() }
case []map[string]any:
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, m := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
}
builder.WriteString("\n")
}
default:
builder.WriteString(fmt.Sprintf("%s%v\n", k, v))
}
}
}
return strings.TrimSpace(builder.String())
} }
// NodeBuild 节点构建 // NodeBuild 节点构建

View File

@@ -16,6 +16,7 @@ import (
"gitea.com/red-future/common/beans" "gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils" "gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
) )
// ComposeMessages 核心拼接提示词主流程 // ComposeMessages 核心拼接提示词主流程
@@ -157,14 +158,14 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
if composeTask.CallbackUrl != "" { if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusFailed composeTask.Status = public.ComposeStatusFailed
composeTask.ErrorMessage = req.ErrorMsg composeTask.ErrorMessage = req.ErrorMsg
_ = gateway.SendCallback(ctx, composeTask) _ = gateway.SendCallback(ctx, composeTask, 0)
} }
return err return err
} }
// handleCallbackSuccess 处理回调成功 // handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error { func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
// 1) 模型配置 // 1) 获取模型配置
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator}, SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName, ModelName: composeTask.ModelName,
@@ -172,6 +173,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
if err != nil { if err != nil {
return fmt.Errorf("查询模型失败: %w", err) return fmt.Errorf("查询模型失败: %w", err)
} }
// 2) 根据运营商获取协议配置
//protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
// ProviderName: model.OperatorName,
//})
// 2) 解析结果 // 2) 解析结果
var messages map[string]any var messages map[string]any
@@ -179,10 +184,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
case public.BuildTypePrompt, public.BuildTypeNode: case public.BuildTypePrompt, public.BuildTypeNode:
messages = ParseResult(req.Messages, model.ResponseBody) messages = ParseResult(req.Messages, model.ResponseBody)
case public.BuildTypeStruct: case public.BuildTypeStruct:
messages = map[string]any{ messages = ParseStructResult(req.Messages, model.ResponseBody)
"total_rounds": 1,
"rounds": []map[string]any{req.Messages},
}
default: default:
messages = req.Messages messages = req.Messages
} }
@@ -201,80 +203,113 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
if err != nil { if err != nil {
return err return err
} }
// 5) 存储历史结果 // 5) 存储提示词结果作为历史请求
var epicycleId int64
// 6) 回调业务方 payload := composeTask.RequestPayload
sessionId := gconv.String(payload["sessionId"])
nodeId := gconv.String(payload["nodeId"])
buildType := gconv.Int(payload["buildType"])
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: messages,
})
}
// 6) 拼接历史内容
// 7) 回调业务方
if composeTask.CallbackUrl != "" { if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess composeTask.Status = public.ComposeStatusSuccess
composeTask.Messages = messages composeTask.Messages = messages
_ = gateway.SendCallback(ctx, composeTask) _ = gateway.SendCallback(ctx, composeTask, epicycleId)
} }
return nil return nil
} }
// ParseResult 解析回调结果 // ParseResult 解析结果
func ParseResult(raw map[string]any, responseBody string) map[string]any { func ParseResult(raw map[string]any, responseBody string) map[string]any {
// responseBody 为空,直接返回原始数据
if responseBody == "" { if responseBody == "" {
return raw return raw
} }
// 按 responseBody 路径取值 contentVal := raw[responseBody]
contentStr, ok := raw[responseBody].(string) if contentVal == nil {
if !ok || contentStr == "" {
return raw return raw
} }
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil { // 已经是数组
return map[string]any{ if arr, ok := contentVal.([]any); ok {
"total_rounds": len(roundsArray), rounds := gconv.Maps(arr)
"rounds": roundsArray, if len(rounds) > 0 {
return map[string]any{"total_rounds": len(rounds), "rounds": rounds}
} }
return raw
} }
if singleRound := tryParseAsMap(contentStr); singleRound != nil { // 是字符串
return map[string]any{ contentStr := gconv.String(contentVal)
"total_rounds": 1, if contentStr == "" {
"rounds": []map[string]any{singleRound}, return raw
} }
// 尝试解析为数组
var arr []map[string]any
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
return map[string]any{"total_rounds": len(arr), "rounds": arr}
}
// 尝试解析为单对象
var obj map[string]any
if err := json.Unmarshal([]byte(contentStr), &obj); err == nil && len(obj) > 0 {
return map[string]any{"total_rounds": 1, "rounds": []map[string]any{obj}}
} }
return map[string]any{"content": contentStr} return map[string]any{"content": contentStr}
} }
func tryParseAsMapArray(jsonStr string) []map[string]any { func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
var arr []map[string]any // 如果外层已有 rounds直接返回
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil { if _, ok := raw["rounds"]; ok {
return nil return raw
}
if len(arr) == 0 {
return nil
}
return arr
} }
func tryParseAsMap(jsonStr string) map[string]any { contentVal := raw[responseBody]
var obj map[string]any if contentVal == nil {
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return nil
}
if len(obj) == 0 {
return nil
}
return obj
}
func ParseNodeResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
if ok && contentStr != "" {
var inner map[string]any
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
return map[string]any{ return map[string]any{
"total_rounds": 1, "total_rounds": 1,
"rounds": []map[string]any{inner}, "rounds": []map[string]any{raw},
} }
} }
contentStr := gconv.String(contentVal)
if contentStr == "" || contentStr == "0" {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{raw},
} }
}
// 尝试解析为数组
var arr []map[string]any
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
// 数组的每个元素包一层 content
var rounds []map[string]any
for _, item := range arr {
rounds = append(rounds, map[string]any{"content": item})
}
return map[string]any{
"total_rounds": len(rounds),
"rounds": rounds,
}
}
// 尝试解析为单个对象
var parsed map[string]any
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
raw[responseBody] = parsed
}
// 兜底:包标准结构
return map[string]any{ return map[string]any{
"total_rounds": 1, "total_rounds": 1,
"rounds": []map[string]any{raw}, "rounds": []map[string]any{raw},

View File

@@ -22,26 +22,25 @@ const (
bytesPerMB = 1024 * 1024 bytesPerMB = 1024 * 1024
) )
// ExtractFileTexts 从 ConsultItem 列表中提取文件内容 // ExtractFileTexts 从 ConsultItem 列表中提取文件内容,返回拼接文本
func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) map[string]string { func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) string {
urls := make([]string, 0, len(consult)) urls := make([]string, 0, len(consult))
for _, item := range consult { for _, item := range consult {
if item.Url != "" { if item.Url != "" {
urls = append(urls, item.Url) urls = append(urls, item.Url)
} }
} }
return FetchFileTexts(ctx, urls) return FetchFileTextsAsString(ctx, urls)
} }
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件 // FetchFileTextsAsString 从 URL 列表获取文件内容,拼接为字符串
func FetchFileTexts(ctx context.Context, urls []string) map[string]string { func FetchFileTextsAsString(ctx context.Context, urls []string) string {
result := make(map[string]string)
if len(urls) == 0 { if len(urls) == 0 {
return result return ""
} }
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8) client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
var builder strings.Builder
for _, rawURL := range urls { for _, rawURL := range urls {
url := util.SanitizeURL(rawURL) url := util.SanitizeURL(rawURL)
@@ -50,23 +49,19 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
} }
if util.IsZipExtension(url) { if util.IsZipExtension(url) {
mergeMap(result, fetchZipFileTexts(ctx, client, url)) for _, text := range fetchZipFileTexts(ctx, client, url) {
builder.WriteString(text)
builder.WriteString("\n")
}
continue continue
} }
if text := fetchAndCleanFileContent(ctx, client, url); text != "" { if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
result[url] = text builder.WriteString(fmt.Sprintf("【文件:%s】\n%s\n", url, text))
} }
} }
return result return builder.String()
}
// mergeMap 合并 map
func mergeMap(dst, src map[string]string) {
for k, v := range src {
dst[k] = v
}
} }
// fetchAndCleanFileContent 获取并清理文件内容 // fetchAndCleanFileContent 获取并清理文件内容