第一次提交

This commit is contained in:
2026-04-29 15:54:14 +08:00
parent 50d2eadbd1
commit e81df5ce5a
51 changed files with 4571 additions and 0 deletions

194
service/auto_tune.go Normal file
View File

@@ -0,0 +1,194 @@
package service
import (
"context"
"fmt"
"math"
"model-asynch/consts/public"
"model-asynch/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g"
)
// AutoTuneResult 单次调参结果(按 model_name
type AutoTuneResult struct {
ModelName string `json:"modelName"` // 模型名称asynch_models.model_name
Samples int `json:"samples"` // 统计样本数(窗口内 state=2/3 且 started_at/finished_at 非空的任务数量)
P90Exec float64 `json:"p90ExecSeconds"` // 执行耗时 P90口径finished_at - started_at
CapMaxConcurrency int `json:"capMaxConcurrency"` // 配置上限asynch_models.max_concurrencycap不会被动态调参覆盖
OldMaxConcurrency int `json:"oldMaxConcurrency"` // 调参前运行时值Redis若无则等于 cap
NewMaxConcurrency int `json:"newMaxConcurrency"` // 本次计算出的运行时值(将写入 Redis受 ±50% 约束且不超过 cap
CapQueueLimit int `json:"capQueueLimit"` // 配置上限asynch_models.queue_limitcap不会被动态调参覆盖
OldQueueLimit int `json:"oldQueueLimit"` // 调参前运行时值Redis若无则等于 cap
NewQueueLimit int `json:"newQueueLimit"` // 本次计算出的运行时值(将写入 Redis受 ±50% 约束且不超过 cap
ExpectedSeconds int `json:"expectedSeconds"` // 模型预计执行时间asynch_models.expected_seconds用于 queue_limit 计算绑定)
}
// AutoTune 由上层定时任务通过接口触发:
// - 统计指定时间窗口内该模型任务的执行耗时finished_at - started_at取 P90
// - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap
// - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2生成运行时值不超过 cap
// - 单次调整幅度限制 ±50%,写入 Redis带 TTL
func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) {
if windowSeconds <= 0 {
windowSeconds = 3600
}
// 1) 读取模型配置cap按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
var modelRows []*entity.AsynchModel
if err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where("deleted_at IS NULL").
Where(entity.AsynchModelCol.Enabled, 1).
Scan(&modelRows); err != nil {
return nil, err
}
modelMap := make(map[string]*entity.AsynchModel)
for _, m := range modelRows {
if m == nil || m.ModelName == "" {
continue
}
cur := modelMap[m.ModelName]
if cur == nil {
modelMap[m.ModelName] = m
continue
}
// 取更大的 cap
if m.MaxConcurrency > cur.MaxConcurrency {
cur.MaxConcurrency = m.MaxConcurrency
}
if m.QueueLimit > cur.QueueLimit {
cur.QueueLimit = m.QueueLimit
}
if m.ExpectedSeconds > cur.ExpectedSeconds {
cur.ExpectedSeconds = m.ExpectedSeconds
}
}
if len(modelMap) == 0 {
return []AutoTuneResult{}, nil
}
// 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时
type statRow struct {
ModelName string
Cnt int
P90Exec float64
}
var stats []statRow
sql := fmt.Sprintf(`
SELECT model_name,
COUNT(1) AS cnt,
COALESCE(percentile_cont(0.9) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM (finished_at - started_at))), 0) AS p90_exec
FROM %s
WHERE deleted_at IS NULL
AND state IN (2,3)
AND started_at IS NOT NULL
AND finished_at IS NOT NULL
AND finished_at >= (NOW() - (? || ' seconds')::interval)
GROUP BY model_name`, public.TableNameTask)
r, err := gfdb.DB(ctx).GetAll(ctx, sql, windowSeconds)
if err != nil {
return nil, err
}
_ = r.Structs(&stats)
statMap := make(map[string]statRow, len(stats))
for _, s := range stats {
statMap[s.ModelName] = s
}
// 3) 调参计算
const utilization = 0.8
const maxChangeRatio = 0.5 // ±50%
const queueFactor = 2.0 // 与 expected_seconds 绑定W_target = expected_seconds * 2
out := make([]AutoTuneResult, 0, len(modelMap))
for modelName, m := range modelMap {
s := statMap[modelName]
capMax := m.MaxConcurrency
capQueue := m.QueueLimit
oldMax := GetRuntimeMaxConcurrency(ctx, modelName, capMax)
oldQueue := GetRuntimeQueueLimit(ctx, modelName, capQueue)
// 默认:无样本则不调整
if s.Cnt <= 0 || s.P90Exec <= 0 {
out = append(out, AutoTuneResult{
ModelName: modelName,
Samples: s.Cnt,
P90Exec: s.P90Exec,
CapMaxConcurrency: capMax,
OldMaxConcurrency: oldMax,
NewMaxConcurrency: oldMax,
CapQueueLimit: capQueue,
OldQueueLimit: oldQueue,
NewQueueLimit: oldQueue,
ExpectedSeconds: m.ExpectedSeconds,
})
continue
}
// arrival_rate ≈ 完成数/3600
arrivalRate := float64(s.Cnt) / 3600.0
// desiredMax = ceil(arrivalRate * p90 / utilization)
desiredMax := int(math.Ceil(arrivalRate * s.P90Exec / utilization))
if desiredMax < 1 {
desiredMax = 1
}
// 单次变化幅度限制
minMax := int(math.Floor(float64(oldMax) * (1 - maxChangeRatio)))
maxMax := int(math.Ceil(float64(oldMax) * (1 + maxChangeRatio)))
if minMax < 1 {
minMax = 1
}
newMax := clampInt(desiredMax, minMax, maxMax)
if capMax > 0 {
newMax = clampInt(newMax, 1, capMax)
}
setRuntimeInt(ctx, runtimeMaxConcurrencyKey(modelName), newMax)
// queue_limitW_target = expected_seconds * queueFactor
exp := m.ExpectedSeconds
if exp <= 0 {
exp = 60
}
wTarget := float64(exp) * queueFactor
desiredQueue := int(math.Ceil(arrivalRate*wTarget)) + newMax
if desiredQueue < newMax {
desiredQueue = newMax
}
newQueue := oldQueue
if capQueue > 0 {
minQ := int(math.Floor(float64(oldQueue) * (1 - maxChangeRatio)))
maxQ := int(math.Ceil(float64(oldQueue) * (1 + maxChangeRatio)))
if minQ < newMax {
minQ = newMax
}
if maxQ < minQ {
maxQ = minQ
}
newQueue = clampInt(desiredQueue, minQ, maxQ)
newQueue = clampInt(newQueue, newMax, capQueue)
setRuntimeInt(ctx, runtimeQueueLimitKey(modelName), newQueue)
}
out = append(out, AutoTuneResult{
ModelName: modelName,
Samples: s.Cnt,
P90Exec: s.P90Exec,
CapMaxConcurrency: capMax,
OldMaxConcurrency: oldMax,
NewMaxConcurrency: newMax,
CapQueueLimit: capQueue,
OldQueueLimit: oldQueue,
NewQueueLimit: newQueue,
ExpectedSeconds: m.ExpectedSeconds,
})
}
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), windowSeconds)
return out, nil
}

