Files
model-gateway/dao/model_gateway_models_dao.go

202 lines
6.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/dto"
"model-gateway/model/entity"
"strconv"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g"
)
var ModelGatewayModels = &modelGatewayModelsDao{}
type modelGatewayModelsDao struct{}
// Insert 插入
func (d *modelGatewayModelsDao) Insert(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).Insert(req)
if err != nil {
return 0, err
}
return r.LastInsertId()
}
// Update 更新
func (d *modelGatewayModelsDao) Update(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Data(req).
Where(entity.ModelGatewayModelCol.Id, req.Id).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// Delete 删除
func (d *modelGatewayModelsDao) Delete(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.ModelGatewayModelCol.Id, req.Id).
Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// Get 获取模型
func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return nil, err
}
var m entity.ModelGatewayModel
err = r.Struct(&m)
return &m, err
}
//// Get 按ID获取带租户隔离只查当前租户
//func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
// var whereCondition strings.Builder
// var queryParams []interface{}
// if !g.IsEmpty(req.Id) {
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Id))
// queryParams = append(queryParams, req.Id)
// }
// if !g.IsEmpty(req.Creator) {
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Creator))
// queryParams = append(queryParams, req.Creator)
// }
// if !g.IsEmpty(req.IsChatModel) {
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.IsChatModel))
// queryParams = append(queryParams, req.IsChatModel)
// }
// if !g.IsEmpty(req.ModelName) {
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.ModelName))
// queryParams = append(queryParams, req.ModelName)
// }
// // 完整 SQL
// sql := `SELECT * FROM "asynch_models" WHERE "deleted_at" IS NULL` + whereCondition.String()
// r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, queryParams...)
// if err != nil {
// return
// }
// var i []*entity.AsynchModel
// if err = r.Structs(&i); err != nil {
// return nil, err
// }
// for _, item := range i {
// m = item
// }
// return
//}
// GetByAcrossTenant 跨租户查询
func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
NoTenantId(ctx).
OmitEmpty().
Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return nil, err
}
var m entity.ModelGatewayModel
err = r.Struct(&m)
return &m, err
}
// GetByCreatorAndPlatform 按创建者、平台获取
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
sql := `
SELECT DISTINCT ON (model_name) *
FROM asynch_models
WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?)
`
args := []any{
req.ModelName, "%" + req.ModelName + "%",
}
// modelType: 传 6 模糊匹配 6%
if req.ModelType > 0 {
prefix := strconv.Itoa(req.ModelType)[:1] // 截取第一位
sql += ` AND model_type::text LIKE ? `
args = append(args, prefix+"%")
}
if !g.IsEmpty(req.IsPrivate) {
sql += ` AND is_private = ? `
args = append(args, req.IsPrivate)
}
if req.IsOwner != nil && *req.IsOwner == 0 {
if req.Enabled != nil && *req.Enabled == 1 {
sql += ` AND creator = ? AND is_owner = ? AND enabled=1 `
} else if req.Enabled != nil && *req.Enabled == 0 {
sql += ` AND creator = ? AND is_owner = ? AND enabled=0 `
} else {
sql += ` AND creator = ? AND is_owner = ? `
}
args = append(args, req.Creator, req.IsOwner)
} else if req.IsOwner != nil && *req.IsOwner == 1 {
if req.Enabled != nil && *req.Enabled == 1 {
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
} else if req.Enabled != nil && *req.Enabled == 0 {
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=0) OR (is_owner = 0 AND enabled=1)) `
} else {
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
}
args = append(args, req.Creator, req.IsOwner)
}
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
if err != nil {
return nil, 0, err
}
err = r.Structs(&list)
if err != nil {
return nil, 0, err
}
total = len(list)
return
}
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
func (d *modelGatewayModelsDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
tenantId, modelName,
)
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
var list []*entity.ModelGatewayModel
if err := r.Structs(&list); err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}