refactor(prompts-core): 重构代码结构和优化工具函数

This commit is contained in:
2026-06-10 14:51:25 +08:00
parent 1c1db7e30c
commit b69e7386e2
10 changed files with 164 additions and 432 deletions

View File

@@ -2,9 +2,7 @@ package prompt
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/service/gateway"
"strings"
@@ -14,8 +12,8 @@ import (
"github.com/gogf/gf/v2/util/gconv"
)
// PromptIR 统一 Prompt 中间表示
type PromptIR struct {
// IR 统一 Prompt 中间表示
type IR struct {
System []Segment `json:"system"`
History []Segment `json:"history"`
User []Segment `json:"user"`
@@ -46,8 +44,8 @@ type ContentMapping struct {
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
func NewPromptIR() *IR {
return &IR{
System: make([]Segment, 0),
History: make([]Segment, 0),
User: make([]Segment, 0),
@@ -55,7 +53,7 @@ func NewPromptIR() *PromptIR {
}
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
func (ir *PromptIR) String() string {
func (ir *IR) String() string {
var builder strings.Builder
for _, seg := range ir.System {
@@ -81,7 +79,7 @@ func (ir *PromptIR) String() string {
}
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
func (ir *PromptIR) GetTotalContent() string {
func (ir *IR) GetTotalContent() string {
var builder strings.Builder
for _, seg := range ir.System {
@@ -103,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string {
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
func (ir *IR) AddSystem(content string) *IR {
if content != "" {
ir.System = append(ir.System, Segment{Type: "text", Content: content})
}
@@ -111,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR {
}
// AddUser 添加用户消息
func (ir *PromptIR) AddUser(content string) *PromptIR {
func (ir *IR) AddUser(content string) *IR {
if content != "" {
ir.User = append(ir.User, Segment{Type: "text", Content: content})
}
@@ -119,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR {
}
// AddHistory 添加历史消息
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
func (ir *IR) AddHistory(role, content string) *IR {
if content != "" {
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
}
@@ -127,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
}
// ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认)
func (ir *PromptIR) ToMessages() []map[string]any {
func (ir *IR) ToMessages() []map[string]any {
var messages []map[string]any
for _, seg := range ir.System {
@@ -168,22 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
// parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
return &ProviderProtocol{
TargetField: e.TargetField,
SystemPromptTemplate: e.SystemPromptTemplate,
MergeOrder: e.MergeOrder,
RoleMapping: gconv.MapStrStr(e.RoleMapping),
ContentMapping: ContentMapping{
Type: gconv.String(e.ContentMapping["type"]),
Field: gconv.String(e.ContentMapping["field"]),
},
RequestTemplate: e.RequestTemplate,
Capabilities: e.Capabilities,
}
// 使用通用解析方法处理各个字段
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities)
return p
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
@@ -195,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel)
}
// mergeByOrder 按协议配置顺序拼接消息
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
var messages []map[string]any
for _, part := range order {
switch part {
case "system":
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
case "history":
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
case "user":
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
}
func mergeByOrder(ir *IR, order []string) []map[string]any {
roleMap := map[string][]Segment{
"system": ir.System,
"history": ir.History,
"user": ir.User,
}
var messages []map[string]any
for _, part := range order {
for _, seg := range roleMap[part] {
msg := map[string]any{"content": seg.Content}
if part == "history" {
msg["role"] = seg.Role
} else {
msg["role"] = part
}
messages = append(messages, msg)
}
}
return messages
}
@@ -247,22 +235,22 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
return messages
}
// mapContent 内容字段映射
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
for _, msg := range messages {
content := msg["content"]
delete(msg, "content")
switch cm.Type {
case "parts":
msg["parts"] = []map[string]any{
{cm.Field: content},
}
default:
msg[cm.Field] = content
}
if cm.Field == "" || cm.Field == "content" {
return messages
}
for i, msg := range messages {
if content, ok := msg["content"]; ok {
delete(msg, "content")
switch cm.Type {
case "parts":
messages[i]["parts"] = []map[string]any{{cm.Field: content}}
default:
messages[i][cm.Field] = content
}
}
}
return messages
}
@@ -277,20 +265,17 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
}
}
// renderTemplate 简单的 {{key}} 模板替换
// renderTemplate 模板渲染
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
b, _ := json.Marshal(p.RequestTemplate)
str := string(b)
if chatModel != nil {
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
result := make(map[string]any, len(p.RequestTemplate)+1)
for k, v := range p.RequestTemplate {
result[k] = v
}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
var result map[string]any
_ = json.Unmarshal([]byte(str), &result)
if chatModel != nil {
result["model"] = chatModel.ModelName
}
result["messages"] = messages
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
result["max_tokens"] = maxTokens