Files
model-gateway/dao/model_dao.go

233 lines
6.4 KiB
Go
Raw Normal View History

2026-04-29 15:54:14 +08:00
package dao
import (
"context"
2026-05-12 13:45:08 +08:00
"fmt"
2026-04-29 15:54:14 +08:00
"model-gateway/consts/public"
"model-gateway/model/dto"
"model-gateway/model/entity"
2026-04-29 15:54:14 +08:00
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/utils"
2026-05-12 13:45:08 +08:00
"github.com/gogf/gf/v2/frame/g"
2026-04-29 15:54:14 +08:00
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelDao{}
type modelDao struct{}
func (d *modelDao) Insert(ctx context.Context, req *dto.CreateModelReq) (id int64, err error) {
asyncModel := new(entity.AsynchModel)
err = gconv.Struct(req, &asyncModel)
if err != nil {
return
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).Data(asyncModel).Insert()
2026-04-29 15:54:14 +08:00
if err != nil {
return 0, err
}
return r.LastInsertId()
}
2026-05-12 13:45:08 +08:00
func (d *modelDao) Update(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) {
2026-04-29 15:54:14 +08:00
// 触发 gfdb 的 updateHook 自动填充 updater需要显式带 updater 字段
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
2026-05-12 13:45:08 +08:00
OmitEmpty().
Where(entity.AsynchModelCol.Id, m.ID).
Data(m).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
func (d *modelDao) DeleteByID(ctx context.Context, id string) (rows int64, err error) {
2026-04-29 15:54:14 +08:00
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.Id, id).
Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.ModelName, modelName).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
2026-05-12 13:45:08 +08:00
func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, err error) {
2026-04-29 15:54:14 +08:00
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
NoTenantId(ctx).
2026-04-29 15:54:14 +08:00
Where(entity.AsynchModelCol.Id, id).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *modelDao) Count(ctx context.Context, req *dto.GetModelReq) (count int, err error) {
count, err = gfdb.DB(ctx).Model(ctx, public.TableNameModel).OmitEmpty().
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.Id, req.ID).Count()
return
}
func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) {
2026-05-12 13:45:08 +08:00
model := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
OrderDesc(entity.AsynchModelCol.CreatedAt)
2026-04-29 15:54:14 +08:00
if modelNameLike != "" {
model = model.WhereLike(entity.AsynchModelCol.ModelName, "%"+modelNameLike+"%")
}
2026-05-12 13:45:08 +08:00
if modelType != 0 {
model = model.Where(entity.AsynchModelCol.ModelType, modelType)
2026-05-12 13:45:08 +08:00
}
if isPrivate != 0 {
model = model.Where(entity.AsynchModelCol.IsPrivate, isPrivate)
}
2026-04-29 15:54:14 +08:00
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}
2026-05-12 13:45:08 +08:00
// ListByCreatorAndPlatform 普通用户:平台公共(tenant_id=0) + 自己创建的(creator=xxx)
func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
// 构建 Where 条件
whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" //1 代表超级管理员
args := []any{creator}
if modelNameLike != "" {
whereSQL += " AND model_name LIKE ?"
args = append(args, "%"+modelNameLike+"%")
}
// 查总数
countSQL := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE %s", public.TableNameModel, whereSQL)
countResult, err := gfdb.DB(ctx).GetAll(ctx, countSQL, args...)
if err != nil {
return nil, 0, err
}
if len(countResult) > 0 {
total = gconv.Int64(countResult[0]["count"])
}
// 查列表
querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL)
if pageNum > 0 && pageSize > 0 {
offset := (pageNum - 1) * pageSize
querySQL += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
}
r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...)
if err != nil {
return nil, 0, err
}
err = r.Structs(&list)
return
}
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
// 基础 SQL
sql := `
SELECT DISTINCT ON (model_name) *
FROM asynch_models
WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?)
AND (? = 0 OR model_type = ?)
`
args := []any{
req.ModelName, "%" + req.ModelName + "%",
req.ModelType, req.ModelType,
}
if !g.IsEmpty(req.IsPrivate) {
sql += ` AND is_private = ? `
args = append(args, req.IsPrivate)
}
if req.IsOwner != nil && *req.IsOwner == 0 {
sql += ` AND creator = ? AND is_owner = ? `
args = append(args, req.Creator)
args = append(args, req.IsOwner)
} else if req.IsOwner != nil && *req.IsOwner == 1 {
if req.Enabled != nil && *req.Enabled == 1 {
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
} else if req.Enabled != nil && *req.Enabled == 0 {
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=0) OR (is_owner = 0 AND enabled=1)) `
} else {
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
}
args = append(args, req.Creator)
args = append(args, req.IsOwner)
}
// 最后拼接排序
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
if err != nil {
return nil, 0, err
}
2026-05-12 13:45:08 +08:00
err = r.Structs(&list)
2026-05-12 13:45:08 +08:00
if err != nil {
return nil, 0, err
2026-05-12 13:45:08 +08:00
}
total = len(list)
2026-05-12 13:45:08 +08:00
return
}
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
2026-05-12 13:45:08 +08:00
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.IsChatModel, 1).
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
2026-05-12 13:45:08 +08:00
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
2026-04-29 15:54:14 +08:00
// ListAll 用于分组展示:查询全部模型(不按类型过滤,类型拆分在 service 层处理)
func (d *modelDao) ListAll(ctx context.Context) (list []*entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
OrderDesc(entity.AsynchModelCol.CreatedAt).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}