refactor(service): 重构模型网关服务结构

This commit is contained in:
2026-06-11 17:58:49 +08:00
parent afd60caf56
commit 1c6c9bae14
34 changed files with 784 additions and 1223 deletions

View File

@@ -17,25 +17,27 @@ import (
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
"github.com/google/uuid"
)
var Task = &taskService{}
var ModelGatewayTask = &taskService{}
type taskService struct{}
// Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
taskID := uuid.NewString()
var (
startAt = time.Now()
taskID = uuid.NewString()
)
// 1) 检查模型配置,并且获取模型
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId,
Creator: userInfo.UserName,
@@ -66,19 +68,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
// 异步调用:注入回调地址后提交,拿到 task_id 轮询
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
}
storedPayload := map[string]any{
"headers": util.ParseHeadMsgHeaders(model.HeadMsg),
"body": req.RequestPayload,
requestPayload := entity.RequestPayload{
Body: req.RequestPayload,
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
}
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
State: public.TaskStatusPending,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: model.ApiKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
RequestPayload: &requestPayload,
EpicycleId: req.EpicycleId,
})
if err != nil { // 入库失败:回滚闸门占位
@@ -97,7 +97,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
apiPath = r.URL.Path
httpMethod = r.Method
}
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
IP: ip,
UserAgent: ua,
APIPath: apiPath,
@@ -109,20 +109,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
Success: 1,
ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: storedPayload,
RequestPayload: &requestPayload,
ResponsePayload: gdb.Map{
"taskId": taskID,
},
})
// 5) 获取任务信息
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
if err != nil {
return nil, err
}
if task == nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
@@ -130,10 +127,96 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.ResultFile.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.ModelGatewayTask.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)
for _, t := range list {
if t == nil {
continue
}
if t.State != public.BuildTypeNode {
continue
}
_ = dao.ModelGatewayTask.MarkDownloadedByID(ctx, t.Id)
// 为了本次返回一致性,内存里也更新
t.State = public.TaskStatusDownloaded
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.ResultFile.OssFile,
TextResult: t.TextResult,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (*dto.ListTaskRes, error) {
if req.PageNum <= 0 {
req.PageNum = 1
}
if req.PageSize <= 0 {
req.PageSize = 10
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
list, total, err := dao.ModelGatewayTask.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
ModelName: req.ModelName,
BizName: req.BizName,
State: req.State,
TaskID: req.TaskID,
})
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}
// ModelTaskCallback 模型异步任务的回调通知
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
// 1. 查本地任务
task, err := dao.Task.Get(ctx, &entity.AsynchTask{
task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: req.TaskID,
})
if err != nil || task == nil {
@@ -167,7 +250,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
}
// 1. 查 state=1执行中的异步任务
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit)
tasks, err := dao.ModelGatewayTask.GetPendingAsyncTasks(ctx, limit)
if err != nil {
return nil, err
}
@@ -176,7 +259,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
var results []dto.QueryTaskItem
for _, t := range tasks {
// 拿到模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
model, err := dao.ModelGatewayModels.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil || model.QueryConfig == nil {
continue
}
@@ -206,100 +289,3 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
Results: results,
}, nil
}
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}