diff --git a/workflow/model/dto/flow/flow_execution_dto.go b/workflow/model/dto/flow/flow_execution_dto.go index 6b8f628..2ffe70c 100644 --- a/workflow/model/dto/flow/flow_execution_dto.go +++ b/workflow/model/dto/flow/flow_execution_dto.go @@ -128,12 +128,11 @@ type ComposeCallbackReq struct { type ModelCallbackReq struct { g.Meta `path:"/modelCallback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 GET 回调:callbackUrl/{bizName}"` - TaskId string `p:"task_id" json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"` - State int `p:"state" json:"state" dc:"网关任务状态"` - OssFile string `p:"oss_file" json:"oss_file" dc:"结果文件地址"` - FileType string `p:"file_type" json:"file_type" dc:"结果文件类型"` - Messages map[string]any `json:"messages"` - ErrorMsg string `json:"error_msg"` + TaskId string `p:"task_id" json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"` + State int `p:"state" json:"state" dc:"网关任务状态"` + OssFile string `p:"oss_file" json:"oss_file" dc:"结果文件地址"` + FileType string `p:"file_type" json:"file_type" dc:"结果文件类型"` + ErrorMsg string `json:"error_msg"` } type VideoCallbackReq struct { diff --git a/workflow/service/flow/flow_execution_service.go b/workflow/service/flow/flow_execution_service.go index 3bc103f..63d9a6c 100644 --- a/workflow/service/flow/flow_execution_service.go +++ b/workflow/service/flow/flow_execution_service.go @@ -664,18 +664,9 @@ func registerNodeToGraph(graph *compose.Graph[any, any], flowNode entity.FlowNod // 执行节点 _, err = lambda(ctx, realInput) durationMs := time.Since(startTime).Milliseconds() - // 上传OSS(每条独立上传) - ossResult1, err := Upload(ctx, &dto.UploadFileBytesReq{ - FileBytes: gconv.Bytes(gconv.String(realInput)), - FileName: fmt.Sprintf("nodeInput:%v.txt", time.Now().UnixMilli()), - }) - if err != nil { - return nil, err - } updateReq := &nodeDto.UpdateNodeExecutionReq{ - Id: nodeExecutionId, - OutputParamsPath: ossResult1.FileURL, - DurationMs: durationMs, + Id: nodeExecutionId, + DurationMs: durationMs, } if err != nil { // 执行失败,更新状态 @@ -689,7 +680,15 @@ func registerNodeToGraph(graph *compose.Graph[any, any], flowNode entity.FlowNod }) return nil, err } - + // 上传OSS(每条独立上传) + ossResult1, err := Upload(ctx, &dto.UploadFileBytesReq{ + FileBytes: gconv.Bytes(gconv.String(realInput)), + FileName: fmt.Sprintf("nodeInput:%v.txt", time.Now().UnixMilli()), + }) + if err != nil { + return nil, err + } + updateReq.OutputParamsPath = ossResult1.FileURL // 执行成功,更新状态 updateReq.Status = node.NodeExecutionStatusSuccess.Code() _, _ = nodeDao.NodeExecutionDao.Update(ctx, updateReq) diff --git a/workflow/service/flow/lambda_node.go b/workflow/service/flow/lambda_node.go index 93894a8..279b4cf 100644 --- a/workflow/service/flow/lambda_node.go +++ b/workflow/service/flow/lambda_node.go @@ -44,16 +44,18 @@ func JudgeLambda(ctx context.Context, input any) (string, error) { // 1. 直接用你原来的方法(返回两个 map) inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config) var outputResult []node.NodeFormField - for _, valueAny := range inputMap { - if field, ok := valueAny.(node.NodeFormField); ok { - outputResult = append(outputResult, field) - } - } - for _, valueAny := range outputMap { - if field, ok := valueAny.(node.NodeFormField); ok { - outputResult = append(outputResult, field) - } - } + outputResult = append(outputResult, inputMap...) + outputResult = append(outputResult, outputMap...) + //for _, valueAny := range inputMap { + // if field, ok := valueAny.(node.NodeFormField); ok { + // outputResult = append(outputResult, field) + // } + //} + //for _, valueAny := range outputMap { + // if field, ok := valueAny.(node.NodeFormField); ok { + // outputResult = append(outputResult, field) + // } + //} for _, valueAny := range modelMap { if field, ok := valueAny.(node.NodeFormField); ok { outputResult = append(outputResult, field) @@ -123,62 +125,75 @@ func BatchModelLambda(ctx context.Context, input any) (any, error) { } } } - // 结果按索引存放,保证顺序 + // 结果按索引存放,切片不同下标并发写无竞争,不用锁 res := make([][]node.NodeFormField, len(reqMap)) var wg sync.WaitGroup - // 用一个通道标记是否完成 - done := make(chan struct{}) - // 错误只存一个 - var execErr error - // 并发执行 + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // 缓冲1错误通道,仅接收第一个错误 + errCh := make(chan error, 1) + + // 并发执行任务 for idx, item := range reqMap { wg.Add(1) go func(idx int, userItem map[string]any) { defer wg.Done() + // 上下文已取消则直接退出 + select { + case <-subCtx.Done(): + return + default: + } + singleUserFrom := []map[string]any{userItem} - output, err := TextNode(ctx, nodeInput, skillName, from, singleUserFrom) + output, err := TextNode(subCtx, nodeInput, skillName, from, singleUserFrom) if err != nil { - // 并发安全赋值错误 - if execErr == nil { - execErr = err + // 仅第一个错误写入通道 + select { + case errCh <- err: + cancel() // 触发全局取消,其他协程快速退出 + default: } return } - - // 直接按原索引写,顺序绝对正确 res[idx] = output }(idx, item) } - // 后台等待所有协程完成,然后关闭 done 通道 + // 任务全部结束后关闭错误通道 go func() { wg.Wait() - close(done) + close(errCh) }() - // 等待全部完成 - <-done - - // 如果有错误,直接返回 - if execErr != nil { - return nil, execErr + // ========== 修正后的等待逻辑 ========== + var execErr error + select { + // 优先捕获业务错误 + case execErr = <-errCh: + if execErr != nil { + // 收到真实业务错误,等待剩余协程收尾后返回 + wg.Wait() + return nil, execErr + } + // execErr == nil 代表通道关闭、无任何错误,走到下方返回完整结果 + case <-subCtx.Done(): + // 上下文被取消,阻塞读完errCh,确认是否存在业务错误 + execErr = <-errCh } - // 全局自增 i + // 拼接输出结果 var globalIndex int var outputRes []node.NodeFormField for _, items := range res { for _, item := range items { - // 1. 拿到原来的 Field:例如 "text_content:2:0" oldField := item.Field - // 2. 找到最后一个 : 的位置 if idx := strings.LastIndex(oldField, ":"); idx != -1 { - // 3. 截断前面部分,拼接上新的 globalIndex item.Field = oldField[:idx+1] + fmt.Sprint(globalIndex) } - // Label 同理 oldLabel := item.Label if idx := strings.LastIndex(oldLabel, ":"); idx != -1 { item.Label = oldLabel[:idx+1] + fmt.Sprint(globalIndex) @@ -437,11 +452,14 @@ func MergeLambda(ctx context.Context, input any) (res any, err error) { // 1. 把所有节点输出拍平成 字段名->内容 的map dataMap := make(map[string]node.NodeFormField) _, outputMap, _ := GetNodeContextContent(nodeInput.Global, nodeInput.Config) - for _, valueAny := range outputMap { - field := node.NodeFormField{} - if field, ok = valueAny.(node.NodeFormField); ok { - dataMap[field.Field] = field - } + //for _, valueAny := range outputMap { + // field := node.NodeFormField{} + // if field, ok = valueAny.(node.NodeFormField); ok { + // dataMap[field.Field] = field + // } + //} + for _, field := range outputMap { + dataMap[field.Field] = field } // 2. 提取所有文案:text_content_0,1,2... diff --git a/workflow/service/flow/lambda_node_imp.go b/workflow/service/flow/lambda_node_imp.go index 55cc100..7fa3379 100644 --- a/workflow/service/flow/lambda_node_imp.go +++ b/workflow/service/flow/lambda_node_imp.go @@ -939,23 +939,17 @@ func HttpNode(ctx context.Context, nodeInput *flowDto.NodeExecutionInput) ([]nod func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, resultFrom []map[string]any, resultUserFrom []map[string]any) { inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config) var outputResult []node.NodeFormField - for _, valueAny := range inputMap { - if field, ok := valueAny.(node.NodeFormField); ok { - outputResult = append(outputResult, field) - } - } + outputResult = append(outputResult, inputMap...) resultUserFrom = []map[string]any{} - for _, valueAny := range outputMap { - if field, ok := valueAny.(node.NodeFormField); ok { - if !strings.Contains(field.Field, "text_url") && !strings.Contains(field.Field, "img_url") { - if strings.Contains(field.Field, "text_content") { - field.Value = StripHtmlTags(gconv.String(field.Value)) - } - resultUserFrom = append(resultUserFrom, map[string]any{ - field.Label: field.Value, - }) + for _, field := range outputMap { + if !strings.Contains(field.Field, "text_url") && !strings.Contains(field.Field, "img_url") { + if strings.Contains(field.Field, "text_content") { + field.Value = StripHtmlTags(gconv.String(field.Value)) } + resultUserFrom = append(resultUserFrom, map[string]any{ + field.Label: field.Value, + }) } } for _, valueAny := range modelMap { @@ -963,18 +957,13 @@ func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, result outputResult = append(outputResult, field) } } - //if !nodeInput.Global.IsDialogue { - for _, item := range outputResult { - resultUserFrom = append(resultUserFrom, map[string]any{ - item.Label: item.Value, - }) + if !nodeInput.Global.IsDialogue { + for _, item := range outputResult { + resultUserFrom = append(resultUserFrom, map[string]any{ + item.Label: item.Value, + }) + } } - for _, item := range nodeInput.Config.FormConfig { - resultUserFrom = append(resultUserFrom, map[string]any{ - item.Label: item.Value, - }) - } - //} if !g.IsEmpty(nodeInput.Global.Desc) { resultUserFrom = append(resultUserFrom, map[string]any{ "desc": nodeInput.Global.Desc, @@ -998,48 +987,37 @@ func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, result return skillName, resultFrom, resultUserFrom } -func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, nodeEntity *entity.FlowNode) (map[string]any, map[string]any, map[string]any) { - input := make(map[string]any) - output := make(map[string]any) +func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, nodeEntity *entity.FlowNode) ([]node.NodeFormField, []node.NodeFormField, map[string]any) { + var input []node.NodeFormField + var output []node.NodeFormField model := make(map[string]any) - // 1. 有引用 → 取引用节点的字段值 if len(nodeEntity.InputSource) > 0 { for _, source := range nodeEntity.InputSource { - refNodeID := source.NodeId - fields := source.Field - - refNode, ok := execInput.ConfigMap[refNodeID] + refNode, ok := execInput.ConfigMap[source.NodeId] if !ok { continue } - - inputMap := buildInputMap(refNode) - outputMap := mergeOutput(refNode.OutputResult) - modelMap := mergeModel(refNode.ModelConfig) - if len(fields) > 0 { + if len(source.Field) > 0 { // 取指定字段 - for _, f := range fields { - if v, ok := inputMap[f]; ok { - input[f] = v - } - if v, ok := modelMap[f]; ok { - model[f] = v - } - for k, v := range outputMap { - if strings.Contains(k, f) { - model[k] = v + for _, f := range source.Field { + for _, v := range refNode.FormConfig { + if strings.Contains(v.Label, f) { + input = append(input, v) } } - } - } else { - // 取全部 - if refNode.NodeCode != node.NodeTypeHttp { - for k, v := range inputMap { - input[k] = v + for _, v := range refNode.ModelConfig.ModelForm { + if g.IsEmpty(v.Value) { + continue + } + if strings.Contains(v.Label, f) { + model[f] = v + } + } + for _, v := range refNode.OutputResult { + if strings.Contains(v.Label, f) { + output = append(output, v) + } } - } - for k, v := range modelMap { - model[k] = v } } } @@ -1047,34 +1025,145 @@ func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, nodeEntity *en return input, output, model } -// buildInputMap 从 FormConfig 构造输入map -func buildInputMap(node *entity.FlowNode) map[string]any { - m := make(map[string]any) - for _, item := range node.FormConfig { - m[item.Label] = item - } - return m -} - -// mergeOutput 合并节点输出 []map → 单map -func mergeOutput(output []node.NodeFormField) map[string]any { - m := make(map[string]any) - for _, item := range output { - m[item.Label] = item - } - return m -} - -// mergeOutput 合并节点输出 []map → 单map -func mergeModel(output node.ModelItem) map[string]any { - m := make(map[string]any) - // 遍历 output.ModelForm 里的每一个 key 和原始值 - for _, rawValue := range output.ModelForm { - if g.IsEmpty(rawValue.Value) { - continue - } - // 包装成 { "value": 原始值 } - m[rawValue.Label] = rawValue.Value - } - return m -} +//func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, resultFrom []map[string]any, resultUserFrom []map[string]any) { +// inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config) +// var outputResult []node.NodeFormField +// for _, valueAny := range inputMap { +// if field, ok := valueAny.(node.NodeFormField); ok { +// outputResult = append(outputResult, field) +// } +// } +// +// resultUserFrom = []map[string]any{} +// for _, valueAny := range outputMap { +// if field, ok := valueAny.(node.NodeFormField); ok { +// if !strings.Contains(field.Field, "text_url") && !strings.Contains(field.Field, "img_url") { +// if strings.Contains(field.Field, "text_content") { +// field.Value = StripHtmlTags(gconv.String(field.Value)) +// } +// resultUserFrom = append(resultUserFrom, map[string]any{ +// field.Label: field.Value, +// }) +// } +// } +// } +// for _, valueAny := range modelMap { +// if field, ok := valueAny.(node.NodeFormField); ok { +// outputResult = append(outputResult, field) +// } +// } +// //if !nodeInput.Global.IsDialogue { +// for _, item := range outputResult { +// resultUserFrom = append(resultUserFrom, map[string]any{ +// item.Label: item.Value, +// }) +// } +// for _, item := range nodeInput.Config.FormConfig { +// resultUserFrom = append(resultUserFrom, map[string]any{ +// item.Label: item.Value, +// }) +// } +// //} +// if !g.IsEmpty(nodeInput.Global.Desc) { +// resultUserFrom = append(resultUserFrom, map[string]any{ +// "desc": nodeInput.Global.Desc, +// }) +// } +// +// resultFrom = []map[string]any{} +// for _, item := range nodeInput.Config.ModelConfig.ModelForm { +// if g.IsEmpty(item.Value) { +// continue +// } +// resultFrom = append(resultFrom, map[string]any{ +// item.Label: item.Value, +// }) +// } +// skillName = nodeInput.Config.SkillName +// if g.IsEmpty(nodeInput.Config.SkillName) { +// skillName = nodeInput.Global.SkillName +// } +// +// return skillName, resultFrom, resultUserFrom +//} +// +//func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, nodeEntity *entity.FlowNode) (map[string]any, map[string]any, map[string]any) { +// input := make(map[string]any) +// output := make(map[string]any) +// model := make(map[string]any) +// // 1. 有引用 → 取引用节点的字段值 +// if len(nodeEntity.InputSource) > 0 { +// for _, source := range nodeEntity.InputSource { +// refNodeID := source.NodeId +// fields := source.Field +// +// refNode, ok := execInput.ConfigMap[refNodeID] +// if !ok { +// continue +// } +// +// inputMap := buildInputMap(refNode) +// outputMap := mergeOutput(refNode.OutputResult) +// modelMap := mergeModel(refNode.ModelConfig) +// if len(fields) > 0 { +// // 取指定字段 +// for _, f := range fields { +// if v, ok := inputMap[f]; ok { +// input[f] = v +// } +// if v, ok := modelMap[f]; ok { +// model[f] = v +// } +// for k, v := range outputMap { +// if strings.Contains(k, f) { +// model[k] = v +// } +// } +// } +// } else { +// // 取全部 +// if refNode.NodeCode != node.NodeTypeHttp { +// for k, v := range inputMap { +// input[k] = v +// } +// } +// for k, v := range modelMap { +// model[k] = v +// } +// } +// } +// } +// return input, output, model +//} +// +//// buildInputMap 从 FormConfig 构造输入map +//func buildInputMap(node *entity.FlowNode) map[string]any { +// m := make(map[string]any) +// for _, item := range node.FormConfig { +// m[item.Label] = item +// } +// return m +//} +// +//// mergeOutput 合并节点输出 []map → 单map +//func mergeOutput(output []node.NodeFormField) map[string]any { +// m := make(map[string]any) +// for _, item := range output { +// m[item.Label] = item +// } +// return m +//} +// +//// mergeOutput 合并节点输出 []map → 单map +//func mergeModel(output node.ModelItem) map[string]any { +// m := make(map[string]any) +// // 遍历 output.ModelForm 里的每一个 key 和原始值 +// for _, rawValue := range output.ModelForm { +// if g.IsEmpty(rawValue.Value) { +// continue +// } +// // 包装成 { "value": 原始值 } +// m[rawValue.Label] = rawValue.Value +// } +// return m +//} diff --git a/workflow/service/flow/lambda_node_util.go b/workflow/service/flow/lambda_node_util.go index 63020a6..353f1be 100644 --- a/workflow/service/flow/lambda_node_util.go +++ b/workflow/service/flow/lambda_node_util.go @@ -258,16 +258,15 @@ func waitGatewayResult(ctx context.Context, taskId string) (map[string]any, erro if task.State == 3 || !g.IsEmpty(task.ErrorMsg) { return nil, fmt.Errorf("模型执行失败:%s", task.ErrorMsg) } - if g.IsEmpty(task.Messages) { + if g.IsEmpty(task.OssFile) { return nil, fmt.Errorf("模型返回结果为空") } // 获取远程文件内容 - //file, err := GetFileBytesFromURL(ctx, task.OssFile) - //if err != nil { - // return nil, err - //} - //task.Messages = gconv.Map(file) - return task.Messages, nil + file, err := GetFileBytesFromURL(ctx, task.OssFile) + if err != nil { + return nil, err + } + return gconv.Map(file), nil } // updateTokenCount updates the token count in node execution @@ -296,6 +295,9 @@ func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.No if err != nil { return nil, err } + if composeResult.Status != "success" { + return nil, fmt.Errorf("模型提示词构建错误") + } modelInfo, err := GetModelInfo(ctx, &flowDto.GetModelInfoReq{ModelName: nodeInput.Config.ModelConfig.ModelName}) if err != nil { @@ -345,31 +347,41 @@ func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.No taskIdList := make([]string, len(composeResult.Messages.Rounds)) for idx, item := range composeResult.Messages.Rounds { - var taskId string - taskId, err = createGatewayTaskOnly(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item) + taskId, err := createGatewayTaskOnly(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item) if err != nil { return nil, err } taskIdList[idx] = taskId } + // 全局共享子上下文,实现一处报错全部终止 + subCtx, globalCancel := context.WithCancel(ctx) + defer globalCancel() // 函数退出兜底释放 + var wg sync.WaitGroup errChan := make(chan error, len(taskIdList)) + // 加互斥锁保护结果map + var mu sync.Mutex + for idx, taskId := range taskIdList { wg.Add(1) go func(idx int, taskId string) { defer wg.Done() - var taskResult map[string]any - taskResult, err = waitGatewayResult(ctx, taskId) + taskResult, err := waitGatewayResult(subCtx, taskId) if err != nil { errChan <- err + globalCancel() // 全局取消,所有协程收到ctx取消信号快速退出 return } + // 加锁写入map,解决并发竞态 + mu.Lock() mapTaskResult[idx] = taskResult + mu.Unlock() + updateTokenCount(ctx, nodeInput.NodeExecutionId, modelInfo.Model.ResponseTokenField, taskResult) }(idx, taskId) } @@ -377,8 +389,15 @@ func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.No wg.Wait() close(errChan) - if len(errChan) > 0 { - return nil, <-errChan + // 收集全部错误,而非只读一条 + var errs []error + for len(errChan) > 0 { + errs = append(errs, <-errChan) + } + + if len(errs) > 0 { + // 返回第一个错误;如需汇总所有错误可拼接 + return nil, errs[0] } } } else { @@ -490,16 +509,17 @@ func VideoConcat(ctx context.Context, videoUrls []string) (r any, err error) { } func GetFileBytesFromURL(ctx context.Context, fileUrl string) ([]byte, error) { + newS := strings.ReplaceAll(fileUrl, "http://cdn.redpowerfuture.com", g.Cfg().MustGet(ctx, "filePrefix").String()) // 使用 GoFrame 客户端(自带超时、追踪、日志等能力) - resp, err := g.Client().Get(ctx, fileUrl) + resp, err := g.Client().Get(ctx, newS) if err != nil { - return nil, gerror.Wrapf(err, "failed to request url: %s", fileUrl) + return nil, gerror.Wrapf(err, "failed to request url: %s", newS) } defer resp.Close() // 校验状态码 if resp.StatusCode != http.StatusOK { - return nil, gerror.Newf("request failed with status code: %d, url: %s", resp.StatusCode, fileUrl) + return nil, gerror.Newf("request failed with status code: %d, url: %s", resp.StatusCode, newS) } // 读取全部内容