feat: 重构节点上下文与并发执行逻辑

重构GetNodeContextContent返回类型为切片,修复并发竞态与协程泄漏问题;回调改用OSS文件获取结果;调整节点输入上传时序
This commit is contained in:
2026-06-18 14:24:48 +08:00
parent fba7d032ae
commit 4df45069e0
5 changed files with 287 additions and 162 deletions

View File

@@ -664,18 +664,9 @@ func registerNodeToGraph(graph *compose.Graph[any, any], flowNode entity.FlowNod
// 执行节点
_, err = lambda(ctx, realInput)
durationMs := time.Since(startTime).Milliseconds()
// 上传OSS每条独立上传
ossResult1, err := Upload(ctx, &dto.UploadFileBytesReq{
FileBytes: gconv.Bytes(gconv.String(realInput)),
FileName: fmt.Sprintf("nodeInput:%v.txt", time.Now().UnixMilli()),
})
if err != nil {
return nil, err
}
updateReq := &nodeDto.UpdateNodeExecutionReq{
Id: nodeExecutionId,
OutputParamsPath: ossResult1.FileURL,
DurationMs: durationMs,
Id: nodeExecutionId,
DurationMs: durationMs,
}
if err != nil {
// 执行失败,更新状态
@@ -689,7 +680,15 @@ func registerNodeToGraph(graph *compose.Graph[any, any], flowNode entity.FlowNod
})
return nil, err
}
// 上传OSS每条独立上传
ossResult1, err := Upload(ctx, &dto.UploadFileBytesReq{
FileBytes: gconv.Bytes(gconv.String(realInput)),
FileName: fmt.Sprintf("nodeInput:%v.txt", time.Now().UnixMilli()),
})
if err != nil {
return nil, err
}
updateReq.OutputParamsPath = ossResult1.FileURL
// 执行成功,更新状态
updateReq.Status = node.NodeExecutionStatusSuccess.Code()
_, _ = nodeDao.NodeExecutionDao.Update(ctx, updateReq)

View File

