package prompt import ( "context" "encoding/json" "fmt" "prompts-core/common/util" "strings" "prompts-core/dao" "prompts-core/model/entity" ) // PromptIR 统一 Prompt 中间表示 type PromptIR struct { System []Segment `json:"system"` History []Segment `json:"history"` User []Segment `json:"user"` } // Segment 消息片段 type Segment struct { Type string `json:"type"` Content string `json:"content"` Role string `json:"role,omitempty"` } // ProviderProtocol 协议编译配置(从 DB JSONB 字段解析) type ProviderProtocol struct { TargetField string `json:"target_field"` MergeOrder []string `json:"merge_order"` RoleMapping map[string]string `json:"role_mapping"` ContentMapping ContentMapping `json:"content_mapping"` RequestTemplate map[string]any `json:"request_template"` SystemPromptTemplate string `json:"system_prompt_template"` } // ContentMapping 内容字段映射 type ContentMapping struct { Type string `json:"type"` Field string `json:"field"` } // NewPromptIR 创建空 PromptIR func NewPromptIR() *PromptIR { return &PromptIR{ System: make([]Segment, 0), History: make([]Segment, 0), User: make([]Segment, 0), } } // String 返回 PromptIR 的完整内容字符串(用于 token 计算) func (ir *PromptIR) String() string { var builder strings.Builder for _, seg := range ir.System { builder.WriteString("System: ") builder.WriteString(seg.Content) builder.WriteString("\n") } for _, seg := range ir.History { builder.WriteString(seg.Role) builder.WriteString(": ") builder.WriteString(seg.Content) builder.WriteString("\n") } for _, seg := range ir.User { builder.WriteString("User: ") builder.WriteString(seg.Content) builder.WriteString("\n") } return builder.String() } // GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算) func (ir *PromptIR) GetTotalContent() string { var builder strings.Builder for _, seg := range ir.System { builder.WriteString(seg.Content) builder.WriteString("\n") } for _, seg := range ir.History { builder.WriteString(seg.Content) builder.WriteString("\n") } for _, seg := range ir.User { builder.WriteString(seg.Content) builder.WriteString("\n") } return builder.String() } // AddSystem 添加系统提示 func (ir *PromptIR) AddSystem(content string) *PromptIR { if content != "" { ir.System = append(ir.System, Segment{Type: "text", Content: content}) } return ir } // AddUser 添加用户消息 func (ir *PromptIR) AddUser(content string) *PromptIR { if content != "" { ir.User = append(ir.User, Segment{Type: "text", Content: content}) } return ir } // AddHistory 添加历史消息 func (ir *PromptIR) AddHistory(role, content string) *PromptIR { if content != "" { ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role}) } return ir } // ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认) func (ir *PromptIR) ToMessages() []map[string]any { var messages []map[string]any for _, seg := range ir.System { messages = append(messages, map[string]any{ "role": "system", "content": seg.Content, }) } for _, seg := range ir.History { messages = append(messages, map[string]any{ "role": seg.Role, "content": seg.Content, }) } for _, seg := range ir.User { messages = append(messages, map[string]any{ "role": "user", "content": seg.Content, }) } return messages } // GetProtocolByProvider 根据 provider_name 获取协议配置 func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderProtocol, error) { entity, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ ProviderName: providerName, Status: 1, }) if err != nil || entity == nil { return nil, err } return parseProtocol(entity), nil } // parseProtocol 将 DB entity 转为编译用协议配置 func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol { p := &ProviderProtocol{ TargetField: e.TargetField, SystemPromptTemplate: e.SystemPromptTemplate, } // 使用通用解析方法处理各个字段 util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder) util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping) util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping) util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate) return p } // Compile 将 PromptIR 按协议配置编译为 Provider Request func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) { if ir == nil || p == nil { return nil, fmt.Errorf("ir and protocol are required") } messages := mergeByOrder(ir, p.MergeOrder) messages = mapRoles(messages, p.RoleMapping) messages = mapContent(messages, p.ContentMapping) return buildRequest(messages, p, chatModel), nil } // mergeByOrder 按协议配置顺序拼接消息 func mergeByOrder(ir *PromptIR, order []string) []map[string]any { var messages []map[string]any for _, part := range order { switch part { case "system": for _, seg := range ir.System { messages = append(messages, map[string]any{ "role": "system", "content": seg.Content, }) } case "history": for _, seg := range ir.History { messages = append(messages, map[string]any{ "role": seg.Role, "content": seg.Content, }) } case "user": for _, seg := range ir.User { messages = append(messages, map[string]any{ "role": "user", "content": seg.Content, }) } } } return messages } // mapRoles 角色映射 func mapRoles(messages []map[string]any, mapping map[string]string) []map[string]any { if len(mapping) == 0 { return messages } for i, msg := range messages { role, ok := msg["role"].(string) if !ok { continue } if mapped, exists := mapping[role]; exists { messages[i]["role"] = mapped } } return messages } // mapContent 内容字段映射 func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any { for _, msg := range messages { content := msg["content"] delete(msg, "content") switch cm.Type { case "parts": msg["parts"] = []map[string]any{ {cm.Field: content}, } default: msg[cm.Field] = content } } return messages } // buildRequest 按 target_field 和 request_template 构建请求体 func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any { if len(p.RequestTemplate) > 0 { return renderTemplate(p.RequestTemplate, messages, chatModel) } return map[string]any{ p.TargetField: messages, } } // renderTemplate 简单的 {{key}} 模板替换 func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any { b, _ := json.Marshal(tmpl) str := string(b) if chatModel != nil { str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`) } msgBytes, _ := json.Marshal(messages) str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes)) var result map[string]any json.Unmarshal([]byte(str), &result) return result }