refactor(model): 重构模型实体和数据访问层
This commit is contained in:
190
dao/model_dao.go
190
dao/model_dao.go
@@ -2,14 +2,11 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
@@ -18,157 +15,80 @@ var Model = &modelDao{}
|
||||
|
||||
type modelDao struct{}
|
||||
|
||||
func (d *modelDao) Insert(ctx context.Context, req *dto.CreateModelReq) (id int64, err error) {
|
||||
asyncModel := new(entity.AsynchModel)
|
||||
err = gconv.Struct(req, &asyncModel)
|
||||
// Insert 插入
|
||||
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) {
|
||||
m := new(entity.AsynchModel)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).Data(asyncModel).Insert()
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *modelDao) Update(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) {
|
||||
// 触发 gfdb 的 updateHook 自动填充 updater,需要显式带 updater 字段
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
// Update 更新
|
||||
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, m.ID).
|
||||
Data(m).
|
||||
Data(&req).
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *modelDao) DeleteByID(ctx context.Context, id string) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.Id, id).
|
||||
// Delete 删除
|
||||
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.ModelName, modelName).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
NoTenantId(ctx).
|
||||
Where(entity.AsynchModelCol.Id, id).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) Count(ctx context.Context, req *dto.GetModelReq) (count int, err error) {
|
||||
count, err = gfdb.DB(ctx).Model(ctx, public.TableNameModel).OmitEmpty().
|
||||
// Get 按ID获取(带租户隔离,只查当前租户)
|
||||
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||
Where(entity.AsynchModelCol.Id, req.ID).Count()
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
OrderDesc(entity.AsynchModelCol.CreatedAt)
|
||||
if modelNameLike != "" {
|
||||
model = model.WhereLike(entity.AsynchModelCol.ModelName, "%"+modelNameLike+"%")
|
||||
}
|
||||
if modelType != 0 {
|
||||
model = model.Where(entity.AsynchModelCol.ModelType, modelType)
|
||||
}
|
||||
if isPrivate != 0 {
|
||||
model = model.Where(entity.AsynchModelCol.IsPrivate, isPrivate)
|
||||
}
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
model = model.Page(pageNum, pageSize)
|
||||
}
|
||||
r, totalInt, err := model.AllAndCount(false)
|
||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = gconv.Int64(totalInt)
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.IsChatModel, 1).
|
||||
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
return
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// ListByCreatorAndPlatform 普通用户:平台公共(tenant_id=0) + 自己创建的(creator=xxx)
|
||||
func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
|
||||
// 构建 Where 条件
|
||||
whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" //1 代表超级管理员
|
||||
args := []any{creator}
|
||||
|
||||
if modelNameLike != "" {
|
||||
whereSQL += " AND model_name LIKE ?"
|
||||
args = append(args, "%"+modelNameLike+"%")
|
||||
}
|
||||
|
||||
// 查总数
|
||||
countSQL := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE %s", public.TableNameModel, whereSQL)
|
||||
countResult, err := gfdb.DB(ctx).GetAll(ctx, countSQL, args...)
|
||||
// GetByAcrossTenant 按ID获取(跨租户,查所有租户)
|
||||
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
NoTenantId(ctx).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return
|
||||
}
|
||||
if len(countResult) > 0 {
|
||||
total = gconv.Int64(countResult[0]["count"])
|
||||
}
|
||||
|
||||
// 查列表
|
||||
querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL)
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
offset := (pageNum - 1) * pageSize
|
||||
querySQL += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
|
||||
}
|
||||
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err = r.Structs(&list)
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// GetByCreatorAndPlatform 按创建者、平台获取
|
||||
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||
// 基础 SQL
|
||||
sql := `
|
||||
@@ -212,7 +132,7 @@ WHERE deleted_at IS NULL
|
||||
// 最后拼接排序
|
||||
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
|
||||
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -226,14 +146,24 @@ WHERE deleted_at IS NULL
|
||||
return
|
||||
}
|
||||
|
||||
// ListAll 用于分组展示:查询全部模型(不按类型过滤,类型拆分在 service 层处理)
|
||||
func (d *modelDao) ListAll(ctx context.Context) (list []*entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
OrderDesc(entity.AsynchModelCol.CreatedAt).
|
||||
All()
|
||||
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
|
||||
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
|
||||
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
|
||||
tenantId, modelName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
var list []*entity.AsynchModel
|
||||
if err := r.Structs(&list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user