feat: 重构异步模型字段并更新依赖
This commit is contained in:
@@ -28,13 +28,13 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
||||
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
||||
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||||
}
|
||||
return compileToProviderRequest(ctx, ir, chatModel)
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||
}
|
||||
|
||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||
ir.AddUser(NodeBuild(ctx, req))
|
||||
return compileToProviderRequest(ctx, ir, chatModel)
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||
}
|
||||
|
||||
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
||||
@@ -50,18 +50,20 @@ func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ch
|
||||
// 用户消息
|
||||
ir.AddSystem(customPrompt)
|
||||
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
||||
return compileToProviderRequest(ctx, ir, chatModel, customPrompt)
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
||||
}
|
||||
|
||||
// compileToProviderRequest 编译为 Provider 请求
|
||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, customPrompt ...string) (map[string]any, error) {
|
||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||
if err != nil || protocol == nil {
|
||||
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
|
||||
}
|
||||
// 如果传了自定义提示词,替换掉协议模板
|
||||
if len(customPrompt) > 0 && customPrompt[0] != "" {
|
||||
protocol.SystemPromptTemplate = customPrompt[0]
|
||||
protocol.SystemPromptTemplate = customPrompt[0] +
|
||||
"【核心铁律】" +
|
||||
"1.【技能内容skill相关】必须完整拼接到System提示词中,作为System提示词的组成部分,不得拆分到其他位置。"
|
||||
}
|
||||
providerReq, err := Compile(ir, protocol, chatModel)
|
||||
if err != nil {
|
||||
@@ -72,6 +74,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
|
||||
"bizName": util.GetServerName(ctx),
|
||||
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
|
||||
"requestPayload": providerReq,
|
||||
"buildType": req.BuildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -84,20 +87,12 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
|
||||
return ""
|
||||
}
|
||||
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
|
||||
|
||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||
outputJSON, //【输出结构】 %s
|
||||
)
|
||||
}
|
||||
|
||||
// buildUserFormContent 构建用户表单内容字符串
|
||||
func buildUserFormContent(userForm []map[string]any) string {
|
||||
var builder strings.Builder
|
||||
for _, item := range userForm {
|
||||
builder.WriteString(fmt.Sprintf("%v\n", item))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// checkOverallContent 检查整体内容是否超出窗口
|
||||
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
|
||||
fullContent := ir.String()
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/service/session"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/consts/public"
|
||||
@@ -173,24 +174,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
// 2) 根据运营商获取协议配置
|
||||
//protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
// ProviderName: model.OperatorName,
|
||||
//})
|
||||
|
||||
// 2) 解析结果
|
||||
var messages map[string]any
|
||||
switch composeTask.BuildType {
|
||||
case public.BuildTypePrompt, public.BuildTypeNode:
|
||||
messages = ParseResult(req.Messages, model.ResponseBody)
|
||||
case public.BuildTypeStruct:
|
||||
messages = ParseStructResult(req.Messages, model.ResponseBody)
|
||||
default:
|
||||
messages = req.Messages
|
||||
}
|
||||
// 3) 合并附加结构
|
||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||
// 4) 更新数据库
|
||||
// 2) 合并附加结构
|
||||
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
||||
// 3) 更新数据库
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
@@ -203,21 +190,31 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 5) 存储提示词结果作为历史请求
|
||||
//var userHistoryMsg map[string]any
|
||||
var epicycleId int64
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
buildType := gconv.Int(payload["buildType"])
|
||||
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: messages,
|
||||
})
|
||||
// 4) 获取历史内容并拼接
|
||||
history, _ := session.GetHistoryMessages(ctx, sessionId, nodeId)
|
||||
for _, msg := range history {
|
||||
role := gconv.String(msg["role"])
|
||||
if role != "user" && role != "assistant" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// 5) 存储提示词结果作为历史请求
|
||||
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: userMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
// 6) 拼接历史内容
|
||||
// 7) 回调业务方
|
||||
// 6) 回调业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusSuccess
|
||||
composeTask.Messages = messages
|
||||
@@ -226,95 +223,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseResult 解析结果
|
||||
func ParseResult(raw map[string]any, responseBody string) map[string]any {
|
||||
if responseBody == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
contentVal := raw[responseBody]
|
||||
if contentVal == nil {
|
||||
return raw
|
||||
}
|
||||
|
||||
// 已经是数组
|
||||
if arr, ok := contentVal.([]any); ok {
|
||||
rounds := gconv.Maps(arr)
|
||||
if len(rounds) > 0 {
|
||||
return map[string]any{"total_rounds": len(rounds), "rounds": rounds}
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// 是字符串
|
||||
contentStr := gconv.String(contentVal)
|
||||
if contentStr == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||
return map[string]any{"total_rounds": len(arr), "rounds": arr}
|
||||
}
|
||||
|
||||
// 尝试解析为单对象
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &obj); err == nil && len(obj) > 0 {
|
||||
return map[string]any{"total_rounds": 1, "rounds": []map[string]any{obj}}
|
||||
}
|
||||
|
||||
return map[string]any{"content": contentStr}
|
||||
}
|
||||
|
||||
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
|
||||
// 如果外层已有 rounds,直接返回
|
||||
if _, ok := raw["rounds"]; ok {
|
||||
return raw
|
||||
}
|
||||
|
||||
contentVal := raw[responseBody]
|
||||
|
||||
var rounds []map[string]any
|
||||
|
||||
// 是字符串,尝试解析
|
||||
contentStr := gconv.String(contentVal)
|
||||
if contentStr == "" || contentStr == "0" {
|
||||
rounds = append(rounds, map[string]any{responseBody: raw})
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||
rounds = append(rounds, map[string]any{responseBody: arr})
|
||||
return map[string]any{
|
||||
"total_rounds": len(rounds),
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为单个对象
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
|
||||
rounds = append(rounds, map[string]any{responseBody: parsed})
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// 兜底:原始字符串作为内容
|
||||
rounds = append(rounds, map[string]any{responseBody: contentStr})
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
@@ -351,3 +259,13 @@ func parseMessagesForResponse(messages any) any {
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
|
||||
// 1) 获取基础数据
|
||||
|
||||
// 4) 模拟历史拼接
|
||||
history, _ := session.GetHistoryMessages(ctx, "88888888", "node1")
|
||||
return &dto.GetPromptTextRes{
|
||||
Messages: history,
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user