refactor(util): 重构映射工具函数并优化异步任务轮询逻辑

This commit is contained in:
2026-06-03 13:30:39 +08:00
parent c11a9ad5c8
commit 3fa2896fc3
9 changed files with 80 additions and 75 deletions

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"prompts-core/consts/public"
"prompts-core/service/gateway"
"strings"
"prompts-core/common/util"
@@ -29,7 +30,7 @@ type UserPromptPayload struct {
}
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (map[string]any, error) {
//1) 处理表单分批
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
@@ -47,7 +48,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
ir.AddSystem(systemPrompt)
@@ -69,13 +70,13 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) {
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
@@ -97,7 +98,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *enti
}
// promptBuildWithRounds 构建系统提示词
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, batches int) string {
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, batches int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: chatModel.OperatorName,
Status: 1,
@@ -144,7 +145,7 @@ func buildUserFormContent(userForm []map[string]any) string {
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}

View File

@@ -42,14 +42,14 @@ func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.Com
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway.AsynchModel, *gateway.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: new(1),
IsChatModel: 1,
})
if err != nil {
return nil, nil, err
@@ -57,8 +57,8 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModels, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
aiModels, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil {
@@ -72,7 +72,7 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
}
// validateUserForm 校验用户表单
func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) error {
if len(req.UserForm) == 0 {
return nil
}
@@ -90,7 +90,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := session.GetHistoryMessages(ctx, req.SessionId)
if err != nil {
@@ -123,7 +123,7 @@ func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatMod
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
@@ -148,7 +148,7 @@ func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) {
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (string, int64, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
if err != nil {
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
@@ -186,7 +186,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName,
})

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/service/gateway"
"strings"
"prompts-core/dao"
@@ -178,7 +179,7 @@ func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
@@ -262,7 +263,7 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
}
// buildRequest 按 target_field 和 request_template 构建请求体
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any {
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
}
@@ -273,7 +274,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
}
// renderTemplate 简单的 {{key}} 模板替换
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any {
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
b, _ := json.Marshal(tmpl)
str := string(b)

View File

@@ -3,17 +3,17 @@ package prompt
import (
"context"
"fmt"
"prompts-core/service/gateway"
"strings"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *gateway.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
if model.TokenConfig == nil || len(req.UserForm) == 0 {
return req, 1, nil
}