From 1c1db7e30cb20b921ebeda3b723679bd3f3b7aaa Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Wed, 10 Jun 2026 10:16:58 +0800 Subject: [PATCH] =?UTF-8?q?feat(prompt):=20=E5=AE=9E=E7=8E=B0=E5=8E=86?= =?UTF-8?q?=E5=8F=B2=E6=B6=88=E6=81=AF=E6=B3=A8=E5=85=A5=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E5=92=8C=E5=8D=8F=E8=AE=AE=E9=85=8D=E7=BD=AE=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 handleCallbackSuccess 函数中新增获取协议配置逻辑 - 实现历史消息获取并在 rounds 中注入历史消息 - 添加 InjectHistory 函数实现按协议顺序合并历史消息 - 在 GetPromptText 接口中集成历史消息注入测试 - 更新 ProviderProtocol 实体中的 MergeOrder 类型为 []string - 新增 Capabilities 字段支持最大 token 配置 - 修改 renderTemplate 函数接收协议对象参数 - 优化会话历史存储逻辑,提取用户消息内容进行记录 - 移除无用的注释代码 handleCallbackSuccess 处理回调成功 --- model/entity/prompts_provider_protocol.go | 20 ++-- service/prompt/prompt_compose_service.go | 132 ++++++++++++++++++---- service/prompt/prompt_ir_service.go | 14 ++- 3 files changed, 134 insertions(+), 32 deletions(-) diff --git a/model/entity/prompts_provider_protocol.go b/model/entity/prompts_provider_protocol.go index bbf4488..865f715 100644 --- a/model/entity/prompts_provider_protocol.go +++ b/model/entity/prompts_provider_protocol.go @@ -6,16 +6,16 @@ import "gitea.com/red-future/common/beans" type ProviderProtocol struct { beans.SQLBaseDO `orm:",inherit"` // 业务字段 - ProviderName string `orm:"provider_name" json:"providerName"` - TargetField string `orm:"target_field" json:"targetField"` - MergeOrder any `orm:"merge_order" json:"mergeOrder"` - RoleMapping any `orm:"role_mapping" json:"roleMapping"` - ContentMapping any `orm:"content_mapping" json:"contentMapping"` - Capabilities any `orm:"capabilities" json:"capabilities"` - RequestTemplate any `orm:"request_template" json:"requestTemplate"` - SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"` - Status int `orm:"status" json:"status"` - Remark string `orm:"remark" json:"remark"` + ProviderName string `orm:"provider_name" json:"providerName"` + TargetField string `orm:"target_field" json:"targetField"` + MergeOrder []string `orm:"merge_order" json:"mergeOrder"` + RoleMapping map[string]any `orm:"role_mapping" json:"roleMapping"` + ContentMapping map[string]any `orm:"content_mapping" json:"contentMapping"` + Capabilities map[string]any `orm:"capabilities" json:"capabilities"` + RequestTemplate map[string]any `orm:"request_template" json:"requestTemplate"` + SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"` + Status int `orm:"status" json:"status"` + Remark string `orm:"remark" json:"remark"` } // providerProtocolCol 列名 diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 791d239..2a2c6e6 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -164,7 +164,6 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask return err } -// handleCallbackSuccess 处理回调成功 func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error { // 1) 获取模型配置 model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ @@ -175,9 +174,35 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas return fmt.Errorf("查询模型失败: %w", err) } - // 2) 合并附加结构 + // 2) 获取协议配置 + protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ + ProviderName: model.OperatorName, + Status: 1, + }) + + // 3) 获取历史消息 + payload := composeTask.RequestPayload + sessionId := gconv.String(payload["sessionId"]) + nodeId := gconv.String(payload["nodeId"]) + var history []dto.FlatMessage + if sessionId != "" && nodeId != "" { + h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{ + SessionId: sessionId, + NodeId: nodeId, + }) + if h != nil { + history = h.Messages + } + } + + // 4) 合并附加结构 messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping) - // 3) 更新数据库 + // 5) 注入历史到 rounds 中 + if protocol != nil && len(history) > 0 { + messages = InjectHistory(messages, history, protocol) + } + + // 6) 更新数据库 _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ TaskId: req.TaskId, Status: public.ComposeStatusSuccess, @@ -189,33 +214,26 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas if err != nil { return err } - //var userHistoryMsg map[string]any + + // 7) 存储历史 var epicycleId int64 - payload := composeTask.RequestPayload - sessionId := gconv.String(payload["sessionId"]) - nodeId := gconv.String(payload["nodeId"]) - buildType := gconv.Int(payload["buildType"]) - if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" { - // 4) 获取历史内容并拼接 - _, _ = session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{ - SessionId: sessionId, - NodeId: nodeId, - }) - // 5) 存储提示词结果作为历史请求 - if userMsg := util.ExtractUserText(messages); userMsg != nil { - epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ + if sessionId != "" && nodeId != "" { + if userMsg := util.ExtractUserText(req.Messages); userMsg != nil { + epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ NodeId: nodeId, SessionId: sessionId, RequestContent: userMsg, }) } } - // 6) 回调业务方 + + // 8) 回调业务方 if composeTask.CallbackUrl != "" { composeTask.Status = public.ComposeStatusSuccess composeTask.ResultJson = messages _ = gateway.SendCallback(ctx, composeTask, epicycleId) } + return nil } @@ -257,7 +275,16 @@ func parseMessagesForResponse(messages any) any { } func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) { + // 1) 获取协议配置 + protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ + ProviderName: "火山引擎", + Status: 1, + }) + if err != nil { + return nil, err + } + // 2) 获取历史消息 history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{ SessionId: "88888888", NodeId: "node1", @@ -265,7 +292,74 @@ func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetProm if err != nil { return nil, err } + + // 3) 模拟roundsData数据 + task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ + TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc", + }) + if err != nil { + return nil, err + } + fmt.Println("[打印数据]", task.ResultJson) + fmt.Println("[打印历史]", history.Messages) + fmt.Println("[打印协议]", protocol) return &dto.GetPromptTextRes{ - Messages: history.Messages, + Messages: InjectHistory(task.ResultJson, history.Messages, protocol), }, nil } + +func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any { + if protocol == nil || len(history) == 0 { + return roundsData + } + // 1) 提取第一轮的 messages + rounds := roundsData["rounds"].([]any) + firstRound := rounds[0].(map[string]any) + original := firstRound["messages"].([]any) + + // 2) 按 merge_order 拼接 + result := make([]any, 0, len(original)+len(history)) + + for _, part := range protocol.MergeOrder { + switch part { + case "system": + for _, m := range original { + msg := m.(map[string]any) + if gconv.String(msg["role"]) == "system" { + result = append(result, msg) + } + } + case "history": + if gconv.Bool(protocol.Capabilities["support_history"]) { + for _, msg := range history { + result = append(result, map[string]any{ + "role": msg.Role, + "content": msg.Content, // 纯字符串,不转换 + }) + } + } + case "user": + for _, m := range original { + msg := m.(map[string]any) + if gconv.String(msg["role"]) == "user" { + result = append(result, msg) + } + } + } + } + + // 3) 角色映射 + if len(protocol.RoleMapping) > 0 { + for _, m := range result { + msg := m.(map[string]any) + role := gconv.String(msg["role"]) + if mapped, ok := protocol.RoleMapping[role]; ok { + msg["role"] = mapped + } + } + } + + // 4) 直接修改原对象 + firstRound["messages"] = result + return roundsData +} diff --git a/service/prompt/prompt_ir_service.go b/service/prompt/prompt_ir_service.go index a0eaf8a..adf3cd0 100644 --- a/service/prompt/prompt_ir_service.go +++ b/service/prompt/prompt_ir_service.go @@ -10,6 +10,8 @@ import ( "prompts-core/dao" "prompts-core/model/entity" + + "github.com/gogf/gf/v2/util/gconv" ) // PromptIR 统一 Prompt 中间表示 @@ -34,6 +36,7 @@ type ProviderProtocol struct { ContentMapping ContentMapping `json:"content_mapping"` RequestTemplate map[string]any `json:"request_template"` SystemPromptTemplate string `json:"system_prompt_template"` + Capabilities map[string]any `json:"capabilities"` } // ContentMapping 内容字段映射 @@ -175,6 +178,7 @@ func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol { util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping) util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping) util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate) + util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities) return p } @@ -265,7 +269,7 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any { // buildRequest 按 target_field 和 request_template 构建请求体 func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any { if len(p.RequestTemplate) > 0 { - return renderTemplate(p.RequestTemplate, messages, chatModel) + return renderTemplate(p, messages, chatModel) } return map[string]any{ @@ -274,8 +278,8 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat } // renderTemplate 简单的 {{key}} 模板替换 -func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any { - b, _ := json.Marshal(tmpl) +func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any { + b, _ := json.Marshal(p.RequestTemplate) str := string(b) if chatModel != nil { @@ -288,5 +292,9 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *g var result map[string]any _ = json.Unmarshal([]byte(str), &result) + if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 { + result["max_tokens"] = maxTokens + } + return result }