Files
model-gateway/service/worker.go

247 lines
6.9 KiB
Go
Raw Normal View History

2026-04-29 15:54:14 +08:00
package service
import (
"context"
"fmt"
"strings"
"time"
2026-05-12 13:45:08 +08:00
"unicode/utf8"
2026-04-29 15:54:14 +08:00
"model-gateway/dao"
"model-gateway/model/entity"
2026-04-29 15:54:14 +08:00
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
2026-05-12 13:45:08 +08:00
"github.com/tidwall/gjson"
2026-04-29 15:54:14 +08:00
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
}
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
// - batchSize: 本次抢占数量
// - goroutines: 本次并发数(协程池大小)
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
if batchSize <= 0 {
batchSize = 10
}
if goroutines <= 0 {
goroutines = 1
}
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
if err != nil {
return 0, err
}
if len(tasks) == 0 {
return 0, nil
}
pool := grpool.New(goroutines)
defer pool.Close()
claimed = len(tasks)
done := make(chan struct{}, claimed)
for _, t := range tasks {
task := t
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
2026-05-12 13:45:08 +08:00
w.handleOne(ctx, task, 0)
2026-04-29 15:54:14 +08:00
done <- struct{}{}
}, func(ctx context.Context, e error) {
if e != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", e))
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
}
done <- struct{}{}
})
}
for i := 0; i < claimed; i++ {
<-done
}
return claimed, nil
}
2026-05-12 13:45:08 +08:00
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
// - 只定向抢占当前 taskId 对应的 pending 任务
// - 若任务已被其它 worker 抢走/已不在 pending则直接返回
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId int64) error {
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return err
}
if task == nil {
return nil
}
w.handleOne(ctx, task, epicycleId)
return nil
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
// 从任务入库的 request_payload 里恢复 payload + headers
2026-04-29 15:54:14 +08:00
payload, headers := parseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
}
// 1) 拉取模型配置
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
2026-05-12 13:45:08 +08:00
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
2026-04-29 15:54:14 +08:00
return
}
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
2026-05-12 13:45:08 +08:00
errMsg := "模型不存在或未启用"
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, errMsg)
2026-04-29 15:54:14 +08:00
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
2026-05-12 13:45:08 +08:00
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = errMsg
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
2026-04-29 15:54:14 +08:00
return
}
2026-05-12 13:45:08 +08:00
// 2) 分布式并发限制
2026-04-29 15:54:14 +08:00
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
2026-05-12 13:45:08 +08:00
leaseSeconds := int64(3600)
2026-04-29 15:54:14 +08:00
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency)
acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
2026-05-12 13:45:08 +08:00
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
2026-04-29 15:54:14 +08:00
return
}
if !acquired {
2026-05-12 13:45:08 +08:00
// 并发满了:放回排队,不回调(不是失败)
2026-04-29 15:54:14 +08:00
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() {
_ = releaseSemaphore(ctx, semKey)
}()
// 3) 调用模型服务
if payload == nil {
payload = map[string]any{
"taskId": t.TaskID,
"inputRef": t.InputRef,
}
}
var (
data []byte
contentType string
ext string
2026-05-12 13:45:08 +08:00
textResult string
2026-04-29 15:54:14 +08:00
)
2026-05-12 13:45:08 +08:00
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
2026-04-29 15:54:14 +08:00
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = loadTmpResult(t.TmpFile)
if err == nil && len(data) > 0 {
contentType, ext = DetectFileType(data)
} else {
data = nil
}
}
if data == nil {
2026-05-12 13:45:08 +08:00
// 统计
2026-04-29 15:54:14 +08:00
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
2026-05-12 13:45:08 +08:00
// 核心调用
2026-04-29 15:54:14 +08:00
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
2026-05-12 13:45:08 +08:00
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
2026-04-29 15:54:14 +08:00
return
}
contentType, ext = DetectFileType(data)
2026-05-12 13:45:08 +08:00
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
2026-04-29 15:54:14 +08:00
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
if err == nil && tmpPath != "" {
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
// 4) 存储 OSS
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
if err != nil {
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
2026-05-12 13:45:08 +08:00
// ============ OSS失败不回调还会重试 ============
// 注意OSS失败保留临时文件下次重试所以这里不触发最终回调
// 如果已经重试多次还没成功,需要在任务超时或超过最大重试次数时才回调失败
2026-04-29 15:54:14 +08:00
return
}
// 5) 更新任务状态成功
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
2026-05-12 13:45:08 +08:00
if err := dao.Task.UpdateSuccessGlobal(
ctx,
t.Id,
ossURL,
fileType,
textResult,
int64(len(data)),
nil,
GetExpendTokens(m.TokenMapping, textResult),
); err != nil {
2026-04-29 15:54:14 +08:00
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
2026-05-12 13:45:08 +08:00
// 成功/失败均不再占用 queue_limit
2026-04-29 15:54:14 +08:00
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
2026-05-12 13:45:08 +08:00
// 6) 成功回调
2026-04-29 15:54:14 +08:00
t.State = 2
t.OssFile = ossURL
t.FileType = fileType
2026-05-12 13:45:08 +08:00
t.TextResult = textResult
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
go triggerCallback(context.WithoutCancel(ctx), t)
// ============ 如果有 epicycleId也触发业务回调 ============
if epicycleId != 0 {
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
}
2026-04-29 15:54:14 +08:00
// 成功后清理临时文件
deleteTmpResult(t.TmpFile)
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}
2026-05-12 13:45:08 +08:00
// GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值
func GetExpendTokens(tokenMapping string, textResult string) int {
value := gjson.Get(textResult, tokenMapping)
if value.Exists() {
return int(value.Int())
} else {
return len(textResult)
}
}