refactor(model): 重构模型实体和数据访问层
This commit is contained in:
@@ -2,7 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service/gateway"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@@ -23,24 +29,23 @@ type asyncWorker struct {
|
||||
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
|
||||
// - batchSize: 本次抢占数量
|
||||
// - goroutines: 本次并发数(协程池大小)
|
||||
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10
|
||||
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||||
if req.BatchSize <= 0 {
|
||||
req.BatchSize = 10
|
||||
}
|
||||
if goroutines <= 0 {
|
||||
goroutines = 1
|
||||
if req.Goroutines <= 0 {
|
||||
req.Goroutines = 1
|
||||
}
|
||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
|
||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return nil, err
|
||||
}
|
||||
if len(tasks) == 0 {
|
||||
return 0, nil
|
||||
return nil, errors.New("no task to run")
|
||||
}
|
||||
pool := grpool.New(goroutines)
|
||||
pool := grpool.New(req.Goroutines)
|
||||
defer pool.Close()
|
||||
|
||||
claimed = len(tasks)
|
||||
claimed := len(tasks)
|
||||
done := make(chan struct{}, claimed)
|
||||
for _, t := range tasks {
|
||||
task := t
|
||||
@@ -58,7 +63,9 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c
|
||||
for i := 0; i < claimed; i++ {
|
||||
<-done
|
||||
}
|
||||
return claimed, nil
|
||||
return &dto.RunWorkRes{
|
||||
Claimed: claimed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
||||
@@ -78,9 +85,9 @@ func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId
|
||||
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
// 从任务入库的 request_payload 里恢复 payload + headers
|
||||
payload, headers := parseStoredPayload(t.RequestPayload)
|
||||
payload, headers := util.ParseStoredPayload(t.RequestPayload)
|
||||
if len(headers) > 0 {
|
||||
ctx = setTaskHeadersToCtx(ctx, headers)
|
||||
ctx = util.SetTaskHeadersToCtx(ctx, headers)
|
||||
}
|
||||
|
||||
// 1) 拉取模型配置
|
||||
@@ -91,7 +98,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
@@ -102,7 +109,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = errMsg
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
@@ -118,7 +125,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
@@ -147,9 +154,9 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
|
||||
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
|
||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||||
data, err = loadTmpResult(t.TmpFile)
|
||||
data, err = os.ReadFile(t.TmpFile)
|
||||
if err == nil && len(data) > 0 {
|
||||
contentType, ext = DetectFileType(data)
|
||||
contentType, ext = util.DetectFileType(data)
|
||||
} else {
|
||||
data = nil
|
||||
}
|
||||
@@ -165,11 +172,11 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
contentType, ext = DetectFileType(data)
|
||||
contentType, ext = util.DetectFileType(data)
|
||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||
textResult = string(data)
|
||||
}
|
||||
@@ -182,7 +189,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
}
|
||||
|
||||
// 4) 存储 OSS
|
||||
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
|
||||
ossURL, err := gateway.UploadByTask(ctx, t, data, ext, contentType)
|
||||
if err != nil {
|
||||
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||||
@@ -198,7 +205,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
if fileType == "" {
|
||||
fileType = contentType
|
||||
}
|
||||
if err := dao.Task.UpdateSuccessGlobal(
|
||||
if err = dao.Task.UpdateSuccessGlobal(
|
||||
ctx,
|
||||
t.Id,
|
||||
ossURL,
|
||||
@@ -206,7 +213,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
textResult,
|
||||
int64(len(data)),
|
||||
nil,
|
||||
GetExpendTokens(m.TokenMapping, textResult),
|
||||
GetExpendTokens(m.ResponseTokenField, textResult),
|
||||
); err != nil {
|
||||
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
||||
return
|
||||
@@ -221,14 +228,33 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
t.FileType = fileType
|
||||
t.TextResult = textResult
|
||||
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ============ 如果有 epicycleId,也触发业务回调 ============
|
||||
if epicycleId != 0 {
|
||||
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
||||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
||||
}
|
||||
|
||||
// 成功后清理临时文件
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = os.Remove(t.TmpFile)
|
||||
}
|
||||
|
||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
@@ -240,7 +266,6 @@ func GetExpendTokens(tokenMapping string, textResult string) int {
|
||||
value := gjson.Get(textResult, tokenMapping)
|
||||
if value.Exists() {
|
||||
return int(value.Int())
|
||||
} else {
|
||||
return len(textResult)
|
||||
}
|
||||
return len(textResult)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user