Files
ai-agent/workflow/service/flow/lambda_node.go
qhd 4df45069e0 feat: 重构节点上下文与并发执行逻辑
重构GetNodeContextContent返回类型为切片,修复并发竞态与协程泄漏问题;回调改用OSS文件获取结果;调整节点输入上传时序
2026-06-18 14:24:48 +08:00

672 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package flow
import (
"ai-agent/workflow/consts/flow"
"ai-agent/workflow/consts/node"
"ai-agent/workflow/consts/public"
fileDao "ai-agent/workflow/dao/file"
flowDao "ai-agent/workflow/dao/flow"
"ai-agent/workflow/model/dto"
fileDto "ai-agent/workflow/model/dto/file"
flowDto "ai-agent/workflow/model/dto/flow"
"context"
"fmt"
"strconv"
"strings"
"sync"
"time"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
func StartLambda(ctx context.Context, input any) (any, error) {
return input, nil
}
func FormLambda(ctx context.Context, input any) (any, error) {
return input, nil
}
func IntentLambda(ctx context.Context, input any) (any, error) {
return input, nil
}
// JudgeLambda 分支判断核心读取IntentLambda的输出 → 返回目标节点ID做路由
func JudgeLambda(ctx context.Context, input any) (string, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return "", fmt.Errorf("入参类型错误,期望 *flowDto.NodeExecutionInput实际 %T", input)
}
// 1. 直接用你原来的方法(返回两个 map
inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
var outputResult []node.NodeFormField
outputResult = append(outputResult, inputMap...)
outputResult = append(outputResult, outputMap...)
//for _, valueAny := range inputMap {
// if field, ok := valueAny.(node.NodeFormField); ok {
// outputResult = append(outputResult, field)
// }
//}
//for _, valueAny := range outputMap {
// if field, ok := valueAny.(node.NodeFormField); ok {
// outputResult = append(outputResult, field)
// }
//}
for _, valueAny := range modelMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
}
}
contextParts := ""
for _, v := range nodeInput.Config.FormConfig {
contextParts = fmt.Sprintf("%s,%s:%s", contextParts, v.Label, v.Value)
}
if !nodeInput.Global.IsDialogue {
for _, v := range outputResult {
contextParts = fmt.Sprintf("%s,%s:%s", contextParts, v.Label, v.Value)
}
}
if !g.IsEmpty(nodeInput.Global.Desc) {
contextParts = fmt.Sprintf("%s,%s:%s", contextParts, "描述", nodeInput.Global.Desc)
}
configMap := gconv.Map(nodeInput.Config.Config)
ids := gconv.Strings(configMap["branch_ids"])
branchIdNameMap := gconv.Map(configMap["branch_id_name_map"])
var branchIdNameLines []string
for _, id := range ids {
name := gconv.String(branchIdNameMap[id])
branchIdNameLines = append(branchIdNameLines, fmt.Sprintf("%s: %s", id, name))
}
getIsChatModel, err := GetIsChatModel(ctx)
if err != nil {
return "", err
}
composeResult, err := GetComposeResult(ctx, 2, getIsChatModel.Model.ModelName, "", "", []map[string]any{{"prompt": strings.Join(branchIdNameLines, "\n")}}, []map[string]any{{"prompt": contextParts}}, nodeInput.Global.FileUrl, nodeInput.Global.SessionId, nodeInput.Config.Id, "判断节点")
if err != nil {
return "", err
}
if g.IsEmpty(composeResult.TaskId) {
return "", fmt.Errorf("msg is empty")
}
content := ""
for key, _ := range getIsChatModel.Model.ResponseBody {
content = gconv.String(composeResult.Messages.Rounds[0][key])
}
fmt.Printf("JudgeLambda路由目标节点ID=%s\n", gconv.String(content))
return content, nil
}
func BatchModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
reqMap := make([]map[string]any, 0)
for _, userItem := range userFrom {
m := gconv.Map(userItem)
for _, i := range nodeInput.Config.InputSource {
for _, f := range i.Field {
val := m[f]
if !g.IsEmpty(val) {
if g.NewVar(val).IsSlice() {
slice := gconv.SliceAny(val)
for _, item := range slice {
reqMap = append(reqMap, map[string]any{f: item})
}
} else {
reqMap = append(reqMap, map[string]any{f: val})
}
}
}
}
}
// 结果按索引存放,切片不同下标并发写无竞争,不用锁
res := make([][]node.NodeFormField, len(reqMap))
var wg sync.WaitGroup
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
// 缓冲1错误通道仅接收第一个错误
errCh := make(chan error, 1)
// 并发执行任务
for idx, item := range reqMap {
wg.Add(1)
go func(idx int, userItem map[string]any) {
defer wg.Done()
// 上下文已取消则直接退出
select {
case <-subCtx.Done():
return
default:
}
singleUserFrom := []map[string]any{userItem}
output, err := TextNode(subCtx, nodeInput, skillName, from, singleUserFrom)
if err != nil {
// 仅第一个错误写入通道
select {
case errCh <- err:
cancel() // 触发全局取消,其他协程快速退出
default:
}
return
}
res[idx] = output
}(idx, item)
}
// 任务全部结束后关闭错误通道
go func() {
wg.Wait()
close(errCh)
}()
// ========== 修正后的等待逻辑 ==========
var execErr error
select {
// 优先捕获业务错误
case execErr = <-errCh:
if execErr != nil {
// 收到真实业务错误,等待剩余协程收尾后返回
wg.Wait()
return nil, execErr
}
// execErr == nil 代表通道关闭、无任何错误,走到下方返回完整结果
case <-subCtx.Done():
// 上下文被取消阻塞读完errCh确认是否存在业务错误
execErr = <-errCh
}
// 拼接输出结果
var globalIndex int
var outputRes []node.NodeFormField
for _, items := range res {
for _, item := range items {
oldField := item.Field
if idx := strings.LastIndex(oldField, ":"); idx != -1 {
item.Field = oldField[:idx+1] + fmt.Sprint(globalIndex)
}
oldLabel := item.Label
if idx := strings.LastIndex(oldLabel, ":"); idx != -1 {
item.Label = oldLabel[:idx+1] + fmt.Sprint(globalIndex)
}
outputRes = append(outputRes, item)
}
globalIndex++
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// TextModelLambda 构建文案
func TextModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
outputRes, err := TextNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// ImageModelLambda 构建图片
func ImageModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
outputRes, err := ImgNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// AudioModelLambda 构建音频
func AudioModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
outputRes, err := AudioOptimizeNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// VideoModelLambda 构建视频
func VideoModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
res, err := VideoOptimizeNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
videoURL := make([]string, 0)
for _, v := range res {
if strings.Contains(v.Field, "content") {
videoURL = append(videoURL, gconv.String(v.Value))
}
}
if g.IsEmpty(videoURL) {
return nil, fmt.Errorf("视频合成失败:模型生成视频失败")
}
waitRes, err := VideoConcat(ctx, videoURL)
if err != nil {
return nil, err
}
msg := new(flowDto.VideoCallbackReq)
if err = gconv.Struct(waitRes, msg); err != nil {
return nil, err
}
urlPrefix, err := utils.GetFileAddressPrefix(ctx)
if err != nil {
return nil, err
}
outputRes := make([]node.NodeFormField, 0)
if nodeInput.Config.IsSaveFile {
outputRes = append(outputRes, node.NodeFormField{
Field: fmt.Sprintf("video_oss_url:content:%d", 0),
Value: msg.FileURL,
Label: fmt.Sprintf("video_oss_url:content:%d", 0),
Type: "string",
})
} else {
outputRes = append(outputRes, node.NodeFormField{
Field: fmt.Sprintf("concat_video_url:content:%d", 0),
Value: urlPrefix + msg.FileURL,
Label: fmt.Sprintf("concat_video_url:content:%d", 0),
Type: "string",
})
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// HttpLambda 构建HTTP(S)接口
func HttpLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
outputRes := make([]node.NodeFormField, 0)
var err error
outputRes, err = HttpNode(ctx, nodeInput)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// DataConversionLambda 构建数据转换
func DataConversionLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
outputRes, err := DataConversionNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
func DataMergeLambda(ctx context.Context, input any) (res any, err error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("参数合并入参类型错误")
}
// var nodeIds []string
// for _, item := range nodeInput.Config.InputSource {
// nodeIds = append(nodeIds, item.NodeId)
// }
//
// // 检查是否所有输入节点都执行完成,并且检查是否有节点失败
// checkAllExecuted := func() (allExecuted bool, hasFailed bool, failedNode string) {
// executedCount := 0
// for _, executedNode := range nodeInput.Global.ExecutedNodes {
// // 检查是否是我们需要的输入节点,并且它失败了
// for _, targetId := range nodeIds {
// if executedNode.NodeId == targetId {
// if executedNode.Status == node.NodeExecutionStatusFailed.Code() {
// return false, true, targetId
// }
// executedCount++
// break
// }
// }
// }
// return executedCount == len(nodeIds), false, ""
// }
//
// // 初次检查
// allExecuted, hasFailed, failedNode := checkAllExecuted()
// if hasFailed {
// return nil, fmt.Errorf("输入节点[%s]执行失败", failedNode)
// }
//
// // 如果不是全部都已执行,阻塞等待直到全部完成、上下文取消或有节点失败
// if !allExecuted {
// // 轮询检查每500ms检查一次依赖ctx超时控制
// ticker := time.NewTicker(500 * time.Millisecond)
// defer ticker.Stop()
//
// for {
// select {
// case <-ctx.Done():
// // 如果上下文已经取消,说明已有节点报错,直接退出
// return nil, ctx.Err()
// case <-ticker.C:
// // 重新检查所有节点
// allExecuted, hasFailed, failedNode := checkAllExecuted()
// if hasFailed {
// // 有一个输入节点失败,直接退出
// return nil, fmt.Errorf("输入节点[%s]执行失败", failedNode)
// }
// if allExecuted {
// // 全部执行完成,退出循环继续执行
// goto allDone
// }
//
// // 再次检查上下文是否已经取消,如果已经取消则立即退出
// select {
// case <-ctx.Done():
// return nil, ctx.Err()
// default:
// }
// }
// }
// }
//allDone:
//
// // 最终检查:所有输入节点都成功了吗
// _, hasFailed, failedNode = checkAllExecuted()
// if hasFailed {
// // 有一个输入节点失败,直接退出
// return nil, fmt.Errorf("输入节点[%s]执行失败", failedNode)
// }
//
// // 构建已执行节点ID的map方便合并时查找
// executedMap := make(map[string]*flowDto.ExecutedNode, len(nodeInput.Global.ExecutedNodes))
// for _, en := range nodeInput.Global.ExecutedNodes {
// executedMap[en.NodeId] = &en
// }
//
// // 合并所有输入源节点的输出结果
// for _, inputSource := range nodeInput.Config.InputSource {
// // 每次循环都检查上下文是否已取消,提前退出
// select {
// case <-ctx.Done():
// return nil, ctx.Err()
// default:
// }
// // 再次检查该节点是否失败
// if en, ok := executedMap[inputSource.NodeId]; ok && en.Status == node.NodeExecutionStatusFailed.Code() {
// return nil, fmt.Errorf("输入节点[%s]执行失败", inputSource.NodeId)
// }
// sourceNodeConfig := nodeInput.Global.ConfigMap[inputSource.NodeId]
// if sourceNodeConfig != nil && len(sourceNodeConfig.OutputResult) > 0 {
// nodeInput.Config.OutputResult = append(nodeInput.Config.OutputResult, sourceNodeConfig.OutputResult...)
// }
// }
return nodeInput, nil
}
func MergeLambda(ctx context.Context, input any) (res any, err error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("汇总节点入参类型错误")
}
// 1. 把所有节点输出拍平成 字段名->内容 的map
dataMap := make(map[string]node.NodeFormField)
_, outputMap, _ := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
//for _, valueAny := range outputMap {
// field := node.NodeFormField{}
// if field, ok = valueAny.(node.NodeFormField); ok {
// dataMap[field.Field] = field
// }
//}
for _, field := range outputMap {
dataMap[field.Field] = field
}
// 2. 提取所有文案text_content_0,1,2...
var contents []node.NodeFormField
for i := 0; ; i++ {
key := fmt.Sprintf("text_content:%d", i)
val, has := dataMap[key]
if !has || val.Value == "" {
break
}
contents = append(contents, val)
}
// 3. 提取所有图片image_0,1,2...
var images []string
for i := 0; ; i++ {
key := fmt.Sprintf("img_url:%d", i)
val, has := dataMap[key]
if !has || val.Value == "" {
break
}
images = append(images, gconv.String(val.Value))
}
// 4. 🔥 核心算法:图片按顺序连续归属给每条文案
textImgMap := make(map[int][]string) // key:文案下标value:图片列表
if len(contents) > 0 && len(images) > 0 {
imgIndex := 0 // 当前用到第几张图片
totalImg := len(images)
for i, item := range contents {
// 图片已分配完,直接退出
if imgIndex >= totalImg {
break
}
// 当前文案需要挂载的图片数量
needCount := gconv.Int(item.Expand)
if needCount <= 0 {
continue
}
var imgList []string
for imgc := 0; imgc < needCount; imgc++ {
// 关键:必须判断是否越界
if imgIndex >= totalImg {
break
}
imgList = append(imgList, images[imgIndex])
imgIndex++
}
// 有图片才存入 map
if len(imgList) > 0 {
textImgMap[i] = imgList
}
}
}
type Item struct {
Content string // 文案(可为空)
Images []string // 图片(可空、可多张)
}
// 🔥 把现有数据转换成通用 Item 列表(支持:纯文案、纯图片、图文任意组合)
var allItems []Item
url, err := utils.GetFileAddressPrefix(ctx)
if err != nil {
return nil, err
}
// 情况1有文案 → 按文案条目生成 Item每条文案+对应图片)
if len(contents) > 0 {
for i, val := range contents {
item := Item{
Content: url + gconv.String(val.Value), // 文案
Images: textImgMap[i], // 自动绑定该条目的图片(没有则为空切片)
}
allItems = append(allItems, item)
}
} else {
// 情况2没有文案只有图片 → 每张/每组图片生成独立 Item纯图片条目
if len(images) > 0 {
for _, img := range images {
allItems = append(allItems, Item{
Content: "",
Images: []string{img},
})
}
}
}
// 5. 生成多条独立HTML记录通用方案任意图文组合每条独立生成+独立上传)
var outputRecords []node.NodeFormField
// 遍历所有【独立图文条目】 → 每条生成独立HTML、独立上传OSS、独立输出记录
for idx, item := range allItems {
// 生成单条HTML
htmlContent := BuildHtml(item.Content, item.Images)
outputRecords = append(outputRecords,
node.NodeFormField{
Field: fmt.Sprintf("item_html_%d", idx),
Value: htmlContent,
Label: fmt.Sprintf("条目%d HTML", idx+1),
Type: "textarea",
},
)
if nodeInput.Config.IsSaveFile {
// 上传OSS每条独立上传
fileName := fmt.Sprintf("item_%d_%d.html", idx, time.Now().UnixMilli())
ossResult, err := Upload(ctx, &dto.UploadFileBytesReq{
FileBytes: []byte(htmlContent),
FileName: fileName,
})
if err != nil {
return nil, err
}
outputRecords = append(outputRecords,
node.NodeFormField{
Field: fmt.Sprintf("item_html_url_%d", idx),
Value: ossResult.FileURL,
Label: fmt.Sprintf("条目%d 地址", idx+1),
Type: "text",
},
)
}
}
// 最终输出多条记录
nodeInput.Config.OutputResult = outputRecords
return nodeInput, nil
}
func SummaryLambda(ctx context.Context, input any) (any, error) {
execInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("汇总节点入参类型错误,实际是 %T", input)
}
// 聚合所有已执行节点的输出结果
var summaryResult []map[string]interface{}
for _, executedNode := range execInput.Global.ExecutedNodes {
nodeID := executedNode.NodeId
nodeConfig := execInput.Global.ConfigMap[nodeID]
if nodeConfig != nil && len(nodeConfig.OutputResult) > 0 {
for _, field := range nodeConfig.OutputResult {
if strings.Contains(field.Field, "http_file_url") || strings.Contains(field.Field, "audio_oss_url") || strings.Contains(field.Field, "video_oss_url") || strings.Contains(field.Field, "item_html_url") || strings.Contains(field.Field, "img_oss_url") || strings.Contains(field.Field, "text_url") {
// 生成 毫秒时间戳 作为 KEY
timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10)
item := make(map[string]interface{})
item[timeKey] = field.Value
summaryResult = append(summaryResult, item)
}
}
}
}
// 把汇总结果存入当前节点的输出
g.Log().Info(ctx, fmt.Sprintf("结果汇总完成,汇总数据:%+v", summaryResult))
err := gfdb.DB(ctx, public.DbNameBlackDeacon).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
flowInfo, err := flowDao.FlowExecutionDao.Get(ctx, &flowDto.GetFlowExecutionReq{
SessionId: execInput.Global.SessionId,
})
if err != nil {
return err
}
executionReq := flowDto.UpdateFlowExecutionReq{
Id: execInput.Global.ExecutionId,
Status: flow.FlowExecutionStatusSuccess.Code(),
OutputParams: summaryResult,
}
_, err = flowDao.FlowExecutionDao.Update(ctx, &executionReq)
if flowInfo != nil {
var url string
url, err = utils.GetFileAddressPrefix(ctx)
if err != nil {
return err
}
createFileTempReq := make([]*fileDto.CreateFileTempReq, 0, len(flowInfo.OutputParams))
for _, fileUrl := range flowInfo.OutputParams {
m := gconv.Map(fileUrl)
for _, v := range m {
var createReq = new(fileDto.CreateFileTempReq)
createReq.BusinessId = flowInfo.SessionId
createReq.FileUrl = url + gconv.String(v)
createFileTempReq = append(createFileTempReq, createReq)
}
}
if len(createFileTempReq) > 0 {
_, err = fileDao.FileTempDao.BatchInsert(ctx, createFileTempReq)
if err != nil {
return err
}
}
}
return nil
})
return execInput, err
}
// CustomLambda 构建自定义
func CustomLambda(ctx context.Context, input any) (any, error) {
fmt.Println("CustomLambda:", input)
return input, nil
}