Files
model-gateway/service/task_service.go

271 lines
7.3 KiB
Go
Raw Normal View History

2026-04-29 15:54:14 +08:00
package service
import (
"context"
"errors"
"model-gateway/common/util"
2026-04-29 15:54:14 +08:00
"time"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
2026-04-29 15:54:14 +08:00
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/google/uuid"
)
var Task = &taskService{}
type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
// 固化 token/user 等信息
ctx = util.AsyncCtx(ctx)
2026-04-29 15:54:14 +08:00
// 1) 检查模型配置
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
})
2026-04-29 15:54:14 +08:00
if err != nil {
return nil, err
}
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
2026-04-29 15:54:14 +08:00
return nil, errors.New("模型不存在或未启用")
}
taskID := uuid.NewString()
// 2) 排队上限严格控制Redis 原子闸门)
limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
if limit > 0 {
ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("任务排队已满,请稍后再试")
}
}
// 将调用模型的 payload 与透传头信息一起存入 request_payload供后台 worker 使用
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": util.ForwardHeaders(ctx),
2026-04-29 15:54:14 +08:00
}
t := &entity.AsynchTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
2026-05-12 13:45:08 +08:00
ModelKey: m.ApiKey,
2026-04-29 15:54:14 +08:00
InputRef: req.InputRef,
RequestPayload: storedPayload,
2026-05-12 13:45:08 +08:00
EpicycleId: req.EpicycleId,
2026-04-29 15:54:14 +08:00
}
_, err = dao.Task.Insert(ctx, t)
if err != nil {
// 入库失败:回滚闸门占位
ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
}
// 3) 写操作日志(尽量不影响主流程,失败忽略)
ip := ""
ua := ""
apiPath := "/task/createTask"
httpMethod := "POST"
if r := g.RequestFromCtx(ctx); r != nil {
ip = r.GetClientIp()
ua = r.UserAgent()
apiPath = r.URL.Path
httpMethod = r.Method
}
2026-05-12 13:45:08 +08:00
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
2026-04-29 15:54:14 +08:00
IP: ip,
UserAgent: ua,
APIPath: apiPath,
HttpMethod: httpMethod,
BizName: req.BizName,
ModelName: req.ModelName,
TaskID: taskID,
OpType: "createTask",
Success: 1,
ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: storedPayload,
ResponsePayload: gdb.Map{
"taskId": taskID,
},
})
2026-05-12 13:45:08 +08:00
// 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
// 一旦任务进入 running/success/failed/downloaded就停止轮询避免一直空转。
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req.EpicycleId)
2026-04-29 15:54:14 +08:00
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
2026-05-12 13:45:08 +08:00
// pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”:
// - 目标:尽快把刚创建的任务拉起来执行
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
// - 这样不会无限轮询runWork 仍负责处理积压队列和未处理到的任务
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, epicycleId int64) {
if taskID == "" {
return
}
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
if interval <= 0 {
interval = 5
}
g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval)
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
tryRun := func() bool {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
2026-05-12 13:45:08 +08:00
if err != nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
return true
}
if t == nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID)
return true
}
switch t.State {
case 0:
if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
2026-05-12 13:45:08 +08:00
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
}
return false
case 1:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID)
return true
case 2, 3, 4:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State)
return true
default:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State)
return true
}
}
// 先立即尝试一次
if stop := tryRun(); stop {
return
}
for {
select {
case <-ctx.Done():
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID)
return
case <-ticker.C:
if stop := tryRun(); stop {
return
}
}
}
}
2026-04-29 15:54:14 +08:00
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
2026-04-29 15:54:14 +08:00
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
2026-04-29 15:54:14 +08:00
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
2026-05-12 13:45:08 +08:00
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
2026-04-29 15:54:14 +08:00
}
2026-05-12 13:45:08 +08:00
if req.PageSize > 0 {
pageSize = req.PageSize
2026-04-29 15:54:14 +08:00
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}