package task import ( "context" "errors" "fmt" "model-gateway/common/util" "model-gateway/service/queue" "time" "model-gateway/dao" "model-gateway/model/dto" "model-gateway/model/entity" "gitea.com/red-future/common/beans" "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/util/gconv" "github.com/google/uuid" ) var Task = &taskService{} type taskService struct{} // Create 创建任务 func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { startAt := time.Now() taskID := uuid.NewString() // 1) 检查模型配置,并且获取模型 userInfo, err := utils.GetUserInfo(ctx) if err != nil { return nil, err } model, err := dao.Model.Get(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{ TenantId: userInfo.TenantId, Creator: userInfo.UserName, }, ModelName: req.ModelName, }) if err != nil { return nil, err } if model == nil || (model.Enabled != nil && *model.Enabled != 1) { return nil, errors.New("模型不存在或未启用") } // 2) 排队上限(严格控制:Redis 原子闸门) limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2) if limit > 0 { ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds) if err != nil { return nil, err } if !ok { return nil, errors.New("任务排队已满,请稍后再试") } } // 3) 插入任务记录 if model.IsAsync != nil && *model.IsAsync == 1 { // 异步调用:注入回调地址后提交,拿到 task_id 轮询 req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl) } storedPayload := map[string]any{ "headers": util.ParseHeadMsgHeaders(model.HeadMsg), "body": req.RequestPayload, } _, err = dao.Task.Insert(ctx, &entity.AsynchTask{ ModelName: req.ModelName, TaskID: taskID, State: 0, BizName: req.BizName, CallbackURL: req.CallbackUrl, ModelKey: model.ApiKey, InputRef: req.InputRef, RequestPayload: storedPayload, EpicycleId: req.EpicycleId, }) if err != nil { // 入库失败:回滚闸门占位 queue.ReleaseQueueSlot(ctx, req.ModelName, taskID) return nil, err } // 4) 写操作日志(不影响主流程,失败忽略) ip := "" ua := "" apiPath := "/task/createTask" httpMethod := "POST" if r := g.RequestFromCtx(ctx); r != nil { ip = util.GetLocalIP() ua = r.UserAgent() apiPath = r.URL.Path httpMethod = r.Method } _, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{ IP: ip, UserAgent: ua, APIPath: apiPath, HttpMethod: httpMethod, BizName: req.BizName, ModelName: req.ModelName, TaskID: taskID, OpType: "createTask", Success: 1, ErrorMsg: "", CostMs: time.Since(startAt).Milliseconds(), RequestPayload: storedPayload, ResponsePayload: gdb.Map{ "taskId": taskID, }, }) // 5) 获取任务信息 task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID) if err != nil { return nil, err } if task == nil { return nil, err } // 5) 创建成功后立即异步尝试执行当前任务 go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) return &dto.CreateTaskRes{TaskID: taskID}, nil } func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) { g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status) // 1. 查本地任务 task, err := dao.Task.Get(ctx, &entity.AsynchTask{ TaskID: req.TaskID, }) if err != nil || task == nil { return nil, fmt.Errorf("任务不存在: %s", req.TaskID) } // 2. 成功:取 video_url 和 usage if req.Status == "succeeded" { result := map[string]any{ "video_url": req.Content["video_url"], "usage": req.Usage, } NotifyAsyncResult(req.TaskID, result, nil) return &dto.ModelTaskCallbackRes{Success: true}, nil } // 3. 失败/过期 if req.Status == "failed" || req.Status == "expired" { NotifyAsyncResult(req.TaskID, nil, fmt.Errorf(req.Status)) return &dto.ModelTaskCallbackRes{Success: true}, nil } return &dto.ModelTaskCallbackRes{Success: true}, nil } // QueryPendingTasks 批量轮询进行中的异步任务 func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (*dto.QueryPendingTasksRes, error) { limit := req.Limit if limit <= 0 { limit = g.Cfg().MustGet(ctx, "asynch.queryPending.limit", 10).Int() } // 1. 查 state=1(执行中)的异步任务 tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit) if err != nil { return nil, err } // 2. 逐个查询 var results []dto.QueryTaskItem for _, t := range tasks { // 拿到模型配置 model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) if err != nil || model == nil || model.QueryConfig == nil { continue } result, err := util.PullTaskResult(ctx, t.TaskID, model.QueryConfig) if err != nil { g.Log().Warningf(ctx, "[轮询] 查询失败 taskID=%s err=%v", t.TaskID, err) continue } status := gconv.String(result["status"]) item := dto.QueryTaskItem{ TaskID: t.TaskID, Status: status, Content: result["content"].(map[string]any), Usage: result["usage"].(map[string]any), } results = append(results, item) // 如果任务完成,通知等待通道 if status == "succeeded" || status == "failed" || status == "expired" { NotifyAsyncResult(t.TaskID, result["content"].(map[string]any), nil) } } return &dto.QueryPendingTasksRes{ Total: len(results), Results: results, }, nil } // GetResult 获取任务结果 func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) { t, err := dao.Task.Get(ctx, &entity.AsynchTask{ TaskID: taskID, }) if err != nil { return nil, err } if t == nil { return nil, errors.New("任务不存在") } return &dto.GetTaskResultRes{ OssFile: t.OssFile, State: t.State, }, nil } // GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间 func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) { if req == nil || len(req.TaskIDs) == 0 { return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil } // 1) 先查当前租户下的任务列表 list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs) if err != nil { return nil, err } // 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at now := time.Now() for _, t := range list { if t == nil { continue } if t.State != 2 { continue } // 按模型配置决定保留时间 m, err := dao.Model.Get(ctx, &entity.AsynchModel{ ModelName: t.ModelName, }) if err != nil { return nil, err } retainSeconds := 86400 if m != nil && m.AutoCleanSeconds > 0 { retainSeconds = m.AutoCleanSeconds } expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second)) _ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt) // 为了本次返回一致性,内存里也更新 t.State = 4 t.ExpireAt = expireAt } // 3) 组装返回 items := make([]dto.GetTaskBatchItem, 0, len(list)) for _, t := range list { if t == nil { continue } items = append(items, dto.GetTaskBatchItem{ TaskID: t.TaskID, State: t.State, OssFile: t.OssFile, }) } return &dto.GetTaskBatchRes{List: items}, nil } // List 获取任务列表 func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) { pageNum, pageSize := 1, 10 if req != nil { if req.PageNum > 0 { pageNum = req.PageNum } if req.PageSize > 0 { pageSize = req.PageSize } } modelName := "" taskID := "" var state *int if req != nil { modelName = req.ModelName taskID = req.TaskID state = req.State } list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state) if err != nil { return nil, err } return &dto.ListTaskRes{List: list, Total: total}, nil }