feat: 重构异步模型字段并更新依赖

This commit is contained in:
2026-06-08 18:01:54 +08:00
parent ee6677c1f8
commit e1461cf0f0
12 changed files with 219 additions and 335 deletions

View File

@@ -28,13 +28,13 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
return compileToProviderRequest(ctx, ir, chatModel)
return compileToProviderRequest(ctx, ir, chatModel, req)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel)
return compileToProviderRequest(ctx, ir, chatModel, req)
}
// buildStructTypeRequest 构建结构体类型请求BuildType=3
@@ -50,18 +50,20 @@ func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ch
// 用户消息
ir.AddSystem(customPrompt)
ir.AddUser(buildUserPrompt(ctx, req, ""))
return compileToProviderRequest(ctx, ir, chatModel, customPrompt)
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, customPrompt ...string) (map[string]any, error) {
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil || protocol == nil {
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
}
// 如果传了自定义提示词,替换掉协议模板
if len(customPrompt) > 0 && customPrompt[0] != "" {
protocol.SystemPromptTemplate = customPrompt[0]
protocol.SystemPromptTemplate = customPrompt[0] +
"【核心铁律】" +
"1.【技能内容skill相关】必须完整拼接到System提示词中作为System提示词的组成部分不得拆分到其他位置。"
}
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
@@ -72,6 +74,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
"bizName": util.GetServerName(ctx),
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
"buildType": req.BuildType,
}, nil
}
@@ -84,20 +87,12 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
return ""
}
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
outputJSON, //【输出结构】 %s
)
}
// buildUserFormContent 构建用户表单内容字符串
func buildUserFormContent(userForm []map[string]any) string {
var builder strings.Builder
for _, item := range userForm {
builder.WriteString(fmt.Sprintf("%v\n", item))
}
return builder.String()
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
fullContent := ir.String()

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"prompts-core/service/session"
"prompts-core/common/util"
"prompts-core/consts/public"
@@ -173,24 +174,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
if err != nil {
return fmt.Errorf("查询模型失败: %w", err)
}
// 2) 根据运营商获取协议配置
//protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
// ProviderName: model.OperatorName,
//})
// 2) 解析结果
var messages map[string]any
switch composeTask.BuildType {
case public.BuildTypePrompt, public.BuildTypeNode:
messages = ParseResult(req.Messages, model.ResponseBody)
case public.BuildTypeStruct:
messages = ParseStructResult(req.Messages, model.ResponseBody)
default:
messages = req.Messages
}
// 3) 合并附加结构
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
// 4) 更新数据库
// 2) 合并附加结构
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
// 3) 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
@@ -203,21 +190,31 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
if err != nil {
return err
}
// 5) 存储提示词结果作为历史请求
//var userHistoryMsg map[string]any
var epicycleId int64
payload := composeTask.RequestPayload
sessionId := gconv.String(payload["sessionId"])
nodeId := gconv.String(payload["nodeId"])
buildType := gconv.Int(payload["buildType"])
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: messages,
})
// 4) 获取历史内容并拼接
history, _ := session.GetHistoryMessages(ctx, sessionId, nodeId)
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
}
// 5) 存储提示词结果作为历史请求
if userMsg := util.ExtractUserText(messages); userMsg != nil {
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: userMsg,
})
}
}
// 6) 拼接历史内容
// 7) 回调业务方
// 6) 回调业务方
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess
composeTask.Messages = messages
@@ -226,95 +223,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
return nil
}
// ParseResult 解析结果
func ParseResult(raw map[string]any, responseBody string) map[string]any {
if responseBody == "" {
return raw
}
contentVal := raw[responseBody]
if contentVal == nil {
return raw
}
// 已经是数组
if arr, ok := contentVal.([]any); ok {
rounds := gconv.Maps(arr)
if len(rounds) > 0 {
return map[string]any{"total_rounds": len(rounds), "rounds": rounds}
}
return raw
}
// 是字符串
contentStr := gconv.String(contentVal)
if contentStr == "" {
return raw
}
// 尝试解析为数组
var arr []map[string]any
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
return map[string]any{"total_rounds": len(arr), "rounds": arr}
}
// 尝试解析为单对象
var obj map[string]any
if err := json.Unmarshal([]byte(contentStr), &obj); err == nil && len(obj) > 0 {
return map[string]any{"total_rounds": 1, "rounds": []map[string]any{obj}}
}
return map[string]any{"content": contentStr}
}
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
// 如果外层已有 rounds直接返回
if _, ok := raw["rounds"]; ok {
return raw
}
contentVal := raw[responseBody]
var rounds []map[string]any
// 是字符串,尝试解析
contentStr := gconv.String(contentVal)
if contentStr == "" || contentStr == "0" {
rounds = append(rounds, map[string]any{responseBody: raw})
return map[string]any{
"total_rounds": 1,
"rounds": rounds,
}
}
// 尝试解析为数组
var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
rounds = append(rounds, map[string]any{responseBody: arr})
return map[string]any{
"total_rounds": len(rounds),
"rounds": rounds,
}
}
// 尝试解析为单个对象
var parsed any
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
rounds = append(rounds, map[string]any{responseBody: parsed})
return map[string]any{
"total_rounds": 1,
"rounds": rounds,
}
}
// 兜底:原始字符串作为内容
rounds = append(rounds, map[string]any{responseBody: contentStr})
return map[string]any{
"total_rounds": 1,
"rounds": rounds,
}
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
@@ -351,3 +259,13 @@ func parseMessagesForResponse(messages any) any {
return messages
}
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
// 1) 获取基础数据
// 4) 模拟历史拼接
history, _ := session.GetHistoryMessages(ctx, "88888888", "node1")
return &dto.GetPromptTextRes{
Messages: history,
}, nil
}