88
service/callback.go Normal file
View File

@@ -0,0 +1,88 @@
package service
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/frame/g"
)
// triggerSuccessCallback 任务成功后的回调钩子:
// - 使用 GET 请求
// - 回调地址为 callbackUrl + "/" + bizName
// - query 参数task_id/state/oss_file/file_type
// 注意:回调失败不影响任务主流程,只记录日志。
func triggerSuccessCallback(ctx context.Context, t *entity.AsynchTask) {
if t == nil {
return
}
callbackURL := strings.TrimSpace(t.CallbackURL)
bizName := strings.TrimSpace(t.BizName)
if callbackURL == "" || bizName == "" {
return
}
u, err := url.Parse(callbackURL)
if err != nil {
g.Log().Warningf(ctx, "[callback] invalid callbackUrl=%s err=%v", callbackURL, err)
return
}
// 必须是可发起 HTTP 请求的绝对地址
if u.Scheme == "" || u.Host == "" {
g.Log().Warningf(ctx, "[callback] callbackUrl must be absolute http(s) url, got=%s", callbackURL)
return
}
// path 末尾拼接 bizName
bizSeg := url.PathEscape(bizName)
if strings.HasSuffix(u.Path, "/") || u.Path == "" {
u.Path = strings.TrimRight(u.Path, "/") + "/" + bizSeg
} else {
u.Path = u.Path + "/" + bizSeg
}
q := u.Query()
q.Set("task_id", t.TaskID)
q.Set("state", fmt.Sprintf("%d", t.State))
q.Set("oss_file", t.OssFile)
q.Set("file_type", t.FileType)
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
g.Log().Warningf(ctx, "[callback] build request failed url=%s err=%v", u.String(), err)
return
}
// 透传必要头部(如 Authorization / X-User-Info
for k, v := range forwardHeaders(ctx) {
if v != "" {
req.Header.Set(k, v)
}
}
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
g.Log().Warningf(ctx, "[callback] request failed url=%s err=%v", u.String(), err)
return
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
g.Log().Warningf(ctx, "[callback] non-2xx url=%s code=%d body=%s", u.String(), resp.StatusCode, msg)
return
}
g.Log().Infof(ctx, "[callback] success url=%s code=%d", u.String(), resp.StatusCode)
}

92
service/cleaner.go Normal file
View File

