diff --git a/common/util/json.go b/common/util/json.go deleted file mode 100644 index 2be1df7..0000000 --- a/common/util/json.go +++ /dev/null @@ -1,81 +0,0 @@ -package util - -import ( - "fmt" - - "github.com/gogf/gf/v2/encoding/gjson" - "github.com/gogf/gf/v2/util/gconv" -) - -// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中 -func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any { - if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 { - return messages - } - - consult := gconv.Interfaces(req["consult"]) - if len(consult) == 0 { - return messages - } - - targetPath := gconv.String(extendMapping["target_content_path"]) - templates := gconv.Map(extendMapping["attachment_templates"]) - if targetPath == "" || len(templates) == 0 { - return messages - } - - msgJson := gjson.New(messages) - - // rounds 路径修正 - if !msgJson.Get("rounds.0").IsNil() { - targetPath = "rounds.0." + targetPath - } - - // 遍历追加 - for _, item := range consult { - itemJson := gjson.New(item) - itemType := itemJson.Get("type").String() - tmpl := gconv.Map(templates[itemType]) - if itemType == "" || len(tmpl) == 0 { - continue - } - - attachment := buildAttachment(tmpl, itemJson.Get("url").String()) - if attachment == nil { - continue - } - - idx := len(msgJson.Get(targetPath).Array()) - _ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment) - } - - return msgJson.Map() -} - -func buildAttachment(tmpl map[string]any, url string) map[string]any { - typ := gconv.String(tmpl["type"]) - if typ == "" || url == "" { - return nil - } - - body := gconv.Map(tmpl["body"]) - fillEmptyInPlace(body, url) - - return map[string]any{ - "type": typ, - typ: body, - } -} - -func fillEmptyInPlace(m map[string]any, value string) { - for k, v := range m { - switch vv := v.(type) { - case string: - if vv == "" { - m[k] = value - } - case map[string]any: - fillEmptyInPlace(vv, value) - } - } -} diff --git a/common/util/mapping.go b/common/util/mapping.go index b8ade11..a752c47 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -1,13 +1,16 @@ package util import ( + "fmt" "strings" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/util/gconv" ) -// ReverseMap 映射 payload 到 mapping +// ======================== 请求映射 ======================== + +// ReverseMap 将 payload 按 mapping 路径映射为嵌套结构 func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { jsonObj := gjson.New("{}") for path, defaultValue := range mapping { @@ -21,6 +24,8 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { return jsonObj.Map() } +// ======================== 用户文本提取 ======================== + // ExtractUserText 从 messages 中提取所有 user 文本 func ExtractUserText(messages map[string]any) map[string]any { msgJson := gjson.New(messages) @@ -29,6 +34,7 @@ func ExtractUserText(messages map[string]any) map[string]any { if msgs.IsNil() { msgs = msgJson.Get("messages") } + var texts []string for _, m := range msgs.Array() { msg := gjson.New(m) @@ -55,3 +61,128 @@ func ExtractUserText(messages map[string]any) map[string]any { "content": strings.Join(texts, "\n"), } } + +// ======================== 附件合并 ======================== + +// MergeConsult 将 consult 附件合并到每个 round 的 content 数组中 +func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any { + if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 { + return messages + } + + consult := gconv.Interfaces(req["consult"]) + if len(consult) == 0 { + return messages + } + + targetPath := gconv.String(extendMapping["target_content_path"]) + templates := gconv.Map(extendMapping["attachment_templates"]) + if targetPath == "" || len(templates) == 0 { + return messages + } + + msgJson := gjson.New(messages) + + rounds := msgJson.Get("rounds").Array() + for i := range rounds { + roundPath := fmt.Sprintf("rounds.%d.%s", i, targetPath) + for _, item := range consult { + itemJson := gjson.New(item) + itemType := itemJson.Get("type").String() + tmpl := gconv.Map(templates[itemType]) + if itemType == "" || len(tmpl) == 0 { + continue + } + + attachment := buildAttachment(tmpl, itemJson.Get("url").String()) + if attachment == nil { + continue + } + + idx := len(msgJson.Get(roundPath).Array()) + _ = msgJson.Set(fmt.Sprintf("%s.%d", roundPath, idx), attachment) + } + } + + return msgJson.Map() +} + +// buildAttachment 根据模板和 url 生成附件对象 +func buildAttachment(tmpl map[string]any, url string) map[string]any { + typ := gconv.String(tmpl["type"]) + if typ == "" || url == "" { + return nil + } + + body := gconv.Map(tmpl["body"]) + fillEmptyInPlace(body, url) + + return map[string]any{ + "type": typ, + typ: body, + } +} + +// fillEmptyInPlace 递归填充空字符串 +func fillEmptyInPlace(m map[string]any, value string) { + for k, v := range m { + switch vv := v.(type) { + case string: + if vv == "" { + m[k] = value + } + case map[string]any: + fillEmptyInPlace(vv, value) + } + } +} + +// ======================== 系统提示词合并 ======================== + +// MergeSystemPrompt 将系统提示词和技能内容拼接到 system role 的 content 中 +func MergeSystemPrompt(messages map[string]any, prompt, skills string, requestMapping map[string]any) map[string]any { + var parts []string + if prompt != "" { + parts = append(parts, prompt) + } + if skills != "" { + parts = append(parts, skills) + } + if len(parts) == 0 { + return messages + } + + systemContent := strings.Join(parts, "\n") + systemPath := getSystemPromptPath(requestMapping) + if systemPath == "" { + return messages + } + + msgJson := gjson.New(messages) + + existing := msgJson.Get(systemPath).String() + if existing != "" { + systemContent = existing + "\n" + systemContent + } + _ = msgJson.Set(systemPath, systemContent) + + return msgJson.Map() +} + +// getSystemPromptPath 从 RequestMapping 中提取 system content 的路径 +func getSystemPromptPath(requestMapping map[string]any) string { + for key, val := range requestMapping { + if !strings.Contains(key, ".role") { + continue + } + if gconv.String(val) != "system" { + continue + } + prefix := strings.TrimSuffix(key, ".role") + contentKey := prefix + ".content" + if _, ok := requestMapping[contentKey]; ok { + return contentKey + } + } + return "" +} diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index 937d95b..284db31 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -20,7 +20,7 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai //1) 构建系统提示词 systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel) ir.AddSystem(systemPrompt) - userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType)) + userPrompt := buildUserPrompt(ctx, req) ir.AddUser(userPrompt) //2) 检查整体内容是否超出窗口 if !checkOverallContent(ir, aiModel) { @@ -40,7 +40,7 @@ func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chat func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) { customPrompt := gjson.New(req.UserForm).Get("0.prompt").String() ir.AddSystem(customPrompt) - ir.AddUser(buildUserPrompt(ctx, req, "")) + ir.AddUser(buildUserPrompt(ctx, req)) return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt) } @@ -81,9 +81,14 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, if err != nil || providerProtocol == nil { return "" } - outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonString() - return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON) + outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{ + "model": aiModel.ModelName, + })).MustToJsonString() + + return fmt.Sprintf(providerProtocol.SystemPromptTemplate, + outputJSON, //%s【输出结构】 + ) } // checkOverallContent 检查整体内容是否超出窗口 @@ -93,15 +98,8 @@ func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool { } // buildUserPrompt 构建用户提示词 -func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string { +func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq) string { var b strings.Builder - b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName)) - if prompt != "" { - b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt)) - } - if skills := SkillMdContent(ctx, req.SkillName); skills != "" { - b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills)) - } if formText := buildUserFormText(req.Form); formText != "" { b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText)) } diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index f2028f9..43f852c 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -232,15 +232,19 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas }) } } + // 4) 合并系统提示词 + systemPrompt := util.GetModelPrompt(ctx, model.ModelType) + skillContent := SkillMdContent(ctx, composeTask.SkillName) + messages = util.MergeSystemPrompt(messages, systemPrompt, skillContent, model.RequestMapping) - // 4) 合并附加结构 + // 5) 合并附加结构 messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping) - // 5) 注入历史 + // 6) 注入历史 if len(history) > 0 { messages = InjectHistory(messages, history, protocol) } - // 6) 更新数据库 + // 7) 更新数据库 _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ TaskId: req.TaskId, Status: public.ComposeStatusSuccess,