@@ -44,16 +44,18 @@ func JudgeLambda(ctx context.Context, input any) (string, error) {
// 1. 直接用你原来的方法(返回两个 map
inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
var outputResult []node.NodeFormField
for _, valueAny := range inputMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
}
}
for _, valueAny := range outputMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
}
}
outputResult = append(outputResult, inputMap...)
outputResult = append(outputResult, outputMap...)
//for _, valueAny := range inputMap {
// if field, ok := valueAny.(node.NodeFormField); ok {
// outputResult = append(outputResult, field)
// }
//}
//for _, valueAny := range outputMap {
// if field, ok := valueAny.(node.NodeFormField); ok {
// outputResult = append(outputResult, field)
// }
//}
for _, valueAny := range modelMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
@@ -123,62 +125,75 @@ func BatchModelLambda(ctx context.Context, input any) (any, error) {
}
}
}
// 结果按索引存放,保证顺序
// 结果按索引存放,切片不同下标并发写无竞争,不用锁
res := make([][]node.NodeFormField, len(reqMap))
var wg sync.WaitGroup
// 用一个通道标记是否完成
done := make(chan struct{})
// 错误只存一个
var execErr error
// 并发执行
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
// 缓冲1错误通道仅接收第一个错误
errCh := make(chan error, 1)
// 并发执行任务
for idx, item := range reqMap {
wg.Add(1)
go func(idx int, userItem map[string]any) {
defer wg.Done()
// 上下文已取消则直接退出
select {
case <-subCtx.Done():
return
default:
}
singleUserFrom := []map[string]any{userItem}
output, err := TextNode(ctx, nodeInput, skillName, from, singleUserFrom)
output, err := TextNode(subCtx, nodeInput, skillName, from, singleUserFrom)
if err != nil {
// 并发安全赋值错误
if execErr == nil {
execErr = err
// 仅第一个错误写入通道
select {
case errCh <- err:
cancel() // 触发全局取消,其他协程快速退出
default:
}
return
}
// 直接按原索引写,顺序绝对正确
res[idx] = output
}(idx, item)
}
// 后台等待所有协程完成,然后关闭 done 通道
// 任务全部结束后关闭错误通道
go func() {
wg.Wait()
close(done)
close(errCh)
}()
// 等待全部完成
<-done
// 如果有错误,直接返回
if execErr != nil {
return nil, execErr
// ========== 修正后的等待逻辑 ==========
var execErr error
select {
// 优先捕获业务错误
case execErr = <-errCh:
if execErr != nil {
// 收到真实业务错误,等待剩余协程收尾后返回
wg.Wait()
return nil, execErr
}
// execErr == nil 代表通道关闭、无任何错误,走到下方返回完整结果
case <-subCtx.Done():
// 上下文被取消阻塞读完errCh确认是否存在业务错误
execErr = <-errCh
}
// 全局自增 i
// 拼接输出结果
var globalIndex int
var outputRes []node.NodeFormField
for _, items := range res {
for _, item := range items {
// 1. 拿到原来的 Field例如 "text_content:2:0"
oldField := item.Field
// 2. 找到最后一个 : 的位置
if idx := strings.LastIndex(oldField, ":"); idx != -1 {
// 3. 截断前面部分,拼接上新的 globalIndex
item.Field = oldField[:idx+1] + fmt.Sprint(globalIndex)
}
// Label 同理
oldLabel := item.Label
if idx := strings.LastIndex(oldLabel, ":"); idx != -1 {
item.Label = oldLabel[:idx+1] + fmt.Sprint(globalIndex)
@@ -437,11 +452,14 @@ func MergeLambda(ctx context.Context, input any) (res any, err error) {
// 1. 把所有节点输出拍平成 字段名->内容 的map
dataMap := make(map[string]node.NodeFormField)
_, outputMap, _ := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
for _, valueAny := range outputMap {
field := node.NodeFormField{}
if field, ok = valueAny.(node.NodeFormField); ok {
dataMap[field.Field] = field
}
//for _, valueAny := range outputMap {
// field := node.NodeFormField{}
// if field, ok = valueAny.(node.NodeFormField); ok {
// dataMap[field.Field] = field
// }
//}
for _, field := range outputMap {
dataMap[field.Field] = field
}
// 2. 提取所有文案text_content_0,1,2...

View File

@@ -939,23 +939,17 @@ func HttpNode(ctx context.Context, nodeInput *flowDto.NodeExecutionInput) ([]nod
func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, resultFrom []map[string]any, resultUserFrom []map[string]any) {
inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
var outputResult []node.NodeFormField
for _, valueAny := range inputMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
}
}
outputResult = append(outputResult, inputMap...)
resultUserFrom = []map[string]any{}
for _, valueAny := range outputMap {
if field, ok := valueAny.(node.NodeFormField); ok {
if !strings.Contains(field.Field, "text_url") && !strings.Contains(field.Field, "img_url") {
if strings.Contains(field.Field, "text_content") {
field.Value = StripHtmlTags(gconv.String(field.Value))
}
resultUserFrom = append(resultUserFrom, map[string]any{
field.Label: field.Value,
})
for _, field := range outputMap {
if !strings.Contains(field.Field, "text_url") && !strings.Contains(field.Field, "img_url") {
if strings.Contains(field.Field, "text_content") {
field.Value = StripHtmlTags(gconv.String(field.Value))
}
resultUserFrom = append(resultUserFrom, map[string]any{
field.Label: field.Value,
})
}
}
for _, valueAny := range modelMap {
@@ -963,18 +957,13 @@ func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, result
outputResult = append(outputResult, field)
}
}
//if !nodeInput.Global.IsDialogue {
for _, item := range outputResult {
resultUserFrom = append(resultUserFrom, map[string]any{
item.Label: item.Value,
})
if !nodeInput.Global.IsDialogue {
for _, item := range outputResult {
resultUserFrom = append(resultUserFrom, map[string]any{
item.Label: item.Value,
})
}
}
for _, item := range nodeInput.Config.FormConfig {
resultUserFrom = append(resultUserFrom, map[string]any{
item.Label: item.Value,
})
}
//}
if !g.IsEmpty(nodeInput.Global.Desc) {
resultUserFrom = append(resultUserFrom, map[string]any{
"desc": nodeInput.Global.Desc,
@@ -998,48 +987,37 @@ func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, result
return skillName, resultFrom, resultUserFrom
}
func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, nodeEntity *entity.FlowNode) (map[string]any, map[string]any, map[string]any) {
input := make(map[string]any)
output := make(map[string]any)
func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, nodeEntity *entity.FlowNode) ([]node.NodeFormField, []node.NodeFormField, map[string]any) {
var input []node.NodeFormField
var output []node.NodeFormField
model := make(map[string]any)
// 1. 有引用 → 取引用节点的字段值
if len(nodeEntity.InputSource) > 0 {
for _, source := range nodeEntity.InputSource {
refNodeID := source.NodeId
fields := source.Field
refNode, ok := execInput.ConfigMap[refNodeID]
refNode, ok := execInput.ConfigMap[source.NodeId]
if !ok {
continue
}
inputMap := buildInputMap(refNode)
outputMap := mergeOutput(refNode.OutputResult)
modelMap := mergeModel(refNode.ModelConfig)
if len(fields) > 0 {
if len(source.Field) > 0 {
// 取指定字段
for _, f := range fields {
if v, ok := inputMap[f]; ok {
input[f] = v
}
if v, ok := modelMap[f]; ok {
model[f] = v
}
for k, v := range outputMap {
if strings.Contains(k, f) {
model[k] = v
for _, f := range source.Field {
for _, v := range refNode.FormConfig {
if strings.Contains(v.Label, f) {
input = append(input, v)
}
}
}
} else {
// 取全部
if refNode.NodeCode != node.NodeTypeHttp {
for k, v := range inputMap {
input[k] = v
for _, v := range refNode.ModelConfig.ModelForm {
if g.IsEmpty(v.Value) {
continue
}
if strings.Contains(v.Label, f) {
model[f] = v
}
}
for _, v := range refNode.OutputResult {
if strings.Contains(v.Label, f) {
output = append(output, v)
}
}
}
for k, v := range modelMap {
model[k] = v
}
}
}
@@ -1047,34 +1025,145 @@ func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, nodeEntity *en
return input, output, model
}
// buildInputMap 从 FormConfig 构造输入map
func buildInputMap(node *entity.FlowNode) map[string]any {
m := make(map[string]any)
for _, item := range node.FormConfig {
m[item.Label] = item
}
return m
}
// mergeOutput 合并节点输出 []map → 单map
func mergeOutput(output []node.NodeFormField) map[string]any {
m := make(map[string]any)
for _, item := range output {
m[item.Label] = item
}
return m
}
// mergeOutput 合并节点输出 []map → 单map
func mergeModel(output node.ModelItem) map[string]any {
m := make(map[string]any)
// 遍历 output.ModelForm 里的每一个 key 和原始值
for _, rawValue := range output.ModelForm {
if g.IsEmpty(rawValue.Value) {
continue
}
// 包装成 { "value": 原始值 }
m[rawValue.Label] = rawValue.Value
}
return m
}
//func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, resultFrom []map[string]any, resultUserFrom []map[string]any) {
// inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
// var outputResult []node.NodeFormField
// for _, valueAny := range inputMap {
// if field, ok := valueAny.(node.NodeFormField); ok {
// outputResult = append(outputResult, field)
// }
// }
//
// resultUserFrom = []map[string]any{}
// for _, valueAny := range outputMap {
// if field, ok := valueAny.(node.NodeFormField); ok {
// if !strings.Contains(field.Field, "text_url") && !strings.Contains(field.Field, "img_url") {
// if strings.Contains(field.Field, "text_content") {
// field.Value = StripHtmlTags(gconv.String(field.Value))
// }
// resultUserFrom = append(resultUserFrom, map[string]any{
// field.Label: field.Value,
// })
// }
// }
// }
// for _, valueAny := range modelMap {
// if field, ok := valueAny.(node.NodeFormField); ok {
// outputResult = append(outputResult, field)
// }
// }
// //if !nodeInput.Global.IsDialogue {
// for _, item := range outputResult {
// resultUserFrom = append(resultUserFrom, map[string]any{
// item.Label: item.Value,
// })
// }
// for _, item := range nodeInput.Config.FormConfig {
// resultUserFrom = append(resultUserFrom, map[string]any{
// item.Label: item.Value,
// })
// }
// //}
// if !g.IsEmpty(nodeInput.Global.Desc) {
// resultUserFrom = append(resultUserFrom, map[string]any{
// "desc": nodeInput.Global.Desc,
// })
// }
//
// resultFrom = []map[string]any{}
// for _, item := range nodeInput.Config.ModelConfig.ModelForm {
// if g.IsEmpty(item.Value) {
// continue
// }
// resultFrom = append(resultFrom, map[string]any{
// item.Label: item.Value,
// })
// }
// skillName = nodeInput.Config.SkillName
// if g.IsEmpty(nodeInput.Config.SkillName) {
// skillName = nodeInput.Global.SkillName
// }
//
// return skillName, resultFrom, resultUserFrom
//}
//
//func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, nodeEntity *entity.FlowNode) (map[string]any, map[string]any, map[string]any) {
// input := make(map[string]any)
// output := make(map[string]any)
// model := make(map[string]any)
// // 1. 有引用 → 取引用节点的字段值
// if len(nodeEntity.InputSource) > 0 {
// for _, source := range nodeEntity.InputSource {
// refNodeID := source.NodeId
// fields := source.Field
//
// refNode, ok := execInput.ConfigMap[refNodeID]
// if !ok {
// continue
// }
//
// inputMap := buildInputMap(refNode)
// outputMap := mergeOutput(refNode.OutputResult)
// modelMap := mergeModel(refNode.ModelConfig)
// if len(fields) > 0 {
// // 取指定字段
// for _, f := range fields {
// if v, ok := inputMap[f]; ok {
// input[f] = v
// }
// if v, ok := modelMap[f]; ok {
// model[f] = v
// }
// for k, v := range outputMap {
// if strings.Contains(k, f) {
// model[k] = v
// }
// }
// }
// } else {
// // 取全部
// if refNode.NodeCode != node.NodeTypeHttp {
// for k, v := range inputMap {
// input[k] = v
// }
// }
// for k, v := range modelMap {
// model[k] = v
// }
// }
// }
// }
// return input, output, model
//}
//
//// buildInputMap 从 FormConfig 构造输入map
//func buildInputMap(node *entity.FlowNode) map[string]any {
// m := make(map[string]any)
// for _, item := range node.FormConfig {
// m[item.Label] = item
// }
// return m
//}
//
//// mergeOutput 合并节点输出 []map → 单map
//func mergeOutput(output []node.NodeFormField) map[string]any {
// m := make(map[string]any)
// for _, item := range output {
// m[item.Label] = item
// }
// return m
//}
//
//// mergeOutput 合并节点输出 []map → 单map
//func mergeModel(output node.ModelItem) map[string]any {
// m := make(map[string]any)
// // 遍历 output.ModelForm 里的每一个 key 和原始值
// for _, rawValue := range output.ModelForm {
// if g.IsEmpty(rawValue.Value) {
// continue
// }
// // 包装成 { "value": 原始值 }
// m[rawValue.Label] = rawValue.Value
// }
// return m
//}

View File

@@ -258,16 +258,15 @@ func waitGatewayResult(ctx context.Context, taskId string) (map[string]any, erro
if task.State == 3 || !g.IsEmpty(task.ErrorMsg) {
return nil, fmt.Errorf("模型执行失败:%s", task.ErrorMsg)
}
if g.IsEmpty(task.Messages) {
if g.IsEmpty(task.OssFile) {
return nil, fmt.Errorf("模型返回结果为空")
}
// 获取远程文件内容
//file, err := GetFileBytesFromURL(ctx, task.OssFile)
//if err != nil {
// return nil, err
//}
//task.Messages = gconv.Map(file)
return task.Messages, nil
file, err := GetFileBytesFromURL(ctx, task.OssFile)
if err != nil {
return nil, err
}
return gconv.Map(file), nil
}
// updateTokenCount updates the token count in node execution
@@ -296,6 +295,9 @@ func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.No
if err != nil {
return nil, err
}
if composeResult.Status != "success" {
return nil, fmt.Errorf("模型提示词构建错误")
}
modelInfo, err := GetModelInfo(ctx, &flowDto.GetModelInfoReq{ModelName: nodeInput.Config.ModelConfig.ModelName})
if err != nil {
@@ -345,31 +347,41 @@ func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.No
taskIdList := make([]string, len(composeResult.Messages.Rounds))
for idx, item := range composeResult.Messages.Rounds {
var taskId string
taskId, err = createGatewayTaskOnly(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item)
taskId, err := createGatewayTaskOnly(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item)
if err != nil {
return nil, err
}
taskIdList[idx] = taskId
}
// 全局共享子上下文,实现一处报错全部终止
subCtx, globalCancel := context.WithCancel(ctx)
defer globalCancel() // 函数退出兜底释放
var wg sync.WaitGroup
errChan := make(chan error, len(taskIdList))
// 加互斥锁保护结果map
var mu sync.Mutex
for idx, taskId := range taskIdList {
wg.Add(1)
go func(idx int, taskId string) {
defer wg.Done()
var taskResult map[string]any
taskResult, err = waitGatewayResult(ctx, taskId)
taskResult, err := waitGatewayResult(subCtx, taskId)
if err != nil {
errChan <- err
globalCancel() // 全局取消所有协程收到ctx取消信号快速退出
return
}
// 加锁写入map解决并发竞态
mu.Lock()
mapTaskResult[idx] = taskResult
mu.Unlock()
updateTokenCount(ctx, nodeInput.NodeExecutionId, modelInfo.Model.ResponseTokenField, taskResult)
}(idx, taskId)
}
@@ -377,8 +389,15 @@ func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.No
wg.Wait()
close(errChan)
if len(errChan) > 0 {
return nil, <-errChan
// 收集全部错误,而非只读一条
var errs []error
for len(errChan) > 0 {
errs = append(errs, <-errChan)
}
if len(errs) > 0 {
// 返回第一个错误;如需汇总所有错误可拼接
return nil, errs[0]
}
}
} else {
@@ -490,16 +509,17 @@ func VideoConcat(ctx context.Context, videoUrls []string) (r any, err error) {
}
func GetFileBytesFromURL(ctx context.Context, fileUrl string) ([]byte, error) {
newS := strings.ReplaceAll(fileUrl, "http://cdn.redpowerfuture.com", g.Cfg().MustGet(ctx, "filePrefix").String())
// 使用 GoFrame 客户端(自带超时、追踪、日志等能力)
resp, err := g.Client().Get(ctx, fileUrl)
resp, err := g.Client().Get(ctx, newS)
if err != nil {
return nil, gerror.Wrapf(err, "failed to request url: %s", fileUrl)
return nil, gerror.Wrapf(err, "failed to request url: %s", newS)
}
defer resp.Close()
// 校验状态码
if resp.StatusCode != http.StatusOK {
return nil, gerror.Newf("request failed with status code: %d, url: %s", resp.StatusCode, fileUrl)
return nil, gerror.Newf("request failed with status code: %d, url: %s", resp.StatusCode, newS)
}
// 读取全部内容