Files
model-asynch/service/worker.go
2026-04-23 13:53:09 +08:00

178 lines
4.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"fmt"
"strings"
"sync"
"time"
"model-asynch/dao"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
mu sync.Mutex
pool *grpool.Pool
closed bool
}
func (w *asyncWorker) Start(ctx context.Context) {
if !g.Cfg().MustGet(ctx, "asynch.worker.enabled", true).Bool() {
g.Log().Warningf(ctx, "[worker] asynch.worker.enabled=falseworker 未启动")
return
}
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
return
}
limit := g.Cfg().MustGet(ctx, "asynch.worker.goroutines", 4).Int()
if limit <= 0 {
limit = 1
}
w.pool = grpool.New(limit)
w.closed = false
go w.pollLoop(ctx)
g.Log().Infof(ctx, "[worker] started, grpool limit=%d", limit)
}
// Stop 关闭协程池,确保 Ctrl+C 能完整退出。
func (w *asyncWorker) Stop(ctx context.Context) {
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
w.pool.Close()
w.closed = true
g.Log().Infof(ctx, "[worker] stopped")
}
}
func (w *asyncWorker) pollLoop(ctx context.Context) {
pollIntervalStr := g.Cfg().MustGet(ctx, "asynch.worker.pollInterval", "1s").String()
pollInterval, _ := time.ParseDuration(pollIntervalStr)
if pollInterval <= 0 {
pollInterval = time.Second
}
batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize", 5).Int()
if batchSize <= 0 {
batchSize = 1
}
g.Log().Infof(ctx, "[worker] poll loop started, poll=%s batch=%d", pollInterval, batchSize)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
w.Stop(ctx)
return
case <-ticker.C:
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
if err != nil {
g.Log().Errorf(ctx, "[worker] claim pending error: %v", err)
continue
}
if len(tasks) == 0 {
continue
}
for _, t := range tasks {
task := t // 防止闭包捕获循环变量
w.mu.Lock()
p := w.pool
w.mu.Unlock()
if p == nil || p.IsClosed() {
// 池已关闭,回滚任务
_ = w.rollbackToPending(ctx, task.Id)
continue
}
_ = p.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task)
}, func(ctx context.Context, err error) {
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", err))
}
})
}
}
}
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
// 从任务入库的 request_payload 里恢复 payload + headers给 OSS 上传透传鉴权用
payload, headers := parseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
}
// 1) 拉取模型配置
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if m == nil || m.Enabled != 1 {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
return
}
// 2) 分布式并发限制(按 tenant+model
semKey := fmt.Sprintf("asynch:sem:%d:%s", t.TenantId, t.ModelName)
leaseSeconds := int64(3600) // 兜底1小时
acquired, err := acquireSemaphore(ctx, semKey, m.MaxConcurrency, leaseSeconds)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if !acquired {
// 并发满了:放回排队(重新置回 state=0下一轮再抢占
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() {
_ = releaseSemaphore(ctx, semKey)
}()
// 3) 调用模型服务
if payload == nil {
payload = map[string]any{
"taskId": t.TaskID,
"inputRef": t.InputRef,
}
}
data, err := InvokeModel(ctx, m, payload)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
// 4) 存储 OSS/MinIO
contentType, ext := DetectFileType(data)
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
// 5) 更新任务状态成功
// 注意expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}