Files
model-gateway/dao/model_dao.go

207 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dao
import (
"context"
"fmt"
"model-gateway/consts/public"
"model-gateway/model/dto"
"model-gateway/model/entity"
"strconv"
"strings"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelDao{}
type modelDao struct{}
// Insert 插入
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) {
m := new(entity.AsynchModel)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
// Update 更新
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Data(&req).
Where(entity.AsynchModelCol.Id, req.Id).
Update()
if err != nil {
return
}
return r.RowsAffected()
}
// Delete 删除
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Delete()
if err != nil {
return
}
return r.RowsAffected()
}
// Get 按ID获取带租户隔离只查当前租户
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
//r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
// OmitEmpty().
// Where(entity.AsynchModelCol.Id, req.Id).
// Where(entity.AsynchModelCol.Creator, req.Creator).
// Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
// Where(entity.AsynchModelCol.ModelName, req.ModelName).
// Fields(fields).One()
//if err != nil {
// return
//}
//err = r.Struct(&m)
var whereCondition strings.Builder
var queryParams []interface{}
if !g.IsEmpty(req.Id) {
whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Id))
queryParams = append(queryParams, req.Id)
}
if !g.IsEmpty(req.Creator) {
whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Creator))
queryParams = append(queryParams, req.Creator)
}
if !g.IsEmpty(req.IsChatModel) {
whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.IsChatModel))
queryParams = append(queryParams, req.IsChatModel)
}
if !g.IsEmpty(req.ModelName) {
whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.ModelName))
queryParams = append(queryParams, req.ModelName)
}
// 完整 SQL
sql := `SELECT * FROM "asynch_models" WHERE "deleted_at" IS NULL` + whereCondition.String()
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, queryParams...)
if err != nil {
return
}
var i []*entity.AsynchModel
if err = r.Structs(&i); err != nil {
return nil, err
}
for _, item := range i {
m = item
}
return
}
// GetByAcrossTenant 按ID获取跨租户查所有租户
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
NoTenantId(ctx).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
// GetByCreatorAndPlatform 按创建者、平台获取
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
sql := `
SELECT DISTINCT ON (model_name) *
FROM asynch_models
WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?)
`
args := []any{
req.ModelName, "%" + req.ModelName + "%",
}
// modelType: 传 6 模糊匹配 6%
if req.ModelType > 0 {
prefix := strconv.Itoa(req.ModelType)[:1] // 截取第一位
sql += ` AND model_type::text LIKE ? `
args = append(args, prefix+"%")
}
if !g.IsEmpty(req.IsPrivate) {
sql += ` AND is_private = ? `
args = append(args, req.IsPrivate)
}
if req.IsOwner != nil && *req.IsOwner == 0 {
if req.Enabled != nil && *req.Enabled == 1 {
sql += ` AND creator = ? AND is_owner = ? AND enabled=1 `
} else if req.Enabled != nil && *req.Enabled == 0 {
sql += ` AND creator = ? AND is_owner = ? AND enabled=0 `
} else {
sql += ` AND creator = ? AND is_owner = ? `
}
args = append(args, req.Creator, req.IsOwner)
} else if req.IsOwner != nil && *req.IsOwner == 1 {
if req.Enabled != nil && *req.Enabled == 1 {
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
} else if req.Enabled != nil && *req.Enabled == 0 {
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=0) OR (is_owner = 0 AND enabled=1)) `
} else {
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
}
args = append(args, req.Creator, req.IsOwner)
}
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
if err != nil {
return nil, 0, err
}
err = r.Structs(&list)
if err != nil {
return nil, 0, err
}
total = len(list)
return
}
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
tenantId, modelName,
)
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
var list []*entity.AsynchModel
if err := r.Structs(&list); err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}