feat(prompt): 实现历史消息注入功能和协议配置优化
- 在 handleCallbackSuccess 函数中新增获取协议配置逻辑 - 实现历史消息获取并在 rounds 中注入历史消息 - 添加 InjectHistory 函数实现按协议顺序合并历史消息 - 在 GetPromptText 接口中集成历史消息注入测试 - 更新 ProviderProtocol 实体中的 MergeOrder 类型为 []string - 新增 Capabilities 字段支持最大 token 配置 - 修改 renderTemplate 函数接收协议对象参数 - 优化会话历史存储逻辑,提取用户消息内容进行记录 - 移除无用的注释代码 handleCallbackSuccess 处理回调成功
This commit is contained in:
@@ -8,11 +8,11 @@ type ProviderProtocol struct {
|
|||||||
// 业务字段
|
// 业务字段
|
||||||
ProviderName string `orm:"provider_name" json:"providerName"`
|
ProviderName string `orm:"provider_name" json:"providerName"`
|
||||||
TargetField string `orm:"target_field" json:"targetField"`
|
TargetField string `orm:"target_field" json:"targetField"`
|
||||||
MergeOrder any `orm:"merge_order" json:"mergeOrder"`
|
MergeOrder []string `orm:"merge_order" json:"mergeOrder"`
|
||||||
RoleMapping any `orm:"role_mapping" json:"roleMapping"`
|
RoleMapping map[string]any `orm:"role_mapping" json:"roleMapping"`
|
||||||
ContentMapping any `orm:"content_mapping" json:"contentMapping"`
|
ContentMapping map[string]any `orm:"content_mapping" json:"contentMapping"`
|
||||||
Capabilities any `orm:"capabilities" json:"capabilities"`
|
Capabilities map[string]any `orm:"capabilities" json:"capabilities"`
|
||||||
RequestTemplate any `orm:"request_template" json:"requestTemplate"`
|
RequestTemplate map[string]any `orm:"request_template" json:"requestTemplate"`
|
||||||
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
|
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
|
||||||
Status int `orm:"status" json:"status"`
|
Status int `orm:"status" json:"status"`
|
||||||
Remark string `orm:"remark" json:"remark"`
|
Remark string `orm:"remark" json:"remark"`
|
||||||
|
|||||||
@@ -164,7 +164,6 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCallbackSuccess 处理回调成功
|
|
||||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||||
// 1) 获取模型配置
|
// 1) 获取模型配置
|
||||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
@@ -175,9 +174,35 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
return fmt.Errorf("查询模型失败: %w", err)
|
return fmt.Errorf("查询模型失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) 合并附加结构
|
// 2) 获取协议配置
|
||||||
|
protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
|
ProviderName: model.OperatorName,
|
||||||
|
Status: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 3) 获取历史消息
|
||||||
|
payload := composeTask.RequestPayload
|
||||||
|
sessionId := gconv.String(payload["sessionId"])
|
||||||
|
nodeId := gconv.String(payload["nodeId"])
|
||||||
|
var history []dto.FlatMessage
|
||||||
|
if sessionId != "" && nodeId != "" {
|
||||||
|
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||||
|
SessionId: sessionId,
|
||||||
|
NodeId: nodeId,
|
||||||
|
})
|
||||||
|
if h != nil {
|
||||||
|
history = h.Messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 合并附加结构
|
||||||
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
||||||
// 3) 更新数据库
|
// 5) 注入历史到 rounds 中
|
||||||
|
if protocol != nil && len(history) > 0 {
|
||||||
|
messages = InjectHistory(messages, history, protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6) 更新数据库
|
||||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||||
TaskId: req.TaskId,
|
TaskId: req.TaskId,
|
||||||
Status: public.ComposeStatusSuccess,
|
Status: public.ComposeStatusSuccess,
|
||||||
@@ -189,33 +214,26 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
//var userHistoryMsg map[string]any
|
|
||||||
|
// 7) 存储历史
|
||||||
var epicycleId int64
|
var epicycleId int64
|
||||||
payload := composeTask.RequestPayload
|
if sessionId != "" && nodeId != "" {
|
||||||
sessionId := gconv.String(payload["sessionId"])
|
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
||||||
nodeId := gconv.String(payload["nodeId"])
|
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||||
buildType := gconv.Int(payload["buildType"])
|
|
||||||
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
|
|
||||||
// 4) 获取历史内容并拼接
|
|
||||||
_, _ = session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
|
||||||
SessionId: sessionId,
|
|
||||||
NodeId: nodeId,
|
|
||||||
})
|
|
||||||
// 5) 存储提示词结果作为历史请求
|
|
||||||
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
|
||||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
|
||||||
NodeId: nodeId,
|
NodeId: nodeId,
|
||||||
SessionId: sessionId,
|
SessionId: sessionId,
|
||||||
RequestContent: userMsg,
|
RequestContent: userMsg,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 6) 回调业务方
|
|
||||||
|
// 8) 回调业务方
|
||||||
if composeTask.CallbackUrl != "" {
|
if composeTask.CallbackUrl != "" {
|
||||||
composeTask.Status = public.ComposeStatusSuccess
|
composeTask.Status = public.ComposeStatusSuccess
|
||||||
composeTask.ResultJson = messages
|
composeTask.ResultJson = messages
|
||||||
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,7 +275,16 @@ func parseMessagesForResponse(messages any) any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
|
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
|
||||||
|
// 1) 获取协议配置
|
||||||
|
protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
|
ProviderName: "火山引擎",
|
||||||
|
Status: 1,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 获取历史消息
|
||||||
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||||
SessionId: "88888888",
|
SessionId: "88888888",
|
||||||
NodeId: "node1",
|
NodeId: "node1",
|
||||||
@@ -265,7 +292,74 @@ func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetProm
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3) 模拟roundsData数据
|
||||||
|
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||||
|
TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fmt.Println("[打印数据]", task.ResultJson)
|
||||||
|
fmt.Println("[打印历史]", history.Messages)
|
||||||
|
fmt.Println("[打印协议]", protocol)
|
||||||
return &dto.GetPromptTextRes{
|
return &dto.GetPromptTextRes{
|
||||||
Messages: history.Messages,
|
Messages: InjectHistory(task.ResultJson, history.Messages, protocol),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
|
||||||
|
if protocol == nil || len(history) == 0 {
|
||||||
|
return roundsData
|
||||||
|
}
|
||||||
|
// 1) 提取第一轮的 messages
|
||||||
|
rounds := roundsData["rounds"].([]any)
|
||||||
|
firstRound := rounds[0].(map[string]any)
|
||||||
|
original := firstRound["messages"].([]any)
|
||||||
|
|
||||||
|
// 2) 按 merge_order 拼接
|
||||||
|
result := make([]any, 0, len(original)+len(history))
|
||||||
|
|
||||||
|
for _, part := range protocol.MergeOrder {
|
||||||
|
switch part {
|
||||||
|
case "system":
|
||||||
|
for _, m := range original {
|
||||||
|
msg := m.(map[string]any)
|
||||||
|
if gconv.String(msg["role"]) == "system" {
|
||||||
|
result = append(result, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "history":
|
||||||
|
if gconv.Bool(protocol.Capabilities["support_history"]) {
|
||||||
|
for _, msg := range history {
|
||||||
|
result = append(result, map[string]any{
|
||||||
|
"role": msg.Role,
|
||||||
|
"content": msg.Content, // 纯字符串,不转换
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "user":
|
||||||
|
for _, m := range original {
|
||||||
|
msg := m.(map[string]any)
|
||||||
|
if gconv.String(msg["role"]) == "user" {
|
||||||
|
result = append(result, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 角色映射
|
||||||
|
if len(protocol.RoleMapping) > 0 {
|
||||||
|
for _, m := range result {
|
||||||
|
msg := m.(map[string]any)
|
||||||
|
role := gconv.String(msg["role"])
|
||||||
|
if mapped, ok := protocol.RoleMapping[role]; ok {
|
||||||
|
msg["role"] = mapped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 直接修改原对象
|
||||||
|
firstRound["messages"] = result
|
||||||
|
return roundsData
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
|
|
||||||
"prompts-core/dao"
|
"prompts-core/dao"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PromptIR 统一 Prompt 中间表示
|
// PromptIR 统一 Prompt 中间表示
|
||||||
@@ -34,6 +36,7 @@ type ProviderProtocol struct {
|
|||||||
ContentMapping ContentMapping `json:"content_mapping"`
|
ContentMapping ContentMapping `json:"content_mapping"`
|
||||||
RequestTemplate map[string]any `json:"request_template"`
|
RequestTemplate map[string]any `json:"request_template"`
|
||||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||||
|
Capabilities map[string]any `json:"capabilities"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ContentMapping 内容字段映射
|
// ContentMapping 内容字段映射
|
||||||
@@ -175,6 +178,7 @@ func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
|||||||
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
||||||
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
||||||
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
||||||
|
util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities)
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,7 +269,7 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
|||||||
// buildRequest 按 target_field 和 request_template 构建请求体
|
// buildRequest 按 target_field 和 request_template 构建请求体
|
||||||
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
|
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
|
||||||
if len(p.RequestTemplate) > 0 {
|
if len(p.RequestTemplate) > 0 {
|
||||||
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
return renderTemplate(p, messages, chatModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
@@ -274,8 +278,8 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// renderTemplate 简单的 {{key}} 模板替换
|
// renderTemplate 简单的 {{key}} 模板替换
|
||||||
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||||
b, _ := json.Marshal(tmpl)
|
b, _ := json.Marshal(p.RequestTemplate)
|
||||||
str := string(b)
|
str := string(b)
|
||||||
|
|
||||||
if chatModel != nil {
|
if chatModel != nil {
|
||||||
@@ -288,5 +292,9 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *g
|
|||||||
var result map[string]any
|
var result map[string]any
|
||||||
_ = json.Unmarshal([]byte(str), &result)
|
_ = json.Unmarshal([]byte(str), &result)
|
||||||
|
|
||||||
|
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
|
||||||
|
result["max_tokens"] = maxTokens
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user