feat(task): 添加构建模型名称字段支持动态校验
- 在ModelGatewayTask实体中新增BuildModelName字段 - 修改ParseAndValidate函数参数,支持传入requiredFields - 在异步工作器中实现构建模型的动态字段校验逻辑 - 添加prompt构建服务中传递构建模型名称的功能 - 实现构建模型不存在时的兜底机制
This commit is contained in:
@@ -20,7 +20,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ParseAndValidate 解析模型响应,并返回标准格式
|
// ParseAndValidate 解析模型响应,并返回标准格式
|
||||||
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
|
func ParseAndValidate(raw map[string]any, requiredFields []string) (map[string]any, error) {
|
||||||
contentStr := gconv.String(raw[entity.ResponseBody])
|
contentStr := gconv.String(raw[entity.ResponseBody])
|
||||||
if strings.TrimSpace(contentStr) == "" {
|
if strings.TrimSpace(contentStr) == "" {
|
||||||
return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody)
|
return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody)
|
||||||
@@ -41,7 +41,7 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[
|
|||||||
return raw, fmt.Errorf("解析后数组为空")
|
return raw, fmt.Errorf("解析后数组为空")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range model.RequiredFields {
|
for _, field := range requiredFields {
|
||||||
for i, r := range arr {
|
for i, r := range arr {
|
||||||
round, _ := r.(map[string]any)
|
round, _ := r.(map[string]any)
|
||||||
if round != nil && gjson.New(round).Get(field).IsNil() {
|
if round != nil && gjson.New(round).Get(field).IsNil() {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ type CreateTaskReq struct {
|
|||||||
RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"`
|
RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"`
|
||||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||||
BuildType int64 `json:"buildType" dc:"构建类型:1-提示词构建 2-节点构建"`
|
BuildType int64 `json:"buildType" dc:"构建类型:1-提示词构建 2-节点构建"`
|
||||||
|
BuildModelName string `json:"buildModelName" json:"buildModelName" dc:"构建模型名称"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateTaskRes struct {
|
type CreateTaskRes struct {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type modelGatewayTaskCol struct {
|
|||||||
TmpFile string
|
TmpFile string
|
||||||
RequestPayload string
|
RequestPayload string
|
||||||
EpicycleId string
|
EpicycleId string
|
||||||
|
BuildModelName string
|
||||||
}
|
}
|
||||||
|
|
||||||
var ModelGatewayTaskCol = modelGatewayTaskCol{
|
var ModelGatewayTaskCol = modelGatewayTaskCol{
|
||||||
@@ -40,6 +41,7 @@ var ModelGatewayTaskCol = modelGatewayTaskCol{
|
|||||||
TmpFile: "tmp_file",
|
TmpFile: "tmp_file",
|
||||||
RequestPayload: "request_payload",
|
RequestPayload: "request_payload",
|
||||||
EpicycleId: "epicycle_id",
|
EpicycleId: "epicycle_id",
|
||||||
|
BuildModelName: "build_model_name",
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelGatewayTask 模型网关任务
|
// ModelGatewayTask 模型网关任务
|
||||||
@@ -60,6 +62,7 @@ type ModelGatewayTask struct {
|
|||||||
TmpFile string `orm:"tmp_file" json:"tmpFile"`
|
TmpFile string `orm:"tmp_file" json:"tmpFile"`
|
||||||
RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"`
|
RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"`
|
||||||
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"`
|
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"`
|
||||||
|
BuildModelName string `orm:"build_model_name" json:"buildModelName"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResultFile OSS 结果文件
|
// ResultFile OSS 结果文件
|
||||||
|
|||||||
@@ -211,6 +211,24 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
|
|
||||||
// parseAndRetry 解析模型返回结果,并重试
|
// parseAndRetry 解析模型返回结果,并重试
|
||||||
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
||||||
|
// 0) 如果指定了构建模型,查出校验字段
|
||||||
|
var requiredFields []string
|
||||||
|
if task.BuildModelName != "" {
|
||||||
|
buildModel, _ := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{
|
||||||
|
TenantId: model.TenantId,
|
||||||
|
Creator: model.Creator,
|
||||||
|
},
|
||||||
|
ModelName: req.ModelName,
|
||||||
|
})
|
||||||
|
if buildModel != nil {
|
||||||
|
requiredFields = buildModel.RequiredFields
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(requiredFields) == 0 {
|
||||||
|
requiredFields = model.RequiredFields // 兜底用当前模型的
|
||||||
|
}
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||||
if attempt > 0 {
|
if attempt > 0 {
|
||||||
@@ -236,11 +254,11 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) 解析 + 校验
|
// 3) 解析 + 校验(用 buildModel 的 RequiredFields)
|
||||||
var parsed map[string]any
|
var parsed map[string]any
|
||||||
switch req.BuildType {
|
switch req.BuildType {
|
||||||
case public.BuildTypePrompt, public.BuildTypeNode:
|
case public.BuildTypePrompt, public.BuildTypeNode:
|
||||||
parsed, err = util.ParseAndValidate(mapped, model)
|
parsed, err = util.ParseAndValidate(mapped, requiredFields)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user