refactor(prompt): 重构异步模型字段和提示词构建服务

This commit is contained in:
2026-05-21 10:53:58 +08:00
parent fee6528f93
commit 15f5761000
7 changed files with 266 additions and 151 deletions

View File

@@ -26,16 +26,16 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
switch req.BuildType {
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, targetModel, history, ir, totalBatches)
return buildPromptTypeRequest(ctx, processedReq, targetModel, chatModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, ir)
return buildNodeTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches)
ir.AddSystem(systemPrompt)
@@ -49,42 +49,23 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ta
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType))
ir.AddUser(userPrompt)
if !checkOverallContent(ir, targetModel) {
availableWindow := util.GetAvailableWindow(targetModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel)
return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel.ModelName, chatModel)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ir *PromptIR) (map[string]any, error) {
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
protocol, err := GetProtocolByProvider(ctx, req.ModelName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
providerReq, err := Compile(ir, protocol, nil)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
return map[string]any{
"modelName": req.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
return compileToProviderRequest(ctx, ir, req.ModelName, req.ModelName, chatModel)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, model *entity.AsynchModel) (map[string]any, error) {
func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, modelName string, chatModel *entity.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, providerName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
@@ -92,17 +73,15 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName st
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
providerReq, err := Compile(ir, protocol, model)
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
fmt.Println("providerReq打印:", util.MustMarshal(providerReq))
return map[string]any{
"modelName": model.ModelName,
"modelName": modelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
}, nil
}