Files
model-gateway/dao/task_dao.go

112 lines
3.2 KiB
Go

package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
)
var Task = &taskDao{}
type taskDao struct{}
// Insert 插入
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
m := new(entity.AsynchTask)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
// Get 获取
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
Fields(fields).One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&m)
return
}
// MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间
func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error {
data := gdb.Map{
entity.AsynchTaskCol.State: 4,
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Where(entity.AsynchTaskCol.State, 2).
Data(data).
Update()
return err
}
// List 任务分页查询(受 gfdb 租户 Hook 影响)
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
m := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
if modelNameLike != "" {
m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%")
}
if taskIDLike != "" {
m = m.WhereLike(entity.AsynchTaskCol.TaskID, "%"+taskIDLike+"%")
}
if state != nil {
m = m.Where(entity.AsynchTaskCol.State, *state)
}
m = m.OrderDesc(entity.AsynchTaskCol.CreatedAt)
if pageNum > 0 && pageSize > 0 {
m = m.Page(pageNum, pageSize)
}
r, totalInt, err := m.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}
// GetPendingAsyncTasks 获取进行中的异步任务
func (d *taskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where("state", 1).
Where("deleted_at IS NULL").
Limit(limit).
Scan(&tasks)
return tasks, err
}