gatway
This commit is contained in:
114
dao/model_dao.go
114
dao/model_dao.go
@@ -2,11 +2,14 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"model-asynch/consts/public"
|
||||
"model-asynch/model/dto"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
@@ -22,12 +25,12 @@ func (d *modelDao) Insert(ctx context.Context, m *entity.AsynchModel) (id int64,
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *modelDao) UpdateByID(ctx context.Context, id int64, data map[string]any) (rows int64, err error) {
|
||||
func (d *modelDao) Update(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) {
|
||||
// 触发 gfdb 的 updateHook 自动填充 updater,需要显式带 updater 字段
|
||||
data[entity.AsynchModelCol.Updater] = ""
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.Id, id).
|
||||
Data(data).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, m.ID).
|
||||
Data(m).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -35,7 +38,21 @@ func (d *modelDao) UpdateByID(ctx context.Context, id int64, data map[string]any
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *modelDao) DeleteByID(ctx context.Context, id int64) (rows int64, err error) {
|
||||
func (d *modelDao) UpdateByID(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) {
|
||||
// 专用于切换会话模型,只更新 is_chat_model 字段
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.Id, m.ID).
|
||||
Data(g.Map{
|
||||
"is_chat_model": m.IsChatModel,
|
||||
}).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *modelDao) DeleteByID(ctx context.Context, id string) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.Id, id).
|
||||
Delete()
|
||||
@@ -59,7 +76,7 @@ func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *ent
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByID(ctx context.Context, id int64) (m *entity.AsynchModel, err error) {
|
||||
func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.Id, id).
|
||||
One()
|
||||
@@ -73,11 +90,15 @@ func (d *modelDao) GetByID(ctx context.Context, id int64) (m *entity.AsynchModel
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameModel).Where("deleted_at IS NULL").OrderDesc(entity.AsynchModelCol.CreatedAt)
|
||||
func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int) (list []*entity.AsynchModel, total int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
OrderDesc(entity.AsynchModelCol.CreatedAt)
|
||||
if modelNameLike != "" {
|
||||
model = model.WhereLike(entity.AsynchModelCol.ModelName, "%"+modelNameLike+"%")
|
||||
}
|
||||
if modelType != 0 {
|
||||
model = model.Where(entity.AsynchModelCol.ModelsType, modelType)
|
||||
}
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
model = model.Page(pageNum, pageSize)
|
||||
}
|
||||
@@ -90,10 +111,85 @@ func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLik
|
||||
return
|
||||
}
|
||||
|
||||
// ListByCreatorAndPlatform 普通用户:平台公共(tenant_id=0) + 自己创建的(creator=xxx)
|
||||
func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
|
||||
// 构建 Where 条件
|
||||
whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" //1 代表超级管理员
|
||||
args := []any{creator}
|
||||
|
||||
if modelNameLike != "" {
|
||||
whereSQL += " AND model_name LIKE ?"
|
||||
args = append(args, "%"+modelNameLike+"%")
|
||||
}
|
||||
|
||||
// 查总数
|
||||
countSQL := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE %s", public.TableNameModel, whereSQL)
|
||||
countResult, err := gfdb.DB(ctx).GetAll(ctx, countSQL, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if len(countResult) > 0 {
|
||||
total = gconv.Int64(countResult[0]["count"])
|
||||
}
|
||||
|
||||
// 查列表
|
||||
querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL)
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
offset := (pageNum - 1) * pageSize
|
||||
querySQL += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
|
||||
}
|
||||
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, creator string, modelNameLike string, modelType int) (list []*entity.AsynchModel, err error) {
|
||||
whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)"
|
||||
args := []any{creator}
|
||||
|
||||
if modelNameLike != "" {
|
||||
whereSQL += " AND model_name LIKE ?"
|
||||
args = append(args, "%"+modelNameLike+"%")
|
||||
}
|
||||
if modelType != 0 {
|
||||
whereSQL += " AND models_type = ?"
|
||||
args = append(args, modelType)
|
||||
}
|
||||
|
||||
querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL)
|
||||
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByIsChatModel(ctx context.Context, userName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.IsChatModel, 1).
|
||||
Where(entity.AsynchModelCol.Creator, userName).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// ListAll 用于分组展示:查询全部模型(不按类型过滤,类型拆分在 service 层处理)
|
||||
func (d *modelDao) ListAll(ctx context.Context) (list []*entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where("deleted_at IS NULL").
|
||||
OrderDesc(entity.AsynchModelCol.CreatedAt).
|
||||
All()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"model-asynch/consts/public"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
type modelTypeDao struct{}
|
||||
|
||||
var ModelType = &modelTypeDao{}
|
||||
|
||||
func (d *modelTypeDao) Insert(ctx context.Context, t *entity.AsynchModelType) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModelType).Data(t).Insert()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *modelTypeDao) UpdateByID(ctx context.Context, id int64, data gdb.Map) (rows int64, err error) {
|
||||
// 触发 gfdb 的 updateHook 自动填充 updater,需要显式带 updater 字段
|
||||
data[entity.AsynchModelTypeCol.Updater] = ""
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModelType).Where(entity.AsynchModelTypeCol.Id, id).Data(data).Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, _ := r.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (d *modelTypeDao) DeleteByID(ctx context.Context, id int64) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModelType).Where(entity.AsynchModelTypeCol.Id, id).Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, _ := r.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (d *modelTypeDao) GetByID(ctx context.Context, id int64) (*entity.AsynchModelType, error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModelType).Where(entity.AsynchModelTypeCol.Id, id).One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
var t *entity.AsynchModelType
|
||||
_ = r.Struct(&t)
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (d *modelTypeDao) List(ctx context.Context, pageNum, pageSize int, typeNameLike string) (list []*entity.AsynchModelType, total int64, err error) {
|
||||
m := gfdb.DB(ctx).Model(ctx, public.TableNameModelType).Where("deleted_at IS NULL").OrderAsc(entity.AsynchModelTypeCol.TypeID)
|
||||
if typeNameLike != "" {
|
||||
m = m.WhereLike(entity.AsynchModelTypeCol.TypeName, "%"+typeNameLike+"%")
|
||||
}
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
m = m.Page(pageNum, pageSize)
|
||||
}
|
||||
r, totalInt, err := m.AllAndCount(false)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = gconv.Int64(totalInt)
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
@@ -13,7 +13,7 @@ type opLogDao struct{}
|
||||
|
||||
var OpLog = &opLogDao{}
|
||||
|
||||
func (d *opLogDao) Insert(ctx context.Context, log *entity.AsynchOpLog) (id int64, err error) {
|
||||
func (d *opLogDao) Insert(ctx context.Context, log *entity.LogsModelOp) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameOpLog).Data(log).Insert()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
@@ -29,7 +29,7 @@ DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *statDao) List(ctx context.Context, pageNum, pageSize int, startDay, endDay string, tenantId *int64, creator, modelName string) (list []*entity.AsynchModelStat, total int64, err error) {
|
||||
func (d *statDao) List(ctx context.Context, pageNum, pageSize int, startDay, endDay string, tenantId *int64, creator, modelName string) (list []*entity.LogsModelStat, total int64, err error) {
|
||||
m := gfdb.DB(ctx).Model(ctx, public.TableNameStat).Where("1=1")
|
||||
if startDay != "" {
|
||||
m = m.Where("day >= ?", startDay)
|
||||
@@ -58,4 +58,3 @@ func (d *statDao) List(ctx context.Context, pageNum, pageSize int, startDay, end
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT id, tenant_id, creator, model_name, task_id, model_key, input_ref, request_payload, phase, tmp_file
|
||||
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file
|
||||
FROM %s
|
||||
WHERE deleted_at IS NULL AND state = 0
|
||||
ORDER BY enqueue_at ASC
|
||||
@@ -55,13 +55,51 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks
|
||||
return
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error {
|
||||
// ClaimPendingByTaskIDGlobal 按 task_id 定向抢占单个 pending 任务(不加 tenant 过滤)
|
||||
// 用于 createTask 创建成功后立即异步尝试执行当前任务,避免只依赖后续 runWork 扫描队列。
|
||||
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (task *entity.AsynchTask, err error) {
|
||||
if taskID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file
|
||||
FROM %s
|
||||
WHERE deleted_at IS NULL AND state = 0 AND task_id = ?
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED`,
|
||||
public.TableNameTask,
|
||||
)
|
||||
r, err := tx.GetOne(sql, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
task = nil
|
||||
return nil
|
||||
}
|
||||
if err := r.Struct(&task); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
_, err = tx.Exec(
|
||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
now, now, task.Id,
|
||||
)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, id int64, ossFile, fileType, textResult string, fileSize int64, expireAt *gtime.Time, expendTokens int) error {
|
||||
now := gtime.Now()
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s
|
||||
SET state=2,
|
||||
oss_file=?,
|
||||
file_type=?,
|
||||
text_result=?,
|
||||
expend_tokens=?,
|
||||
file_size=?,
|
||||
error_msg='',
|
||||
finished_at=?,
|
||||
@@ -71,7 +109,7 @@ SET state=2,
|
||||
tmp_file='',
|
||||
updated_at=?
|
||||
WHERE id=?`, public.TableNameTask),
|
||||
ossFile, fileType, fileSize, now, now, now, id,
|
||||
ossFile, fileType, textResult, expendTokens, fileSize, now, now, now, id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user