Files
prompts-core/common/util/mapping.go

197 lines
4.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package util
import (
"fmt"
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// ======================== 请求映射 ========================
// ReverseMap 将 payload 按 mapping 路径映射为嵌套结构
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
jsonObj := gjson.New("{}")
for path, defaultValue := range mapping {
val := gjson.New(payload).Get(path)
if !val.IsNil() {
_ = jsonObj.Set(path, val.Val())
} else if defaultValue != nil {
_ = jsonObj.Set(path, defaultValue)
}
}
return jsonObj.Map()
}
// ======================== 用户文本提取 ========================
// ExtractUserText 从 messages 中提取所有 user 文本
func ExtractUserText(messages map[string]any) map[string]any {
msgJson := gjson.New(messages)
msgs := msgJson.Get("rounds.0.messages")
if msgs.IsNil() {
msgs = msgJson.Get("messages")
}
var texts []string
for _, m := range msgs.Array() {
msg := gjson.New(m)
if msg.Get("role").String() != "user" {
continue
}
content := msg.Get("content").Val()
switch c := content.(type) {
case string:
texts = append(texts, c)
case []any:
for _, item := range c {
if m, ok := item.(map[string]any); ok {
if t := gconv.String(m["text"]); t != "" {
texts = append(texts, t)
}
}
}
}
}
return map[string]any{
"role": "user",
"content": strings.Join(texts, "\n"),
}
}
// ======================== 附件合并 ========================
// MergeConsult 将 consult 附件合并到每个 round 的 content 数组中
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
return messages
}
consult := gconv.Interfaces(req["consult"])
if len(consult) == 0 {
return messages
}
targetPath := gconv.String(extendMapping["target_content_path"])
templates := gconv.Map(extendMapping["attachment_templates"])
if targetPath == "" || len(templates) == 0 {
return messages
}
msgJson := gjson.New(messages)
rounds := msgJson.Get("rounds").Array()
for i := range rounds {
roundPath := fmt.Sprintf("rounds.%d.%s", i, targetPath)
for _, item := range consult {
itemJson := gjson.New(item)
itemType := itemJson.Get("type").String()
tmpl := gconv.Map(templates[itemType])
if itemType == "" || len(tmpl) == 0 {
continue
}
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
if attachment == nil {
continue
}
idx := len(msgJson.Get(roundPath).Array())
_ = msgJson.Set(fmt.Sprintf("%s.%d", roundPath, idx), attachment)
}
}
return msgJson.Map()
}
// buildAttachment 根据模板和 url 生成附件对象
func buildAttachment(tmpl map[string]any, url string) map[string]any {
typ := gconv.String(tmpl["type"])
if typ == "" || url == "" {
return nil
}
body := gconv.Map(tmpl["body"])
fillEmptyInPlace(body, url)
return map[string]any{
"type": typ,
typ: body,
}
}
// fillEmptyInPlace 递归填充空字符串
func fillEmptyInPlace(m map[string]any, value string) {
for k, v := range m {
switch vv := v.(type) {
case string:
if vv == "" {
m[k] = value
}
case map[string]any:
fillEmptyInPlace(vv, value)
}
}
}
// ======================== 系统提示词合并 ========================
// MergeSystemPrompt 将系统提示词和技能内容拼接到 system role 的 content 中
func MergeSystemPrompt(messages map[string]any, prompt, skills, userPrompt string, requestMapping map[string]any) map[string]any {
var parts []string
if prompt != "" {
parts = append(parts, prompt)
}
if userPrompt != "" {
parts = append(parts, userPrompt)
}
if skills != "" {
parts = append(parts, skills)
}
if len(parts) == 0 {
return messages
}
systemContent := strings.Join(parts, "\n")
systemPath := getSystemPromptPath(requestMapping)
if systemPath == "" {
return messages
}
msgJson := gjson.New(messages)
// 如果有 rounds加前缀
if msgJson.Get("rounds.0").Val() != nil {
systemPath = "rounds.0." + systemPath
}
existing := msgJson.Get(systemPath).String()
if existing != "" {
systemContent = existing + "\n" + systemContent
}
_ = msgJson.Set(systemPath, systemContent)
return msgJson.Map()
}
// getSystemPromptPath 从 RequestMapping 中提取 system content 的路径
func getSystemPromptPath(requestMapping map[string]any) string {
for key, val := range requestMapping {
if !strings.Contains(key, ".role") {
continue
}
if gconv.String(val) != "system" {
continue
}
prefix := strings.TrimSuffix(key, ".role")
contentKey := prefix + ".content"
if _, ok := requestMapping[contentKey]; ok {
return contentKey
}
}
return ""
}