feat(prompt): 实现历史消息注入功能和协议配置优化

- 在 handleCallbackSuccess 函数中新增获取协议配置逻辑
- 实现历史消息获取并在 rounds 中注入历史消息
- 添加 InjectHistory 函数实现按协议顺序合并历史消息
- 在 GetPromptText 接口中集成历史消息注入测试
- 更新 ProviderProtocol 实体中的 MergeOrder 类型为 []string
- 新增 Capabilities 字段支持最大 token 配置
- 修改 renderTemplate 函数接收协议对象参数
- 优化会话历史存储逻辑,提取用户消息内容进行记录
- 移除无用的注释代码 handleCallbackSuccess 处理回调成功
This commit is contained in:
2026-06-10 10:16:58 +08:00
parent 78114f99c7
commit 1c1db7e30c
3 changed files with 134 additions and 32 deletions

View File

@@ -6,16 +6,16 @@ import "gitea.com/red-future/common/beans"
type ProviderProtocol struct {
beans.SQLBaseDO `orm:",inherit"`
// 业务字段
ProviderName string `orm:"provider_name" json:"providerName"`
TargetField string `orm:"target_field" json:"targetField"`
MergeOrder any `orm:"merge_order" json:"mergeOrder"`
RoleMapping any `orm:"role_mapping" json:"roleMapping"`
ContentMapping any `orm:"content_mapping" json:"contentMapping"`
Capabilities any `orm:"capabilities" json:"capabilities"`
RequestTemplate any `orm:"request_template" json:"requestTemplate"`
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
Status int `orm:"status" json:"status"`
Remark string `orm:"remark" json:"remark"`
ProviderName string `orm:"provider_name" json:"providerName"`
TargetField string `orm:"target_field" json:"targetField"`
MergeOrder []string `orm:"merge_order" json:"mergeOrder"`
RoleMapping map[string]any `orm:"role_mapping" json:"roleMapping"`
ContentMapping map[string]any `orm:"content_mapping" json:"contentMapping"`
Capabilities map[string]any `orm:"capabilities" json:"capabilities"`
RequestTemplate map[string]any `orm:"request_template" json:"requestTemplate"`
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
Status int `orm:"status" json:"status"`
Remark string `orm:"remark" json:"remark"`
}
// providerProtocolCol 列名

View File

@@ -164,7 +164,6 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
return err
}
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
// 1) 获取模型配置
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
@@ -175,9 +174,35 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
return fmt.Errorf("查询模型失败: %w", err)
}
// 2) 合并附加结构
// 2) 获取协议配置
protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: model.OperatorName,
Status: 1,
})
// 3) 获取历史消息
payload := composeTask.RequestPayload
sessionId := gconv.String(payload["sessionId"])
nodeId := gconv.String(payload["nodeId"])
var history []dto.FlatMessage
if sessionId != "" && nodeId != "" {
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: sessionId,
NodeId: nodeId,
})
if h != nil {
history = h.Messages
}
}
// 4) 合并附加结构
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
// 3) 更新数据库
// 5) 注入历史到 rounds 中
if protocol != nil && len(history) > 0 {
messages = InjectHistory(messages, history, protocol)
}
// 6) 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
@@ -189,33 +214,26 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
if err != nil {
return err
}
//var userHistoryMsg map[string]any
// 7) 存储历史
var epicycleId int64
payload := composeTask.RequestPayload
sessionId := gconv.String(payload["sessionId"])
nodeId := gconv.String(payload["nodeId"])
buildType := gconv.Int(payload["buildType"])
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
// 4) 获取历史内容并拼接
_, _ = session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: sessionId,
NodeId: nodeId,
})
// 5) 存储提示词结果作为历史请求
if userMsg := util.ExtractUserText(messages); userMsg != nil {
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
if sessionId != "" && nodeId != "" {
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: userMsg,
})
}
}
// 6) 回调业务方
// 8) 回调业务方
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess
composeTask.ResultJson = messages
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
}
return nil
}
@@ -257,7 +275,16 @@ func parseMessagesForResponse(messages any) any {
}
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
// 1) 获取协议配置
protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: "火山引擎",
Status: 1,
})
if err != nil {
return nil, err
}
// 2) 获取历史消息
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: "88888888",
NodeId: "node1",
@@ -265,7 +292,74 @@ func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetProm
if err != nil {
return nil, err
}
// 3) 模拟roundsData数据
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc",
})
if err != nil {
return nil, err
}
fmt.Println("[打印数据]", task.ResultJson)
fmt.Println("[打印历史]", history.Messages)
fmt.Println("[打印协议]", protocol)
return &dto.GetPromptTextRes{
Messages: history.Messages,
Messages: InjectHistory(task.ResultJson, history.Messages, protocol),
}, nil
}
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
if protocol == nil || len(history) == 0 {
return roundsData
}
// 1) 提取第一轮的 messages
rounds := roundsData["rounds"].([]any)
firstRound := rounds[0].(map[string]any)
original := firstRound["messages"].([]any)
// 2) 按 merge_order 拼接
result := make([]any, 0, len(original)+len(history))
for _, part := range protocol.MergeOrder {
switch part {
case "system":
for _, m := range original {
msg := m.(map[string]any)
if gconv.String(msg["role"]) == "system" {
result = append(result, msg)
}
}
case "history":
if gconv.Bool(protocol.Capabilities["support_history"]) {
for _, msg := range history {
result = append(result, map[string]any{
"role": msg.Role,
"content": msg.Content, // 纯字符串,不转换
})
}
}
case "user":
for _, m := range original {
msg := m.(map[string]any)
if gconv.String(msg["role"]) == "user" {
result = append(result, msg)
}
}
}
}
// 3) 角色映射
if len(protocol.RoleMapping) > 0 {
for _, m := range result {
msg := m.(map[string]any)
role := gconv.String(msg["role"])
if mapped, ok := protocol.RoleMapping[role]; ok {
msg["role"] = mapped
}
}
}
// 4) 直接修改原对象
firstRound["messages"] = result
return roundsData
}

View File

@@ -10,6 +10,8 @@ import (
"prompts-core/dao"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// PromptIR 统一 Prompt 中间表示
@@ -34,6 +36,7 @@ type ProviderProtocol struct {
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
SystemPromptTemplate string `json:"system_prompt_template"`
Capabilities map[string]any `json:"capabilities"`
}
// ContentMapping 内容字段映射
@@ -175,6 +178,7 @@ func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities)
return p
}
@@ -265,7 +269,7 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
// buildRequest 按 target_field 和 request_template 构建请求体
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
return renderTemplate(p, messages, chatModel)
}
return map[string]any{
@@ -274,8 +278,8 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
}
// renderTemplate 简单的 {{key}} 模板替换
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
b, _ := json.Marshal(tmpl)
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
b, _ := json.Marshal(p.RequestTemplate)
str := string(b)
if chatModel != nil {
@@ -288,5 +292,9 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *g
var result map[string]any
_ = json.Unmarshal([]byte(str), &result)
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
result["max_tokens"] = maxTokens
}
return result
}