package prompt import ( "context" "encoding/json" "fmt" "prompts-core/common/util" "strings" "prompts-core/dao" "prompts-core/model/entity" ) // PromptIR 统一 Prompt 中间表示 type PromptIR struct { System []Segment `json:"system"` History []Segment `json:"history"` User []Segment `json:"user"` } // Segment 消息片段 type Segment struct { Type string `json:"type"` // text/image Content string `json:"content"` Role string `json:"role,omitempty"` } // NewPromptIR 创建空 PromptIR func NewPromptIR() *PromptIR { return &PromptIR{ System: make([]Segment, 0), History: make([]Segment, 0), User: make([]Segment, 0), } } // AddSystem 添加系统提示 func (ir *PromptIR) AddSystem(content string) *PromptIR { if content != "" { ir.System = append(ir.System, Segment{Type: "text", Content: content}) } return ir } // AddUser 添加用户消息 func (ir *PromptIR) AddUser(content string) *PromptIR { if content != "" { ir.User = append(ir.User, Segment{Type: "text", Content: content}) } return ir } // AddHistory 添加历史消息 func (ir *PromptIR) AddHistory(role, content string) *PromptIR { if content != "" { ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role}) } return ir } // ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认) func (ir *PromptIR) ToMessages() []map[string]any { var messages []map[string]any // 1. 系统消息 for _, seg := range ir.System { messages = append(messages, map[string]any{ "role": "system", "content": seg.Content, }) } // 2. 历史消息 for _, seg := range ir.History { messages = append(messages, map[string]any{ "role": seg.Role, "content": seg.Content, }) } // 3. 用户消息 for _, seg := range ir.User { messages = append(messages, map[string]any{ "role": "user", "content": seg.Content, }) } return messages } // GetProtocolByProvider 根据 provider_name 获取协议配置 func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderProtocol, error) { entity, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ ProviderName: providerName, Status: 1, }) if err != nil || entity == nil { return nil, err } entity.MergeOrder = util.ParseJSONField(entity.MergeOrder) entity.RoleMapping = util.ParseJSONField(entity.RoleMapping) entity.ContentMapping = util.ParseJSONField(entity.ContentMapping) entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate) entity.ContentMapping = util.ParseJSONField(entity.ContentMapping) return parseProtocol(entity), nil } // parseProtocol 将 DB entity 转为编译用协议配置 func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol { p := &ProviderProtocol{ TargetField: e.TargetField, } // MergeOrder: any → []string if e.MergeOrder != nil { b, _ := json.Marshal(e.MergeOrder) json.Unmarshal(b, &p.MergeOrder) } // RoleMapping: any → map[string]string if e.RoleMapping != nil { b, _ := json.Marshal(e.RoleMapping) json.Unmarshal(b, &p.RoleMapping) } // ContentMapping: any → ContentMapping if e.ContentMapping != nil { b, _ := json.Marshal(e.ContentMapping) json.Unmarshal(b, &p.ContentMapping) } // RequestTemplate: any → map[string]any if e.RequestTemplate != nil { b, _ := json.Marshal(e.RequestTemplate) json.Unmarshal(b, &p.RequestTemplate) } fmt.Printf("parseProtocol: %+v\n", p) return p } // ProviderProtocol 协议编译配置(从 DB JSONB 字段解析) type ProviderProtocol struct { TargetField string `json:"target_field"` MergeOrder []string `json:"merge_order"` RoleMapping map[string]string `json:"role_mapping"` ContentMapping ContentMapping `json:"content_mapping"` RequestTemplate map[string]any `json:"request_template"` } // ContentMapping 内容字段映射 type ContentMapping struct { Type string `json:"type"` // direct/parts Field string `json:"field"` // content/text } // Compile 将 PromptIR 按协议配置编译为 Provider Request func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) { if ir == nil || p == nil { return nil, fmt.Errorf("ir and protocol are required") } // 1. 按 merge_order 拼接消息 messages := mergeByOrder(ir, p.MergeOrder) // 2. 角色映射 messages = mapRoles(messages, p.RoleMapping) // 3. 内容字段映射 messages = mapContent(messages, p.ContentMapping) // 4. 按 target_field + request_template 构建请求体 return buildRequest(messages, p, chatModel), nil } // mergeByOrder 按协议配置顺序拼接消息 func mergeByOrder(ir *PromptIR, order []string) []map[string]any { var messages []map[string]any for _, part := range order { switch part { case "system": for _, seg := range ir.System { messages = append(messages, map[string]any{ "role": "system", "content": seg.Content, }) } case "history": for _, seg := range ir.History { messages = append(messages, map[string]any{ "role": seg.Role, "content": seg.Content, }) } case "user": for _, seg := range ir.User { messages = append(messages, map[string]any{ "role": "user", "content": seg.Content, }) } } } return messages } // mapRoles 角色映射 func mapRoles(messages []map[string]any, mapping map[string]string) []map[string]any { if len(mapping) == 0 { return messages } for i, msg := range messages { role, ok := msg["role"].(string) if !ok { continue } if mapped, exists := mapping[role]; exists { messages[i]["role"] = mapped } } return messages } // mapContent 内容字段映射 func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any { for _, msg := range messages { content := msg["content"] delete(msg, "content") switch cm.Type { case "parts": // Gemini 格式: {"parts": [{"text": "..."}]} msg["parts"] = []map[string]any{ {cm.Field: content}, } default: // direct: {"content": "..."} msg[cm.Field] = content } } return messages } // buildRequest 按 target_field 和 request_template 构建请求体 func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any { if len(p.RequestTemplate) > 0 { return renderTemplate(p.RequestTemplate, messages, chatModel) } return map[string]any{ p.TargetField: messages, } } // renderTemplate 简单的 {{key}} 模板替换 func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any { b, _ := json.Marshal(tmpl) str := string(b) // 替换 {{model}} str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`) // 替换 {{messages}} msgBytes, _ := json.Marshal(messages) str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes)) var result map[string]any json.Unmarshal([]byte(str), &result) return result }