From 55eb4366398c23d459570d6e644cf36901494dfb Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Fri, 29 May 2026 17:54:19 +0800 Subject: [PATCH] =?UTF-8?q?refactor(service):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E6=A8=A1=E5=9D=97=E7=BB=93=E6=9E=84=E5=B9=B6?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=A8=A1=E5=9E=8B=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/config.go | 5 + common/util/json.go | 31 ++--- common/util/mapping.go | 148 +++++++++++++++++++++++ dao/model_dao.go | 12 ++ model/dto/prompt_compose_dto.go | 4 +- service/prompt/prompt_build_service.go | 33 +++-- service/prompt/prompt_compose_service.go | 24 ++-- 7 files changed, 204 insertions(+), 53 deletions(-) create mode 100644 common/util/mapping.go diff --git a/common/util/config.go b/common/util/config.go index 4f27f97..5a7be25 100644 --- a/common/util/config.go +++ b/common/util/config.go @@ -8,6 +8,11 @@ import ( "github.com/gogf/gf/v2/util/gconv" ) +// GetServerName 获取服务名称 +func GetServerName(ctx context.Context) string { + return g.Cfg().MustGet(ctx, "server.name", "").String() +} + // GetServerPort 从配置获取服务端口 func GetServerPort(ctx context.Context) string { address := g.Cfg().MustGet(ctx, "server.address", ":8080").String() diff --git a/common/util/json.go b/common/util/json.go index 83d2615..3894ec7 100644 --- a/common/util/json.go +++ b/common/util/json.go @@ -2,51 +2,34 @@ package util import ( "encoding/json" - "fmt" "strconv" "github.com/gogf/gf/v2/container/gvar" - "github.com/gogf/gf/v2/encoding/gjson" + gfgjson "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/util/gconv" - tGjson "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) -// ParseOutput 解析模型输出为 JSON 格式 -func ParseOutput(text string) (map[string]any, error) { - j, err := gjson.LoadJson([]byte(text)) - if err != nil { - return nil, fmt.Errorf("解析模型输出失败: %w", err) - } - - return j.Map(), nil -} - // ConvertToMessages 将原始数据转换为消息列表 func ConvertToMessages(raw any) []map[string]any { if raw == nil { return nil } - j, err := gjson.LoadJson(gconv.Bytes(raw)) - if err != nil { - return nil + j := gfgjson.New(raw) + messages := j.Get("messages") + if !messages.IsNil() { + return gconv.Maps(messages.Val()) } - - if j.Contains("messages") { - return gconv.Maps(j.Get("messages").Array()) - } - return []map[string]any{j.Map()} } // FormToJSON 将表单数据转换为 JSON 字符串 -func FormToJSON(form map[string]any) string { +func FormToJSON(form []map[string]any) string { if form == nil { - return "{}" + return "[]" } - b, _ := json.Marshal(form) return string(b) } diff --git a/common/util/mapping.go b/common/util/mapping.go new file mode 100644 index 0000000..f423e41 --- /dev/null +++ b/common/util/mapping.go @@ -0,0 +1,148 @@ +package util + +import ( + "fmt" + "net/url" + "prompts-core/model/entity" + "strings" + + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" +) + +// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性 +// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段 +func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error { + // 1) 获取校验配置,并取值 + requestMapping := model.RequestMapping + contentKey := "" + for k := range model.ResponseBody { + contentKey = k + break + } + contentStr, ok := raw[contentKey].(string) + if !ok || contentStr == "" { + return fmt.Errorf("%s 字段为空或不是字符串", contentKey) + } + + // 2) 解析 content 为 JSON 数组 + var rounds []map[string]any + if err := gjson.DecodeTo(contentStr, &rounds); err != nil { + return fmt.Errorf("解析 content JSON 数组失败: %w", err) + } + if len(rounds) == 0 { + return fmt.Errorf("content 数组为空") + } + + // 3) 逐条校验:只检查默认值为空的必填字段是否存在 + for i, round := range rounds { + for path, defaultValue := range requestMapping { + if !g.IsEmpty(defaultValue) { + continue + } + if gjson.New(round).Get(path).IsNil() { + return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path) + } + } + } + return nil +} + +// ReverseMap 映射 payload 到 mapping +func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { + jsonObj := gjson.New("{}") + for path, defaultValue := range mapping { + val := gjson.New(payload).Get(path) + if !val.IsNil() { + _ = jsonObj.Set(path, val.Val()) + } else if defaultValue != nil { + _ = jsonObj.Set(path, defaultValue) + } + } + return jsonObj.Map() +} + +// MapResponsePayload 映射模型响应为标准格式 +func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) { + if len(mapping) == 0 { + return responseBytes, nil + } + + responseJson := gjson.New(responseBytes) + resultJson := gjson.New("{}") + + for standardField, modelPath := range mapping { + path := gconv.String(modelPath) + if path == "" { + continue + } + val := responseJson.Get(path) + if val.IsNil() { + continue + } + resultJson.Set(standardField, val.Val()) + } + + return []byte(resultJson.String()), nil +} + +// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔: +// 示例: +// - X-API-Key:qwen3-tts-key,operation:true,count:123 +// - X-API-Key:"qwen3-tts-key",operation:"true" +// +// 说明: +// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。 +// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。 +func ParseHeadMsgHeaders(headMsg string) map[string]string { + headMsg = strings.TrimSpace(headMsg) + if headMsg == "" { + return nil + } + out := map[string]string{} + parts := strings.Split(headMsg, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + // HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容) + if strings.Contains(p, ":") { + kv := strings.SplitN(p, ":", 2) + k := strings.TrimSpace(kv[0]) + v := strings.TrimSpace(kv[1]) + v = strings.Trim(v, "\"") + if k != "" && v != "" { + out[k] = v + } + continue + } + if strings.Contains(p, "=") { + kv := strings.SplitN(p, "=", 2) + k := strings.TrimSpace(kv[0]) + v := strings.TrimSpace(kv[1]) + v = strings.Trim(v, "\"") + if k != "" && v != "" { + out[k] = v + } + continue + } + } + if len(out) == 0 { + return nil + } + return out +} + +// PayloadToQuery 将 payload 转为 url.Values +func PayloadToQuery(payload map[string]any) (url.Values, error) { + q := url.Values{} + for k, v := range payload { + if v == nil { + continue + } + q.Set(k, gconv.String(v)) + } + return q, nil +} diff --git a/dao/model_dao.go b/dao/model_dao.go index 42cd0a9..c3aa59a 100644 --- a/dao/model_dao.go +++ b/dao/model_dao.go @@ -26,3 +26,15 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s err = r.Struct(&m) return } + +// GetsByModelName 批量获取模型 +func (d *modelDao) GetsByModelName(ctx context.Context, creator string, modelNames []string, fields ...string) (list []*entity.AsynchModel, err error) { + err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). + OmitEmpty(). + Where(entity.AsynchModelCol.Creator, creator). + WhereIn(entity.AsynchModelCol.ModelName, modelNames). + Fields(fields). + Scan(&list) + + return +} diff --git a/model/dto/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go index c656c99..972662e 100644 --- a/model/dto/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -6,10 +6,10 @@ type ComposeMessagesReq struct { g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"` ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"` BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点 - SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"` + SessionId string `p:"sessionId" json:"sessionId" dc:"会话ID"` //v:"required#sessionId不能为空" Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"` CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"` - Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"` + Form []map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"` UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"` Consult []ConsultItem `json:"consult" dc:"附件列表(图片/视频/音频)"` SkillName string `p:"skillName" json:"skillName" dc:"技能名称"` diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index 8ed8ee6..d86a4ec 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -20,7 +20,7 @@ import ( type UserPromptPayload struct { Model string `json:"model"` PromptInfo string `json:"promptInfo"` - Form map[string]any `json:"form"` + Form any `json:"form"` UserForm any `json:"userForm"` Consult []dto.ConsultItem `json:"consult"` UserFilesText map[string]string `json:"userFilesText"` @@ -30,6 +30,7 @@ type UserPromptPayload struct { // buildInferenceRequest 构建推理请求 func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) { + //1) 处理表单分批 processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel) if err != nil { return nil, fmt.Errorf("处理用户表单分批失败: %w", err) @@ -47,9 +48,10 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha // buildPromptTypeRequest 构建提示词类型请求(BuildType=1) func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) { - systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches) + //1) 构建系统提示词 + systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches) ir.AddSystem(systemPrompt) - + //2) 构建历史对话 for _, msg := range history { role := gconv.String(msg["role"]) if role != "user" && role != "assistant" { @@ -57,7 +59,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai } ir.AddHistory(role, gconv.String(msg["content"])) } - userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType)) ir.AddUser(userPrompt) if !checkOverallContent(ir, aiModel) { @@ -70,7 +71,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai // buildNodeTypeRequest 构建节点类型请求(BuildType=2) func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) { ir.AddUser(NodeBuild(ctx, req)) - return compileToProviderRequest(ctx, ir, chatModel) } @@ -90,33 +90,33 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *enti return map[string]any{ "modelName": chatModel.ModelName, - "bizName": "prompts-core", + "bizName": util.GetServerName(ctx), "callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"), "requestPayload": providerReq, }, nil } -// promptBuildWithRounds 构建系统提示词(包含轮次信息) -func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string { +// promptBuildWithRounds 构建系统提示词 +func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, batches int) string { providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ - ProviderName: model.OperatorName, + ProviderName: chatModel.OperatorName, Status: 1, }) if err != nil || providerProtocol == nil { return "" } - outputJSON := util.JSONPretty(model.RequestMapping) - maxWindowSize := util.GetMaxWindowSize(model.TokenConfig) - availableWindow := util.GetAvailableWindow(model.TokenConfig) - + outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{})) + maxWindowSize := util.GetMaxWindowSize(chatModel.TokenConfig) + availableWindow := util.GetAvailableWindow(chatModel.TokenConfig) + formContent := buildUserFormContent(req.Form) userFormContent := buildUserFormContent(req.UserForm) formInfo := fmt.Sprintf(` 【系统表单(系统提示词/参数)】 %s 【用户表单全文(必须完整阅读,全部作为用户提示词来源)】 %s -`, util.FormToJSON(req.Form), userFormContent) +`, formContent, userFormContent) inputInfo := fmt.Sprintf(` 目标模型: %s @@ -129,11 +129,8 @@ func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, mod req.ModelName, // %s 目标模型名称 maxWindowSize, // %d 最大窗口 availableWindow, // %d 可用窗口 - totalRounds, // %d 数组长度(多轮输出要求) - totalRounds, // %d 数组长度(结构铁律) outputJSON, // %s 输出结构 inputInfo, // %s 完整输入信息 - totalRounds, // %d 数组长度(最后一行) ) } @@ -157,7 +154,7 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st payload := UserPromptPayload{ Model: req.ModelName, PromptInfo: prompt, - Form: req.Form, + Form: prepareUserFormPayload(req.Form), UserForm: prepareUserFormPayload(req.UserForm), Consult: req.Consult, UserFilesText: ExtractFileTexts(ctx, req.Consult), diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index eebec25..999c128 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -21,13 +21,16 @@ import ( // ComposeMessages 核心拼接提示词主流程 func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) { + //1) 获取模型信息 chatModel, aiModel, err := GetModelMessage(ctx, req) if err != nil { return nil, err } + //2) 校验用户表单 if err = validateUserForm(req, aiModel); err != nil { return nil, err } + //3) 处理不同类型 switch req.BuildType { case public.BuildTypePrompt: return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建 @@ -54,7 +57,7 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity. if chatModel == nil { return nil, nil, errors.New("当前没有对话模型,请添加") } - aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + aiModels, err := dao.Model.Get(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, ModelName: req.ModelName, }) @@ -62,10 +65,10 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity. return nil, nil, err } - if aiModel == nil { + if aiModels == nil { return nil, nil, errors.New("需要构建的模型不存在") } - return chatModel, aiModel, nil + return chatModel, aiModels, nil } // validateUserForm 校验用户表单 @@ -150,12 +153,15 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo if err != nil { return "", 0, fmt.Errorf("构建推理请求失败: %w", err) } - id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ - SessionId: req.SessionId, - RequestContent: util.GetUserMessage(taskReq), - }) - if err != nil { - return "", 0, fmt.Errorf("保存历史会话失败: %w", err) + id := int64(0) + if req.SessionId != "" { + id, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ + SessionId: req.SessionId, + RequestContent: util.GetUserMessage(taskReq), + }) + if err != nil { + return "", 0, fmt.Errorf("保存历史会话失败: %w", err) + } } taskID, err := gateway.CreateGatewayTask(ctx, taskReq) if err != nil {