refactor(service): 重构服务模块结构并优化模型配置

This commit is contained in:
2026-05-29 17:54:19 +08:00
parent e487b4bb5e
commit d409b84b58
24 changed files with 943 additions and 1158 deletions

View File

@@ -0,0 +1,269 @@
package task
import (
"context"
"errors"
"model-gateway/common/util"
"model-gateway/service/queue"
"time"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/google/uuid"
)
var Task = &taskService{}
type taskService struct{}
// Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
taskID := uuid.NewString()
// 1) 检查模型配置
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
})
if err != nil {
return nil, err
}
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
return nil, errors.New("模型不存在或未启用")
}
// 2) 排队上限严格控制Redis 原子闸门)
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
if limit > 0 {
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("任务排队已满,请稍后再试")
}
}
// 3) 插入任务记录
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": util.ForwardHeaders(ctx),
}
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: m.ApiKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
EpicycleId: req.EpicycleId,
})
if err != nil {
// 入库失败:回滚闸门占位
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
}
// 4) 写操作日志(不影响主流程,失败忽略)
ip := ""
ua := ""
apiPath := "/task/createTask"
httpMethod := "POST"
if r := g.RequestFromCtx(ctx); r != nil {
ip = util.GetLocalIP()
ua = r.UserAgent()
apiPath = r.URL.Path
httpMethod = r.Method
}
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
IP: ip,
UserAgent: ua,
APIPath: apiPath,
HttpMethod: httpMethod,
BizName: req.BizName,
ModelName: req.ModelName,
TaskID: taskID,
OpType: "createTask",
Success: 1,
ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: storedPayload,
ResponsePayload: gdb.Map{
"taskId": taskID,
},
})
// 5) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
// 一旦任务进入 running/success/failed/downloaded就停止轮询避免一直空转。
go s.pollAndRunUntilPicked(util.AsyncCtx(ctx), taskID, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// pollAndRunUntilPicked 定向轮询执行刚创建的任务
// - 目标:尽快把刚创建的任务拉起来执行
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
// - 不会无限轮询runWork 仍负责处理积压队列和未处理到的任务
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) {
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds", 5).Int()
pollTimeout := g.Cfg().MustGet(ctx, "asynch.worker.pollTimeoutSeconds", 300).Int()
pollCtx, cancel := context.WithTimeout(ctx, time.Duration(pollTimeout)*time.Second)
defer cancel()
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
g.Log().Infof(ctx, "[任务自动执行][开始] taskId=%s 轮询间隔=%ds 超时=%ds", taskID, interval, pollTimeout)
tryRun := func() bool {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=查询失败 err=%v", taskID, err)
return true
}
if t == nil {
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=任务不存在", taskID)
return true
}
switch t.State {
case 0:
//RunByTaskID 尝试执行任务
if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil {
g.Log().Warningf(ctx, "[任务自动执行][重试] taskId=%s 状态=待处理 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[任务自动执行][已触发] taskId=%s 状态=待处理", taskID)
}
return false
case 1:
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=执行中", taskID)
return true
case 2, 3, 4:
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=终态 状态=%d", taskID, t.State)
return true
default:
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=未知状态 状态=%d", taskID, t.State)
return true
}
}
// 立即尝试一次
if stop := tryRun(); stop {
return
}
for {
select {
case <-pollCtx.Done():
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=轮询超时", taskID)
return
case <-ticker.C:
if stop := tryRun(); stop {
return
}
}
}
}
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}

418
service/task/worker.go Normal file
View File

