Files
ai-agent/workflow/dao/node/node_execution_dao.go

100 lines
3.4 KiB
Go
Raw Normal View History

package node
import (
"ai-agent/workflow/consts/public"
nodeDto "ai-agent/workflow/model/dto/node"
"ai-agent/workflow/model/entity"
"context"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var NodeExecutionDao = &nodeExecutionDao{}
type nodeExecutionDao struct{}
// Insert 插入节点执行记录
func (d *nodeExecutionDao) Insert(ctx context.Context, req *nodeDto.CreateNodeExecutionReq) (id int64, err error) {
nodeExecution := new(entity.NodeExecution)
err = gconv.Struct(req, &nodeExecution)
if err != nil {
return 0, err
}
r, err := gfdb.DB(ctx, public.DbNameBlackDeacon).Model(ctx, public.TableNameNodeExecution).Insert(&nodeExecution)
if err != nil {
return 0, err
}
return r.LastInsertId()
}
// Update 更新节点执行记录
func (d *nodeExecutionDao) Update(ctx context.Context, req *nodeDto.UpdateNodeExecutionReq) (rows int64, err error) {
model := gfdb.DB(ctx, public.DbNameBlackDeacon).Model(ctx, public.TableNameNodeExecution).OmitEmpty()
if !g.IsEmpty(req.CompletionTokens) {
model.Data(entity.NodeExecutionCol.CompletionTokens, &gdb.Counter{
Field: entity.NodeExecutionCol.CompletionTokens,
Value: gconv.Float64(req.CompletionTokens),
})
}
if !g.IsEmpty(req.PromptTokens) {
model.Data(entity.NodeExecutionCol.PromptTokens, &gdb.Counter{
Field: entity.NodeExecutionCol.PromptTokens,
Value: gconv.Float64(req.PromptTokens),
})
}
if !g.IsEmpty(req.TotalTokens) {
model.Data(entity.NodeExecutionCol.TotalTokens, &gdb.Counter{
Field: entity.NodeExecutionCol.TotalTokens,
Value: gconv.Float64(req.TotalTokens),
})
}
r, err := model.Data(&req).Where(entity.NodeExecutionCol.Id, req.Id).Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// Delete 删除节点执行记录
func (d *nodeExecutionDao) Delete(ctx context.Context, req *nodeDto.DeleteNodeExecutionReq) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameBlackDeacon).Model(ctx, public.TableNameNodeExecution).Where(entity.NodeExecutionCol.Id, req.Id).Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// Get 根据ID查询节点执行记录
func (d *nodeExecutionDao) Get(ctx context.Context, req *nodeDto.GetNodeExecutionReq, fields ...string) (res *entity.NodeExecution, err error) {
r, err := gfdb.DB(ctx, public.DbNameBlackDeacon).Model(ctx, public.TableNameNodeExecution).NoTenantId(ctx).OmitEmpty().
Where(entity.NodeExecutionCol.Id, req.Id).
Fields(fields).One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&res)
return res, err
}
// ListByFlowExecutionId 查询指定流程执行下的所有节点执行记录
func (d *nodeExecutionDao) ListByFlowExecutionId(ctx context.Context, req *nodeDto.ListNodeExecutionByFlowReq, fields ...string) (res []*entity.NodeExecution, total int, err error) {
model := gfdb.DB(ctx, public.DbNameBlackDeacon).Model(ctx, public.TableNameNodeExecution).NoTenantId(ctx).Fields(fields).OmitEmpty()
model.Where(entity.NodeExecutionCol.FlowExecutionId, req.FlowExecutionId)
model.OrderAsc(entity.NodeExecutionCol.CreatedAt)
if req.Page != nil {
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
}
r, total, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
err = r.Structs(&res)
return res, total, err
}