From de70d331151d5e3999f3a41ca938728161e6b400 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Fri, 5 Jun 2026 11:00:05 +0800 Subject: [PATCH] =?UTF-8?q?refactor(prompt):=20=E9=87=8D=E6=9E=84=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E8=AF=8D=E6=9E=84=E5=BB=BA=E6=9C=8D=E5=8A=A1=E5=92=8C?= =?UTF-8?q?=E5=9B=9E=E8=B0=83=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/dto/prompt_compose_dto.go | 3 +- service/gateway/gateway_http_service.go | 11 +- service/prompt/prompt_build_service.go | 137 +++++++---------- service/prompt/prompt_compose_service.go | 141 +++++++++++------- service/prompt/prompt_files_handle_service.go | 31 ++-- 5 files changed, 166 insertions(+), 157 deletions(-) diff --git a/model/dto/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go index 4ffded8..fd41d55 100644 --- a/model/dto/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -22,8 +22,7 @@ type ConsultItem struct { Url string `json:"url" dc:"附件地址"` } type ComposeMessagesRes struct { - TaskId string `json:"taskId" dc:"任务ID"` - EpicycleId int64 `json:"epicycle_id" dc:"轮次ID"` + TaskId string `json:"taskId" dc:"任务ID"` } // MultiRoundResult 多轮返回结果 diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index 6768d05..0f3c7fd 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -156,17 +156,18 @@ type SendCallbackReq struct { } // SendCallback 向业务方发送回调 -func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error { +func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycleId int64) error { // 1. 检查回调地址 if composeTask.CallbackUrl == "" { return fmt.Errorf("回调地址为空,taskId=%s", composeTask.TaskId) } // 2. 构造请求体 req := SendCallbackReq{ - TaskId: composeTask.TaskId, - Status: composeTask.Status, - Messages: composeTask.Messages, - ErrorMsg: composeTask.ErrorMessage, + TaskId: composeTask.TaskId, + Status: composeTask.Status, + Messages: composeTask.Messages, + ErrorMsg: composeTask.ErrorMessage, + EpicycleId: epicycleId, } // 3. 发送 POST 请求 headers := util.ForwardHeaders(ctx) diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index b08556c..65a10b8 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "prompts-core/service/gateway" - "prompts-core/service/session" "strings" "prompts-core/common/util" @@ -17,34 +16,14 @@ import ( "github.com/gogf/gf/v2/util/gconv" ) -// UserPromptPayload 用户提示词请求体 -type UserPromptPayload struct { - Model string `json:"model"` - PromptInfo string `json:"promptInfo"` - Form any `json:"form"` - UserForm any `json:"userForm"` - Consult []dto.ConsultItem `json:"consult"` - UserFilesText map[string]string `json:"userFilesText"` - Skills string `json:"skills"` - BuildType int `json:"buildType"` -} - // buildPromptTypeRequest 构建提示词类型请求(BuildType=1) func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) { //1) 构建系统提示词 - systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches) + systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel) ir.AddSystem(systemPrompt) - //2) 构建历史对话 - history, _ := session.GetHistoryMessages(ctx, req.SessionId) - for _, msg := range history { - role := gconv.String(msg["role"]) - if role != "user" && role != "assistant" { - continue - } - ir.AddHistory(role, gconv.String(msg["content"])) - } userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType)) ir.AddUser(userPrompt) + //2) 检查整体内容是否超出窗口 if !checkOverallContent(ir, aiModel) { availableWindow := util.GetAvailableWindow(aiModel.TokenConfig) return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow) @@ -96,8 +75,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate }, nil } -// promptBuildWithRounds 构建系统提示词 -func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, batches int) string { +func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string { providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ ProviderName: chatModel.OperatorName, Status: 1, @@ -105,32 +83,9 @@ func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, cha if err != nil || providerProtocol == nil { return "" } - outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{})) - maxWindowSize := util.GetMaxWindowSize(chatModel.TokenConfig) - availableWindow := util.GetAvailableWindow(chatModel.TokenConfig) - formContent := buildUserFormContent(req.Form) - userFormContent := buildUserFormContent(req.UserForm) - formInfo := fmt.Sprintf(` -【系统表单(系统提示词/参数)】 -%s -【用户表单全文(必须完整阅读,全部作为用户提示词来源)】 -%s -`, formContent, userFormContent) - - inputInfo := fmt.Sprintf(` -目标模型: %s -%s -技能名称: %s -用户文件: %v -`, req.ModelName, formInfo, req.SkillName, req.Consult) - return fmt.Sprintf(providerProtocol.SystemPromptTemplate, - req.ModelName, // %s 目标模型名称 - maxWindowSize, // %d 最大窗口 - availableWindow, // %d 可用窗口 - outputJSON, // %s 输出结构 - inputInfo, // %s 完整输入信息 + outputJSON, //【输出结构】 %s ) } @@ -151,43 +106,67 @@ func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool { // buildUserPrompt 构建用户提示词 func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string { - payload := UserPromptPayload{ - Model: req.ModelName, - PromptInfo: prompt, - Form: prepareUserFormPayload(req.Form), - UserForm: prepareUserFormPayload(req.UserForm), - Consult: req.Consult, - UserFilesText: ExtractFileTexts(ctx, req.Consult), - Skills: SkillMdContent(ctx, req.SkillName), - BuildType: req.BuildType, + var b strings.Builder + b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName)) + if prompt != "" { + b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt)) } - return gjson.New(payload).String() + if skills := SkillMdContent(ctx, req.SkillName); skills != "" { + b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills)) + } + if formText := buildUserFormText(req.Form); formText != "" { + b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText)) + } + if userFormText := buildUserFormText(req.UserForm); userFormText != "" { + b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText)) + } + if len(req.Consult) > 0 { + b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String())) + } + if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" { + b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts)) + } + return b.String() } -// prepareUserFormPayload 准备用户表单载荷 -func prepareUserFormPayload(userForm []map[string]any) any { - if len(userForm) == 0 { - return nil +// buildUserFormText 构建用户表单内容字符串 +func buildUserFormText(form []map[string]any) string { + if len(form) == 0 { + return "" } - - if _, ok := userForm[0]["batch_index"]; ok { - return userForm - } - - return mergeUserFormTexts(userForm) -} - -// mergeUserFormTexts 合并 UserForm 中的所有文本内容 -func mergeUserFormTexts(userForm []map[string]any) string { var builder strings.Builder - for i, item := range userForm { - text := getItemText(item) - if i > 0 { - builder.WriteString("\n\n") + for _, item := range form { + for k, v := range item { + switch val := v.(type) { + case []any: + // 数组类型:逐条列出 + builder.WriteString(fmt.Sprintf("%s:\n", k)) + for i, elem := range val { + if m, ok := elem.(map[string]any); ok { + builder.WriteString(fmt.Sprintf(" %d. ", i+1)) + for mk, mv := range m { + builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv)) + } + builder.WriteString("\n") + } else { + builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem)) + } + } + case []map[string]any: + builder.WriteString(fmt.Sprintf("%s:\n", k)) + for i, m := range val { + builder.WriteString(fmt.Sprintf(" %d. ", i+1)) + for mk, mv := range m { + builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv)) + } + builder.WriteString("\n") + } + default: + builder.WriteString(fmt.Sprintf("%s:%v\n", k, v)) + } } - builder.WriteString(text) } - return builder.String() + return strings.TrimSpace(builder.String()) } // NodeBuild 节点构建 diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index a6ab82e..d0032fb 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -16,6 +16,7 @@ import ( "gitea.com/red-future/common/beans" "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" ) // ComposeMessages 核心拼接提示词主流程 @@ -157,14 +158,14 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask if composeTask.CallbackUrl != "" { composeTask.Status = public.ComposeStatusFailed composeTask.ErrorMessage = req.ErrorMsg - _ = gateway.SendCallback(ctx, composeTask) + _ = gateway.SendCallback(ctx, composeTask, 0) } return err } // handleCallbackSuccess 处理回调成功 func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error { - // 1) 查模型配置 + // 1) 获取模型配置 model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator}, ModelName: composeTask.ModelName, @@ -172,6 +173,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas if err != nil { return fmt.Errorf("查询模型失败: %w", err) } + // 2) 根据运营商获取协议配置 + //protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ + // ProviderName: model.OperatorName, + //}) // 2) 解析结果 var messages map[string]any @@ -179,10 +184,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas case public.BuildTypePrompt, public.BuildTypeNode: messages = ParseResult(req.Messages, model.ResponseBody) case public.BuildTypeStruct: - messages = map[string]any{ - "total_rounds": 1, - "rounds": []map[string]any{req.Messages}, - } + messages = ParseStructResult(req.Messages, model.ResponseBody) default: messages = req.Messages } @@ -201,80 +203,113 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas if err != nil { return err } - // 5) 存储历史结果 - - // 6) 回调业务方 + // 5) 存储提示词结果作为历史请求 + var epicycleId int64 + payload := composeTask.RequestPayload + sessionId := gconv.String(payload["sessionId"]) + nodeId := gconv.String(payload["nodeId"]) + buildType := gconv.Int(payload["buildType"]) + if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" { + epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ + NodeId: nodeId, + SessionId: sessionId, + RequestContent: messages, + }) + } + // 6) 拼接历史内容 + // 7) 回调业务方 if composeTask.CallbackUrl != "" { composeTask.Status = public.ComposeStatusSuccess composeTask.Messages = messages - _ = gateway.SendCallback(ctx, composeTask) + _ = gateway.SendCallback(ctx, composeTask, epicycleId) } return nil } -// ParseResult 解析回调结果 +// ParseResult 解析结果 func ParseResult(raw map[string]any, responseBody string) map[string]any { - // responseBody 为空,直接返回原始数据 if responseBody == "" { return raw } - // 按 responseBody 路径取值 - contentStr, ok := raw[responseBody].(string) - if !ok || contentStr == "" { + contentVal := raw[responseBody] + if contentVal == nil { return raw } - if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil { - return map[string]any{ - "total_rounds": len(roundsArray), - "rounds": roundsArray, + // 已经是数组 + if arr, ok := contentVal.([]any); ok { + rounds := gconv.Maps(arr) + if len(rounds) > 0 { + return map[string]any{"total_rounds": len(rounds), "rounds": rounds} } + return raw } - if singleRound := tryParseAsMap(contentStr); singleRound != nil { - return map[string]any{ - "total_rounds": 1, - "rounds": []map[string]any{singleRound}, - } + // 是字符串 + contentStr := gconv.String(contentVal) + if contentStr == "" { + return raw + } + + // 尝试解析为数组 + var arr []map[string]any + if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 { + return map[string]any{"total_rounds": len(arr), "rounds": arr} + } + + // 尝试解析为单对象 + var obj map[string]any + if err := json.Unmarshal([]byte(contentStr), &obj); err == nil && len(obj) > 0 { + return map[string]any{"total_rounds": 1, "rounds": []map[string]any{obj}} } return map[string]any{"content": contentStr} } -func tryParseAsMapArray(jsonStr string) []map[string]any { - var arr []map[string]any - if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil { - return nil +func ParseStructResult(raw map[string]any, responseBody string) map[string]any { + // 如果外层已有 rounds,直接返回 + if _, ok := raw["rounds"]; ok { + return raw } - if len(arr) == 0 { - return nil - } - return arr -} -func tryParseAsMap(jsonStr string) map[string]any { - var obj map[string]any - if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil { - return nil - } - if len(obj) == 0 { - return nil - } - return obj -} - -func ParseNodeResult(raw map[string]any) map[string]any { - contentStr, ok := raw["content"].(string) - if ok && contentStr != "" { - var inner map[string]any - if err := json.Unmarshal([]byte(contentStr), &inner); err == nil { - return map[string]any{ - "total_rounds": 1, - "rounds": []map[string]any{inner}, - } + contentVal := raw[responseBody] + if contentVal == nil { + return map[string]any{ + "total_rounds": 1, + "rounds": []map[string]any{raw}, } } + + contentStr := gconv.String(contentVal) + if contentStr == "" || contentStr == "0" { + return map[string]any{ + "total_rounds": 1, + "rounds": []map[string]any{raw}, + } + } + + // 尝试解析为数组 + var arr []map[string]any + if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 { + // 数组的每个元素包一层 content + var rounds []map[string]any + for _, item := range arr { + rounds = append(rounds, map[string]any{"content": item}) + } + return map[string]any{ + "total_rounds": len(rounds), + "rounds": rounds, + } + } + + // 尝试解析为单个对象 + var parsed map[string]any + if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil { + raw[responseBody] = parsed + } + + // 兜底:包标准结构 return map[string]any{ "total_rounds": 1, "rounds": []map[string]any{raw}, diff --git a/service/prompt/prompt_files_handle_service.go b/service/prompt/prompt_files_handle_service.go index 6e514de..010e99d 100644 --- a/service/prompt/prompt_files_handle_service.go +++ b/service/prompt/prompt_files_handle_service.go @@ -22,26 +22,25 @@ const ( bytesPerMB = 1024 * 1024 ) -// ExtractFileTexts 从 ConsultItem 列表中提取文件内容 -func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) map[string]string { +// ExtractFileTexts 从 ConsultItem 列表中提取文件内容,返回拼接文本 +func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) string { urls := make([]string, 0, len(consult)) for _, item := range consult { if item.Url != "" { urls = append(urls, item.Url) } } - return FetchFileTexts(ctx, urls) + return FetchFileTextsAsString(ctx, urls) } -// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件 -func FetchFileTexts(ctx context.Context, urls []string) map[string]string { - result := make(map[string]string) - +// FetchFileTextsAsString 从 URL 列表获取文件内容,拼接为字符串 +func FetchFileTextsAsString(ctx context.Context, urls []string) string { if len(urls) == 0 { - return result + return "" } client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8) + var builder strings.Builder for _, rawURL := range urls { url := util.SanitizeURL(rawURL) @@ -50,23 +49,19 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string { } if util.IsZipExtension(url) { - mergeMap(result, fetchZipFileTexts(ctx, client, url)) + for _, text := range fetchZipFileTexts(ctx, client, url) { + builder.WriteString(text) + builder.WriteString("\n") + } continue } if text := fetchAndCleanFileContent(ctx, client, url); text != "" { - result[url] = text + builder.WriteString(fmt.Sprintf("【文件:%s】\n%s\n", url, text)) } } - return result -} - -// mergeMap 合并 map -func mergeMap(dst, src map[string]string) { - for k, v := range src { - dst[k] = v - } + return builder.String() } // fetchAndCleanFileContent 获取并清理文件内容