@@ -0,0 +1,92 @@
package service
import (
"context"
"time"
"model-asynch/dao"
"github.com/gogf/gf/v2/frame/g"
)
var Cleaner = &cleaner{}
type cleaner struct{}
// RunOnce 由上层定时任务触发:执行一次清理/重试
func (c *cleaner) RunOnce(ctx context.Context) {
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
} else {
for _, t := range expired {
deleteTmpResult(t.TmpFile)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
}
// 2) 超时任务标失败
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
} else {
for _, t := range list {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "任务超时自动失败")
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
}
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))
}
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
} else {
for _, t := range retryable {
// 失败任务重新入队state=3 -> 0先严格占用 queue_limit slot占用失败则留在失败态下一轮再尝试
// 获取模型配置以得到 queue_limit / expected_seconds
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || m == nil {
continue
}
limit := GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
if limit > 0 {
ok, _ := AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
if !ok {
continue
}
}
// retry_queue_max_seconds 控制失败重试的排队策略:
// - =0失败重试插队到队首
// - >0当任务从创建到现在的排队时长 >= maxSeconds则插队到队首否则仍放到队尾
now := time.Now()
enqueueAt := now
maxSeconds := t.RetryQueueMaxSeconds
if maxSeconds == 0 {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
} else if maxSeconds > 0 && t.CreatedAt != nil {
if now.Sub(t.CreatedAt.Time) >= time.Duration(maxSeconds)*time.Second {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
}
}
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
}
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
}
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
} else {
for _, t := range exhausted {
deleteTmpResult(t.TmpFile)
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
}
}

35
service/file_detect.go Normal file
View File

@@ -0,0 +1,35 @@
package service
import (
"net/http"
"strings"
)
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
func DetectFileType(data []byte) (contentType string, ext string) {
if len(data) == 0 {
return "application/octet-stream", ""
}
ct := http.DetectContentType(data)
switch ct {
case "audio/mpeg":
return ct, ".mp3"
case "audio/wave", "audio/wav", "audio/x-wav":
return ct, ".wav"
case "video/mp4":
return ct, ".mp4"
case "image/png":
return ct, ".png"
case "image/jpeg":
return ct, ".jpg"
case "application/pdf":
return ct, ".pdf"
default:
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json
if parts := strings.Split(ct, "/"); len(parts) == 2 {
return ct, "." + parts[1]
}
return ct, ""
}
}

54
service/headers.go Normal file
View File

@@ -0,0 +1,54 @@
package service
import (
"context"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
)
// asyncCtx 固化异步执行所需的 token/user避免请求结束后丢失仅在“同请求内起 goroutine”有用
// 本项目当前是“落库 + 后台 worker”模式因此还会把必要信息持久化到任务表的 request_payload 中。
func asyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
}
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
}
}
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo
func forwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
if token, ok := ctx.Value("token").(string); ok && token != "" {
headers["Authorization"] = token
}
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
headers["X-User-Info"] = x
}
// 兜底:从请求头拿
if r := g.RequestFromCtx(ctx); r != nil {
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}
return headers
}

185
service/model_invoker.go Normal file
View File

