feat: 重构节点上下文与并发执行逻辑
重构GetNodeContextContent返回类型为切片,修复并发竞态与协程泄漏问题;回调改用OSS文件获取结果;调整节点输入上传时序
This commit is contained in:
@@ -258,16 +258,15 @@ func waitGatewayResult(ctx context.Context, taskId string) (map[string]any, erro
|
||||
if task.State == 3 || !g.IsEmpty(task.ErrorMsg) {
|
||||
return nil, fmt.Errorf("模型执行失败:%s", task.ErrorMsg)
|
||||
}
|
||||
if g.IsEmpty(task.Messages) {
|
||||
if g.IsEmpty(task.OssFile) {
|
||||
return nil, fmt.Errorf("模型返回结果为空")
|
||||
}
|
||||
// 获取远程文件内容
|
||||
//file, err := GetFileBytesFromURL(ctx, task.OssFile)
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
//task.Messages = gconv.Map(file)
|
||||
return task.Messages, nil
|
||||
file, err := GetFileBytesFromURL(ctx, task.OssFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gconv.Map(file), nil
|
||||
}
|
||||
|
||||
// updateTokenCount updates the token count in node execution
|
||||
@@ -296,6 +295,9 @@ func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.No
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if composeResult.Status != "success" {
|
||||
return nil, fmt.Errorf("模型提示词构建错误")
|
||||
}
|
||||
|
||||
modelInfo, err := GetModelInfo(ctx, &flowDto.GetModelInfoReq{ModelName: nodeInput.Config.ModelConfig.ModelName})
|
||||
if err != nil {
|
||||
@@ -345,31 +347,41 @@ func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.No
|
||||
taskIdList := make([]string, len(composeResult.Messages.Rounds))
|
||||
|
||||
for idx, item := range composeResult.Messages.Rounds {
|
||||
var taskId string
|
||||
taskId, err = createGatewayTaskOnly(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item)
|
||||
taskId, err := createGatewayTaskOnly(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
taskIdList[idx] = taskId
|
||||
}
|
||||
|
||||
// 全局共享子上下文,实现一处报错全部终止
|
||||
subCtx, globalCancel := context.WithCancel(ctx)
|
||||
defer globalCancel() // 函数退出兜底释放
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(taskIdList))
|
||||
|
||||
// 加互斥锁保护结果map
|
||||
var mu sync.Mutex
|
||||
|
||||
for idx, taskId := range taskIdList {
|
||||
wg.Add(1)
|
||||
|
||||
go func(idx int, taskId string) {
|
||||
defer wg.Done()
|
||||
|
||||
var taskResult map[string]any
|
||||
taskResult, err = waitGatewayResult(ctx, taskId)
|
||||
taskResult, err := waitGatewayResult(subCtx, taskId)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
globalCancel() // 全局取消,所有协程收到ctx取消信号快速退出
|
||||
return
|
||||
}
|
||||
|
||||
// 加锁写入map,解决并发竞态
|
||||
mu.Lock()
|
||||
mapTaskResult[idx] = taskResult
|
||||
mu.Unlock()
|
||||
|
||||
updateTokenCount(ctx, nodeInput.NodeExecutionId, modelInfo.Model.ResponseTokenField, taskResult)
|
||||
}(idx, taskId)
|
||||
}
|
||||
@@ -377,8 +389,15 @@ func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.No
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
if len(errChan) > 0 {
|
||||
return nil, <-errChan
|
||||
// 收集全部错误,而非只读一条
|
||||
var errs []error
|
||||
for len(errChan) > 0 {
|
||||
errs = append(errs, <-errChan)
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
// 返回第一个错误;如需汇总所有错误可拼接
|
||||
return nil, errs[0]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -490,16 +509,17 @@ func VideoConcat(ctx context.Context, videoUrls []string) (r any, err error) {
|
||||
}
|
||||
|
||||
func GetFileBytesFromURL(ctx context.Context, fileUrl string) ([]byte, error) {
|
||||
newS := strings.ReplaceAll(fileUrl, "http://cdn.redpowerfuture.com", g.Cfg().MustGet(ctx, "filePrefix").String())
|
||||
// 使用 GoFrame 客户端(自带超时、追踪、日志等能力)
|
||||
resp, err := g.Client().Get(ctx, fileUrl)
|
||||
resp, err := g.Client().Get(ctx, newS)
|
||||
if err != nil {
|
||||
return nil, gerror.Wrapf(err, "failed to request url: %s", fileUrl)
|
||||
return nil, gerror.Wrapf(err, "failed to request url: %s", newS)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
// 校验状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, gerror.Newf("request failed with status code: %d, url: %s", resp.StatusCode, fileUrl)
|
||||
return nil, gerror.Newf("request failed with status code: %d, url: %s", resp.StatusCode, newS)
|
||||
}
|
||||
|
||||
// 读取全部内容
|
||||
|
||||
Reference in New Issue
Block a user