2026-04-29 15:54:14 +08:00
|
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
2026-05-21 10:41:37 +08:00
|
|
|
|
"errors"
|
2026-04-29 15:54:14 +08:00
|
|
|
|
"fmt"
|
2026-05-21 10:41:37 +08:00
|
|
|
|
"model-gateway/common/util"
|
|
|
|
|
|
"model-gateway/model/dto"
|
|
|
|
|
|
"model-gateway/service/gateway"
|
|
|
|
|
|
"os"
|
|
|
|
|
|
"path/filepath"
|
2026-04-29 15:54:14 +08:00
|
|
|
|
"strings"
|
|
|
|
|
|
"time"
|
2026-05-12 13:45:08 +08:00
|
|
|
|
"unicode/utf8"
|
2026-04-29 15:54:14 +08:00
|
|
|
|
|
2026-05-15 14:56:26 +08:00
|
|
|
|
"model-gateway/dao"
|
|
|
|
|
|
"model-gateway/model/entity"
|
2026-04-29 15:54:14 +08:00
|
|
|
|
|
|
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
|
|
|
|
"github.com/gogf/gf/v2/os/grpool"
|
2026-05-12 13:45:08 +08:00
|
|
|
|
"github.com/tidwall/gjson"
|
2026-04-29 15:54:14 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
var AsyncWorker = &asyncWorker{}
|
|
|
|
|
|
|
|
|
|
|
|
type asyncWorker struct {
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
|
|
|
|
|
|
// - batchSize: 本次抢占数量
|
|
|
|
|
|
// - goroutines: 本次并发数(协程池大小)
|
2026-05-21 10:41:37 +08:00
|
|
|
|
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
|
|
|
|
|
if req.BatchSize <= 0 {
|
|
|
|
|
|
req.BatchSize = 10
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
2026-05-21 10:41:37 +08:00
|
|
|
|
if req.Goroutines <= 0 {
|
|
|
|
|
|
req.Goroutines = 1
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
2026-05-21 10:41:37 +08:00
|
|
|
|
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if err != nil {
|
2026-05-21 10:41:37 +08:00
|
|
|
|
return nil, err
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
|
|
|
|
|
if len(tasks) == 0 {
|
2026-05-21 10:41:37 +08:00
|
|
|
|
return nil, errors.New("no task to run")
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
2026-05-21 10:41:37 +08:00
|
|
|
|
pool := grpool.New(req.Goroutines)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
defer pool.Close()
|
2026-05-21 10:41:37 +08:00
|
|
|
|
claimed := len(tasks)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
done := make(chan struct{}, claimed)
|
|
|
|
|
|
for _, t := range tasks {
|
|
|
|
|
|
task := t
|
|
|
|
|
|
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
|
2026-05-12 13:45:08 +08:00
|
|
|
|
w.handleOne(ctx, task, 0)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
done <- struct{}{}
|
|
|
|
|
|
}, func(ctx context.Context, e error) {
|
|
|
|
|
|
if e != nil {
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", e))
|
|
|
|
|
|
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
|
|
|
|
|
}
|
|
|
|
|
|
done <- struct{}{}
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
for i := 0; i < claimed; i++ {
|
|
|
|
|
|
<-done
|
|
|
|
|
|
}
|
2026-05-21 10:41:37 +08:00
|
|
|
|
return &dto.RunWorkRes{
|
|
|
|
|
|
Claimed: claimed,
|
|
|
|
|
|
}, nil
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
|
|
|
|
|
// - 只定向抢占当前 taskId 对应的 pending 任务
|
|
|
|
|
|
// - 若任务已被其它 worker 抢走/已不在 pending,则直接返回
|
|
|
|
|
|
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId int64) error {
|
|
|
|
|
|
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
if task == nil {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
w.handleOne(ctx, task, epicycleId)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
|
|
|
|
|
// 从任务入库的 request_payload 里恢复 payload + headers
|
2026-05-21 10:41:37 +08:00
|
|
|
|
payload, headers := util.ParseStoredPayload(t.RequestPayload)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if len(headers) > 0 {
|
2026-05-21 10:41:37 +08:00
|
|
|
|
ctx = util.SetTaskHeadersToCtx(ctx, headers)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 1) 拉取模型配置
|
|
|
|
|
|
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
|
|
|
|
|
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ============ 失败回调 ============
|
|
|
|
|
|
t.State = 3
|
|
|
|
|
|
t.ErrorMsg = err.Error()
|
2026-05-21 10:41:37 +08:00
|
|
|
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ================================
|
2026-04-29 15:54:14 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
2026-05-15 14:56:26 +08:00
|
|
|
|
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
|
2026-05-12 13:45:08 +08:00
|
|
|
|
errMsg := "模型不存在或未启用"
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, errMsg)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ============ 失败回调 ============
|
|
|
|
|
|
t.State = 3
|
|
|
|
|
|
t.ErrorMsg = errMsg
|
2026-05-21 10:41:37 +08:00
|
|
|
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ================================
|
2026-04-29 15:54:14 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// 2) 分布式并发限制
|
2026-04-29 15:54:14 +08:00
|
|
|
|
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
leaseSeconds := int64(3600)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency)
|
|
|
|
|
|
acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
|
|
|
|
|
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ============ 失败回调 ============
|
|
|
|
|
|
t.State = 3
|
|
|
|
|
|
t.ErrorMsg = err.Error()
|
2026-05-21 10:41:37 +08:00
|
|
|
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ================================
|
2026-04-29 15:54:14 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
if !acquired {
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// 并发满了:放回排队,不回调(不是失败)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
_ = w.rollbackToPending(ctx, t.Id)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
defer func() {
|
|
|
|
|
|
_ = releaseSemaphore(ctx, semKey)
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
// 3) 调用模型服务
|
|
|
|
|
|
if payload == nil {
|
|
|
|
|
|
payload = map[string]any{
|
|
|
|
|
|
"taskId": t.TaskID,
|
|
|
|
|
|
"inputRef": t.InputRef,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
var (
|
|
|
|
|
|
data []byte
|
|
|
|
|
|
contentType string
|
|
|
|
|
|
ext string
|
2026-05-12 13:45:08 +08:00
|
|
|
|
textResult string
|
2026-04-29 15:54:14 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
2026-05-21 10:41:37 +08:00
|
|
|
|
data, err = os.ReadFile(t.TmpFile)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if err == nil && len(data) > 0 {
|
2026-05-21 10:41:37 +08:00
|
|
|
|
contentType, ext = util.DetectFileType(data)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
data = nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
if data == nil {
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// 统计
|
2026-04-29 15:54:14 +08:00
|
|
|
|
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// 核心调用
|
2026-04-29 15:54:14 +08:00
|
|
|
|
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
|
|
|
|
|
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ============ 失败回调 ============
|
|
|
|
|
|
t.State = 3
|
|
|
|
|
|
t.ErrorMsg = err.Error()
|
2026-05-21 10:41:37 +08:00
|
|
|
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ================================
|
2026-04-29 15:54:14 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
2026-05-21 10:41:37 +08:00
|
|
|
|
contentType, ext = util.DetectFileType(data)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
|
|
|
|
|
textResult = string(data)
|
|
|
|
|
|
}
|
2026-04-29 15:54:14 +08:00
|
|
|
|
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
|
|
|
|
|
|
if err == nil && tmpPath != "" {
|
|
|
|
|
|
t.TmpFile = tmpPath
|
|
|
|
|
|
t.Phase = 1
|
|
|
|
|
|
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 4) 存储 OSS
|
2026-05-21 10:41:37 +08:00
|
|
|
|
ossURL, err := gateway.UploadByTask(ctx, t, data, ext, contentType)
|
2026-04-29 15:54:14 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
|
|
|
|
|
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
|
|
|
|
|
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ============ OSS失败不回调(还会重试) ============
|
|
|
|
|
|
// 注意:OSS失败保留临时文件,下次重试,所以这里不触发最终回调
|
|
|
|
|
|
// 如果已经重试多次还没成功,需要在任务超时或超过最大重试次数时才回调失败
|
2026-04-29 15:54:14 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 5) 更新任务状态成功
|
|
|
|
|
|
fileType := strings.TrimPrefix(ext, ".")
|
|
|
|
|
|
if fileType == "" {
|
|
|
|
|
|
fileType = contentType
|
|
|
|
|
|
}
|
2026-05-21 10:41:37 +08:00
|
|
|
|
if err = dao.Task.UpdateSuccessGlobal(
|
2026-05-12 13:45:08 +08:00
|
|
|
|
ctx,
|
|
|
|
|
|
t.Id,
|
|
|
|
|
|
ossURL,
|
|
|
|
|
|
fileType,
|
|
|
|
|
|
textResult,
|
|
|
|
|
|
int64(len(data)),
|
|
|
|
|
|
nil,
|
2026-05-21 10:41:37 +08:00
|
|
|
|
GetExpendTokens(m.ResponseTokenField, textResult),
|
2026-05-12 13:45:08 +08:00
|
|
|
|
); err != nil {
|
2026-04-29 15:54:14 +08:00
|
|
|
|
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
2026-05-12 13:45:08 +08:00
|
|
|
|
|
|
|
|
|
|
// 成功/失败均不再占用 queue_limit
|
2026-04-29 15:54:14 +08:00
|
|
|
|
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
|
|
|
|
|
|
// 6) 成功回调
|
2026-04-29 15:54:14 +08:00
|
|
|
|
t.State = 2
|
|
|
|
|
|
t.OssFile = ossURL
|
|
|
|
|
|
t.FileType = fileType
|
2026-05-12 13:45:08 +08:00
|
|
|
|
t.TextResult = textResult
|
|
|
|
|
|
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
|
2026-05-21 10:41:37 +08:00
|
|
|
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
// ============ 如果有 epicycleId,也触发业务回调 ============
|
|
|
|
|
|
if epicycleId != 0 {
|
2026-05-21 10:41:37 +08:00
|
|
|
|
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-29 15:54:14 +08:00
|
|
|
|
// 成功后清理临时文件
|
2026-05-21 10:41:37 +08:00
|
|
|
|
_ = os.Remove(t.TmpFile)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
|
|
|
|
|
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
|
|
|
|
|
dir := filepath.Join(os.TempDir(), "model-asynch")
|
|
|
|
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
|
|
|
|
return "", err
|
|
|
|
|
|
}
|
|
|
|
|
|
if ext == "" {
|
|
|
|
|
|
ext = ".bin"
|
|
|
|
|
|
}
|
|
|
|
|
|
if ext[0] != '.' {
|
|
|
|
|
|
ext = "." + ext
|
|
|
|
|
|
}
|
|
|
|
|
|
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
|
|
|
|
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
|
|
|
|
|
return "", err
|
|
|
|
|
|
}
|
|
|
|
|
|
return path, nil
|
2026-04-29 15:54:14 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
|
|
|
|
|
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
|
|
|
|
|
}
|
2026-05-12 13:45:08 +08:00
|
|
|
|
|
|
|
|
|
|
// GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值
|
|
|
|
|
|
func GetExpendTokens(tokenMapping string, textResult string) int {
|
|
|
|
|
|
value := gjson.Get(textResult, tokenMapping)
|
|
|
|
|
|
if value.Exists() {
|
|
|
|
|
|
return int(value.Int())
|
|
|
|
|
|
}
|
2026-05-21 10:41:37 +08:00
|
|
|
|
return len(textResult)
|
2026-05-12 13:45:08 +08:00
|
|
|
|
}
|