@@ -0,0 +1,185 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"model-asynch/model/entity"
)
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
// 示例:
// - X-API-Key:qwen3-tts-key,operation:true,count:123
// - X-API-Key:"qwen3-tts-key",operation:"true"
//
// 说明:
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
func parseHeadMsgHeaders(headMsg string) map[string]string {
headMsg = strings.TrimSpace(headMsg)
if headMsg == "" {
return nil
}
out := map[string]string{}
parts := strings.Split(headMsg, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// HeaderName:HeaderValue推荐 / HeaderName=HeaderValue兼容
if strings.Contains(p, ":") {
kv := strings.SplitN(p, ":", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
if strings.Contains(p, "=") {
kv := strings.SplitN(p, "=", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
}
if len(out) == 0 {
return nil
}
return out
}
func payloadToQuery(payload any) (url.Values, error) {
if payload == nil {
return url.Values{}, nil
}
// 统一转成 map[string]any
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
m := map[string]any{}
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
q := url.Values{}
for k, v := range m {
if v == nil {
continue
}
// 复杂类型直接 json 字符串化
switch vv := v.(type) {
case string:
q.Set(k, vv)
case float64, bool, int, int64, uint64:
q.Set(k, fmt.Sprintf("%v", vv))
default:
bs, _ := json.Marshal(v)
q.Set(k, string(bs))
}
}
return q, nil
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
if m == nil || m.BaseURL == "" {
return nil, fmt.Errorf("模型配置不完整")
}
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
if strings.TrimSpace(m.Route) == "" {
url = strings.TrimRight(m.BaseURL, "/")
}
timeout := time.Duration(m.TimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 60 * time.Second
}
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
if method == "" {
method = http.MethodPost
}
var (
req *http.Request
err error
)
switch method {
case http.MethodGet:
q, err := payloadToQuery(payload)
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(url, "?") {
url = url + "&" + q.Encode()
} else {
url = url + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
default:
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
}
if err != nil {
return nil, err
}
// 先注入模型配置 head_msg静态头部
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
req.Header.Set(hk, hv)
}
// 透传必要头部(如 Authorization / X-User-Info
for k, v := range forwardHeaders(ctx) {
if v != "" {
req.Header.Set(k, v)
}
}
// 最后注入动态 modelKey覆盖/补充静态 head_msg
for hk, hv := range parseHeadMsgHeaders(modelKey) {
req.Header.Set(hk, hv)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// 尽量把错误体带回去,方便排查
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
return b, nil
}

125
service/model_service.go Normal file
View File

@@ -0,0 +1,125 @@
package service
import (
"context"
"errors"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
)
var Model = &modelService{}
type modelService struct{}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
m := &entity.AsynchModel{
ModelName: req.ModelName,
ModelsType: normalizeModelsType(req.ModelsType),
BaseURL: req.BaseURL,
Route: req.Route,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSecs: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
}
if m.HttpMethod == "" {
m.HttpMethod = "POST"
}
if m.Enabled == 0 {
m.Enabled = 1
}
if m.MaxConcurrency <= 0 {
m.MaxConcurrency = 10
}
if m.QueueLimit <= 0 {
m.QueueLimit = 1000
}
if m.TimeoutSeconds <= 0 {
m.TimeoutSeconds = 60
}
if m.AutoCleanSeconds <= 0 {
m.AutoCleanSeconds = 86400
}
id, err := dao.Model.Insert(ctx, m)
if err != nil {
return nil, err
}
return &dto.CreateModelRes{ID: id}, nil
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
data := map[string]any{}
if req.BaseURL != "" {
data[entity.AsynchModelCol.BaseURL] = req.BaseURL
}
if req.Route != "" {
data[entity.AsynchModelCol.Route] = req.Route
}
if req.HttpMethod != nil && *req.HttpMethod != "" {
data[entity.AsynchModelCol.HttpMethod] = *req.HttpMethod
}
if req.HeadMsg != nil {
data[entity.AsynchModelCol.HeadMsg] = *req.HeadMsg
}
if req.Form != nil {
data[entity.AsynchModelCol.FormJSON] = req.Form
}
if req.ModelsType != nil {
data[entity.AsynchModelCol.ModelsType] = normalizeModelsType(*req.ModelsType)
}
if req.Enabled != nil {
data[entity.AsynchModelCol.Enabled] = *req.Enabled
}
if req.MaxConcurrency != nil {
data[entity.AsynchModelCol.MaxConcurrency] = *req.MaxConcurrency
}
if req.QueueLimit != nil {
data[entity.AsynchModelCol.QueueLimit] = *req.QueueLimit
}
if req.TimeoutSeconds != nil {
data[entity.AsynchModelCol.TimeoutSeconds] = *req.TimeoutSeconds
}
if req.ExpectedSeconds != nil {
data[entity.AsynchModelCol.ExpectedSeconds] = *req.ExpectedSeconds
}
if req.RetryTimes != nil {
data[entity.AsynchModelCol.RetryTimes] = *req.RetryTimes
}
if req.RetryQueueMaxSeconds != nil {
data[entity.AsynchModelCol.RetryQueueMaxSecs] = *req.RetryQueueMaxSeconds
}
if req.AutoCleanSeconds != nil {
data[entity.AsynchModelCol.AutoCleanSeconds] = *req.AutoCleanSeconds
}
if req.Remark != nil {
data[entity.AsynchModelCol.Remark] = *req.Remark
}
if len(data) == 0 {
return errors.New("无可更新字段")
}
_, err := dao.Model.UpdateByID(ctx, req.ID, data)
return err
}
func (s *modelService) Delete(ctx context.Context, id int64) error {
_, err := dao.Model.DeleteByID(ctx, id)
return err
}
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
return dao.Model.GetByID(ctx, id)
}
func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
return dao.Model.List(ctx, pageNum, pageSize, modelNameLike)
}

View File

@@ -0,0 +1,217 @@
package service
import (
"context"
"encoding/json"
"errors"
"strings"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/container/gvar"
)
type modelTypeService struct{}
var ModelType = &modelTypeService{}
func normalizeFormValue(v any) any {
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
if v == nil {
return []any{}
}
switch t := v.(type) {
case string:
s := strings.TrimSpace(t)
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
case []byte:
if len(t) == 0 {
return []any{}
}
return normalizeFormValueFromJSONBytes(t)
case *gvar.Var:
// goframe 常见的 DB 返回类型
if t == nil {
return []any{}
}
b := t.Bytes()
if len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
s := strings.TrimSpace(t.String())
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
default:
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
if vb, ok := v.(interface{ Bytes() []byte }); ok {
if b := vb.Bytes(); len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
}
if vs, ok := v.(interface{ String() string }); ok {
if s := strings.TrimSpace(vs.String()); s != "" {
return normalizeFormValueFromJSONString(s)
}
}
// 已经是 []any / map[string]any 等结构
return v
}
}
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
func normalizeFormValueFromJSONString(s string) any {
var out any
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
return []any{}
}
// 如果解出来还是 string且看起来是 JSON再解一层
if inner, ok := out.(string); ok {
inner = strings.TrimSpace(inner)
if inner == "" {
return []any{}
}
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
var out2 any
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
return out2
}
}
return []any{}
}
return out
}
func normalizeFormValueFromJSONBytes(b []byte) any {
var out any
if err := json.Unmarshal(b, &out); err != nil || out == nil {
return []any{}
}
// bytes 解出来也可能是 string同上
if inner, ok := out.(string); ok {
return normalizeFormValueFromJSONString(inner)
}
return out
}
func (s *modelTypeService) Create(ctx context.Context, req *dto.CreateModelTypeReq) (res *dto.CreateModelTypeRes, err error) {
t := &entity.AsynchModelType{
TypeID: req.TypeID,
TypeName: req.TypeName,
Remark: req.Remark,
}
id, err := dao.ModelType.Insert(ctx, t)
if err != nil {
return nil, err
}
return &dto.CreateModelTypeRes{ID: id}, nil
}
func (s *modelTypeService) Update(ctx context.Context, req *dto.UpdateModelTypeReq) error {
data := map[string]any{}
if req.TypeID != nil {
data[entity.AsynchModelTypeCol.TypeID] = *req.TypeID
}
if req.TypeName != nil {
data[entity.AsynchModelTypeCol.TypeName] = *req.TypeName
}
if req.Remark != nil {
data[entity.AsynchModelTypeCol.Remark] = *req.Remark
}
if len(data) == 0 {
return errors.New("无可更新字段")
}
_, err := dao.ModelType.UpdateByID(ctx, req.ID, data)
return err
}
func (s *modelTypeService) Delete(ctx context.Context, id int64) error {
_, err := dao.ModelType.DeleteByID(ctx, id)
return err
}
func (s *modelTypeService) Get(ctx context.Context, id int64) (*entity.AsynchModelType, error) {
return dao.ModelType.GetByID(ctx, id)
}
func (s *modelTypeService) List(ctx context.Context, pageNum, pageSize int, typeNameLike string) (list []*entity.AsynchModelType, total int64, err error) {
return dao.ModelType.List(ctx, pageNum, pageSize, typeNameLike)
}
// ListWithModels 按类型分组返回模型(返回数组,便于前端直接渲染)
func (s *modelTypeService) ListWithModels(ctx context.Context, req *dto.ListModelTypeWithModelsReq) (res []dto.ModelTypeWithModelsItem, err error) {
types, _, err := dao.ModelType.List(ctx, 1, 1000, "")
if err != nil {
return nil, err
}
// 过滤类型(按 typeId / typeName 模糊)
filterTypeID := 0
filterTypeName := ""
if req != nil {
filterTypeID = req.TypeID
filterTypeName = strings.TrimSpace(req.Type)
}
typeIDs := make([]int, 0, len(types))
typeNameMap := make(map[int]string, len(types))
for _, t := range types {
if t == nil {
continue
}
if filterTypeID > 0 && t.TypeID != filterTypeID {
continue
}
if filterTypeName != "" && !strings.Contains(t.TypeName, filterTypeName) {
continue
}
typeIDs = append(typeIDs, t.TypeID)
typeNameMap[t.TypeID] = t.TypeName
}
models, err := dao.Model.ListAll(ctx)
if err != nil {
return nil, err
}
itemsMap := map[int][]dto.ModelTypeModelItem{}
for _, m := range models {
if m == nil {
continue
}
form := normalizeFormValue(m.Form)
// 一个模型可能支持多个类型models_type="1,2,3"
for _, tid := range parseModelsTypeIDs(m.ModelsType) {
// 若请求过滤了类型,则只输出该类型
if filterTypeID > 0 && tid != filterTypeID {
continue
}
if filterTypeName != "" {
if _, ok := typeNameMap[tid]; !ok {
continue
}
}
itemsMap[tid] = append(itemsMap[tid], dto.ModelTypeModelItem{
ID: m.Id,
Name: m.ModelName,
Form: form,
})
}
}
out := make([]dto.ModelTypeWithModelsItem, 0, len(typeIDs))
for _, tid := range typeIDs {
items := itemsMap[tid]
if items == nil {
items = make([]dto.ModelTypeModelItem, 0)
}
out = append(out, dto.ModelTypeWithModelsItem{
TypeID: tid,
Type: typeNameMap[tid],
Items: items,
})
}
return out, nil
}

