refactor(service): 重构模型网关服务结构
This commit is contained in:
201
dao/model_gateway_models_dao.go
Normal file
201
dao/model_gateway_models_dao.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"strconv"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
var ModelGatewayModels = &modelGatewayModelsDao{}
|
||||
|
||||
type modelGatewayModelsDao struct{}
|
||||
|
||||
// Insert 插入
|
||||
func (d *modelGatewayModelsDao) Insert(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).Insert(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
// Update 更新
|
||||
func (d *modelGatewayModelsDao) Update(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Data(req).
|
||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// Delete 删除
|
||||
func (d *modelGatewayModelsDao) Delete(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// Get 获取模型
|
||||
func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
|
||||
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m entity.ModelGatewayModel
|
||||
err = r.Struct(&m)
|
||||
return &m, err
|
||||
}
|
||||
|
||||
//// Get 按ID获取(带租户隔离,只查当前租户)
|
||||
//func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
// var whereCondition strings.Builder
|
||||
// var queryParams []interface{}
|
||||
// if !g.IsEmpty(req.Id) {
|
||||
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Id))
|
||||
// queryParams = append(queryParams, req.Id)
|
||||
// }
|
||||
// if !g.IsEmpty(req.Creator) {
|
||||
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Creator))
|
||||
// queryParams = append(queryParams, req.Creator)
|
||||
// }
|
||||
// if !g.IsEmpty(req.IsChatModel) {
|
||||
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.IsChatModel))
|
||||
// queryParams = append(queryParams, req.IsChatModel)
|
||||
// }
|
||||
// if !g.IsEmpty(req.ModelName) {
|
||||
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.ModelName))
|
||||
// queryParams = append(queryParams, req.ModelName)
|
||||
// }
|
||||
// // 完整 SQL
|
||||
// sql := `SELECT * FROM "asynch_models" WHERE "deleted_at" IS NULL` + whereCondition.String()
|
||||
// r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, queryParams...)
|
||||
// if err != nil {
|
||||
// return
|
||||
// }
|
||||
// var i []*entity.AsynchModel
|
||||
// if err = r.Structs(&i); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// for _, item := range i {
|
||||
// m = item
|
||||
// }
|
||||
// return
|
||||
//}
|
||||
|
||||
// GetByAcrossTenant 跨租户查询
|
||||
func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
NoTenantId(ctx).
|
||||
OmitEmpty().
|
||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
|
||||
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m entity.ModelGatewayModel
|
||||
err = r.Struct(&m)
|
||||
return &m, err
|
||||
}
|
||||
|
||||
// GetByCreatorAndPlatform 按创建者、平台获取
|
||||
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
|
||||
sql := `
|
||||
SELECT DISTINCT ON (model_name) *
|
||||
FROM asynch_models
|
||||
WHERE deleted_at IS NULL
|
||||
AND (? = '' OR model_name LIKE ?)
|
||||
`
|
||||
args := []any{
|
||||
req.ModelName, "%" + req.ModelName + "%",
|
||||
}
|
||||
|
||||
// modelType: 传 6 模糊匹配 6%
|
||||
if req.ModelType > 0 {
|
||||
prefix := strconv.Itoa(req.ModelType)[:1] // 截取第一位
|
||||
sql += ` AND model_type::text LIKE ? `
|
||||
args = append(args, prefix+"%")
|
||||
}
|
||||
|
||||
if !g.IsEmpty(req.IsPrivate) {
|
||||
sql += ` AND is_private = ? `
|
||||
args = append(args, req.IsPrivate)
|
||||
}
|
||||
|
||||
if req.IsOwner != nil && *req.IsOwner == 0 {
|
||||
if req.Enabled != nil && *req.Enabled == 1 {
|
||||
sql += ` AND creator = ? AND is_owner = ? AND enabled=1 `
|
||||
} else if req.Enabled != nil && *req.Enabled == 0 {
|
||||
sql += ` AND creator = ? AND is_owner = ? AND enabled=0 `
|
||||
} else {
|
||||
sql += ` AND creator = ? AND is_owner = ? `
|
||||
}
|
||||
args = append(args, req.Creator, req.IsOwner)
|
||||
} else if req.IsOwner != nil && *req.IsOwner == 1 {
|
||||
if req.Enabled != nil && *req.Enabled == 1 {
|
||||
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
|
||||
} else if req.Enabled != nil && *req.Enabled == 0 {
|
||||
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=0) OR (is_owner = 0 AND enabled=1)) `
|
||||
} else {
|
||||
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
|
||||
}
|
||||
args = append(args, req.Creator, req.IsOwner)
|
||||
}
|
||||
|
||||
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
|
||||
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err = r.Structs(&list)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
total = len(list)
|
||||
return
|
||||
}
|
||||
|
||||
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
|
||||
func (d *modelGatewayModelsDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (*entity.ModelGatewayModel, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
|
||||
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
|
||||
tenantId, modelName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
var list []*entity.ModelGatewayModel
|
||||
if err := r.Structs(&list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
Reference in New Issue
Block a user