refactor(model): 重构模型实体和数据访问层

This commit is contained in:
2026-05-21 10:41:37 +08:00
parent a080a5536d
commit 170568e03e
35 changed files with 903 additions and 1072 deletions

View File

@@ -2,14 +2,11 @@ package dao
import (
"context"
"fmt"
"model-gateway/consts/public"
"model-gateway/model/dto"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -18,157 +15,80 @@ var Model = &modelDao{}
type modelDao struct{}
func (d *modelDao) Insert(ctx context.Context, req *dto.CreateModelReq) (id int64, err error) {
asyncModel := new(entity.AsynchModel)
err = gconv.Struct(req, &asyncModel)
// Insert 插入
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) {
m := new(entity.AsynchModel)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).Data(asyncModel).Insert()
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
Insert(m)
if err != nil {
return 0, err
return
}
return r.LastInsertId()
}
func (d *modelDao) Update(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) {
// 触发 gfdb 的 updateHook 自动填充 updater需要显式带 updater 字段
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
// Update 更新
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Id, m.ID).
Data(m).
Data(&req).
Where(entity.AsynchModelCol.Id, req.Id).
Update()
if err != nil {
return 0, err
return
}
return r.RowsAffected()
}
func (d *modelDao) DeleteByID(ctx context.Context, id string) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.Id, id).
// Delete 删除
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Delete()
if err != nil {
return 0, err
return
}
return r.RowsAffected()
}
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.ModelName, modelName).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
NoTenantId(ctx).
Where(entity.AsynchModelCol.Id, id).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *modelDao) Count(ctx context.Context, req *dto.GetModelReq) (count int, err error) {
count, err = gfdb.DB(ctx).Model(ctx, public.TableNameModel).OmitEmpty().
// Get 按ID获取带租户隔离只查当前租户
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.Id, req.ID).Count()
return
}
func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) {
model := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
OrderDesc(entity.AsynchModelCol.CreatedAt)
if modelNameLike != "" {
model = model.WhereLike(entity.AsynchModelCol.ModelName, "%"+modelNameLike+"%")
}
if modelType != 0 {
model = model.Where(entity.AsynchModelCol.ModelType, modelType)
}
if isPrivate != 0 {
model = model.Where(entity.AsynchModelCol.IsPrivate, isPrivate)
}
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.IsChatModel, 1).
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
return
}
err = r.Struct(&m)
return
}
// ListByCreatorAndPlatform 普通用户:平台公共(tenant_id=0) + 自己创建的(creator=xxx)
func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
// 构建 Where 条件
whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" //1 代表超级管理员
args := []any{creator}
if modelNameLike != "" {
whereSQL += " AND model_name LIKE ?"
args = append(args, "%"+modelNameLike+"%")
}
// 查总数
countSQL := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE %s", public.TableNameModel, whereSQL)
countResult, err := gfdb.DB(ctx).GetAll(ctx, countSQL, args...)
// GetByAcrossTenant 按ID获取跨租户查所有租户
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
NoTenantId(ctx).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return nil, 0, err
return
}
if len(countResult) > 0 {
total = gconv.Int64(countResult[0]["count"])
}
// 查列表
querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL)
if pageNum > 0 && pageSize > 0 {
offset := (pageNum - 1) * pageSize
querySQL += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
}
r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...)
if err != nil {
return nil, 0, err
}
err = r.Structs(&list)
err = r.Struct(&m)
return
}
// GetByCreatorAndPlatform 按创建者、平台获取
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
// 基础 SQL
sql := `
@@ -212,7 +132,7 @@ WHERE deleted_at IS NULL
// 最后拼接排序
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
if err != nil {
return nil, 0, err
}
@@ -226,14 +146,24 @@ WHERE deleted_at IS NULL
return
}
// ListAll 用于分组展示:查询全部模型(不按类型过滤,类型拆分在 service 层处理)
func (d *modelDao) ListAll(ctx context.Context) (list []*entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
OrderDesc(entity.AsynchModelCol.CreatedAt).
All()
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
tenantId, modelName,
)
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
if r.IsEmpty() {
return nil, nil
}
var list []*entity.AsynchModel
if err := r.Structs(&list); err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}

View File

@@ -1,32 +0,0 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
)
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).GetAll(ctx,
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
tenantId, modelName,
)
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
var list []*entity.AsynchModel
if err := r.Structs(&list); err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}

View File

@@ -7,14 +7,22 @@ import (
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
type opLogDao struct{}
var OpLog = &opLogDao{}
func (d *opLogDao) Insert(ctx context.Context, log *entity.LogsModelOp) (id int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameOpLog).Data(log).Insert()
// Insert 插入
func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) {
m := new(entity.LogsModelOp)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).
Insert(m)
if err != nil {
return 0, err
}

View File

@@ -25,7 +25,7 @@ ON CONFLICT (day, tenant_id, creator, model_name)
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
public.TableNameStat, public.TableNameStat,
)
_, err := gfdb.DB(ctx).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
return err
}

View File

@@ -2,9 +2,6 @@ package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
@@ -18,40 +15,47 @@ var Task = &taskDao{}
type taskDao struct{}
func (d *taskDao) Insert(ctx context.Context, t *entity.AsynchTask) (id int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Data(t).Insert()
// Insert 插入
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
m := new(entity.AsynchTask)
err = gconv.Struct(req, &m)
if err != nil {
return 0, err
return
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
func (d *taskDao) GetByTaskID(ctx context.Context, taskID string) (t *entity.AsynchTask, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.TaskID, taskID).
One()
// Get 获取
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
Fields(fields).One()
if err != nil {
return nil, err
return
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&t)
err = r.Struct(&m)
return
}
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.AsynchTask, err error) {
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
err = r.Structs(&m)
return
}
@@ -62,7 +66,7 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Where(entity.AsynchTaskCol.State, 2).
Data(data).
@@ -70,73 +74,6 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
return err
}
func (d *taskDao) UpdateRunning(ctx context.Context, id int64) error {
now := gtime.Now()
data := gdb.Map{
entity.AsynchTaskCol.State: 1,
entity.AsynchTaskCol.StartedAt: now,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}
func (d *taskDao) UpdateSuccess(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error {
now := gtime.Now()
data := gdb.Map{
entity.AsynchTaskCol.State: 2,
entity.AsynchTaskCol.OssFile: ossFile,
entity.AsynchTaskCol.FileType: fileType,
entity.AsynchTaskCol.FileSize: fileSize,
entity.AsynchTaskCol.ErrorMsg: "",
entity.AsynchTaskCol.FinishedAt: now,
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}
func (d *taskDao) UpdateFailed(ctx context.Context, id int64, errorMsg string) error {
now := gtime.Now()
data := gdb.Map{
entity.AsynchTaskCol.State: 3,
entity.AsynchTaskCol.ErrorMsg: errorMsg,
entity.AsynchTaskCol.FinishedAt: now,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}
func (d *taskDao) SoftDeleteByTaskID(ctx context.Context, taskID string) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.TaskID, taskID).
Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// CountActiveByModel 统计某模型排队中/执行中的任务数,用于 queue_limit 限制(近似值)
func (d *taskDao) CountActiveByModel(ctx context.Context, modelName string) (int64, error) {
n, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.ModelName, modelName).
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
Count()
return int64(n), err
}
// List 任务分页查询(受 gfdb 租户 Hook 影响)
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
m := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
@@ -161,90 +98,3 @@ func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike
err = r.Structs(&list)
return
}
// ClaimPending 抢占 pending 任务state=0并在同一事务中更新为 runningstate=1
// 使用 PostgreSQL: FOR UPDATE SKIP LOCKED 避免多 worker 重复消费
func (d *taskDao) ClaimPending(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
if batchSize <= 0 {
batchSize = 1
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT id, tenant_id, model_name, task_id, input_ref, request_payload
FROM %s
WHERE deleted_at IS NULL AND state = 0
ORDER BY created_at ASC
LIMIT %d
FOR UPDATE SKIP LOCKED`,
public.TableNameTask,
batchSize,
)
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
tasks = nil
return nil
}
if err := r.Structs(&tasks); err != nil {
return err
}
// 更新为 running
now := time.Now()
for _, t := range tasks {
// tx.Model 不走 gfdb Hook这里手动更新必要字段
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, t.Id,
)
if err != nil {
return err
}
}
return nil
})
return
}
// ListExpiredSuccess 获取已成功且过期的任务
func (d *taskDao) ListExpiredSuccess(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 100
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.State, 2).
Where(entity.AsynchTaskCol.ExpireAt+" IS NOT NULL").
Where(entity.AsynchTaskCol.ExpireAt+" < ?", gtime.Now()).
Limit(limit).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// ListTimeoutTasks 获取超时的排队/执行中任务
func (d *taskDao) ListTimeoutTasks(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 100
}
deadline := gtime.New(time.Now().Add(-timeout))
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
Where(entity.AsynchTaskCol.UpdatedAt+" < ?", deadline).
Limit(limit).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// DebugPing 用于启动时检测数据库连通性(可选)
func (d *taskDao) DebugPing(ctx context.Context) error {
_, err := gfdb.DB(ctx).GetAll(ctx, "SELECT 1")
return err
}

View File

@@ -150,14 +150,6 @@ func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFi
return err
}
func (d *taskDao) SoftDeleteByTaskIDGlobal(ctx context.Context, taskID string) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET deleted_at=NOW(), updated_at=NOW() WHERE task_id=? AND deleted_at IS NULL`, public.TableNameTask),
taskID,
)
return err
}
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),