View File

@@ -0,0 +1,52 @@
package service
import (
"sort"
"strconv"
"strings"
)
// normalizeModelsType 将 "1, 2,2,3" 归一化为 "1,2,3"
// - 去空格
// - 去重
// - 升序排序
func normalizeModelsType(v string) string {
ids := parseModelsTypeIDs(v)
if len(ids) == 0 {
return ""
}
parts := make([]string, 0, len(ids))
for _, id := range ids {
parts = append(parts, strconv.Itoa(id))
}
return strings.Join(parts, ",")
}
// parseModelsTypeIDs 解析 models_type 字段(支持 "1,2,3"),返回去重后的 int 列表(升序)。
func parseModelsTypeIDs(v string) []int {
v = strings.TrimSpace(v)
if v == "" {
return nil
}
raw := strings.Split(v, ",")
seen := map[int]struct{}{}
out := make([]int, 0, len(raw))
for _, s := range raw {
s = strings.TrimSpace(s)
if s == "" || s == "0" {
continue
}
id, err := strconv.Atoi(s)
if err != nil || id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
sort.Ints(out)
return out
}

25
service/payload.go Normal file
View File

@@ -0,0 +1,25 @@
package service
import "github.com/gogf/gf/v2/util/gconv"
// parseStoredPayload 解析入库的 request_payload拆出模型调用 payload 与透传 headers
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
func parseStoredPayload(v any) (payload any, headers map[string]string) {
if v == nil {
return nil, nil
}
m := gconv.Map(v)
if len(m) == 0 {
return v, nil
}
if h, ok := m["headers"]; ok {
headers = gconv.MapStrStr(h)
}
if p, ok := m["payload"]; ok {
payload = p
} else {
payload = v
}
return
}

107
service/queue_gate.go Normal file
View File

@@ -0,0 +1,107 @@
package service
import (
"context"
"fmt"
"math"
"time"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// ===== 严格 queue_limitRedis 原子闸门 =====
//
// 背景:原来的 queue_limit 通过“Count + Insert”做近似控制分布式并发创建时会短暂超限。
// 目标:以 Redis Lua 脚本实现原子校验 + 入队占位,做到严格不超限。
//
// 计数口径与原逻辑保持一致:只统计 state=0/1排队中/执行中)。
// - CreateTask 成功入库后占用 1 个 slot
// - 任务成功/失败state->2/3释放 slot
// - 失败任务重试state 3->0需要再次占用 slot若占位失败则暂不重试留在 state=3下次 cleaner 再尝试)
//
// 说明:为避免极端情况下“占位泄漏”导致永久占满,采用 ZSET + 过期时间的方式自动回收。
// 只要任务实际生命周期远小于 gateTTLSeconds就可保持严格。
const (
queueGateKeyPrefix = "asynch:qgate:" // asynch:qgate:{modelName}
)
// Lua清理过期 slot然后按 limit 做原子判定并占位
var queueGateAcquireLua = `
local key = KEYS[1]
local now = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local expireAt = tonumber(ARGV[3])
local member = ARGV[4]
local keyTTL = tonumber(ARGV[5])
-- 先清理过期的占位
redis.call("ZREMRANGEBYSCORE", key, "-inf", now)
local current = tonumber(redis.call("ZCARD", key) or "0")
if current >= limit then
return 0
end
redis.call("ZADD", key, expireAt, member)
redis.call("EXPIRE", key, keyTTL)
return 1
`
// Lua释放 slot幂等
var queueGateReleaseLua = `
local key = KEYS[1]
local member = ARGV[1]
redis.call("ZREM", key, member)
return 1
`
func queueGateKey(modelName string) string {
return fmt.Sprintf("%s%s", queueGateKeyPrefix, modelName)
}
// calcGateTTLSeconds 计算闸门占位的“自动回收 TTL”
// 取 expectedSeconds 的倍数并做上下限,避免任务异常导致永久占位。
func calcGateTTLSeconds(expectedSeconds int) int {
// 默认至少 1 小时;最多 24 小时
minTTL := 3600
maxTTL := 24 * 3600
if expectedSeconds <= 0 {
return minTTL
}
ttl := int(math.Ceil(float64(expectedSeconds) * 10)) // 预计耗时 * 10 做兜底
if ttl < minTTL {
ttl = minTTL
}
if ttl > maxTTL {
ttl = maxTTL
}
return ttl
}
// AcquireQueueSlot 严格入队:原子占位(成功返回 true
func AcquireQueueSlot(ctx context.Context, modelName, taskId string, limit int, expectedSeconds int) (bool, error) {
if limit <= 0 {
return true, nil
}
key := queueGateKey(modelName)
now := time.Now().Unix()
ttl := calcGateTTLSeconds(expectedSeconds)
expireAt := now + int64(ttl)
// keyTTL 要略大于 member TTL避免 key 先过期导致计数丢失
keyTTL := ttl + 60
r, err := g.Redis().Do(ctx, "EVAL", queueGateAcquireLua, 1, key, now, limit, expireAt, taskId, keyTTL)
if err != nil {
return false, fmt.Errorf("queue gate acquire failed: %w", err)
}
return gconv.Int(r) == 1, nil
}
// ReleaseQueueSlot 释放占位(幂等)
func ReleaseQueueSlot(ctx context.Context, modelName, taskId string) {
if taskId == "" || modelName == "" {
return
}
key := queueGateKey(modelName)
_, _ = g.Redis().Do(ctx, "EVAL", queueGateReleaseLua, 1, key, taskId)
}

83
service/runtime_tune.go Normal file
View File

@@ -0,0 +1,83 @@
package service
import (
"context"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// 运行时调参存储在 Redis不修改 asynch_models 中的 cap最大上限
// 上层每小时调用 /model/autoTune 写入运行时值Worker/CreateTask 读取运行时值生效。
const (
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
runtimeTTLSeconds = 2 * 3600 // 2小时避免一次调参失败导致立即回退
)
func runtimeMaxConcurrencyKey(modelName string) string {
return runtimeMaxCKeyPrefix + modelName
}
func runtimeQueueLimitKey(modelName string) string {
return runtimeQueueKeyPrefix + modelName
}
func getRuntimeInt(ctx context.Context, key string) (int, bool) {
v, err := g.Redis().Do(ctx, "GET", key)
if err != nil || v == nil {
return 0, false
}
iv := gconv.Int(v)
if iv <= 0 {
return 0, false
}
return iv, true
}
func setRuntimeInt(ctx context.Context, key string, val int) {
if val <= 0 {
return
}
// SETEX key ttl val
_, _ = g.Redis().Do(ctx, "SETEX", key, runtimeTTLSeconds, val)
}
// GetRuntimeMaxConcurrency 返回运行时并发上限(<= cap。若不存在运行时值则返回 cap。
func GetRuntimeMaxConcurrency(ctx context.Context, modelName string, cap int) int {
if cap <= 0 {
return cap
}
if v, ok := getRuntimeInt(ctx, runtimeMaxConcurrencyKey(modelName)); ok {
if v > cap {
return cap
}
return v
}
return cap
}
// GetRuntimeQueueLimit 返回运行时队列上限(<= cap。若不存在运行时值则返回 cap。
func GetRuntimeQueueLimit(ctx context.Context, modelName string, cap int) int {
if cap <= 0 {
return cap
}
if v, ok := getRuntimeInt(ctx, runtimeQueueLimitKey(modelName)); ok {
if v > cap {
return cap
}
return v
}
return cap
}
func clampInt(v, minV, maxV int) int {
if v < minV {
return minV
}
if v > maxV {
return maxV
}
return v
}

56
service/semaphore.go Normal file
View File

@@ -0,0 +1,56 @@
package service
import (
"context"
"fmt"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var acquireLua = `
local current = tonumber(redis.call("GET", KEYS[1]) or "0")
local max = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
if current >= max then
return 0
end
current = redis.call("INCR", KEYS[1])
if current == 1 then
redis.call("EXPIRE", KEYS[1], ttl)
end
if current > max then
redis.call("DECR", KEYS[1])
return 0
end
return 1
`
var releaseLua = `
local current = tonumber(redis.call("DECR", KEYS[1]) or "0")
if current <= 0 then
redis.call("DEL", KEYS[1])
end
return 1
`
func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
if max <= 0 {
// 不限制
return true, nil
}
if ttlSeconds <= 0 {
ttlSeconds = 3600
}
r, err := g.Redis().Do(ctx, "EVAL", acquireLua, 1, key, max, ttlSeconds)
if err != nil {
return false, fmt.Errorf("获取并发令牌失败: %w", err)
}
return gconv.Int(r) == 1, nil
}
func releaseSemaphore(ctx context.Context, key string) error {
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
return err
}

40
service/stat_service.go Normal file
View File

@@ -0,0 +1,40 @@
package service
import (
"context"
"model-asynch/dao"
"model-asynch/model/dto"
)
type statService struct{}
var Stat = &statService{}
func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
pageNum, pageSize := 1, 10
if req != nil && req.Page != nil {
if req.Page.PageNum > 0 {
pageNum = int(req.Page.PageNum)
}
if req.Page.PageSize > 0 {
pageSize = int(req.Page.PageSize)
}
}
startDay, endDay := "", ""
var tenantID *int64
creator, modelName := "", ""
if req != nil {
startDay = req.StartDay
endDay = req.EndDay
tenantID = req.TenantID
creator = req.Creator
modelName = req.ModelName
}
list, total, err := dao.Stat.List(ctx, pageNum, pageSize, startDay, endDay, tenantID, creator, modelName)
if err != nil {
return nil, err
}
return &dto.ListModelStatRes{List: list, Total: total}, nil
}

18
service/storage.go Normal file
View File

@@ -0,0 +1,18 @@
package service
import (
"context"
"errors"
"model-asynch/model/entity"
)
// StorageService 结果存储OSS/MinIO抽象
type StorageService interface {
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
}
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO
var Storage StorageService = &ossStorage{}
var ErrStorageNotConfigured = errors.New("存储未配置")

82
service/storage_oss.go Normal file
View File

@@ -0,0 +1,82 @@
package service
import (
"bytes"
"context"
"fmt"
"mime/multipart"
"time"
"model-asynch/model/entity"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/guid"
)
// 对接你们的 oss 文件服务POST oss/file/uploadFile (multipart/form-data)
type ossStorage struct{}
type uploadFileResponse struct {
FileURL string `json:"fileURL"` // 文件 URL
FileSize int `json:"fileSize"` // 文件大小(字节)
FileName string `json:"fileName"` // 文件名
FileFormat string `json:"fileFormat"` // 文件格式
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
}
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
// multipart
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
ext := fileExt
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
contentType := writer.FormDataContentType()
if err := writer.Close(); err != nil {
return "", err
}
headers := forwardHeaders(ctx)
headers["Content-Type"] = contentType
fullURL := "oss/file/uploadFile"
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
var resp uploadFileResponse
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
return "", err
}
fmt.Println("打印结果 resp:", resp)
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
return resp.FileURL, nil
}
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx给 worker 调 OSS 用
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
if headers == nil {
return ctx
}
if v := gconv.String(headers["Authorization"]); v != "" {
ctx = context.WithValue(ctx, "token", v)
}
if v := gconv.String(headers["X-User-Info"]); v != "" {
ctx = context.WithValue(ctx, "xUserInfo", v)
}
return ctx
}

