feat(prompt): 实现历史消息注入功能和协议配置优化
- 在 handleCallbackSuccess 函数中新增获取协议配置逻辑 - 实现历史消息获取并在 rounds 中注入历史消息 - 添加 InjectHistory 函数实现按协议顺序合并历史消息 - 在 GetPromptText 接口中集成历史消息注入测试 - 更新 ProviderProtocol 实体中的 MergeOrder 类型为 []string - 新增 Capabilities 字段支持最大 token 配置 - 修改 renderTemplate 函数接收协议对象参数 - 优化会话历史存储逻辑,提取用户消息内容进行记录 - 移除无用的注释代码 handleCallbackSuccess 处理回调成功
This commit is contained in:
@@ -6,16 +6,16 @@ import "gitea.com/red-future/common/beans"
|
||||
type ProviderProtocol struct {
|
||||
beans.SQLBaseDO `orm:",inherit"`
|
||||
// 业务字段
|
||||
ProviderName string `orm:"provider_name" json:"providerName"`
|
||||
TargetField string `orm:"target_field" json:"targetField"`
|
||||
MergeOrder any `orm:"merge_order" json:"mergeOrder"`
|
||||
RoleMapping any `orm:"role_mapping" json:"roleMapping"`
|
||||
ContentMapping any `orm:"content_mapping" json:"contentMapping"`
|
||||
Capabilities any `orm:"capabilities" json:"capabilities"`
|
||||
RequestTemplate any `orm:"request_template" json:"requestTemplate"`
|
||||
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
|
||||
Status int `orm:"status" json:"status"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
ProviderName string `orm:"provider_name" json:"providerName"`
|
||||
TargetField string `orm:"target_field" json:"targetField"`
|
||||
MergeOrder []string `orm:"merge_order" json:"mergeOrder"`
|
||||
RoleMapping map[string]any `orm:"role_mapping" json:"roleMapping"`
|
||||
ContentMapping map[string]any `orm:"content_mapping" json:"contentMapping"`
|
||||
Capabilities map[string]any `orm:"capabilities" json:"capabilities"`
|
||||
RequestTemplate map[string]any `orm:"request_template" json:"requestTemplate"`
|
||||
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
|
||||
Status int `orm:"status" json:"status"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
}
|
||||
|
||||
// providerProtocolCol 列名
|
||||
|
||||
@@ -164,7 +164,6 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
||||
return err
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
// 1) 获取模型配置
|
||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
@@ -175,9 +174,35 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
return fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
|
||||
// 2) 合并附加结构
|
||||
// 2) 获取协议配置
|
||||
protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: model.OperatorName,
|
||||
Status: 1,
|
||||
})
|
||||
|
||||
// 3) 获取历史消息
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
var history []dto.FlatMessage
|
||||
if sessionId != "" && nodeId != "" {
|
||||
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
})
|
||||
if h != nil {
|
||||
history = h.Messages
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 合并附加结构
|
||||
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
||||
// 3) 更新数据库
|
||||
// 5) 注入历史到 rounds 中
|
||||
if protocol != nil && len(history) > 0 {
|
||||
messages = InjectHistory(messages, history, protocol)
|
||||
}
|
||||
|
||||
// 6) 更新数据库
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
@@ -189,33 +214,26 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//var userHistoryMsg map[string]any
|
||||
|
||||
// 7) 存储历史
|
||||
var epicycleId int64
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
buildType := gconv.Int(payload["buildType"])
|
||||
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
|
||||
// 4) 获取历史内容并拼接
|
||||
_, _ = session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
})
|
||||
// 5) 存储提示词结果作为历史请求
|
||||
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
if sessionId != "" && nodeId != "" {
|
||||
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: userMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
// 6) 回调业务方
|
||||
|
||||
// 8) 回调业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusSuccess
|
||||
composeTask.ResultJson = messages
|
||||
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -257,7 +275,16 @@ func parseMessagesForResponse(messages any) any {
|
||||
}
|
||||
|
||||
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
|
||||
// 1) 获取协议配置
|
||||
protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: "火山引擎",
|
||||
Status: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 获取历史消息
|
||||
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: "88888888",
|
||||
NodeId: "node1",
|
||||
@@ -265,7 +292,74 @@ func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetProm
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3) 模拟roundsData数据
|
||||
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Println("[打印数据]", task.ResultJson)
|
||||
fmt.Println("[打印历史]", history.Messages)
|
||||
fmt.Println("[打印协议]", protocol)
|
||||
return &dto.GetPromptTextRes{
|
||||
Messages: history.Messages,
|
||||
Messages: InjectHistory(task.ResultJson, history.Messages, protocol),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
|
||||
if protocol == nil || len(history) == 0 {
|
||||
return roundsData
|
||||
}
|
||||
// 1) 提取第一轮的 messages
|
||||
rounds := roundsData["rounds"].([]any)
|
||||
firstRound := rounds[0].(map[string]any)
|
||||
original := firstRound["messages"].([]any)
|
||||
|
||||
// 2) 按 merge_order 拼接
|
||||
result := make([]any, 0, len(original)+len(history))
|
||||
|
||||
for _, part := range protocol.MergeOrder {
|
||||
switch part {
|
||||
case "system":
|
||||
for _, m := range original {
|
||||
msg := m.(map[string]any)
|
||||
if gconv.String(msg["role"]) == "system" {
|
||||
result = append(result, msg)
|
||||
}
|
||||
}
|
||||
case "history":
|
||||
if gconv.Bool(protocol.Capabilities["support_history"]) {
|
||||
for _, msg := range history {
|
||||
result = append(result, map[string]any{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content, // 纯字符串,不转换
|
||||
})
|
||||
}
|
||||
}
|
||||
case "user":
|
||||
for _, m := range original {
|
||||
msg := m.(map[string]any)
|
||||
if gconv.String(msg["role"]) == "user" {
|
||||
result = append(result, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 角色映射
|
||||
if len(protocol.RoleMapping) > 0 {
|
||||
for _, m := range result {
|
||||
msg := m.(map[string]any)
|
||||
role := gconv.String(msg["role"])
|
||||
if mapped, ok := protocol.RoleMapping[role]; ok {
|
||||
msg["role"] = mapped
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 直接修改原对象
|
||||
firstRound["messages"] = result
|
||||
return roundsData
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// PromptIR 统一 Prompt 中间表示
|
||||
@@ -34,6 +36,7 @@ type ProviderProtocol struct {
|
||||
ContentMapping ContentMapping `json:"content_mapping"`
|
||||
RequestTemplate map[string]any `json:"request_template"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||
Capabilities map[string]any `json:"capabilities"`
|
||||
}
|
||||
|
||||
// ContentMapping 内容字段映射
|
||||
@@ -175,6 +178,7 @@ func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
||||
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
||||
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
||||
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
||||
util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -265,7 +269,7 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
||||
// buildRequest 按 target_field 和 request_template 构建请求体
|
||||
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
|
||||
if len(p.RequestTemplate) > 0 {
|
||||
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
||||
return renderTemplate(p, messages, chatModel)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
@@ -274,8 +278,8 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
|
||||
}
|
||||
|
||||
// renderTemplate 简单的 {{key}} 模板替换
|
||||
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||
b, _ := json.Marshal(tmpl)
|
||||
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||
b, _ := json.Marshal(p.RequestTemplate)
|
||||
str := string(b)
|
||||
|
||||
if chatModel != nil {
|
||||
@@ -288,5 +292,9 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *g
|
||||
var result map[string]any
|
||||
_ = json.Unmarshal([]byte(str), &result)
|
||||
|
||||
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
|
||||
result["max_tokens"] = maxTokens
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user