2026-05-29 17:54:19 +08:00
|
|
|
|
package task
|
2026-04-29 15:54:14 +08:00
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"errors"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"fmt"
|
2026-05-21 10:41:37 +08:00
|
|
|
|
"model-gateway/common/util"
|
2026-06-08 18:01:53 +08:00
|
|
|
|
"model-gateway/consts/public"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"model-gateway/service/queue"
|
2026-04-29 15:54:14 +08:00
|
|
|
|
"time"
|
|
|
|
|
|
|
2026-05-15 14:56:26 +08:00
|
|
|
|
"model-gateway/dao"
|
|
|
|
|
|
"model-gateway/model/dto"
|
|
|
|
|
|
"model-gateway/model/entity"
|
2026-04-29 15:54:14 +08:00
|
|
|
|
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"gitea.com/red-future/common/beans"
|
|
|
|
|
|
"gitea.com/red-future/common/utils"
|
2026-04-29 15:54:14 +08:00
|
|
|
|
"github.com/gogf/gf/v2/database/gdb"
|
|
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
|
|
|
|
"github.com/gogf/gf/v2/os/gtime"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"github.com/gogf/gf/v2/util/gconv"
|
2026-04-29 15:54:14 +08:00
|
|
|
|
"github.com/google/uuid"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
var Task = &taskService{}
|
|
|
|
|
|
|
|
|
|
|
|
type taskService struct{}
|
|
|
|
|
|
|
2026-05-29 17:54:19 +08:00
|
|
|
|
// Create 创建任务
|
2026-04-29 15:54:14 +08:00
|
|
|
|
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
|
|
|
|
|
startAt := time.Now()
|
2026-05-29 17:54:19 +08:00
|
|
|
|
taskID := uuid.NewString()
|
2026-06-02 20:26:45 +08:00
|
|
|
|
// 1) 检查模型配置,并且获取模型
|
|
|
|
|
|
userInfo, err := utils.GetUserInfo(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
|
|
|
|
|
SQLBaseDO: beans.SQLBaseDO{
|
|
|
|
|
|
TenantId: userInfo.TenantId,
|
|
|
|
|
|
Creator: userInfo.UserName,
|
|
|
|
|
|
},
|
2026-05-21 10:41:37 +08:00
|
|
|
|
ModelName: req.ModelName,
|
|
|
|
|
|
})
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
if model == nil || (model.Enabled != nil && *model.Enabled != 1) {
|
2026-04-29 15:54:14 +08:00
|
|
|
|
return nil, errors.New("模型不存在或未启用")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2) 排队上限(严格控制:Redis 原子闸门)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if limit > 0 {
|
2026-06-02 20:26:45 +08:00
|
|
|
|
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return nil, errors.New("任务排队已满,请稍后再试")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-29 17:54:19 +08:00
|
|
|
|
// 3) 插入任务记录
|
2026-06-08 18:01:53 +08:00
|
|
|
|
if model.CallMode != nil && *model.CallMode == public.CallModeAsync {
|
2026-06-02 20:26:45 +08:00
|
|
|
|
// 异步调用:注入回调地址后提交,拿到 task_id 轮询
|
|
|
|
|
|
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
|
|
|
|
|
|
}
|
2026-04-29 15:54:14 +08:00
|
|
|
|
storedPayload := map[string]any{
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"headers": util.ParseHeadMsgHeaders(model.HeadMsg),
|
|
|
|
|
|
"body": req.RequestPayload,
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
2026-05-29 17:54:19 +08:00
|
|
|
|
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
|
2026-04-29 15:54:14 +08:00
|
|
|
|
ModelName: req.ModelName,
|
|
|
|
|
|
TaskID: taskID,
|
|
|
|
|
|
State: 0,
|
|
|
|
|
|
BizName: req.BizName,
|
|
|
|
|
|
CallbackURL: req.CallbackUrl,
|
2026-06-02 20:26:45 +08:00
|
|
|
|
ModelKey: model.ApiKey,
|
2026-04-29 15:54:14 +08:00
|
|
|
|
InputRef: req.InputRef,
|
|
|
|
|
|
RequestPayload: storedPayload,
|
2026-05-12 13:45:08 +08:00
|
|
|
|
EpicycleId: req.EpicycleId,
|
2026-05-29 17:54:19 +08:00
|
|
|
|
})
|
2026-06-02 20:26:45 +08:00
|
|
|
|
if err != nil { // 入库失败:回滚闸门占位
|
2026-05-29 17:54:19 +08:00
|
|
|
|
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-29 17:54:19 +08:00
|
|
|
|
// 4) 写操作日志(不影响主流程,失败忽略)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
ip := ""
|
|
|
|
|
|
ua := ""
|
|
|
|
|
|
apiPath := "/task/createTask"
|
|
|
|
|
|
httpMethod := "POST"
|
|
|
|
|
|
if r := g.RequestFromCtx(ctx); r != nil {
|
2026-06-03 18:37:17 +08:00
|
|
|
|
ip = utils.GetLocalIP()
|
2026-04-29 15:54:14 +08:00
|
|
|
|
ua = r.UserAgent()
|
|
|
|
|
|
apiPath = r.URL.Path
|
|
|
|
|
|
httpMethod = r.Method
|
|
|
|
|
|
}
|
2026-05-12 13:45:08 +08:00
|
|
|
|
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
|
2026-04-29 15:54:14 +08:00
|
|
|
|
IP: ip,
|
|
|
|
|
|
UserAgent: ua,
|
|
|
|
|
|
APIPath: apiPath,
|
|
|
|
|
|
HttpMethod: httpMethod,
|
|
|
|
|
|
BizName: req.BizName,
|
|
|
|
|
|
ModelName: req.ModelName,
|
|
|
|
|
|
TaskID: taskID,
|
|
|
|
|
|
OpType: "createTask",
|
|
|
|
|
|
Success: 1,
|
|
|
|
|
|
ErrorMsg: "",
|
|
|
|
|
|
CostMs: time.Since(startAt).Milliseconds(),
|
|
|
|
|
|
RequestPayload: storedPayload,
|
|
|
|
|
|
ResponsePayload: gdb.Map{
|
|
|
|
|
|
"taskId": taskID,
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
2026-05-12 13:45:08 +08:00
|
|
|
|
|
2026-06-02 20:26:45 +08:00
|
|
|
|
// 5) 获取任务信息
|
|
|
|
|
|
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
if task == nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 5) 创建成功后立即异步尝试执行当前任务
|
|
|
|
|
|
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
|
|
|
|
|
|
|
2026-04-29 15:54:14 +08:00
|
|
|
|
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-02 20:26:45 +08:00
|
|
|
|
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
|
|
|
|
|
|
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
|
|
|
|
|
|
// 1. 查本地任务
|
|
|
|
|
|
task, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
|
|
|
|
|
TaskID: req.TaskID,
|
|
|
|
|
|
})
|
|
|
|
|
|
if err != nil || task == nil {
|
|
|
|
|
|
return nil, fmt.Errorf("任务不存在: %s", req.TaskID)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 成功:取 video_url 和 usage
|
|
|
|
|
|
if req.Status == "succeeded" {
|
|
|
|
|
|
result := map[string]any{
|
|
|
|
|
|
"video_url": req.Content["video_url"],
|
|
|
|
|
|
"usage": req.Usage,
|
2026-05-12 13:45:08 +08:00
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
NotifyAsyncResult(req.TaskID, result, nil)
|
|
|
|
|
|
return &dto.ModelTaskCallbackRes{Success: true}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3. 失败/过期
|
|
|
|
|
|
if req.Status == "failed" || req.Status == "expired" {
|
|
|
|
|
|
NotifyAsyncResult(req.TaskID, nil, fmt.Errorf(req.Status))
|
|
|
|
|
|
return &dto.ModelTaskCallbackRes{Success: true}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return &dto.ModelTaskCallbackRes{Success: true}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// QueryPendingTasks 批量轮询进行中的异步任务
|
|
|
|
|
|
func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (*dto.QueryPendingTasksRes, error) {
|
|
|
|
|
|
limit := req.Limit
|
|
|
|
|
|
if limit <= 0 {
|
|
|
|
|
|
limit = g.Cfg().MustGet(ctx, "asynch.queryPending.limit", 10).Int()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 1. 查 state=1(执行中)的异步任务
|
|
|
|
|
|
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 逐个查询
|
|
|
|
|
|
var results []dto.QueryTaskItem
|
|
|
|
|
|
for _, t := range tasks {
|
|
|
|
|
|
// 拿到模型配置
|
|
|
|
|
|
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
|
|
|
|
|
if err != nil || model == nil || model.QueryConfig == nil {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2026-06-03 13:30:39 +08:00
|
|
|
|
result, err := util.PullTaskResult(ctx, nil, model.QueryConfig, model.HeadMsg)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
g.Log().Warningf(ctx, "[轮询] 查询失败 taskID=%s err=%v", t.TaskID, err)
|
|
|
|
|
|
continue
|
2026-05-12 13:45:08 +08:00
|
|
|
|
}
|
2026-05-29 17:54:19 +08:00
|
|
|
|
|
2026-06-02 20:26:45 +08:00
|
|
|
|
status := gconv.String(result["status"])
|
|
|
|
|
|
item := dto.QueryTaskItem{
|
|
|
|
|
|
TaskID: t.TaskID,
|
|
|
|
|
|
Status: status,
|
|
|
|
|
|
Content: result["content"].(map[string]any),
|
|
|
|
|
|
Usage: result["usage"].(map[string]any),
|
2026-05-12 13:45:08 +08:00
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
results = append(results, item)
|
|
|
|
|
|
|
|
|
|
|
|
// 如果任务完成,通知等待通道
|
|
|
|
|
|
if status == "succeeded" || status == "failed" || status == "expired" {
|
|
|
|
|
|
NotifyAsyncResult(t.TaskID, result["content"].(map[string]any), nil)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
|
|
|
|
|
|
return &dto.QueryPendingTasksRes{
|
|
|
|
|
|
Total: len(results),
|
|
|
|
|
|
Results: results,
|
|
|
|
|
|
}, nil
|
2026-05-12 13:45:08 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-29 17:54:19 +08:00
|
|
|
|
// GetResult 获取任务结果
|
2026-04-29 15:54:14 +08:00
|
|
|
|
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
2026-05-21 10:41:37 +08:00
|
|
|
|
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
|
|
|
|
|
TaskID: taskID,
|
|
|
|
|
|
})
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
if t == nil {
|
|
|
|
|
|
return nil, errors.New("任务不存在")
|
|
|
|
|
|
}
|
|
|
|
|
|
return &dto.GetTaskResultRes{
|
|
|
|
|
|
OssFile: t.OssFile,
|
|
|
|
|
|
State: t.State,
|
|
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
|
|
|
|
|
|
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
|
|
|
|
|
if req == nil || len(req.TaskIDs) == 0 {
|
|
|
|
|
|
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
// 1) 先查当前租户下的任务列表
|
|
|
|
|
|
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
|
|
|
|
|
|
now := time.Now()
|
|
|
|
|
|
for _, t := range list {
|
|
|
|
|
|
if t == nil {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if t.State != 2 {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
// 按模型配置决定保留时间
|
2026-05-21 10:41:37 +08:00
|
|
|
|
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
|
|
|
|
|
ModelName: t.ModelName,
|
|
|
|
|
|
})
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
retainSeconds := 86400
|
|
|
|
|
|
if m != nil && m.AutoCleanSeconds > 0 {
|
|
|
|
|
|
retainSeconds = m.AutoCleanSeconds
|
|
|
|
|
|
}
|
|
|
|
|
|
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
|
|
|
|
|
|
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
|
|
|
|
|
|
|
|
|
|
|
|
// 为了本次返回一致性,内存里也更新
|
|
|
|
|
|
t.State = 4
|
|
|
|
|
|
t.ExpireAt = expireAt
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3) 组装返回
|
|
|
|
|
|
items := make([]dto.GetTaskBatchItem, 0, len(list))
|
|
|
|
|
|
for _, t := range list {
|
|
|
|
|
|
if t == nil {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
items = append(items, dto.GetTaskBatchItem{
|
|
|
|
|
|
TaskID: t.TaskID,
|
|
|
|
|
|
State: t.State,
|
|
|
|
|
|
OssFile: t.OssFile,
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
return &dto.GetTaskBatchRes{List: items}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-29 17:54:19 +08:00
|
|
|
|
// List 获取任务列表
|
2026-04-29 15:54:14 +08:00
|
|
|
|
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
|
|
|
|
|
pageNum, pageSize := 1, 10
|
2026-05-12 13:45:08 +08:00
|
|
|
|
if req != nil {
|
|
|
|
|
|
if req.PageNum > 0 {
|
|
|
|
|
|
pageNum = req.PageNum
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
2026-05-12 13:45:08 +08:00
|
|
|
|
if req.PageSize > 0 {
|
|
|
|
|
|
pageSize = req.PageSize
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
modelName := ""
|
|
|
|
|
|
taskID := ""
|
|
|
|
|
|
var state *int
|
|
|
|
|
|
if req != nil {
|
|
|
|
|
|
modelName = req.ModelName
|
|
|
|
|
|
taskID = req.TaskID
|
|
|
|
|
|
state = req.State
|
|
|
|
|
|
}
|
|
|
|
|
|
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
return &dto.ListTaskRes{List: list, Total: total}, nil
|
|
|
|
|
|
}
|