2026-04-23 13:53:09 +08:00
|
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"strings"
|
|
|
|
|
|
"sync"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
|
"model-asynch/dao"
|
|
|
|
|
|
"model-asynch/model/entity"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
|
|
|
|
"github.com/gogf/gf/v2/os/grpool"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
var AsyncWorker = &asyncWorker{}
|
|
|
|
|
|
|
|
|
|
|
|
type asyncWorker struct {
|
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
|
pool *grpool.Pool
|
|
|
|
|
|
closed bool
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (w *asyncWorker) Start(ctx context.Context) {
|
|
|
|
|
|
if !g.Cfg().MustGet(ctx, "asynch.worker.enabled", true).Bool() {
|
|
|
|
|
|
g.Log().Warningf(ctx, "[worker] asynch.worker.enabled=false,worker 未启动")
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
w.mu.Lock()
|
|
|
|
|
|
defer w.mu.Unlock()
|
|
|
|
|
|
if w.pool != nil && !w.pool.IsClosed() {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
limit := g.Cfg().MustGet(ctx, "asynch.worker.goroutines", 4).Int()
|
|
|
|
|
|
if limit <= 0 {
|
|
|
|
|
|
limit = 1
|
|
|
|
|
|
}
|
|
|
|
|
|
w.pool = grpool.New(limit)
|
|
|
|
|
|
w.closed = false
|
|
|
|
|
|
go w.pollLoop(ctx)
|
|
|
|
|
|
g.Log().Infof(ctx, "[worker] started, grpool limit=%d", limit)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Stop 关闭协程池,确保 Ctrl+C 能完整退出。
|
|
|
|
|
|
func (w *asyncWorker) Stop(ctx context.Context) {
|
|
|
|
|
|
w.mu.Lock()
|
|
|
|
|
|
defer w.mu.Unlock()
|
|
|
|
|
|
if w.pool != nil && !w.pool.IsClosed() {
|
|
|
|
|
|
w.pool.Close()
|
|
|
|
|
|
w.closed = true
|
|
|
|
|
|
g.Log().Infof(ctx, "[worker] stopped")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (w *asyncWorker) pollLoop(ctx context.Context) {
|
|
|
|
|
|
pollIntervalStr := g.Cfg().MustGet(ctx, "asynch.worker.pollInterval", "1s").String()
|
|
|
|
|
|
pollInterval, _ := time.ParseDuration(pollIntervalStr)
|
|
|
|
|
|
if pollInterval <= 0 {
|
|
|
|
|
|
pollInterval = time.Second
|
|
|
|
|
|
}
|
|
|
|
|
|
batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize", 5).Int()
|
|
|
|
|
|
if batchSize <= 0 {
|
|
|
|
|
|
batchSize = 1
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
g.Log().Infof(ctx, "[worker] poll loop started, poll=%s batch=%d", pollInterval, batchSize)
|
|
|
|
|
|
ticker := time.NewTicker(pollInterval)
|
|
|
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
|
w.Stop(ctx)
|
|
|
|
|
|
return
|
|
|
|
|
|
case <-ticker.C:
|
|
|
|
|
|
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
g.Log().Errorf(ctx, "[worker] claim pending error: %v", err)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(tasks) == 0 {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
for _, t := range tasks {
|
|
|
|
|
|
task := t // 防止闭包捕获循环变量
|
|
|
|
|
|
w.mu.Lock()
|
|
|
|
|
|
p := w.pool
|
|
|
|
|
|
w.mu.Unlock()
|
|
|
|
|
|
if p == nil || p.IsClosed() {
|
|
|
|
|
|
// 池已关闭,回滚任务
|
|
|
|
|
|
_ = w.rollbackToPending(ctx, task.Id)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
_ = p.AddWithRecover(ctx, func(ctx context.Context) {
|
|
|
|
|
|
w.handleOne(ctx, task)
|
|
|
|
|
|
}, func(ctx context.Context, err error) {
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", err))
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
|
|
|
|
|
// 从任务入库的 request_payload 里恢复 payload + headers,给 OSS 上传透传鉴权用
|
|
|
|
|
|
payload, headers := parseStoredPayload(t.RequestPayload)
|
|
|
|
|
|
if len(headers) > 0 {
|
|
|
|
|
|
ctx = setTaskHeadersToCtx(ctx, headers)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 1) 拉取模型配置
|
|
|
|
|
|
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
if m == nil || m.Enabled != 1 {
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2) 分布式并发限制(按 tenant+model)
|
|
|
|
|
|
semKey := fmt.Sprintf("asynch:sem:%d:%s", t.TenantId, t.ModelName)
|
|
|
|
|
|
leaseSeconds := int64(3600) // 兜底1小时
|
|
|
|
|
|
acquired, err := acquireSemaphore(ctx, semKey, m.MaxConcurrency, leaseSeconds)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
if !acquired {
|
|
|
|
|
|
// 并发满了:放回排队(重新置回 state=0),下一轮再抢占
|
|
|
|
|
|
_ = w.rollbackToPending(ctx, t.Id)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
defer func() {
|
|
|
|
|
|
_ = releaseSemaphore(ctx, semKey)
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
// 3) 调用模型服务
|
|
|
|
|
|
if payload == nil {
|
|
|
|
|
|
payload = map[string]any{
|
|
|
|
|
|
"taskId": t.TaskID,
|
|
|
|
|
|
"inputRef": t.InputRef,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-04-25 10:42:21 +08:00
|
|
|
|
var (
|
|
|
|
|
|
data []byte
|
|
|
|
|
|
contentType string
|
|
|
|
|
|
ext string
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载,避免重复跑模型
|
|
|
|
|
|
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
|
|
|
|
|
data, err = loadTmpResult(t.TmpFile)
|
|
|
|
|
|
if err == nil && len(data) > 0 {
|
|
|
|
|
|
contentType, ext = DetectFileType(data)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// 临时文件不可用:回退重新调用模型
|
|
|
|
|
|
data = nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
if data == nil {
|
2026-04-27 10:42:42 +08:00
|
|
|
|
// 统计:仅在真正请求模型时 +1(OSS 重试不计入)
|
|
|
|
|
|
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
|
|
|
|
|
|
|
2026-04-25 10:42:21 +08:00
|
|
|
|
data, err = InvokeModel(ctx, m, payload)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
contentType, ext = DetectFileType(data)
|
|
|
|
|
|
// 将模型输出写入临时文件,后续若 OSS 失败可只重试 OSS
|
|
|
|
|
|
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
|
|
|
|
|
|
if err == nil && tmpPath != "" {
|
|
|
|
|
|
t.TmpFile = tmpPath
|
|
|
|
|
|
t.Phase = 1
|
|
|
|
|
|
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
|
|
|
|
|
}
|
2026-04-23 13:53:09 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-25 10:42:21 +08:00
|
|
|
|
// 4) 存储 OSS
|
2026-04-23 13:53:09 +08:00
|
|
|
|
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
|
|
|
|
|
|
if err != nil {
|
2026-04-25 10:42:21 +08:00
|
|
|
|
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
|
|
|
|
|
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
2026-04-23 13:53:09 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 5) 更新任务状态成功
|
|
|
|
|
|
// 注意:expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
|
|
|
|
|
|
fileType := strings.TrimPrefix(ext, ".")
|
|
|
|
|
|
if fileType == "" {
|
|
|
|
|
|
fileType = contentType
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
|
|
|
|
|
|
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
2026-04-25 10:42:21 +08:00
|
|
|
|
// 成功后清理临时文件
|
|
|
|
|
|
deleteTmpResult(t.TmpFile)
|
2026-04-23 13:53:09 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
|
|
|
|
|
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
|
|
|
|
|
}
|