feat(prompt): 重构提示词服务并添加模型类型子分类

This commit is contained in:
2026-05-22 09:49:46 +08:00
parent a34eb4ea61
commit 92092575bc
9 changed files with 353 additions and 540 deletions

View File

@@ -16,8 +16,8 @@ import (
)
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, targetModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, targetModel)
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
@@ -26,7 +26,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
switch req.BuildType {
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, targetModel, chatModel, history, ir, totalBatches)
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, chatModel, ir)
default:
@@ -35,8 +35,8 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches)
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches)
ir.AddSystem(systemPrompt)
for _, msg := range history {
@@ -47,26 +47,30 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ta
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType))
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
ir.AddUser(userPrompt)
if !checkOverallContent(ir, targetModel) {
availableWindow := util.GetAvailableWindow(targetModel.TokenConfig)
if !checkOverallContent(ir, aiModel) {
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel.ModelName, chatModel)
// 记录历史会话
_, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: ir.User,
})
return compileToProviderRequest(ctx, ir, chatModel)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, req.ModelName, req.ModelName, chatModel)
return compileToProviderRequest(ctx, ir, chatModel)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, modelName string, chatModel *entity.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, providerName)
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
@@ -79,7 +83,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName st
}
return map[string]any{
"modelName": modelName,
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,