192
service/task_service.go Normal file
View File

@@ -0,0 +1,192 @@
package service
import (
"context"
"errors"
"time"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/google/uuid"
)
var Task = &taskService{}
type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
// 固化 token/user 等信息
ctx = asyncCtx(ctx)
// 1) 检查模型配置
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
if err != nil {
return nil, err
}
if m == nil || m.Enabled != 1 {
return nil, errors.New("模型不存在或未启用")
}
taskID := uuid.NewString()
// 2) 排队上限严格控制Redis 原子闸门)
limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
if limit > 0 {
ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("任务排队已满,请稍后再试")
}
}
// 将调用模型的 payload 与透传头信息一起存入 request_payload供后台 worker 使用
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": forwardHeaders(ctx),
}
t := &entity.AsynchTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: req.ModelKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
}
_, err = dao.Task.Insert(ctx, t)
if err != nil {
// 入库失败:回滚闸门占位
ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
}
// 3) 写操作日志(尽量不影响主流程,失败忽略)
ip := ""
ua := ""
apiPath := "/task/createTask"
httpMethod := "POST"
if r := g.RequestFromCtx(ctx); r != nil {
ip = r.GetClientIp()
ua = r.UserAgent()
apiPath = r.URL.Path
httpMethod = r.Method
}
_, _ = dao.OpLog.Insert(ctx, &entity.AsynchOpLog{
IP: ip,
UserAgent: ua,
APIPath: apiPath,
HttpMethod: httpMethod,
BizName: req.BizName,
ModelName: req.ModelName,
TaskID: taskID,
OpType: "createTask",
Success: 1,
ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: storedPayload,
ResponsePayload: gdb.Map{
"taskId": taskID,
},
})
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.GetByTaskID(ctx, taskID)
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil && req.Page != nil {
if req.Page.PageNum > 0 {
pageNum = int(req.Page.PageNum)
}
if req.Page.PageSize > 0 {
pageSize = int(req.Page.PageSize)
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}

38
service/tmp_store.go Normal file
View File

@@ -0,0 +1,38 @@
package service
import (
"fmt"
"os"
"path/filepath"
)
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
func loadTmpResult(path string) ([]byte, error) {
return os.ReadFile(path)
}
func deleteTmpResult(path string) {
if path == "" {
return
}
_ = os.Remove(path)
}

176
service/worker.go Normal file
View File

@@ -0,0 +1,176 @@
package service
import (
"context"
"fmt"
"strings"
"time"
"model-asynch/dao"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
}
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
// - batchSize: 本次抢占数量
// - goroutines: 本次并发数(协程池大小)
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
if batchSize <= 0 {
batchSize = 10
}
if goroutines <= 0 {
goroutines = 1
}
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
if err != nil {
return 0, err
}
if len(tasks) == 0 {
return 0, nil
}
pool := grpool.New(goroutines)
defer pool.Close()
claimed = len(tasks)
done := make(chan struct{}, claimed)
for _, t := range tasks {
task := t
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task)
done <- struct{}{}
}, func(ctx context.Context, e error) {
if e != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", e))
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
}
done <- struct{}{}
})
}
for i := 0; i < claimed; i++ {
<-done
}
return claimed, nil
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
// 从任务入库的 request_payload 里恢复 payload + headers给 OSS 上传透传鉴权用
payload, headers := parseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
}
// 1) 拉取模型配置
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
return
}
if m == nil || m.Enabled != 1 {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
return
}
// 2) 分布式并发限制(按 model_name 全局维度)
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
leaseSeconds := int64(3600) // 兜底1小时
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency)
acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
return
}
if !acquired {
// 并发满了:放回排队(重新置回 state=0下一轮再抢占
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() {
_ = releaseSemaphore(ctx, semKey)
}()
// 3) 调用模型服务
if payload == nil {
payload = map[string]any{
"taskId": t.TaskID,
"inputRef": t.InputRef,
}
}
var (
data []byte
contentType string
ext string
)
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载,避免重复跑模型
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = loadTmpResult(t.TmpFile)
if err == nil && len(data) > 0 {
contentType, ext = DetectFileType(data)
} else {
// 临时文件不可用:回退重新调用模型
data = nil
}
}
if data == nil {
// 统计:仅在真正请求模型时 +1OSS 重试不计入)
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
return
}
contentType, ext = DetectFileType(data)
// 将模型输出写入临时文件,后续若 OSS 失败可只重试 OSS
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
if err == nil && tmpPath != "" {
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
// 4) 存储 OSS
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
if err != nil {
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
return
}
// 5) 更新任务状态成功
// 注意expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
// 成功/失败均不再占用 queue_limitstate=0/1 才占用)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// 6) 成功回调(不影响主流程)
t.State = 2
t.OssFile = ossURL
t.FileType = fileType
go triggerSuccessCallback(context.WithoutCancel(ctx), t)
// 成功后清理临时文件
deleteTmpResult(t.TmpFile)
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}