From 525b391f09239099edd0290730f904a428d3df9d Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Tue, 23 Jun 2026 14:55:50 +0800 Subject: [PATCH] =?UTF-8?q?feat(task):=20=E6=B7=BB=E5=8A=A0=E6=9E=84?= =?UTF-8?q?=E5=BB=BA=E6=A8=A1=E5=9E=8B=E5=90=8D=E7=A7=B0=E5=AD=97=E6=AE=B5?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=8A=A8=E6=80=81=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在ModelGatewayTask实体中新增BuildModelName字段 - 修改ParseAndValidate函数参数,支持传入requiredFields - 在异步工作器中实现构建模型的动态字段校验逻辑 - 添加prompt构建服务中传递构建模型名称的功能 - 实现构建模型不存在时的兜底机制 --- common/util/mapping.go | 4 ++-- model/dto/model_gateway_task_dto.go | 1 + model/entity/model_gateway_task.go | 3 +++ service/task/worker.go | 22 ++++++++++++++++++++-- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/common/util/mapping.go b/common/util/mapping.go index 85547dd..ee61f21 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -20,7 +20,7 @@ import ( ) // ParseAndValidate 解析模型响应,并返回标准格式 -func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) { +func ParseAndValidate(raw map[string]any, requiredFields []string) (map[string]any, error) { contentStr := gconv.String(raw[entity.ResponseBody]) if strings.TrimSpace(contentStr) == "" { return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody) @@ -41,7 +41,7 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[ return raw, fmt.Errorf("解析后数组为空") } - for _, field := range model.RequiredFields { + for _, field := range requiredFields { for i, r := range arr { round, _ := r.(map[string]any) if round != nil && gjson.New(round).Get(field).IsNil() { diff --git a/model/dto/model_gateway_task_dto.go b/model/dto/model_gateway_task_dto.go index 8d7b0a1..710e39f 100644 --- a/model/dto/model_gateway_task_dto.go +++ b/model/dto/model_gateway_task_dto.go @@ -13,6 +13,7 @@ type CreateTaskReq struct { RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` BuildType int64 `json:"buildType" dc:"构建类型:1-提示词构建 2-节点构建"` + BuildModelName string `json:"buildModelName" json:"buildModelName" dc:"构建模型名称"` } type CreateTaskRes struct { diff --git a/model/entity/model_gateway_task.go b/model/entity/model_gateway_task.go index baaa13e..eabd320 100644 --- a/model/entity/model_gateway_task.go +++ b/model/entity/model_gateway_task.go @@ -21,6 +21,7 @@ type modelGatewayTaskCol struct { TmpFile string RequestPayload string EpicycleId string + BuildModelName string } var ModelGatewayTaskCol = modelGatewayTaskCol{ @@ -40,6 +41,7 @@ var ModelGatewayTaskCol = modelGatewayTaskCol{ TmpFile: "tmp_file", RequestPayload: "request_payload", EpicycleId: "epicycle_id", + BuildModelName: "build_model_name", } // ModelGatewayTask 模型网关任务 @@ -60,6 +62,7 @@ type ModelGatewayTask struct { TmpFile string `orm:"tmp_file" json:"tmpFile"` RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"` EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` + BuildModelName string `orm:"build_model_name" json:"buildModelName"` } // ResultFile OSS 结果文件 diff --git a/service/task/worker.go b/service/task/worker.go index 6ed9db5..db52fe3 100644 --- a/service/task/worker.go +++ b/service/task/worker.go @@ -211,6 +211,24 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa // parseAndRetry 解析模型返回结果,并重试 func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { + // 0) 如果指定了构建模型,查出校验字段 + var requiredFields []string + if task.BuildModelName != "" { + buildModel, _ := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ + SQLBaseDO: beans.SQLBaseDO{ + TenantId: model.TenantId, + Creator: model.Creator, + }, + ModelName: req.ModelName, + }) + if buildModel != nil { + requiredFields = buildModel.RequiredFields + } + } + if len(requiredFields) == 0 { + requiredFields = model.RequiredFields // 兜底用当前模型的 + } + var lastErr error for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { @@ -236,11 +254,11 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta }) } - // 3) 解析 + 校验 + // 3) 解析 + 校验(用 buildModel 的 RequiredFields) var parsed map[string]any switch req.BuildType { case public.BuildTypePrompt, public.BuildTypeNode: - parsed, err = util.ParseAndValidate(mapped, model) + parsed, err = util.ParseAndValidate(mapped, requiredFields) if err == nil { return parsed, nil }