diff --git a/controller/model_controller.go b/controller/model_controller.go index af6653f..acd9191 100644 --- a/controller/model_controller.go +++ b/controller/model_controller.go @@ -52,7 +52,7 @@ func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto. pageSize = req.PageSize } } - list, total, err := service.Model.List(ctx, pageNum, pageSize, req.ModelName, req.ModelType) + list, total, err := service.Model.List(ctx, pageNum, pageSize, req) if err != nil { return nil, err } diff --git a/dao/model_dao.go b/dao/model_dao.go index 93f74ed..f97f263 100644 --- a/dao/model_dao.go +++ b/dao/model_dao.go @@ -90,7 +90,7 @@ func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, er return } -func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int) (list []*entity.AsynchModel, total int64, err error) { +func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) { model := gfdb.DB(ctx).Model(ctx, public.TableNameModel). OrderDesc(entity.AsynchModelCol.CreatedAt) if modelNameLike != "" { @@ -99,6 +99,9 @@ func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLik if modelType != 0 { model = model.Where(entity.AsynchModelCol.ModelsType, modelType) } + if isPrivate != 0 { + model = model.Where(entity.AsynchModelCol.IsPrivate, isPrivate) + } if pageNum > 0 && pageSize > 0 { model = model.Page(pageNum, pageSize) } @@ -148,7 +151,7 @@ func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, return } -func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, creator string, modelNameLike string, modelType int) (list []*entity.AsynchModel, err error) { +func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, creator string, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, err error) { whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" args := []any{creator} @@ -160,6 +163,10 @@ func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, creator string, whereSQL += " AND models_type = ?" args = append(args, modelType) } + if isPrivate != 0 { + whereSQL += " AND is_private = ?" + args = append(args, isPrivate) + } querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL) diff --git a/model/dto/model_dto.go b/model/dto/model_dto.go index 9a1cb95..c8ed514 100644 --- a/model/dto/model_dto.go +++ b/model/dto/model_dto.go @@ -82,6 +82,7 @@ type ListModelReq struct { PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数(默认10)"` ModelName string `p:"modelName" json:"modelName" dc:"模型名称(模糊查询,可选)"` ModelType int `p:"modelType" json:"modelType" dc:"模型类型"` + IsPrivate int `p:"isPrivate" json:"isPrivate" dc:"是否私有化 0-私有 1-公共"` } type ListModelRes struct { diff --git a/service/model_service.go b/service/model_service.go index ab29b4a..42f6efc 100644 --- a/service/model_service.go +++ b/service/model_service.go @@ -89,7 +89,7 @@ func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, return model, nil } -func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int) (list []*entity.AsynchModel, total int64, err error) { +func (s *modelService) List(ctx context.Context, pageNum, pageSize int, req *dto.ListModelReq) (list []*entity.AsynchModel, total int64, err error) { isSuperAdmin, err := IsSuperAdmin(ctx) if err != nil { return nil, 0, err @@ -103,9 +103,9 @@ func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNam var count int64 if isSuperAdmin { - models, count, err = dao.Model.List(ctx, pageNum, pageSize, modelNameLike, modelType) + models, count, err = dao.Model.List(ctx, pageNum, pageSize, req.ModelName, req.ModelType, req.IsPrivate) } else { - models, count, err = s.getModelsWithDedup(ctx, user.UserName, pageNum, pageSize, modelNameLike, modelType) + models, count, err = s.getModelsWithDedup(ctx, user.UserName, pageNum, pageSize, req.ModelName, req.ModelType, req.IsPrivate) } if err != nil { return nil, 0, err @@ -122,9 +122,9 @@ func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNam } // getModelsWithDedup 获取普通用户的模型列表并去重 -func (s *modelService) getModelsWithDedup(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string, modelType int) (list []*entity.AsynchModel, total int64, err error) { +func (s *modelService) getModelsWithDedup(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) { // 1. 查全量数据(不分页,便于去重) - allModels, err := dao.Model.GetByCreatorAndPlatform(ctx, creator, modelNameLike, modelType) + allModels, err := dao.Model.GetByCreatorAndPlatform(ctx, creator, modelNameLike, modelType, isPrivate) if err != nil { return nil, 0, err }