@@ -0,0 +1,418 @@
package task
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"model-gateway/common/util"
"model-gateway/model/dto"
"model-gateway/service/gateway"
"model-gateway/service/queue"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"unicode/utf8"
"model-gateway/dao"
"model-gateway/model/entity"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
}
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
// - batchSize: 本次抢占数量
// - goroutines: 本次并发数(协程池大小)
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
if req.BatchSize <= 0 {
req.BatchSize = 10
}
if req.Goroutines <= 0 {
req.Goroutines = 1
}
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
if err != nil {
return nil, err
}
if len(tasks) == 0 {
return nil, errors.New("no task to run")
}
pool := grpool.New(req.Goroutines)
defer pool.Close()
claimed := len(tasks)
done := make(chan struct{}, claimed)
for _, t := range tasks {
task := t
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0})
done <- struct{}{}
}, func(ctx context.Context, e error) {
if e != nil {
task.ErrorMsg = fmt.Sprintf("worker panic: %v", e)
_ = dao.Task.UpdateFailedGlobal(ctx, task)
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
}
done <- struct{}{}
})
}
for i := 0; i < claimed; i++ {
<-done
}
return &dto.RunWorkRes{
Claimed: claimed,
}, nil
}
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
// - 只定向抢占当前 taskId 对应的 pending 任务
// - 若任务已被其它 worker 抢走/已不在 pending则直接返回
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, req *dto.CreateTaskReq) error {
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return err
}
if task == nil {
return nil
}
w.handleOne(ctx, task, req)
return nil
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *dto.CreateTaskReq) {
payload := util.ParseStoredPayload(t.RequestPayload)
maxRetry := 0 // 后面从 model 取
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", t.TaskID, t.ModelName)
// 1) 获取模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil {
w.failTask(ctx, t, "模型不存在或未启用")
return
}
maxRetry = model.RetryTimes
// 2) 分布式并发控制
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
maxC := queue.GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
if !acquired {
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", t.TaskID)
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// 3) request_payload 校验
if payload == nil {
w.failTask(ctx, t, "request_payload 为空")
return
}
// 4) 调用模型(不重试,失败直接回调)
textResult, err := w.callModel(ctx, t, model, payload)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
// 5) 上传 OSS可重试
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
}
oss, err = w.uploadOSS(ctx, t)
if err == nil {
break
}
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
w.failTask(ctx, t, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
}
}
// 6) 解析校验(可重试,失败重新调模型)
//if req.BuildType == 1 {
// for attempt := 0; attempt <= maxRetry; attempt++ {
// if attempt > 0 {
// g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
// }
// // 6.1) 校验数据
// err = util.ValidatePromptResult(textResult, model)
// if err == nil {
// break
// }
// g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
// t.TaskID, attempt, maxRetry, err)
// if attempt == maxRetry {
// w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
// return
// }
// // 6.2) 重新调模型
// newResult, modelErr := w.callModel(ctx, t, model, payload)
// if modelErr != nil {
// g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
// t.TaskID, attempt, maxRetry, modelErr)
// continue
// }
// textResult = newResult
// }
//}
// 7) 成功回调
t.State = 2
t.OssFile = oss.FileAddressPrefix + oss.FileURL
t.FileType = oss.FileFormat
t.TextResult = textResult
t.FileSize = int64(oss.FileSize)
t.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, textResult))
if err = dao.Task.UpdateSuccessGlobal(ctx, t); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", t.TaskID, err)
return
}
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
}
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL)
_ = os.Remove(t.TmpFile)
}
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
// callModel 调用模型 + 检测文件类型 + 保存临时文件
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
var data []byte
var contentType, ext, textResult string
var err error
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
data, err = os.ReadFile(task.TmpFile)
if err != nil || len(data) == 0 {
data = nil
}
}
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
if err != nil {
return nil, err
}
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, ext)
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
}
contentType, ext = util.DetectFileType(data)
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
return gjson.New(textResult).Map(), nil
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[string]any, modelKey string) ([]byte, error) {
// 1请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 2构建请求 URL 和超时
baseURL := strings.TrimRight(model.BaseURL, "/")
timeout := time.Duration(model.TimeoutSeconds) * time.Second
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
// 3构建 HTTP 请求
var req *http.Request
switch method {
case http.MethodGet:
q, err := util.PayloadToQuery(payload)
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(baseURL, "?") {
baseURL = baseURL + "&" + q.Encode()
} else {
baseURL = baseURL + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
default:
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
}
// 4注入请求头先模型静态配置再动态 modelKey后者可覆盖前者
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv)
}
for hk, hv := range util.ParseHeadMsgHeaders(modelKey) {
req.Header.Set(hk, hv)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
// 5发送请求
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 6读取响应体
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 7检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
// 8响应参数映射
mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b)
if err != nil {
g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
return b, nil
}
return mappedResponse, nil
}
// // InvokeModel 调用模型服务,返回二进制结果
//
// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
// if m == nil || m.BaseURL == "" {
// return nil, fmt.Errorf("模型配置不完整")
// }
// // 请求参数映射
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
// if err != nil {
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
// }
// // 合并请求头
// headers := util.ForwardHeaders(ctx)
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
// headers[hk] = hv
// }
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
// headers[hk] = hv
// }
//
// // 设置超时
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
// if timeout <= 0 {
// timeout = 600 * time.Second
// }
// ctx, cancel := context.WithTimeout(ctx, timeout)
// defer cancel()
//
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
// if method == "" {
// method = http.MethodPost
// }
//
// var respBytes []byte
//
// switch method {
// case http.MethodGet:
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// default:
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// }
// if err != nil {
// return nil, err
// }
// // 响应参数映射
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
// if err != nil {
// g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
// return respBytes, nil
// }
// return mappedResponse, nil
// }
// uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile)
if err != nil {
return nil, fmt.Errorf("读取临时文件失败: %w", err)
}
_, ext := util.DetectFileType(data)
return gateway.UploadByTask(ctx, data, ext)
}
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg string) {
t.State = 3
t.ErrorMsg = errMsg
_ = dao.Task.UpdateFailedGlobal(ctx, t)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
}
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}
// GetExpendTokens 根据映射路径从 result 中提取消耗 token 值
func GetExpendTokens(responseTokenField string, result map[string]any) int {
val := gjson.New(result).Get(responseTokenField)
if val.IsNil() {
return 0
}
return val.Int()
}