package dao import ( "context" "fmt" "model-gateway/consts/public" "model-gateway/model/entity" "gitea.redpowerfuture.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/util/gconv" ) var ModelGatewayTask = &modelGatewayTaskDao{} type modelGatewayTaskDao struct{} // Insert 插入 func (d *modelGatewayTaskDao) Insert(ctx context.Context, req *entity.ModelGatewayTask) (id int64, err error) { m := new(entity.ModelGatewayTask) err = gconv.Struct(req, &m) if err != nil { return } r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Insert(m) if err != nil { return } return r.LastInsertId() } // Update 更新(按ID) func (d *modelGatewayTaskDao) Update(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). OmitEmpty(). Data(req). Where(entity.ModelGatewayTaskCol.Id, req.Id). Update() if err != nil { return } return r.RowsAffected() } // Get 获取(按TaskID 或 ID) func (d *modelGatewayTaskDao) Get(ctx context.Context, req *entity.ModelGatewayTask) (m *entity.ModelGatewayTask, err error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). OmitEmpty(). Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID). Where(entity.ModelGatewayTaskCol.Id, req.Id). One() if err != nil { return } err = r.Struct(&m) return } // List 分页查询 func (d *modelGatewayTaskDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayTask) (list []*entity.ModelGatewayTask, total int64, err error) { model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). OmitEmpty(). Where(entity.ModelGatewayTaskCol.Creator, req.Creator). Where(entity.ModelGatewayTaskCol.ModelName, "%"+req.ModelName+"%"). Where(entity.ModelGatewayTaskCol.BizName, req.BizName). Where(entity.ModelGatewayTaskCol.State, req.State). Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID). OrderDesc(entity.ModelGatewayTaskCol.CreatedAt) if pageNum > 0 && pageSize > 0 { model = model.Page(pageNum, pageSize) } r, totalInt, err := model.AllAndCount(false) if err != nil { return nil, 0, err } total = gconv.Int64(totalInt) err = r.Structs(&list) return } // Delete 删除(软删,按ID) func (d *modelGatewayTaskDao) Delete(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). Where(entity.ModelGatewayTaskCol.Id, req.Id). Delete() if err != nil { return } return r.RowsAffected() } // ListByTaskIDs 批量查询 func (d *modelGatewayTaskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.ModelGatewayTask, err error) { if len(taskIDs) == 0 { return nil, nil } r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). WhereIn(entity.ModelGatewayTaskCol.TaskID, taskIDs). All() if err != nil { return nil, err } err = r.Structs(&list) return } // MarkDownloadedByID 标记已下载 func (d *modelGatewayTaskDao) MarkDownloadedByID(ctx context.Context, id int64) error { _, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). Where(entity.ModelGatewayTaskCol.Id, id). Where(entity.ModelGatewayTaskCol.State, 2). Data(map[string]any{entity.ModelGatewayTaskCol.State: 4}). Update() return err } // GetPendingAsyncTasks 获取进行中的异步任务 func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.ModelGatewayTask, error) { var tasks []*entity.ModelGatewayTask err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). Where(entity.ModelGatewayTaskCol.State, 1). Limit(limit). Scan(&tasks) return tasks, err } // ======================== 事务抢占 ======================== // ClaimByID 按主键抢占,返回抢占后的任务 func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) { var task entity.ModelGatewayTask err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { r, err := tx.Model(public.TableNameTask). Where(entity.ModelGatewayTaskCol.Id, id). Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). Limit(1). LockUpdate(). One() if err != nil { return err } if r.IsEmpty() { return fmt.Errorf("任务已被抢占或不存在: id=%d", id) } if err := r.Struct(&task); err != nil { return err } _, err = tx.Model(public.TableNameTask). Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}). Where(entity.ModelGatewayTaskCol.Id, id). OmitEmpty(). Update() return err }) if err != nil { return nil, err } return &task, nil }